Dibiddo commited on
Commit
1f1debf
·
verified ·
1 Parent(s): 85a1045

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -2,16 +2,23 @@ import gradio as gr
2
  import torch
3
  import qai_hub as hub
4
  from qai_hub_models.models.detr_resnet50 import Model
 
 
5
 
6
  # 加载模型
7
  torch_model = Model.from_pretrained()
8
 
9
  def detect_objects(image):
10
- # 预处理图像
11
- image_tensor = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) # 转换为 (C, H, W) 格式
 
12
 
 
 
 
 
13
  # 使用模型进行推理
14
- outputs = torch_model(image_tensor)
15
 
16
  # 格式化输出结果
17
  detections = []
 
2
  import torch
3
  import qai_hub as hub
4
  from qai_hub_models.models.detr_resnet50 import Model
5
+ from PIL import Image
6
+ import numpy as np
7
 
8
  # 加载模型
9
  torch_model = Model.from_pretrained()
10
 
11
  def detect_objects(image):
12
+ # 将图像转换为 RGB 格式并调整大小
13
+ image = Image.fromarray(image).convert("RGB")
14
+ image = image.resize((800, 800)) # 根据模型要求调整图像大小
15
 
16
+ # 转换为张量并进行标准化
17
+ image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1) # 转换为 (C, H, W) 格式
18
+ image_tensor = image_tensor.float() / 255.0 # 将像素值归一化到 [0, 1]
19
+
20
  # 使用模型进行推理
21
+ outputs = torch_model(image_tensor.unsqueeze(0)) # 添加批次维度
22
 
23
  # 格式化输出结果
24
  detections = []