SissiFeng commited on
Commit
751afab
·
verified ·
1 Parent(s): e1b01c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -63
app.py CHANGED
@@ -16,7 +16,6 @@ from PIL import ImageDraw
16
  import requests
17
 
18
 
19
- # 设置日志记录
20
  logging.basicConfig(
21
  level=logging.INFO,
22
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
@@ -27,13 +26,12 @@ logging.basicConfig(
27
  )
28
  logger = logging.getLogger("bambu-analysis")
29
 
30
- HOST = "mqtt.bambulab.com" # 默认值
31
- PORT = 8883 # 默认值
32
- USERNAME = "bblp" # 默认值
33
- PASSWORD = "bblp" # 默认值
34
  DEFAULT_SERIAL = "0309CA471800852"
35
 
36
- # 尝试从环境变量获取
37
  if os.environ.get("host"):
38
  HOST = os.environ.get("host")
39
  if os.environ.get("port"):
@@ -56,7 +54,6 @@ client = None
56
  response_topic = None # Will be set dynamically
57
 
58
  def create_client(host, port, username, password):
59
- """完全使用同事的创建客户端函数"""
60
  global client
61
  client = mqtt.Client()
62
  client.username_pw_set(username, password)
@@ -93,7 +90,6 @@ def get_data(serial=DEFAULT_SERIAL):
93
  logger.info(f"Subscribing to {response_topic}")
94
  client.subscribe(response_topic)
95
 
96
- # 发送请求获取数据
97
  logger.info(f"Publishing request to {request_topic}")
98
  client.publish(request_topic, json.dumps("HI"))
99
 
@@ -112,7 +108,6 @@ def get_data(serial=DEFAULT_SERIAL):
112
  )
113
 
114
  def send_print_parameters(nozzle_temp, bed_temp, print_speed, fan_speed):
115
- """发送打印参数到打印机"""
116
  serial = DEFAULT_SERIAL
117
  logger.info(f"Sending parameters to {serial}: nozzle={nozzle_temp}, bed={bed_temp}, speed={print_speed}, fan={fan_speed}")
118
  try:
@@ -140,17 +135,14 @@ def send_print_parameters(nozzle_temp, bed_temp, print_speed, fan_speed):
140
  return f"Error sending parameters: {e}"
141
 
142
  def get_image_base64(image):
143
- """将图像转换为base64用于API传输"""
144
  if image is None:
145
  logger.warning("No image to encode")
146
  return None
147
 
148
  try:
149
- # 转换为PIL图像
150
  if isinstance(image, np.ndarray):
151
  image = Image.fromarray(image)
152
 
153
- # 转换为base64
154
  buffer = io.BytesIO()
155
  image.save(buffer, format="PNG")
156
  img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
@@ -161,7 +153,6 @@ def get_image_base64(image):
161
  return None
162
 
163
  def get_test_image(image_name=None):
164
- """获取测试图片"""
165
  import os
166
  import random
167
 
@@ -171,7 +162,6 @@ def get_test_image(image_name=None):
171
  logger.error(f"Test images directory not found: {test_dir}")
172
  return None
173
 
