| import json |
| import os |
| from collections import Counter |
| from functools import lru_cache |
| from time import perf_counter |
|
|
| os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") |
|
|
| import gradio as gr |
| import numpy as np |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
| from ultralytics import YOLO |
|
|
|
|
| MODEL_REPO_ID = "liamxdev/vtsr" |
| MODEL_FILENAME = "vtsr_int8.onnx" |
| LABEL_MAPPING_FILENAME = "label-mapping.json" |
| MODEL_DISPLAY_NAME = "YOLOv8n · ONNX INT8 · CPU" |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_label_mapping() -> dict[str, str]: |
| mapping_path = hf_hub_download( |
| repo_id=MODEL_REPO_ID, |
| filename=LABEL_MAPPING_FILENAME, |
| ) |
| with open(mapping_path, encoding="utf-8") as mapping_file: |
| return json.load(mapping_file) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_model() -> YOLO: |
| model_path = hf_hub_download( |
| repo_id=MODEL_REPO_ID, |
| filename=MODEL_FILENAME, |
| ) |
| return YOLO(model_path, task="detect") |
|
|
|
|
| def detect( |
| image: Image.Image | np.ndarray | None, |
| confidence: float, |
| iou: float, |
| ) -> tuple[Image.Image | None, list[list], list[list], dict]: |
| if image is None: |
| raise gr.Error("Vui lòng tải lên một ảnh.") |
|
|
| if isinstance(image, np.ndarray): |
| image = Image.fromarray(image) |
| image = image.convert("RGB") |
|
|
| model = load_model() |
| label_mapping = load_label_mapping() |
|
|
| started_at = perf_counter() |
| result = model.predict( |
| source=image, |
| imgsz=640, |
| conf=confidence, |
| iou=iou, |
| verbose=False, |
| )[0] |
| total_ms = (perf_counter() - started_at) * 1000 |
|
|
| detections = [] |
| if result.boxes is not None: |
| for class_id, score, xyxy in zip( |
| result.boxes.cls.tolist(), |
| result.boxes.conf.tolist(), |
| result.boxes.xyxy.tolist(), |
| ): |
| code = str(result.names[int(class_id)]) |
| detections.append( |
| { |
| "code": code, |
| "meaning": label_mapping.get(code, "Chưa có mô tả"), |
| "confidence": round(float(score) * 100, 2), |
| "box": [round(float(value), 1) for value in xyxy], |
| } |
| ) |
|
|
| annotated_image = Image.fromarray(result.plot()[..., ::-1]) |
| detail_rows = [ |
| [ |
| item["code"], |
| item["meaning"], |
| item["confidence"], |
| *item["box"], |
| ] |
| for item in detections |
| ] |
|
|
| counts = Counter(item["code"] for item in detections) |
| meanings = {item["code"]: item["meaning"] for item in detections} |
| summary_rows = [ |
| [code, meanings[code], count] |
| for code, count in counts.most_common() |
| ] |
|
|
| speed = result.speed or {} |
| metadata = { |
| "model": MODEL_DISPLAY_NAME, |
| "artifact": MODEL_FILENAME, |
| "image_size": { |
| "width": image.width, |
| "height": image.height, |
| }, |
| "detection_count": len(detections), |
| "timing_ms": { |
| "preprocess": round(float(speed.get("preprocess", 0)), 2), |
| "inference": round(float(speed.get("inference", 0)), 2), |
| "postprocess": round(float(speed.get("postprocess", 0)), 2), |
| "total": round(total_ms, 2), |
| }, |
| } |
|
|
| return annotated_image, summary_rows, detail_rows, metadata |
|
|
|
|
| with gr.Blocks(title="VTSR INT8 ONNX") as demo: |
| gr.Markdown( |
| """ |
| # 🚦 Nhận diện biển báo giao thông Việt Nam |
| |
| Demo YOLOv8n ONNX INT8 nhận diện 56 loại biển báo giao thông Việt Nam. |
| Model trả về mã biển báo, ý nghĩa tiếng Việt, độ tin cậy và bounding |
| box. |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_image = gr.Image( |
| type="pil", |
| label="Ảnh đầu vào", |
| sources=["upload", "clipboard", "webcam"], |
| ) |
| gr.Markdown("**Model:** YOLOv8n · ONNX INT8 · CPU") |
| confidence_slider = gr.Slider( |
| minimum=0.05, |
| maximum=0.95, |
| value=0.25, |
| step=0.05, |
| label="Ngưỡng tin cậy", |
| ) |
| iou_slider = gr.Slider( |
| minimum=0.1, |
| maximum=0.9, |
| value=0.7, |
| step=0.05, |
| label="Ngưỡng IoU", |
| ) |
| detect_button = gr.Button("Nhận diện", variant="primary") |
|
|
| with gr.Column(scale=1): |
| output_image = gr.Image( |
| type="pil", |
| label="Kết quả nhận diện", |
| ) |
| metadata_output = gr.JSON(label="Thông tin xử lý") |
|
|
| gr.Markdown("## Thống kê theo loại biển báo") |
| summary_output = gr.Dataframe( |
| headers=["Mã biển báo", "Ý nghĩa", "Số lượng"], |
| datatype=["str", "str", "number"], |
| interactive=False, |
| ) |
|
|
| gr.Markdown("## Chi tiết detection") |
| detail_output = gr.Dataframe( |
| headers=[ |
| "Mã biển báo", |
| "Ý nghĩa", |
| "Độ tin cậy (%)", |
| "x1", |
| "y1", |
| "x2", |
| "y2", |
| ], |
| datatype=[ |
| "str", |
| "str", |
| "number", |
| "number", |
| "number", |
| "number", |
| "number", |
| ], |
| interactive=False, |
| ) |
|
|
| detect_button.click( |
| fn=detect, |
| inputs=[ |
| input_image, |
| confidence_slider, |
| iou_slider, |
| ], |
| outputs=[ |
| output_image, |
| summary_output, |
| detail_output, |
| metadata_output, |
| ], |
| ) |
|
|
| gr.Markdown( |
| """ |
| **Lưu ý:** Demo chỉ mang tính giáo dục và nghiên cứu. Kết quả không |
| thay thế việc diễn giải biển báo chính thức hoặc quyết định an toàn |
| khi tham gia giao thông. |
| |
| Model: [liamxdev/vtsr](https://huggingface.co/liamxdev/vtsr) |
| """ |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue(default_concurrency_limit=2).launch( |
| server_name="0.0.0.0", |
| server_port=int(os.getenv("PORT", "7860")), |
| show_error=True, |
| ssr_mode=False, |
| ) |
|
|