oilbread commited on
Commit
045db4b
Β·
1 Parent(s): ed0e36e
Files changed (1) hide show
  1. app.py +99 -123
app.py CHANGED
@@ -1,135 +1,111 @@
1
- import gradio as gr
2
  import cv2
3
- import numpy as np
4
- import onnxruntime as ort
5
  import torch
6
- import subprocess
7
- import os
8
- from gradio_webrtc import WebRTC
9
- # --- λͺ¨λΈ λ‘œλ“œ ---
10
-
11
- def convert_pt_to_onnx(pt_path="yolov5n.pt", onnx_path="yolov5n.onnx"):
12
- if not os.path.exists(pt_path):
13
- raise FileNotFoundError(f"{pt_path} 파일이 μ‘΄μž¬ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.")
14
- if os.path.exists(onnx_path):
15
- print("ONNX λͺ¨λΈμ΄ 이미 μ‘΄μž¬ν•©λ‹ˆλ‹€.")
16
- return onnx_path
17
-
18
- # PyTorch YOLOv5 export via subprocess (ꢌμž₯)
19
- subprocess.run([
20
- "python", "export.py",
21
- "--weights", pt_path,
22
- "--img", "640",
23
- "--batch", "1",
24
- "--device", "cpu",
25
- "--include", "onnx",
26
- "--simplify"
27
- ], cwd="yolov5", check=True)
28
-
29
- # κ²°κ³Ό 파일 μœ„μΉ˜ 확인
30
- if not os.path.exists(onnx_path):
31
- raise RuntimeError("ONNX λ³€ν™˜ μ‹€νŒ¨. export.py μ‹€ν–‰ κ²°κ³Όλ₯Ό ν™•μΈν•˜μ„Έμš”.")
32
- return onnx_path
33
- session = None
34
- input_name = None
35
- output_names = []
36
-
37
- # session = ort.InferenceSession("yolov5n.onnx")
38
- # input_name = session.get_inputs()[0].name
39
- # output_names = [output.name for output in session.get_outputs()]
40
- COCO_CLASSES = ["person", "bicycle", "car", "motorbike", "bus", "truck", "traffic light", "stop sign"] # μ˜ˆμ‹œ
41
-
42
- def on_model_upload(weight_file):
43
- global session, input_name, output_names
44
- session, input_name, output_names = load_model(weight_file)
45
- return "λͺ¨λΈμ΄ μ„±κ³΅μ μœΌλ‘œ λ‘œλ“œλ˜μ—ˆμŠ΅λ‹ˆλ‹€."
46
-
47
- def load_model(weight_file):
48
- ext = os.path.splitext(weight_file.name)[1]
49
- if ext == ".pt":
50
- pt_path = weight_file.name
51
- onnx_path = "yolov5n.onnx"
52
- convert_pt_to_onnx(pt_path, onnx_path)
53
- elif ext == ".onnx":
54
- onnx_path = weight_file.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  else:
56
- raise ValueError("μ§€μ›ν•˜μ§€ μ•ŠλŠ” 파일 ν˜•μ‹μž…λ‹ˆλ‹€. .pt λ˜λŠ” .onnx만 ν—ˆμš©λ©λ‹ˆλ‹€.")
 
57
 
58
- session = ort.InferenceSession(onnx_path)
59
- input_name = session.get_inputs()[0].name
60
- output_names = [output.name for output in session.get_outputs()]
61
- return session, input_name, output_names
 
 
 
 
62
 
63
- # --- 객체 탐지 ν•¨μˆ˜ ---
64
- def detect_live(frame, conf_thresh=0.3):
65
- if frame is None:
66
- return None
67
 