174
- # 获取所有图片文件
175
  image_files = [f for f in os.listdir(test_dir)
176
  if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
177
 
@@ -179,11 +169,9 @@ def get_test_image(image_name=None):
179
  logger.error("No test images found")
180
  return None
181
 
182
- # 如果指定了图片名称,使用指定的图片
183
  if image_name and image_name in image_files:
184
  image_path = os.path.join(test_dir, image_name)
185
  else:
186
- # 否则随机选择一张图片
187
  image_path = os.path.join(test_dir, random.choice(image_files))
188
 
189
  logger.info(f"Using test image: {image_path}")
@@ -195,8 +183,7 @@ def get_test_image(image_name=None):
195
  return None
196
 
197
  def capture_image(url=None, use_test_image=False, test_image_name=None):
198
- """从URL或测试文件夹获取图像"""
199
- # 优先使用测试图像(如果指定)
200
  if use_test_image:
201
  logger.info("Using test image instead of URL")
202
  test_img = get_test_image(test_image_name)
@@ -205,7 +192,7 @@ def capture_image(url=None, use_test_image=False, test_image_name=None):
205
  else:
206
  logger.warning("Failed to get specified test image, trying URL")
207
 
208
- # 尝试从URL获取图像
209
  if url:
210
  try:
211
  logger.info(f"Capturing image from URL: {url}")
@@ -217,12 +204,10 @@ def capture_image(url=None, use_test_image=False, test_image_name=None):
217
  except Exception as e:
218
  logger.error(f"Error capturing image from URL: {e}")
219
 
220
- # 如果URL获取失败,尝试使用随机测试图片
221
  logger.info("URL capture failed or not provided, using random test image")
222
  return get_test_image()
223
 
224
  def health_check():
225
- """健康检查端点"""
226
  status = {
227
  "app": "running",
228
  "time": time.strftime("%Y-%m-%d %H:%M:%S"),
@@ -232,10 +217,8 @@ def health_check():
232
  logger.info(f"Health check: {status}")
233
  return status
234
 
235
- # 创建 Gradio 应用
236
  demo = gr.Blocks(title="Bambu A1 Mini Print Control")
237
 
238
- # 在 Blocks 上下文中注册所有内容
239
  with demo:
240
  gr.Markdown("# Bambu A1 Mini Print Control")
241
 
@@ -254,7 +237,6 @@ with demo:
254
 
255
  with gr.Row():
256
  with gr.Column():
257
- # 打印参数输入
258
  nozzle_temp = gr.Slider(minimum=180, maximum=250, step=1, value=200, label="Nozzle Temperature (°C)")
259
  bed_temp = gr.Slider(minimum=40, maximum=100, step=1, value=60, label="Bed Temperature (°C)")
260
  print_speed = gr.Slider(minimum=20, maximum=150, step=1, value=60, label="Print Speed (mm/s)")
@@ -262,7 +244,6 @@ with demo:
262
 
263
  send_params_btn = gr.Button("Send Print Parameters")
264
 
265
- # 连接按钮
266
  refresh_btn.click(
267
  fn=get_data,
268
  outputs=[current_status, current_bed_temp, current_nozzle_temp, last_update]
@@ -279,24 +260,19 @@ with demo:
279
  outputs=[current_status]
280
  )
281
 
282
- # API端点
283
  def api_get_data():
284
- """API端点:获取状态"""
285
  logger.info("API call: get_data")
286
  return get_data()
287
 
288
  def api_capture_frame(url=None, use_test_image=False, test_image_name=None):
289
- """API端点:捕获图像帧"""
290
  logger.info(f"API call: capture_frame with URL: {url}, use_test_image: {use_test_image}")
291
 
292
  try:
293
  img = capture_image(url, use_test_image, test_image_name)
294
  if img:
295
- # 确保图像是 RGB 格式(去除 Alpha 通道)
296
  if img.mode == 'RGBA':
297
  img = img.convert('RGB')
298
 
299
- # 转换为base64
300
  buffered = io.BytesIO()
301
  img.save(buffered, format="JPEG")
302
  img_str = base64.b64encode(buffered.getvalue()).decode()
@@ -318,23 +294,18 @@ with demo:
318
  }
319
 
320
  def api_lambda(img_data=None, param_1=200, param_2=60, param_3=60, param_4=100, use_test_image=False, test_image_name=None):
321
- """API端点:分析图像"""
322
  logger.info(f"API call: lambda with params: {param_1}, {param_2}, {param_3}, {param_4}, use_test_image: {use_test_image}, test_image_name: {test_image_name}")
323
  try:
324
- # 获取图像
325
  img = None
326
 
327
- # 使用测试图片
328
  if use_test_image:
329
  logger.info(f"Lambda using test image: {test_image_name}")
330
  img = get_test_image(test_image_name)
331
 
332
- # 检查输入是否是 URL
333
  elif img_data and isinstance(img_data, str) and (img_data.startswith('http://') or img_data.startswith('https://')):
334
  logger.info(f"Lambda received image URL: {img_data}")
335
- img = capture_image(img_data) # 从 URL 获取图像
336
 
337
- # 检查输入是否是 base64 数据
338
  elif img_data and isinstance(img_data, str):
339
  try:
340
  logger.info("Lambda received base64 image data")
@@ -343,28 +314,19 @@ with demo:
343
  except Exception as e:
344
  logger.error(f"Failed to decode base64 image: {e}")
345
 
346
- # 如果没有图像,捕获一个默认图像
347
  if img is None:
348
  logger.info("No valid image data received, using default test image")
349
  img = get_test_image()
350
 
351
- # 分析图像
352
  if img:
353
- # 转换为numpy数组
354
  img_array = np.array(img)
355
-
356
- # 进行图像分析
357
- # 这里应该是您的图像分析代码
358
- # 为了测试,我们使用模拟的分析结果
359
-
360
- # 根据参数调整分析结果
361
  quality_level = 'low'
362
  if 190 <= param_1 <= 210 and param_3 <= 50 and param_4 >= 80:
363
  quality_level = 'high'
364
  elif 185 <= param_1 <= 215 and param_3 <= 70 and param_4 >= 60:
365
  quality_level = 'medium'
366
 
367
- # 根据质量级别设置指标
368
  if quality_level == 'high':
369
  missing_rate = 0.02
370
  excess_rate = 0.01
@@ -378,10 +340,8 @@ with demo:
378
  excess_rate = 0.07
379
  stringing_rate = 0.05
380
 
381
- # 计算均匀性
382
  uniformity_score = 1.0 - (missing_rate + excess_rate + stringing_rate)
383
 
384
- # 计算性能分数
385
  print_quality_score = 1.0 - (missing_rate * 2.0 + excess_rate * 1.5 + stringing_rate * 1.0)
386
  print_quality_score = max(0, min(1, print_quality_score))
387
 
@@ -397,9 +357,7 @@ with demo:
397
  0.2 * material_efficiency_score
398
  )
