hippoiam10 commited on
Commit
9ec5f95
·
verified ·
1 Parent(s): 7e8f486

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision import transforms
4
+ from ultralytics import YOLO
5
+ import gradio as gr
6
+ from PIL import Image, ImageDraw
7
+ import numpy as np
8
+ from sklearn.metrics import precision_recall_fscore_support
9
+
10
+ # 載入 YOLOv8 模型
11
+ yolo_model = YOLO("yolov8n.pt")
12
+
13
+ # 載入 Faster R-CNN 模型
14
+ faster_rcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="COCO_V1")
15
+ faster_rcnn_model.eval()
16
+
17
+ # 偵測函數
18
+ def detect_objects(image):
19
+ transform = transforms.Compose([transforms.ToTensor()])
20
+ img_tensor = transform(image).unsqueeze(0)
21
+
22
+ # YOLO 偵測
23
+ yolo_results = yolo_model(image)
24
+ yolo_image = yolo_results[0].plot() # YOLO 偵測結果
25
+ yolo_boxes = yolo_results[0].boxes.xyxy.cpu().numpy()
26
+ yolo_confidence = yolo_results[0].boxes.conf.cpu().numpy()
27
+
28
+ # Faster R-CNN 偵測
29
+ with torch.no_grad():
30
+ prediction = faster_rcnn_model(img_tensor)
31
+
32
+ rcnn_boxes = prediction[0]["boxes"].cpu().numpy()
33
+ rcnn_scores = prediction[0]["scores"].cpu().numpy()
34
+
35
+ # Faster R-CNN 畫框
36
+ rcnn_image = image.copy()
37
+ draw = ImageDraw.Draw(rcnn_image)
38
+ for i in range(len(rcnn_scores)):
39
+ if rcnn_scores[i] > 0.5:
40
+ box = rcnn_boxes[i]
41
+ draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline="red", width=3)
42
+
43
+ # 評估指標
44
+ def evaluate_model(pred_boxes, confs):
45
+ y_true = []
46
+ y_pred = []
47
+ for i in range(len(pred_boxes)):
48
+ if confs[i] > 0.5:
49
+ y_true.append(1)
50
+ y_pred.append(1)
51
+ precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
52
+ return precision, recall, f1
53
+
54
+ # 計算 YOLO 和 Faster R-CNN 的評估指標
55
+ yolo_precision, yolo_recall, yolo_f1 = evaluate_model(yolo_boxes, yolo_confidence)
56
+ rcnn_precision, rcnn_recall, rcnn_f1 = evaluate_model(rcnn_boxes, rcnn_scores)
57
+
58
+ evaluation_results = {
59
+ "YOLO": {
60
+ "Precision": round(yolo_precision, 3),
61
+ "Recall": round(yolo_recall, 3),
62
+ "F1-Score": round(yolo_f1, 3),
63
+ "Confidence": round(np.mean(yolo_confidence), 3),
64
+ },
65
+ "Faster R-CNN": {
66
+ "Precision": round(rcnn_precision, 3),
67
+ "Recall": round(rcnn_recall, 3),
68
+ "F1-Score": round(rcnn_f1, 3),
69
+ "Confidence": round(np.mean(rcnn_scores), 3),
70
+ }
71
+ }
72
+
73
+ return Image.fromarray(yolo_image), rcnn_image, evaluation_results
74
+
75
+ # Gradio 介面
76
+ demo = gr.Interface(
77
+ fn=detect_objects,
78
+ inputs=gr.Image(type="pil", label="上傳圖片"),
79
+ outputs=[
80
+ gr.Image(type="pil", label="YOLO 偵測結果"),
81
+ gr.Image(type="pil", label="Faster R-CNN 偵測結果"),
82
+ gr.JSON(label="評估指標")
83
+ ],
84
+ title="YOLO vs Faster R-CNN 物件偵測",
85
+ description="上傳圖片,系統將使用 YOLOv8 和 Faster R-CNN 進行偵測並顯示結果。",
86
+ )
87
+
88
+ # 使用 gradio deploy 而非 launch
89
+ demo.queue() # 啟用佇列,確保請求不會超載
90
+ demo.launch()