Dibiddo's picture
Update app.py
0168843 verified
import gradio as gr
import torch
from qai_hub_models.models.detr_resnet50 import Model
from PIL import Image, ImageDraw
import numpy as np
# 注册 AVIF 支持(根据所安装的插件选择一种)
try:
from pillow_avif import register_avif_opener
register_avif_opener()
except ImportError:
try:
import pillow_heif
pillow_heif.register_heif_opener()
except ImportError:
print("AVIF support not available. Please install 'pillow-avif-plugin' or 'pillow-heif'.")
# 加载模型
torch_model = Model.from_pretrained()
def detect_objects(image):
if image is None:
raise ValueError("No image uploaded!") # 检查图像是否为 None
# 将图像转换为 RGB 格式并调整大小
image = Image.fromarray(image).convert("RGB")
original_image = image.copy() # 保存原始图像以便绘制边界框
image = image.resize((800, 800)) # 根据模型要求调整图像大小
# 转换为张量并进行标准化
image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1) # 转换为 (C, H, W) 格式
image_tensor = image_tensor.float() / 255.0 # 将像素值归一化到 [0, 1]
# 使用模型进行推理
with torch.no_grad(): # 禁用梯度计算以提高性能
outputs = torch_model(image_tensor.unsqueeze(0)) # 添加批次维度
# 获取预测结果(根据具体的输出格式进行调整)
predictions = outputs['logits'] if 'logits' in outputs else outputs[0] # 确保获取正确的输出
# 格式化输出结果
detections = []
confidence_threshold = 0.8 # 设置置信度阈值为 0.8
for i in range(predictions.shape[1]): # 遍历每个预测
score = predictions[0, i, -1].item() # 假设最后一个维度是分数
if score > confidence_threshold: # 使用新的阈值过滤低置信度的预测
box = predictions[0, i, :-1].tolist() # 获取边界框坐标(假设在前面)
box[0] *= original_image.width / 800 # 将坐标缩放回原始图像尺寸
box[1] *= original_image.height / 800
box[2] *= original_image.width / 800
box[3] *= original_image.height / 800
detections.append({
"label": f"Object {i}",
"confidence": round(score, 3),
"box": box,
})
# 绘制边界框和标签到原始图像上
draw = ImageDraw.Draw(original_image)
if box[1] < box[3]: # 确保 y_min < y_max
draw.rectangle(box[:4], outline="red", width=3)
draw.text((box[0], box[1]), f"{detections[-1]['label']} ({detections[-1]['confidence']})", fill="red")
return original_image, detections
# 创建 Gradio 接口,自动处理 AVIF 图像并转换为 PNG 格式以供显示和处理。
with gr.Blocks() as iface:
gr.Markdown("# Object Detection with DETR-ResNet50")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="numpy", label="Upload Image (supports PNG, JPEG, AVIF...)")
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear")
with gr.Column(scale=1):
output_image = gr.Image(label="Detected Image")
output_json = gr.JSON(label="Detections")
def on_submit(image):
try:
detected_image, detections = detect_objects(image)
return detected_image, detections
except Exception as e:
return None, {"error": str(e)}
def on_clear():
return None, None, None # 清空输入和输出
submit_button.click(on_submit, inputs=image_input, outputs=[output_image, output_json])
clear_button.click(on_clear, inputs=None, outputs=[image_input, output_image, output_json]) # 修复清除功能
# 启动应用程序
if __name__ == "__main__":
iface.launch()