399
 
400
- # 创建可视化
401
- # 这里应该是您的可视化代码
402
- # 为了测试,我们简单地在图像上添加一些文本
403
  img_draw = img.copy()
404
  draw = ImageDraw.Draw(img_draw)
405
  draw.text((10, 10), f"Quality: {quality_level.upper()}", fill=(255, 0, 0))
@@ -407,7 +365,6 @@ with demo:
407
  draw.text((10, 50), f"Excess: {excess_rate:.2f}", fill=(255, 0, 0))
408
  draw.text((10, 70), f"Stringing: {stringing_rate:.2f}", fill=(255, 0, 0))
409
 
410
- # 返回结果
411
  result = {
412
  "success": True,
413
  "missing_rate": missing_rate,
@@ -420,7 +377,7 @@ with demo:
420
  "total_performance_score": total_performance_score
421
  }
422
 
423
- # 将图像转换为base64并添加到结果中
424
  if img_draw.mode == 'RGBA':
425
  img_draw = img_draw.convert('RGB')
426
 
@@ -442,15 +399,12 @@ with demo:
442
  }
443
 
444
  def api_send_print_parameters(nozzle_temp=200, bed_temp=60, print_speed=60, fan_speed=100):
445
- """API端点:发送打印参数"""
446
  logger.info(f"API call: send_print_parameters with nozzle={nozzle_temp}, bed={bed_temp}, speed={print_speed}, fan={fan_speed}")
447
  return send_print_parameters(nozzle_temp, bed_temp, print_speed, fan_speed)
448
 
449
- # 创建用于 API 输出的 JSON 组件
450
  api_json_output = gr.JSON()
451
  api_text_output = gr.Textbox()
452
 
