Spaces:
Sleeping
Sleeping
maiz
Browse files
app.py
CHANGED
|
@@ -1,135 +1,111 @@
|
|
| 1 |
-
import
|
| 2 |
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
-
import onnxruntime as ort
|
| 5 |
import torch
|
| 6 |
-
import
|
| 7 |
-
import
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
else:
|
| 56 |
-
raise ValueError("μ§μνμ§ μλ νμΌ νμμ
λλ€. .pt λλ .onnxλ§
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
if frame is None:
|
| 66 |
-
return None
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 71 |
-
img = cv2.resize(img, (640, 640))
|
| 72 |
-
img = img.astype(np.float32) / 255.0
|
| 73 |
-
img = img.transpose(2, 0, 1)
|
| 74 |
-
img = np.expand_dims(img, axis=0)
|
| 75 |
-
|
| 76 |
-
outputs = session.run(output_names, {input_name: img})
|
| 77 |
-
preds = outputs[0].squeeze()
|
| 78 |
-
|
| 79 |
-
boxes = []
|
| 80 |
-
for det in preds:
|
| 81 |
-
conf = det[4]
|
| 82 |
-
if conf < conf_thresh:
|
| 83 |
-
continue
|
| 84 |
-
scores = det[5:]
|
| 85 |
-
class_id = np.argmax(scores)
|
| 86 |
-
score = scores[class_id]
|
| 87 |
-
if score * conf < conf_thresh:
|
| 88 |
-
continue
|
| 89 |
-
cx, cy, w, h = det[:4]
|
| 90 |
-
x1 = int((cx - w/2) * w0 / 640)
|
| 91 |
-
y1 = int((cy - h/2) * h0 / 640)
|
| 92 |
-
x2 = int((cx + w/2) * w0 / 640)
|
| 93 |
-
y2 = int((cy + h/2) * h0 / 640)
|
| 94 |
-
boxes.append((x1, y1, x2, y2, conf*score, class_id))
|
| 95 |
-
|
| 96 |
-
for x1, y1, x2, y2, conf, class_id in boxes:
|
| 97 |
-
label = f"{COCO_CLASSES[class_id]} {conf:.2f}"
|
| 98 |
-
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 99 |
-
cv2.putText(frame, label, (x1, y1 - 10),
|
| 100 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 101 |
-
|
| 102 |
-
return frame
|
| 103 |
-
|
| 104 |
-
rtc_configuration = {
|
| 105 |
-
"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]
|
| 106 |
-
}
|
| 107 |
|
| 108 |
# --- Gradio UI ---
|
| 109 |
with gr.Blocks() as demo:
|
| 110 |
-
gr.
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
key="webcam"
|
| 123 |
)
|
| 124 |
-
conf_slider = gr.Slider(label="Confidence", minimum=0.1, maximum=1.0, value=0.3)
|
| 125 |
-
|
| 126 |
-
output_img = gr.Image(label="νμ§ κ²°κ³Ό")
|
| 127 |
-
|
| 128 |
-
model_file.change(fn=on_model_upload, inputs=model_file, outputs=model_status)
|
| 129 |
-
|
| 130 |
-
gr.Live(fn=detect_live, inputs=[webcam, conf_slider], outputs=output_img)
|
| 131 |
-
|
| 132 |
|
| 133 |
-
# --- μ€ν ---
|
| 134 |
if __name__ == "__main__":
|
| 135 |
-
demo.launch(
|
|
|
|
| 1 |
+
import os
|
| 2 |
import cv2
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
+
import onnx
|
| 5 |
+
import numpy as np
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
|
| 9 |
+
# --- YOLOv5n ONNX μΆλ‘ ν΄λμ€ (κ°λ¨ν ꡬν) ---
|
| 10 |
+
import onnxruntime
|
| 11 |
+
|
| 12 |
+
class YOLOv5nONNX:
|
| 13 |
+
def __init__(self, onnx_path):
|
| 14 |
+
self.session = onnxruntime.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
|
| 15 |
+
input_shape = self.session.get_inputs()[0].shape # e.g. [1,3,640,640]
|
| 16 |
+
self.input_height, self.input_width = input_shape[2], input_shape[3]
|
| 17 |
+
|
| 18 |
+
def preprocess(self, image):
|
| 19 |
+
# BGR to RGB, resize, normalize, transpose
|
| 20 |
+
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 21 |
+
img = cv2.resize(img, (self.input_width, self.input_height))
|
| 22 |
+
img = img.astype(np.float32) / 255.0
|
| 23 |
+
img = np.transpose(img, (2,0,1)) # HWC to CHW
|
| 24 |
+
img = np.expand_dims(img, axis=0)
|
| 25 |
+
return img
|
| 26 |
+
|
| 27 |
+
def postprocess(self, outputs, conf_threshold=0.3):
|
| 28 |
+
# μ¬κΈ°μ κ°λ¨ν confidence νν°λ§ μ λλ§, 보ν΅μ NMS νμ
|
| 29 |
+
preds = outputs[0]
|
| 30 |
+
preds = preds[preds[:, 4] > conf_threshold] # conf νν°λ§
|
| 31 |
+
return preds
|
| 32 |
+
|
| 33 |
+
def detect_objects(self, image, conf_threshold=0.3):
|
| 34 |
+
input_tensor = self.preprocess(image)
|
| 35 |
+
outputs = self.session.run(None, {'images': input_tensor})
|
| 36 |
+
preds = self.postprocess(outputs, conf_threshold)
|
| 37 |
+
|
| 38 |
+
# κ°λ¨ν λ°μ€ 그리기 μμ (μ¬κΈ°μλ bbox μ’ν κ°μ νμ, μ€μ yolov5 onnx μΆλ ₯ ννμ λ§κ² μμ νμ)
|
| 39 |
+
for *box, conf, cls in preds:
|
| 40 |
+
x1, y1, x2, y2 = map(int, box)
|
| 41 |
+
cv2.rectangle(image, (x1,y1), (x2,y2), (0,255,0), 2)
|
| 42 |
+
cv2.putText(image, f'{int(cls)}:{conf:.2f}', (x1,y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
|
| 43 |
+
return image
|
| 44 |
+
|
| 45 |
+
# --- PT β ONNX λ³ν ν¨μ ---
|
| 46 |
+
def convert_pt_to_onnx(pt_path, onnx_path, input_size=(640, 640)):
|
| 47 |
+
model = torch.hub.load('ultralytics/yolov5', 'custom', path=pt_path, source='local', force_reload=True)
|
| 48 |
+
model.eval()
|
| 49 |
+
dummy_input = torch.randn(1, 3, *input_size)
|
| 50 |
+
torch.onnx.export(
|
| 51 |
+
model.model,
|
| 52 |
+
dummy_input,
|
| 53 |
+
onnx_path,
|
| 54 |
+
input_names=['images'],
|
| 55 |
+
output_names=['output'],
|
| 56 |
+
opset_version=12,
|
| 57 |
+
dynamic_axes={'images': {0: 'batch'}, 'output': {0: 'batch'}},
|
| 58 |
+
)
|
| 59 |
+
onnx_model = onnx.load(onnx_path)
|
| 60 |
+
onnx.checker.check_model(onnx_model)
|
| 61 |
+
print(f"Converted {pt_path} to {onnx_path}")
|
| 62 |
+
|
| 63 |
+
# --- λͺ¨λΈ λ‘λ© ν¨μ ---
|
| 64 |
+
model = None
|
| 65 |
+
def load_model(weight_path):
|
| 66 |
+
global model
|
| 67 |
+
ext = os.path.splitext(weight_path)[1].lower()
|
| 68 |
+
if ext == '.pt':
|
| 69 |
+
onnx_path = weight_path.replace('.pt', '.onnx')
|
| 70 |
+
if not os.path.exists(onnx_path):
|
| 71 |
+
convert_pt_to_onnx(weight_path, onnx_path)
|
| 72 |
+
model = YOLOv5nONNX(onnx_path)
|
| 73 |
+
elif ext == '.onnx':
|
| 74 |
+
model = YOLOv5nONNX(weight_path)
|
| 75 |
else:
|
| 76 |
+
raise ValueError("μ§μνμ§ μλ λͺ¨λΈ νμΌ νμμ
λλ€. .pt λλ .onnxλ§ κ°λ₯")
|
| 77 |
+
print(f"Model loaded from {weight_path}")
|
| 78 |
|
| 79 |
+
# --- Detection ν¨μ ---
|
| 80 |
+
def detection(image, weight_file, conf_threshold=0.3):
|
| 81 |
+
global model
|
| 82 |
+
if weight_file is not None:
|
| 83 |
+
if model is None or weight_file.name != getattr(model, 'weight_path', None):
|
| 84 |
+
# λͺ¨λΈ μλ‘ λ‘λ
|
| 85 |
+
load_model(weight_file.name)
|
| 86 |
+
model.weight_path = weight_file.name
|
| 87 |
|
| 88 |
+
if model is None:
|
| 89 |
+
return image
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
result_img = model.detect_objects(image, conf_threshold)
|
| 92 |
+
return result_img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# --- Gradio UI ---
|
| 95 |
with gr.Blocks() as demo:
|
| 96 |
+
gr.Markdown("# YOLOv5n ONNX Detection with Weight Upload")
|
| 97 |
+
|
| 98 |
+
weight_file = gr.File(label="Upload YOLOv5n weights (.pt or .onnx)", file_types=['.pt', '.onnx'])
|
| 99 |
+
conf_threshold = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Confidence Threshold")
|
| 100 |
+
input_image = gr.Image(source="webcam", streaming=True)
|
| 101 |
+
output_image = gr.Image()
|
| 102 |
+
|
| 103 |
+
input_image.stream(
|
| 104 |
+
fn=detection,
|
| 105 |
+
inputs=[input_image, weight_file, conf_threshold],
|
| 106 |
+
outputs=output_image,
|
| 107 |
+
every=0.1,
|
|
|
|
| 108 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
|
|
|
| 110 |
if __name__ == "__main__":
|
| 111 |
+
demo.launch()
|