import gradio as gr import torch from PIL import Image from ultralytics import YOLO import matplotlib.pyplot as plt import io from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas # 修复代码:显式允许 YOLO 的自定义类 from ultralytics.nn.tasks import DetectionModel torch.serialization.add_safe_globals([DetectionModel]) # 加载模型(现在可以安全使用 weights_only=True) model = YOLO('detect-best.pt', weights_only=True) def predict(img, conf, iou): results = model.predict(img, conf=conf, iou=iou) name = results[0].names cls = results[0].boxes.cls # 初始化计数器 counters = { 0: 'crazing', 1: 'inclusion', 2: 'patches', 3: 'pitted_surface', 4: 'rolled_inscale', 5: 'scratches' } counts = {v: 0 for v in counters.values()} # 统计类别 for i in cls: counts[counters[int(i)]] += 1 # 绘制柱状图 fig, ax = plt.subplots(figsize=(10, 5)) ax.bar(counts.keys(), counts.values()) ax.set_title('Defect Category Distribution') ax.set_ylim(0, max(counts.values()) + 1) plt.xticks(rotation=45, ha="right") # 转换为图像 buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') plt.close() chart_img = Image.open(buf) # 处理检测结果 im_bgr = results[0].plot() det_img = Image.fromarray(im_bgr[..., ::-1]) # 返回检测结果和统计图 return [det_img, chart_img] # 界面设置 base_conf, base_iou = 0.25, 0.45 title = "基于改进YOLOv8算法的工业瑕疵辅助检测系统" des = "鼠标点击上传图片即可检测缺陷,可通过鼠标调整预测置信度,还可点击网页最下方示例图片进行预测" interface = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="输入图片"), gr.Slider(0, 1, value=base_conf, label="置信度阈值"), gr.Slider(0, 1, value=base_iou, label="IoU阈值") ], outputs=[ gr.Image(label="检测结果"), gr.Image(label="缺陷统计") ], title=title, description=des, examples=[ ["example1.jpg", base_conf, base_iou], ["example2.jpg", base_conf, base_iou], ["example3.jpg", base_conf, base_iou] ] ) interface.launch()