453
- # 注册 API 端点,使用 Gradio 组件作为输出
454
  capture_frame_api = demo.load(
455
  fn=api_capture_frame,
456
  inputs=[
@@ -458,7 +412,7 @@ with demo:
458
  gr.Checkbox(label="Use Test Image", value=False),
459
  gr.Textbox(label="Test Image Name", value="")
460
  ],
461
- outputs=api_json_output, # 使用 JSON 组件而不是字符串 "json"
462
  api_name="capture_frame"
463
  )
464
 
@@ -473,14 +427,14 @@ with demo:
473
  gr.Checkbox(label="Use Test Image", value=False),
474
  gr.Textbox(label="Test Image Name", value="")
475
  ],
476
- outputs=api_json_output, # 使用 JSON 组件而不是字符串 "json"
477
  api_name="lambda"
478
  )
479
 
480
  get_data_api = demo.load(
481
  fn=api_get_data,
482
  inputs=None,
483
- outputs=api_json_output, # 使用 JSON 组件而不是字符串 "json"
484
  api_name="get_data"
485
  )
486
 
@@ -492,22 +446,19 @@ with demo:
492
  gr.Number(label="Print Speed", value=60),
493
  gr.Number(label="Fan Speed", value=100)
494
  ],
495
- outputs=api_text_output, # 使用 Textbox 组件而不是字符串 "text"
496
  api_name="send_print_parameters"
497
  )
498
 
499
- # 启动应用
500
  if __name__ == "__main__":
501
  logger.info("Starting Bambu A1 Mini Print Control application")
502
 
503
- # 尝试初始化MQTT连接
504
  try:
505
  logger.info("Initializing MQTT client")
506
  create_client(HOST, PORT, USERNAME, PASSWORD)
507
  except Exception as e:
508
  logger.error(f"Failed to initialize MQTT: {e}")
509
 
510
- # 启动应用
511
  demo.queue().launch(
512
  show_error=True,
513
  share=False,
 
16
  import requests
17
 
18
 
 
19
  logging.basicConfig(
20
  level=logging.INFO,
21
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
 
26
  )
27
  logger = logging.getLogger("bambu-analysis")
28
 
29
+ HOST = "mqtt.bambulab.com"
30
+ PORT = 8883
31
+ USERNAME = "bblp"
32
+ PASSWORD = "bblp"
33
  DEFAULT_SERIAL = "0309CA471800852"
34
 
 
35
  if os.environ.get("host"):
36
  HOST = os.environ.get("host")
37
  if os.environ.get("port"):
 
54
  response_topic = None # Will be set dynamically
55
 
56
  def create_client(host, port, username, password):
 
57
  global client
58
  client = mqtt.Client()
59
  client.username_pw_set(username, password)
 
90
  logger.info(f"Subscribing to {response_topic}")
91
  client.subscribe(response_topic)
92
 
 
93
  logger.info(f"Publishing request to {request_topic}")
94
  client.publish(request_topic, json.dumps("HI"))
95
 
 
108
  )
109
 
110
  def send_print_parameters(nozzle_temp, bed_temp, print_speed, fan_speed):
 
111
  serial = DEFAULT_SERIAL
112
  logger.info(f"Sending parameters to {serial}: nozzle={nozzle_temp}, bed={bed_temp}, speed={print_speed}, fan={fan_speed}")
113
  try:
 
135
  return f"Error sending parameters: {e}"
136
 
137
  def get_image_base64(image):
 
138
  if image is None:
139
  logger.warning("No image to encode")
140
  return None
141
 
142
  try:
 
143
  if isinstance(image, np.ndarray):
144
  image = Image.fromarray(image)
145
 
 
146
  buffer = io.BytesIO()
147
  image.save(buffer, format="PNG")
148
  img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
 
153
  return None
154
 
155
  def get_test_image(image_name=None):
 
156
  import os
157
  import random
158
 
 
162
  logger.error(f"Test images directory not found: {test_dir}")
