File size: 4,122 Bytes
045db4b
d957ec2
c2bf829
045db4b
 
 
 
32d997a
045db4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6adecd
045db4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2bf829
045db4b
 
c2bf829
045db4b
 
 
 
 
 
 
 
c2bf829
045db4b
 
2de9bec
045db4b
 
d667def
e81c2c5
 
f028118
045db4b
 
 
 
32d997a
f028118
 
 
 
045db4b
f028118
045db4b
d4bebfb
c8907bd
 
045db4b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import cv2
import torch
import onnx
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
from gradio_webrtc import WebRTC
# --- YOLOv5n ONNX μΆ”λ‘  클래슀 (κ°„λ‹¨νžˆ κ΅¬ν˜„) ---
import onnxruntime

class YOLOv5nONNX:
    def __init__(self, onnx_path):
        self.session = onnxruntime.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
        input_shape = self.session.get_inputs()[0].shape  # e.g. [1,3,640,640]
        self.input_height, self.input_width = input_shape[2], input_shape[3]
    
    def preprocess(self, image):
        # BGR to RGB, resize, normalize, transpose
        img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.input_width, self.input_height))
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, (2,0,1))  # HWC to CHW
        img = np.expand_dims(img, axis=0)
        return img
    
    def postprocess(self, outputs, conf_threshold=0.3):
        # μ—¬κΈ°μ„  κ°„λ‹¨νžˆ confidence 필터링 μ •λ„λ§Œ, 보톡은 NMS ν•„μš”
        preds = outputs[0]
        preds = preds[preds[:, 4] > conf_threshold]  # conf 필터링
        return preds
    
    def detect_objects(self, image, conf_threshold=0.3):
        input_tensor = self.preprocess(image)
        outputs = self.session.run(None, {'images': input_tensor})
        preds = self.postprocess(outputs, conf_threshold)

        # κ°„λ‹¨ν•œ λ°•μŠ€ 그리기 μ˜ˆμ‹œ (μ—¬κΈ°μ„œλŠ” bbox μ’Œν‘œ κ°€μ • ν•„μš”, μ‹€μ œ yolov5 onnx 좜λ ₯ ν˜•νƒœμ— 맞게 μˆ˜μ • ν•„μš”)
        for *box, conf, cls in preds:
            x1, y1, x2, y2 = map(int, box)
            cv2.rectangle(image, (x1,y1), (x2,y2), (0,255,0), 2)
            cv2.putText(image, f'{int(cls)}:{conf:.2f}', (x1,y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
        return image

# --- PT β†’ ONNX λ³€ν™˜ ν•¨μˆ˜ ---
def convert_pt_to_onnx(pt_path, onnx_path, input_size=(640, 640)):
    model = torch.hub.load('ultralytics/yolov5', 'custom', path=pt_path)
    model.eval()
    dummy_input = torch.randn(1, 3, *input_size)
    torch.onnx.export(
        model.model,
        dummy_input,
        onnx_path,
        input_names=['images'],
        output_names=['output'],
        opset_version=12,
        dynamic_axes={'images': {0: 'batch'}, 'output': {0: 'batch'}},
    )
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print(f"Converted {pt_path} to {onnx_path}")

# --- λͺ¨λΈ λ‘œλ”© ν•¨μˆ˜ ---
model = None
def load_model(weight_path):
    global model
    ext = os.path.splitext(weight_path)[1].lower()
    if ext == '.pt':
        onnx_path = weight_path.replace('.pt', '.onnx')
        if not os.path.exists(onnx_path):
            convert_pt_to_onnx(weight_path, onnx_path)
        model = YOLOv5nONNX(onnx_path)
    elif ext == '.onnx':
        model = YOLOv5nONNX(weight_path)
    else:
        raise ValueError("μ§€μ›ν•˜μ§€ μ•ŠλŠ” λͺ¨λΈ 파일 ν˜•μ‹μž…λ‹ˆλ‹€. .pt λ˜λŠ” .onnx만 κ°€λŠ₯")
    print(f"Model loaded from {weight_path}")

# --- Detection ν•¨μˆ˜ ---
def detection(image, weight_file, conf_threshold=0.3):
    global model
    if weight_file is not None:
        if model is None or weight_file.name != getattr(model, 'weight_path', None):
            # λͺ¨λΈ μƒˆλ‘œ λ‘œλ“œ
            load_model(weight_file.name)
            model.weight_path = weight_file.name

    if model is None:
        return image

    result_img = model.detect_objects(image, conf_threshold)
    return result_img

# --- Gradio UI ---
with gr.Blocks() as demo:
    gr.Markdown("# YOLOv5n ONNX Detection with Weight Upload (WebRTC)")

    weight_file = gr.File(label="Upload YOLOv5n weights (.pt or .onnx)", file_types=['.pt', '.onnx'])
    conf_threshold = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Confidence Threshold")

    rtc_stream = WebRTC(label="WebRTC Webcam")

    output_image = gr.Image(label="Detection Output")

    rtc_stream.stream(
        fn=detection,
        inputs=[rtc_stream, weight_file, conf_threshold],
        outputs=output_image,
    )

if __name__ == "__main__":
    demo.launch()