Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from qai_hub_models.models.detr_resnet50 import Model | |
| from PIL import Image, ImageDraw | |
| import numpy as np | |
| # 注册 AVIF 支持(根据所安装的插件选择一种) | |
| try: | |
| from pillow_avif import register_avif_opener | |
| register_avif_opener() | |
| except ImportError: | |
| try: | |
| import pillow_heif | |
| pillow_heif.register_heif_opener() | |
| except ImportError: | |
| print("AVIF support not available. Please install 'pillow-avif-plugin' or 'pillow-heif'.") | |
| # 加载模型 | |
| torch_model = Model.from_pretrained() | |
| def detect_objects(image): | |
| if image is None: | |
| raise ValueError("No image uploaded!") # 检查图像是否为 None | |
| # 将图像转换为 RGB 格式并调整大小 | |
| image = Image.fromarray(image).convert("RGB") | |
| original_image = image.copy() # 保存原始图像以便绘制边界框 | |
| image = image.resize((800, 800)) # 根据模型要求调整图像大小 | |
| # 转换为张量并进行标准化 | |
| image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1) # 转换为 (C, H, W) 格式 | |
| image_tensor = image_tensor.float() / 255.0 # 将像素值归一化到 [0, 1] | |
| # 使用模型进行推理 | |
| with torch.no_grad(): # 禁用梯度计算以提高性能 | |
| outputs = torch_model(image_tensor.unsqueeze(0)) # 添加批次维度 | |
| # 获取预测结果(根据具体的输出格式进行调整) | |
| predictions = outputs['logits'] if 'logits' in outputs else outputs[0] # 确保获取正确的输出 | |
| # 格式化输出结果 | |
| detections = [] | |
| confidence_threshold = 0.8 # 设置置信度阈值为 0.8 | |
| for i in range(predictions.shape[1]): # 遍历每个预测 | |
| score = predictions[0, i, -1].item() # 假设最后一个维度是分数 | |
| if score > confidence_threshold: # 使用新的阈值过滤低置信度的预测 | |
| box = predictions[0, i, :-1].tolist() # 获取边界框坐标(假设在前面) | |
| box[0] *= original_image.width / 800 # 将坐标缩放回原始图像尺寸 | |
| box[1] *= original_image.height / 800 | |
| box[2] *= original_image.width / 800 | |
| box[3] *= original_image.height / 800 | |
| detections.append({ | |
| "label": f"Object {i}", | |
| "confidence": round(score, 3), | |
| "box": box, | |
| }) | |
| # 绘制边界框和标签到原始图像上 | |
| draw = ImageDraw.Draw(original_image) | |
| if box[1] < box[3]: # 确保 y_min < y_max | |
| draw.rectangle(box[:4], outline="red", width=3) | |
| draw.text((box[0], box[1]), f"{detections[-1]['label']} ({detections[-1]['confidence']})", fill="red") | |
| return original_image, detections | |
| # 创建 Gradio 接口,自动处理 AVIF 图像并转换为 PNG 格式以供显示和处理。 | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# Object Detection with DETR-ResNet50") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="numpy", label="Upload Image (supports PNG, JPEG, AVIF...)") | |
| submit_button = gr.Button("Submit") | |
| clear_button = gr.Button("Clear") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Detected Image") | |
| output_json = gr.JSON(label="Detections") | |
| def on_submit(image): | |
| try: | |
| detected_image, detections = detect_objects(image) | |
| return detected_image, detections | |
| except Exception as e: | |
| return None, {"error": str(e)} | |
| def on_clear(): | |
| return None, None, None # 清空输入和输出 | |
| submit_button.click(on_submit, inputs=image_input, outputs=[output_image, output_json]) | |
| clear_button.click(on_clear, inputs=None, outputs=[image_input, output_image, output_json]) # 修复清除功能 | |
| # 启动应用程序 | |
| if __name__ == "__main__": | |
| iface.launch() | |