Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from ultralytics import YOLO | |
| import matplotlib.pyplot as plt | |
| import io | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
| # 修复代码:显式允许 YOLO 的自定义类 | |
| from ultralytics.nn.tasks import DetectionModel | |
| torch.serialization.add_safe_globals([DetectionModel]) | |
| # 加载模型(现在可以安全使用 weights_only=True) | |
| model = YOLO('detect-best.pt', weights_only=True) | |
| def predict(img, conf, iou): | |
| results = model.predict(img, conf=conf, iou=iou) | |
| name = results[0].names | |
| cls = results[0].boxes.cls | |
| # 初始化计数器 | |
| counters = { | |
| 0: 'crazing', | |
| 1: 'inclusion', | |
| 2: 'patches', | |
| 3: 'pitted_surface', | |
| 4: 'rolled_inscale', | |
| 5: 'scratches' | |
| } | |
| counts = {v: 0 for v in counters.values()} | |
| # 统计类别 | |
| for i in cls: | |
| counts[counters[int(i)]] += 1 | |
| # 绘制柱状图 | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.bar(counts.keys(), counts.values()) | |
| ax.set_title('Defect Category Distribution') | |
| ax.set_ylim(0, max(counts.values()) + 1) | |
| plt.xticks(rotation=45, ha="right") | |
| # 转换为图像 | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight') | |
| plt.close() | |
| chart_img = Image.open(buf) | |
| # 处理检测结果 | |
| im_bgr = results[0].plot() | |
| det_img = Image.fromarray(im_bgr[..., ::-1]) | |
| # 返回检测结果和统计图 | |
| return [det_img, chart_img] | |
| # 界面设置 | |
| base_conf, base_iou = 0.25, 0.45 | |
| title = "基于改进YOLOv8算法的工业瑕疵辅助检测系统" | |
| des = "鼠标点击上传图片即可检测缺陷,可通过鼠标调整预测置信度,还可点击网页最下方示例图片进行预测" | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="输入图片"), | |
| gr.Slider(0, 1, value=base_conf, label="置信度阈值"), | |
| gr.Slider(0, 1, value=base_iou, label="IoU阈值") | |
| ], | |
| outputs=[ | |
| gr.Image(label="检测结果"), | |
| gr.Image(label="缺陷统计") | |
| ], | |
| title=title, | |
| description=des, | |
| examples=[ | |
| ["example1.jpg", base_conf, base_iou], | |
| ["example2.jpg", base_conf, base_iou], | |
| ["example3.jpg", base_conf, base_iou] | |
| ] | |
| ) | |
| interface.launch() | |