File size: 2,351 Bytes
9cefd5f
 
 
 
 
 
 
ad08f40
 
 
 
 
 
bdb9039
9cefd5f
 
 
 
 
ad08f40
 
 
 
 
 
 
 
 
 
 
 
 
9cefd5f
ad08f40
 
 
 
 
 
 
9cefd5f
ad08f40
 
9cefd5f
ad08f40
 
 
 
 
 
 
 
 
 
9cefd5f
ad08f40
9cefd5f
 
 
ad08f40
9cefd5f
ad08f40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cefd5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
from PIL import Image
from ultralytics import YOLO
import matplotlib.pyplot as plt
import io
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

# 修复代码:显式允许 YOLO 的自定义类
from ultralytics.nn.tasks import DetectionModel
torch.serialization.add_safe_globals([DetectionModel])

# 加载模型(现在可以安全使用 weights_only=True)
model = YOLO('detect-best.pt', weights_only=True)

def predict(img, conf, iou):
    results = model.predict(img, conf=conf, iou=iou)
    name = results[0].names
    cls = results[0].boxes.cls
    
    # 初始化计数器
    counters = {
        0: 'crazing',
        1: 'inclusion',
        2: 'patches',
        3: 'pitted_surface',
        4: 'rolled_inscale',
        5: 'scratches'
    }
    counts = {v: 0 for v in counters.values()}
    
    # 统计类别
    for i in cls:
        counts[counters[int(i)]] += 1
    
    # 绘制柱状图
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.bar(counts.keys(), counts.values())
    ax.set_title('Defect Category Distribution')
    ax.set_ylim(0, max(counts.values()) + 1)
    plt.xticks(rotation=45, ha="right")
    
    # 转换为图像
    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight')
    plt.close()
    chart_img = Image.open(buf)
    
    # 处理检测结果
    im_bgr = results[0].plot()
    det_img = Image.fromarray(im_bgr[..., ::-1])
    
    # 返回检测结果和统计图
    return [det_img, chart_img]

# 界面设置
base_conf, base_iou = 0.25, 0.45
title = "基于改进YOLOv8算法的工业瑕疵辅助检测系统"
des = "鼠标点击上传图片即可检测缺陷,可通过鼠标调整预测置信度,还可点击网页最下方示例图片进行预测"

interface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="输入图片"),
        gr.Slider(0, 1, value=base_conf, label="置信度阈值"),
        gr.Slider(0, 1, value=base_iou, label="IoU阈值")
    ],
    outputs=[
        gr.Image(label="检测结果"),
        gr.Image(label="缺陷统计")
    ],
    title=title,
    description=des,
    examples=[
        ["example1.jpg", base_conf, base_iou],
        ["example2.jpg", base_conf, base_iou],
        ["example3.jpg", base_conf, base_iou]
    ]
)

interface.launch()