import gradio as gr import torch from qai_hub_models.models.detr_resnet50 import Model from PIL import Image, ImageDraw import numpy as np # 注册 AVIF 支持(根据所安装的插件选择一种) try: from pillow_avif import register_avif_opener register_avif_opener() except ImportError: try: import pillow_heif pillow_heif.register_heif_opener() except ImportError: print("AVIF support not available. Please install 'pillow-avif-plugin' or 'pillow-heif'.") # 加载模型 torch_model = Model.from_pretrained() def detect_objects(image): if image is None: raise ValueError("No image uploaded!") # 检查图像是否为 None # 将图像转换为 RGB 格式并调整大小 image = Image.fromarray(image).convert("RGB") original_image = image.copy() # 保存原始图像以便绘制边界框 image = image.resize((800, 800)) # 根据模型要求调整图像大小 # 转换为张量并进行标准化 image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1) # 转换为 (C, H, W) 格式 image_tensor = image_tensor.float() / 255.0 # 将像素值归一化到 [0, 1] # 使用模型进行推理 with torch.no_grad(): # 禁用梯度计算以提高性能 outputs = torch_model(image_tensor.unsqueeze(0)) # 添加批次维度 # 获取预测结果(根据具体的输出格式进行调整) predictions = outputs['logits'] if 'logits' in outputs else outputs[0] # 确保获取正确的输出 # 格式化输出结果 detections = [] confidence_threshold = 0.8 # 设置置信度阈值为 0.8 for i in range(predictions.shape[1]): # 遍历每个预测 score = predictions[0, i, -1].item() # 假设最后一个维度是分数 if score > confidence_threshold: # 使用新的阈值过滤低置信度的预测 box = predictions[0, i, :-1].tolist() # 获取边界框坐标(假设在前面) box[0] *= original_image.width / 800 # 将坐标缩放回原始图像尺寸 box[1] *= original_image.height / 800 box[2] *= original_image.width / 800 box[3] *= original_image.height / 800 detections.append({ "label": f"Object {i}", "confidence": round(score, 3), "box": box, }) # 绘制边界框和标签到原始图像上 draw = ImageDraw.Draw(original_image) if box[1] < box[3]: # 确保 y_min < y_max draw.rectangle(box[:4], outline="red", width=3) draw.text((box[0], box[1]), f"{detections[-1]['label']} ({detections[-1]['confidence']})", fill="red") return original_image, detections # 创建 Gradio 接口,自动处理 AVIF 图像并转换为 PNG 格式以供显示和处理。 with gr.Blocks() as iface: gr.Markdown("# Object Detection with DETR-ResNet50") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="numpy", label="Upload Image (supports PNG, JPEG, AVIF...)") submit_button = gr.Button("Submit") clear_button = gr.Button("Clear") with gr.Column(scale=1): output_image = gr.Image(label="Detected Image") output_json = gr.JSON(label="Detections") def on_submit(image): try: detected_image, detections = detect_objects(image) return detected_image, detections except Exception as e: return None, {"error": str(e)} def on_clear(): return None, None, None # 清空输入和输出 submit_button.click(on_submit, inputs=image_input, outputs=[output_image, output_json]) clear_button.click(on_clear, inputs=None, outputs=[image_input, output_image, output_json]) # 修复清除功能 # 启动应用程序 if __name__ == "__main__": iface.launch()