163
  return None
164
 
 
165
  image_files = [f for f in os.listdir(test_dir)
166
  if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
167
 
 
169
  logger.error("No test images found")
170
  return None
171
 
 
172
  if image_name and image_name in image_files:
173
  image_path = os.path.join(test_dir, image_name)
174
  else:
 
175
  image_path = os.path.join(test_dir, random.choice(image_files))
176
 
177
  logger.info(f"Using test image: {image_path}")
 
183
  return None
184
 
185
  def capture_image(url=None, use_test_image=False, test_image_name=None):
186
+
 
187
  if use_test_image:
188
  logger.info("Using test image instead of URL")
189
  test_img = get_test_image(test_image_name)
 
192
  else:
193
  logger.warning("Failed to get specified test image, trying URL")
194
 
195
+
196
  if url:
197
  try:
198
  logger.info(f"Capturing image from URL: {url}")
 
204
  except Exception as e:
205
  logger.error(f"Error capturing image from URL: {e}")
206
 
 
207
  logger.info("URL capture failed or not provided, using random test image")
208
  return get_test_image()
209
 
210
  def health_check():
 
211
  status = {
212
  "app": "running",
213
  "time": time.strftime("%Y-%m-%d %H:%M:%S"),
 
217
  logger.info(f"Health check: {status}")
218
  return status
219
 
 
220
  demo = gr.Blocks(title="Bambu A1 Mini Print Control")
221
 
 
222
  with demo:
223
  gr.Markdown("# Bambu A1 Mini Print Control")
224
 
 
237
 
238
  with gr.Row():
239
  with gr.Column():
 
240
  nozzle_temp = gr.Slider(minimum=180, maximum=250, step=1, value=200, label="Nozzle Temperature (°C)")
241
  bed_temp = gr.Slider(minimum=40, maximum=100, step=1, value=60, label="Bed Temperature (°C)")
242
  print_speed = gr.Slider(minimum=20, maximum=150, step=1, value=60, label="Print Speed (mm/s)")
 
244
 
245
  send_params_btn = gr.Button("Send Print Parameters")
246
 
 
247
  refresh_btn.click(
248
  fn=get_data,
249
  outputs=[current_status, current_bed_temp, current_nozzle_temp, last_update]
 
260
  outputs=[current_status]
261
  )
262
 
 
263
  def api_get_data():
 
264
  logger.info("API call: get_data")
265
  return get_data()
266
 
267
  def api_capture_frame(url=None, use_test_image=False, test_image_name=None):
 
268
  logger.info(f"API call: capture_frame with URL: {url}, use_test_image: {use_test_image}")
269
 
270
  try:
271
  img = capture_image(url, use_test_image, test_image_name)
272
  if img:
 
273
  if img.mode == 'RGBA':
274
  img = img.convert('RGB')
275
 
 
276
  buffered = io.BytesIO()
277
  img.save(buffered, format="JPEG")
278
  img_str = base64.b64encode(buffered.getvalue()).decode()
 
294
  }
295
 
296
  def api_lambda(img_data=None, param_1=200, param_2=60, param_3=60, param_4=100, use_test_image=False, test_image_name=None):
 
297
  logger.info(f"API call: lambda with params: {param_1}, {param_2}, {param_3}, {param_4}, use_test_image: {use_test_image}, test_image_name: {test_image_name}")
298
  try:
 
299
  img = None
300
 
 
301
  if use_test_image:
302
  logger.info(f"Lambda using test image: {test_image_name}")
303
  img = get_test_image(test_image_name)
304
 
 
305
  elif img_data and isinstance(img_data, str) and (img_data.startswith('http://') or img_data.startswith('https://')):
306
  logger.info(f"Lambda received image URL: {img_data}")
307
+ img = capture_image(img_data)
308
 
 
309
  elif img_data and isinstance(img_data, str):
310
  try:
311
  logger.info("Lambda received base64 image data")
 
314
  except Exception as e:
315
  logger.error(f"Failed to decode base64 image: {e}")
316
 
 
317
  if img is None:
318
  logger.info("No valid image data received, using default test image")
319
  img = get_test_image()
320
 
 
321
  if img:
 
322
  img_array = np.array(img)
323
+
 
 
 
 
 
324
  quality_level = 'low'
325
  if 190 <= param_1 <= 210 and param_3 <= 50 and param_4 >= 80:
326
  quality_level = 'high'
327
  elif 185 <= param_1 <= 215 and param_3 <= 70 and param_4 >= 60:
328
  quality_level = 'medium'
329
 
 
330
  if quality_level == 'high':
331
  missing_rate = 0.02
332
  excess_rate = 0.01
 
340
  excess_rate = 0.07
341
  stringing_rate = 0.05
342
 
 
343
  uniformity_score = 1.0 - (missing_rate + excess_rate + stringing_rate)
344
 
 
345
  print_quality_score = 1.0 - (missing_rate * 2.0 + excess_rate * 1.5 + stringing_rate * 1.0)
346
  print_quality_score = max(0, min(1, print_quality_score))
347
 
 
357
  0.2 * material_efficiency_score
358
  )