68
- h0, w0 = frame.shape[:2]
69
-
70
- img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
71
- img = cv2.resize(img, (640, 640))
72
- img = img.astype(np.float32) / 255.0
73
- img = img.transpose(2, 0, 1)
74
- img = np.expand_dims(img, axis=0)
75
-
76
- outputs = session.run(output_names, {input_name: img})
77
- preds = outputs[0].squeeze()
78
-
79
- boxes = []
80
- for det in preds:
81
- conf = det[4]
82
- if conf < conf_thresh:
83
- continue
84
- scores = det[5:]
85
- class_id = np.argmax(scores)
86
- score = scores[class_id]
87
- if score * conf < conf_thresh:
88
- continue
89
- cx, cy, w, h = det[:4]
90
- x1 = int((cx - w/2) * w0 / 640)
91
- y1 = int((cy - h/2) * h0 / 640)
92
- x2 = int((cx + w/2) * w0 / 640)
93
- y2 = int((cy + h/2) * h0 / 640)
94
- boxes.append((x1, y1, x2, y2, conf*score, class_id))
95
-
96
- for x1, y1, x2, y2, conf, class_id in boxes:
97
- label = f"{COCO_CLASSES[class_id]} {conf:.2f}"
98
- cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
99
- cv2.putText(frame, label, (x1, y1 - 10),
100
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
101
-
102
- return frame
103
-
104
- rtc_configuration = {
105
- "iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]
106
- }
107
 
108
  # --- Gradio UI ---
109
  with gr.Blocks() as demo:
110
- gr.HTML("<h1>YOLOv5n μ‹€μ‹œκ°„ 탐지 (ONNX or PT μ—…λ‘œλ“œ κ°€λŠ₯)</h1>")
111
-
112
- model_file = gr.File(label="YOLO λͺ¨λΈ μ—…λ‘œλ“œ (.pt λ˜λŠ” .onnx)", file_types=[".pt", ".onnx"])
113
- model_status = gr.Textbox(label="λͺ¨λΈ μƒνƒœ", interactive=False)
114
-
115
- with gr.Row():
116
- webrtc_stream = WebRTC(
117
- label="μ›ΉμΊ  슀트림",
118
- mode="sendrecv",
119
- video_frame_callback=lambda frame: detect(frame, conf_slider.value),
120
- rtc_configuration=rtc_configuration,
121
- media_stream_constraints={"video": True, "audio": False},
122
- key="webcam"
123
  )
124
- conf_slider = gr.Slider(label="Confidence", minimum=0.1, maximum=1.0, value=0.3)
125
-
126
- output_img = gr.Image(label="탐지 κ²°κ³Ό")
127
-
128
- model_file.change(fn=on_model_upload, inputs=model_file, outputs=model_status)
129
-
130
- gr.Live(fn=detect_live, inputs=[webcam, conf_slider], outputs=output_img)
131
-
132
 
133
- # --- μ‹€ν–‰ ---
134
  if __name__ == "__main__":
135
- demo.launch(share=True)
 
1
+ import os
2
  import cv2
 
 
3
  import torch
4
+ import onnx
5
+ import numpy as np
6
+ import gradio as gr
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ # --- YOLOv5n ONNX μΆ”λ‘  클래슀 (κ°„λ‹¨νžˆ κ΅¬ν˜„) ---
10
+ import onnxruntime
11
+
12
+ class YOLOv5nONNX:
13
+ def __init__(self, onnx_path):
14
+ self.session = onnxruntime.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
15
+ input_shape = self.session.get_inputs()[0].shape # e.g. [1,3,640,640]
16
+ self.input_height, self.input_width = input_shape[2], input_shape[3]
17
+
18
+ def preprocess(self, image):
19
+ # BGR to RGB, resize, normalize, transpose
20
+ img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
21
+ img = cv2.resize(img, (self.input_width, self.input_height))
22
+ img = img.astype(np.float32) / 255.0
23
+ img = np.transpose(img, (2,0,1)) # HWC to CHW
24
+ img = np.expand_dims(img, axis=0)
25
+ return img
26
+
27
+ def postprocess(self, outputs, conf_threshold=0.3):
28
+ # μ—¬κΈ°μ„  κ°„λ‹¨νžˆ confidence 필터링 μ •λ„λ§Œ, 보톡은 NMS ν•„μš”
29
+ preds = outputs[0]
30
+ preds = preds[preds[:, 4] > conf_threshold] # conf 필터링
31
+ return preds
32
+
33
+ def detect_objects(self, image, conf_threshold=0.3):
34
+ input_tensor = self.preprocess(image)
35
+ outputs = self.session.run(None, {'images': input_tensor})
36
+ preds = self.postprocess(outputs, conf_threshold)
37
+
38
+ # κ°„λ‹¨ν•œ λ°•μŠ€ 그리기 μ˜ˆμ‹œ (μ—¬κΈ°μ„œλŠ” bbox μ’Œν‘œ κ°€μ • ν•„μš”, μ‹€μ œ yolov5 onnx 좜λ ₯ ν˜•νƒœμ— 맞게 μˆ˜μ • ν•„μš”)
39
+ for *box, conf, cls in preds:
40
+ x1, y1, x2, y2 = map(int, box)
41
+ cv2.rectangle(image, (x1,y1), (x2,y2), (0,255,0), 2)
42
+ cv2.putText(image, f'{int(cls)}:{conf:.2f}', (x1,y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
43
+ return image
44
+
45
+ # --- PT β†’ ONNX λ³€ν™˜ ν•¨μˆ˜ ---
46
+ def convert_pt_to_onnx(pt_path, onnx_path, input_size=(640, 640)):
47
+ model = torch.hub.load('ultralytics/yolov5', 'custom', path=pt_path, source='local', force_reload=True)
48
+ model.eval()
49
+ dummy_input = torch.randn(1, 3, *input_size)
50
+ torch.onnx.export(
51
+ model.model,
52
+ dummy_input,
53
+ onnx_path,
54
+ input_names=['images'],
55
+ output_names=['output'],
56
+ opset_version=12,
57
+ dynamic_axes={'images': {0: 'batch'}, 'output': {0: 'batch'}},
58
+ )
59
+ onnx_model = onnx.load(onnx_path)
60
+ onnx.checker.check_model(onnx_model)
61
+ print(f"Converted {pt_path} to {onnx_path}")
62
+
63
+ # --- λͺ¨λΈ λ‘œλ”© ν•¨μˆ˜ ---
64
+ model = None
65
+ def load_model(weight_path):
66
+ global model
67
+ ext = os.path.splitext(weight_path)[1].lower()
68
+ if ext == '.pt':
69
+ onnx_path = weight_path.replace('.pt', '.onnx')
70
+ if not os.path.exists(onnx_path):
71
+ convert_pt_to_onnx(weight_path, onnx_path)
72
+ model = YOLOv5nONNX(onnx_path)
73
+ elif ext == '.onnx':
74
+ model = YOLOv5nONNX(weight_path)
75
  else:
76
+ raise ValueError("μ§€μ›ν•˜μ§€ μ•ŠλŠ” λͺ¨λΈ 파일 ν˜•μ‹μž…λ‹ˆλ‹€. .pt λ˜λŠ” .onnx만 κ°€λŠ₯")
77
+ print(f"Model loaded from {weight_path}")
78
 
79
+ # --- Detection ν•¨μˆ˜ ---
80
+ def detection(image, weight_file, conf_threshold=0.3):
81
+ global model
82
+ if weight_file is not None:
83
+ if model is None or weight_file.name != getattr(model, 'weight_path', None):
84
+ # λͺ¨λΈ μƒˆλ‘œ λ‘œλ“œ
85
+ load_model(weight_file.name)
86
+ model.weight_path = weight_file.name
87
 
88
+ if model is None:
89
+ return image
 
 
90
 
91
+ result_img = model.detect_objects(image, conf_threshold)
92
+ return result_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # --- Gradio UI ---
95
  with gr.Blocks() as demo:
96
+ gr.Markdown("# YOLOv5n ONNX Detection with Weight Upload")
97
+
98
+ weight_file = gr.File(label="Upload YOLOv5n weights (.pt or .onnx)", file_types=['.pt', '.onnx'])
99
+ conf_threshold = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Confidence Threshold")
100
+ input_image = gr.Image(source="webcam", streaming=True)
101
+ output_image = gr.Image()
102
+
103
+ input_image.stream(
104
+ fn=detection,
105
+ inputs=[input_image, weight_file, conf_threshold],
106
+ outputs=output_image,
107
+ every=0.1,
 
108
  )
 
 
 
 
 
 
 
 
109
 
 
110
  if __name__ == "__main__":
111
+ demo.launch()