Dibiddo commited on
Commit
5bdb8ec
·
verified ·
1 Parent(s): 1f1debf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -18,16 +18,23 @@ def detect_objects(image):
18
  image_tensor = image_tensor.float() / 255.0 # 将像素值归一化到 [0, 1]
19
 
20
  # 使用模型进行推理
21
- outputs = torch_model(image_tensor.unsqueeze(0)) # 添加批次维
 
 
 
 
22
 
23
  # 格式化输出结果
24
  detections = []
25
- for output in outputs:
26
- detections.append({
27
- "label": output['label'],
28
- "confidence": round(output['score'], 3),
29
- "box": output['box'].tolist() # 转换为列表格式以便 JSON 输出
30
- })
 
 
 
31
 
32
  return detections
33
 
 
18
  image_tensor = image_tensor.float() / 255.0 # 将像素值归一化到 [0, 1]
19
 
20
  # 使用模型进行推理
21
+ with torch.no_grad(): # 禁用梯计算以提高性能
22
+ outputs = torch_model(image_tensor.unsqueeze(0)) # 添加批次维度
23
+
24
+ # 获取预测结果(根据具体的输出格式进行调整)
25
+ predictions = outputs['logits'] if 'logits' in outputs else outputs[0] # 确保获取正确的输出
26
 
27
  # 格式化输出结果
28
  detections = []
29
+ for i in range(predictions.shape[1]): # 遍历每个预测
30
+ score = predictions[0, i, -1].item() # 假设最后一个维度是分数
31
+ if score > 0.5: # 设置阈值过滤低置信度的预测
32
+ box = predictions[0, i, :-1].tolist() # 获取边界框坐标(假设在前面)
33
+ detections.append({
34
+ "label": f"Object {i}", # 根据实际情况替换标签
35
+ "confidence": round(score, 3),
36
+ "box": box,
37
+ })
38
 
39
  return detections
40