359
 
360
+
 
 
361
  img_draw = img.copy()
362
  draw = ImageDraw.Draw(img_draw)
363
  draw.text((10, 10), f"Quality: {quality_level.upper()}", fill=(255, 0, 0))
 
365
  draw.text((10, 50), f"Excess: {excess_rate:.2f}", fill=(255, 0, 0))
366
  draw.text((10, 70), f"Stringing: {stringing_rate:.2f}", fill=(255, 0, 0))
367
 
 
368
  result = {
369
  "success": True,
370
  "missing_rate": missing_rate,
 
377
  "total_performance_score": total_performance_score
378
  }
379
 
380
+
381
  if img_draw.mode == 'RGBA':
382
  img_draw = img_draw.convert('RGB')
383
 
 
399
  }
400
 
401
  def api_send_print_parameters(nozzle_temp=200, bed_temp=60, print_speed=60, fan_speed=100):
 
402
  logger.info(f"API call: send_print_parameters with nozzle={nozzle_temp}, bed={bed_temp}, speed={print_speed}, fan={fan_speed}")
403
  return send_print_parameters(nozzle_temp, bed_temp, print_speed, fan_speed)
404
 
 
405
  api_json_output = gr.JSON()
406
  api_text_output = gr.Textbox()
407
 
 
408
  capture_frame_api = demo.load(
409
  fn=api_capture_frame,
410
  inputs=[
 
412
  gr.Checkbox(label="Use Test Image", value=False),
413
  gr.Textbox(label="Test Image Name", value="")
414
  ],
415
+ outputs=api_json_output,
416
  api_name="capture_frame"
417
  )
418
 
 
427
  gr.Checkbox(label="Use Test Image", value=False),
428
  gr.Textbox(label="Test Image Name", value="")
429
  ],
430
+ outputs=api_json_output,
431
  api_name="lambda"
432
  )
433
 
434
  get_data_api = demo.load(
435
  fn=api_get_data,
436
  inputs=None,
437
+ outputs=api_json_output,
438
  api_name="get_data"
439
  )
440
 
 
446
  gr.Number(label="Print Speed", value=60),
447
  gr.Number(label="Fan Speed", value=100)
448
  ],
449
+ outputs=api_text_output,
450
  api_name="send_print_parameters"
451
  )
452
 
 
453
  if __name__ == "__main__":
454
  logger.info("Starting Bambu A1 Mini Print Control application")
455
 
 
456
  try:
457
  logger.info("Initializing MQTT client")
458
  create_client(HOST, PORT, USERNAME, PASSWORD)
459
  except Exception as e:
460
  logger.error(f"Failed to initialize MQTT: {e}")
461
 
 
462
  demo.queue().launch(
463
  show_error=True,
464
  share=False,