yikesongcai commited on
Commit
ad08f40
·
verified ·
1 Parent(s): bdb9039

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -48
app.py CHANGED
@@ -5,66 +5,77 @@ from ultralytics import YOLO
5
  import matplotlib.pyplot as plt
6
  import io
7
  from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
 
 
 
 
 
 
8
  model = YOLO('detect-best.pt', weights_only=True)
9
 
10
  def predict(img, conf, iou):
11
  results = model.predict(img, conf=conf, iou=iou)
12
  name = results[0].names
13
  cls = results[0].boxes.cls
14
- crazing = 0
15
- inclusion = 0
16
- patches = 0
17
- pitted_surface = 0
18
- rolled_inscale = 0
19
- scratches = 0
 
 
 
 
 
 
 
20
  for i in cls:
21
- if i == 0:
22
- crazing += 1
23
- elif i == 1:
24
- inclusion += 1
25
- elif i == 2:
26
- patches += 1
27
- elif i == 3:
28
- pitted_surface += 1
29
- elif i == 4:
30
- rolled_inscale += 1
31
- elif i == 5:
32
- scratches += 1
33
- # 绘制柱状图
34
- fig, ax = plt.subplots()
35
- categories = ['crazing','inclusion', 'patches' ,'pitted_surface', 'rolled_inscale' ,'scratches']
36
- counts = [crazing,inclusion, patches ,pitted_surface, rolled_inscale ,scratches]
37
- ax.bar(categories, counts)
38
- ax.set_title('Category-Count')
39
- plt.ylim(0,5)
40
  plt.xticks(rotation=45, ha="right")
41
- ax.set_xlabel('Category')
42
- ax.set_ylabel('Count')
43
- # 将图表保存为字节流
44
  buf = io.BytesIO()
45
- canvas = FigureCanvas(fig)
46
- canvas.print_png(buf)
47
- plt.close(fig) # 关闭图形,释放资源
48
-
49
- # 将字节流转换为PIL Image
50
- image_png = Image.open(buf)
51
- # 绘制并返回结果图片和类别计数图表
52
-
53
- for i, r in enumerate(results):
54
- # Plot results image
55
- im_bgr = r.plot() # BGR-order numpy array
56
- im_rgb = Image.fromarray(im_bgr[..., ::-1]) # RGB-order PIL image
57
-
58
- # Show results to screen (in supported environments)
59
- return im_rgb
60
 
 
61
  base_conf, base_iou = 0.25, 0.45
62
  title = "基于改进YOLOv8算法的工业瑕疵辅助检测系统"
63
  des = "鼠标点击上传图片即可检测缺陷,可通过鼠标调整预测置信度,还可点击网页最下方示例图片进行预测"
 
64
  interface = gr.Interface(
65
- inputs=['image', gr.Slider(maximum=1, minimum=0, value=base_conf), gr.Slider(maximum=1, minimum=0, value=base_iou)],
66
- outputs=["image"], fn=predict, title=title, description=des,
67
- examples=[["example1.jpg", base_conf, base_iou],
68
- ["example2.jpg", base_conf, base_iou],
69
- ["example3.jpg", base_conf, base_iou]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  interface.launch()
 
5
  import matplotlib.pyplot as plt
6
  import io
7
  from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
8
+
9
+ # 修复代码:显式允许 YOLO 的自定义类
10
+ from ultralytics.nn.tasks import DetectionModel
11
+ torch.serialization.add_safe_globals([DetectionModel])
12
+
13
+ # 加载模型(现在可以安全使用 weights_only=True)
14
  model = YOLO('detect-best.pt', weights_only=True)
15
 
16
  def predict(img, conf, iou):
17
  results = model.predict(img, conf=conf, iou=iou)
18
  name = results[0].names
19
  cls = results[0].boxes.cls
20
+
21
+ # 初始化计数器
22
+ counters = {
23
+ 0: 'crazing',
24
+ 1: 'inclusion',
25
+ 2: 'patches',
26
+ 3: 'pitted_surface',
27
+ 4: 'rolled_inscale',
28
+ 5: 'scratches'
29
+ }
30
+ counts = {v: 0 for v in counters.values()}
31
+
32
+ # 统计类别
33
  for i in cls:
34
+ counts[counters[int(i)]] += 1
35
+
36
+ # 绘制柱状图
37
+ fig, ax = plt.subplots(figsize=(10, 5))
38
+ ax.bar(counts.keys(), counts.values())
39
+ ax.set_title('Defect Category Distribution')
40
+ ax.set_ylim(0, max(counts.values()) + 1)
 
 
 
 
 
 
 
 
 
 
 
 
41
  plt.xticks(rotation=45, ha="right")
42
+
43
+ # 转换为图像
 
44
  buf = io.BytesIO()
45
+ plt.savefig(buf, format='png', bbox_inches='tight')
46
+ plt.close()
47
+ chart_img = Image.open(buf)
48
+
49
+ # 处理检测结果
50
+ im_bgr = results[0].plot()
51
+ det_img = Image.fromarray(im_bgr[..., ::-1])
52
+
53
+ # 返回检测结果和统计图
54
+ return [det_img, chart_img]
 
 
 
 
 
55
 
56
+ # 界面设置
57
  base_conf, base_iou = 0.25, 0.45
58
  title = "基于改进YOLOv8算法的工业瑕疵辅助检测系统"
59
  des = "鼠标点击上传图片即可检测缺陷,可通过鼠标调整预测置信度,还可点击网页最下方示例图片进行预测"
60
+
61
  interface = gr.Interface(
62
+ fn=predict,
63
+ inputs=[
64
+ gr.Image(type="pil", label="输入图片"),
65
+ gr.Slider(0, 1, value=base_conf, label="置信度阈值"),
66
+ gr.Slider(0, 1, value=base_iou, label="IoU阈值")
67
+ ],
68
+ outputs=[
69
+ gr.Image(label="检测结果"),
70
+ gr.Image(label="缺陷统计")
71
+ ],
72
+ title=title,
73
+ description=des,
74
+ examples=[
75
+ ["example1.jpg", base_conf, base_iou],
76
+ ["example2.jpg", base_conf, base_iou],
77
+ ["example3.jpg", base_conf, base_iou]
78
+ ]
79
+ )
80
+
81
  interface.launch()