import json import os from collections import Counter from functools import lru_cache from time import perf_counter os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") import gradio as gr import numpy as np from huggingface_hub import hf_hub_download from PIL import Image from ultralytics import YOLO MODEL_REPO_ID = "liamxdev/vtsr" MODEL_FILENAME = "vtsr_int8.onnx" LABEL_MAPPING_FILENAME = "label-mapping.json" MODEL_DISPLAY_NAME = "YOLOv8n · ONNX INT8 · CPU" @lru_cache(maxsize=1) def load_label_mapping() -> dict[str, str]: mapping_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=LABEL_MAPPING_FILENAME, ) with open(mapping_path, encoding="utf-8") as mapping_file: return json.load(mapping_file) @lru_cache(maxsize=1) def load_model() -> YOLO: model_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME, ) return YOLO(model_path, task="detect") def detect( image: Image.Image | np.ndarray | None, confidence: float, iou: float, ) -> tuple[Image.Image | None, list[list], list[list], dict]: if image is None: raise gr.Error("Vui lòng tải lên một ảnh.") if isinstance(image, np.ndarray): image = Image.fromarray(image) image = image.convert("RGB") model = load_model() label_mapping = load_label_mapping() started_at = perf_counter() result = model.predict( source=image, imgsz=640, conf=confidence, iou=iou, verbose=False, )[0] total_ms = (perf_counter() - started_at) * 1000 detections = [] if result.boxes is not None: for class_id, score, xyxy in zip( result.boxes.cls.tolist(), result.boxes.conf.tolist(), result.boxes.xyxy.tolist(), ): code = str(result.names[int(class_id)]) detections.append( { "code": code, "meaning": label_mapping.get(code, "Chưa có mô tả"), "confidence": round(float(score) * 100, 2), "box": [round(float(value), 1) for value in xyxy], } ) annotated_image = Image.fromarray(result.plot()[..., ::-1]) detail_rows = [ [ item["code"], item["meaning"], item["confidence"], *item["box"], ] for item in detections ] counts = Counter(item["code"] for item in detections) meanings = {item["code"]: item["meaning"] for item in detections} summary_rows = [ [code, meanings[code], count] for code, count in counts.most_common() ] speed = result.speed or {} metadata = { "model": MODEL_DISPLAY_NAME, "artifact": MODEL_FILENAME, "image_size": { "width": image.width, "height": image.height, }, "detection_count": len(detections), "timing_ms": { "preprocess": round(float(speed.get("preprocess", 0)), 2), "inference": round(float(speed.get("inference", 0)), 2), "postprocess": round(float(speed.get("postprocess", 0)), 2), "total": round(total_ms, 2), }, } return annotated_image, summary_rows, detail_rows, metadata with gr.Blocks(title="VTSR INT8 ONNX") as demo: gr.Markdown( """ # 🚦 Nhận diện biển báo giao thông Việt Nam Demo YOLOv8n ONNX INT8 nhận diện 56 loại biển báo giao thông Việt Nam. Model trả về mã biển báo, ý nghĩa tiếng Việt, độ tin cậy và bounding box. """ ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( type="pil", label="Ảnh đầu vào", sources=["upload", "clipboard", "webcam"], ) gr.Markdown("**Model:** YOLOv8n · ONNX INT8 · CPU") confidence_slider = gr.Slider( minimum=0.05, maximum=0.95, value=0.25, step=0.05, label="Ngưỡng tin cậy", ) iou_slider = gr.Slider( minimum=0.1, maximum=0.9, value=0.7, step=0.05, label="Ngưỡng IoU", ) detect_button = gr.Button("Nhận diện", variant="primary") with gr.Column(scale=1): output_image = gr.Image( type="pil", label="Kết quả nhận diện", ) metadata_output = gr.JSON(label="Thông tin xử lý") gr.Markdown("## Thống kê theo loại biển báo") summary_output = gr.Dataframe( headers=["Mã biển báo", "Ý nghĩa", "Số lượng"], datatype=["str", "str", "number"], interactive=False, ) gr.Markdown("## Chi tiết detection") detail_output = gr.Dataframe( headers=[ "Mã biển báo", "Ý nghĩa", "Độ tin cậy (%)", "x1", "y1", "x2", "y2", ], datatype=[ "str", "str", "number", "number", "number", "number", "number", ], interactive=False, ) detect_button.click( fn=detect, inputs=[ input_image, confidence_slider, iou_slider, ], outputs=[ output_image, summary_output, detail_output, metadata_output, ], ) gr.Markdown( """ **Lưu ý:** Demo chỉ mang tính giáo dục và nghiên cứu. Kết quả không thay thế việc diễn giải biển báo chính thức hoặc quyết định an toàn khi tham gia giao thông. Model: [liamxdev/vtsr](https://huggingface.co/liamxdev/vtsr) """ ) if __name__ == "__main__": demo.queue(default_concurrency_limit=2).launch( server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), show_error=True, ssr_mode=False, )