import os import cv2 import torch import onnx import numpy as np import gradio as gr from huggingface_hub import hf_hub_download from gradio_webrtc import WebRTC # --- YOLOv5n ONNX 추론 클래스 (간단히 구현) --- import onnxruntime class YOLOv5nONNX: def __init__(self, onnx_path): self.session = onnxruntime.InferenceSession(onnx_path, providers=['CPUExecutionProvider']) input_shape = self.session.get_inputs()[0].shape # e.g. [1,3,640,640] self.input_height, self.input_width = input_shape[2], input_shape[3] def preprocess(self, image): # BGR to RGB, resize, normalize, transpose img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (self.input_width, self.input_height)) img = img.astype(np.float32) / 255.0 img = np.transpose(img, (2,0,1)) # HWC to CHW img = np.expand_dims(img, axis=0) return img def postprocess(self, outputs, conf_threshold=0.3): # 여기선 간단히 confidence 필터링 정도만, 보통은 NMS 필요 preds = outputs[0] preds = preds[preds[:, 4] > conf_threshold] # conf 필터링 return preds def detect_objects(self, image, conf_threshold=0.3): input_tensor = self.preprocess(image) outputs = self.session.run(None, {'images': input_tensor}) preds = self.postprocess(outputs, conf_threshold) # 간단한 박스 그리기 예시 (여기서는 bbox 좌표 가정 필요, 실제 yolov5 onnx 출력 형태에 맞게 수정 필요) for *box, conf, cls in preds: x1, y1, x2, y2 = map(int, box) cv2.rectangle(image, (x1,y1), (x2,y2), (0,255,0), 2) cv2.putText(image, f'{int(cls)}:{conf:.2f}', (x1,y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2) return image # --- PT → ONNX 변환 함수 --- def convert_pt_to_onnx(pt_path, onnx_path, input_size=(640, 640)): model = torch.hub.load('ultralytics/yolov5', 'custom', path=pt_path) model.eval() dummy_input = torch.randn(1, 3, *input_size) torch.onnx.export( model.model, dummy_input, onnx_path, input_names=['images'], output_names=['output'], opset_version=12, dynamic_axes={'images': {0: 'batch'}, 'output': {0: 'batch'}}, ) onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) print(f"Converted {pt_path} to {onnx_path}") # --- 모델 로딩 함수 --- model = None def load_model(weight_path): global model ext = os.path.splitext(weight_path)[1].lower() if ext == '.pt': onnx_path = weight_path.replace('.pt', '.onnx') if not os.path.exists(onnx_path): convert_pt_to_onnx(weight_path, onnx_path) model = YOLOv5nONNX(onnx_path) elif ext == '.onnx': model = YOLOv5nONNX(weight_path) else: raise ValueError("지원하지 않는 모델 파일 형식입니다. .pt 또는 .onnx만 가능") print(f"Model loaded from {weight_path}") # --- Detection 함수 --- def detection(image, weight_file, conf_threshold=0.3): global model if weight_file is not None: if model is None or weight_file.name != getattr(model, 'weight_path', None): # 모델 새로 로드 load_model(weight_file.name) model.weight_path = weight_file.name if model is None: return image result_img = model.detect_objects(image, conf_threshold) return result_img # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("# YOLOv5n ONNX Detection with Weight Upload (WebRTC)") weight_file = gr.File(label="Upload YOLOv5n weights (.pt or .onnx)", file_types=['.pt', '.onnx']) conf_threshold = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Confidence Threshold") rtc_stream = WebRTC(label="WebRTC Webcam") output_image = gr.Image(label="Detection Output") rtc_stream.stream( fn=detection, inputs=[rtc_stream, weight_file, conf_threshold], outputs=output_image, ) if __name__ == "__main__": demo.launch()