vtsr / app.py
liamxdev's picture
Upload folder using huggingface_hub
6e9374a verified
Raw
History Blame Contribute Delete
6.3 kB
import json
import os
from collections import Counter
from functools import lru_cache
from time import perf_counter
os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
from PIL import Image
from ultralytics import YOLO
MODEL_REPO_ID = "liamxdev/vtsr"
MODEL_FILENAME = "vtsr_int8.onnx"
LABEL_MAPPING_FILENAME = "label-mapping.json"
MODEL_DISPLAY_NAME = "YOLOv8n · ONNX INT8 · CPU"
@lru_cache(maxsize=1)
def load_label_mapping() -> dict[str, str]:
mapping_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=LABEL_MAPPING_FILENAME,
)
with open(mapping_path, encoding="utf-8") as mapping_file:
return json.load(mapping_file)
@lru_cache(maxsize=1)
def load_model() -> YOLO:
model_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=MODEL_FILENAME,
)
return YOLO(model_path, task="detect")
def detect(
image: Image.Image | np.ndarray | None,
confidence: float,
iou: float,
) -> tuple[Image.Image | None, list[list], list[list], dict]:
if image is None:
raise gr.Error("Vui lòng tải lên một ảnh.")
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGB")
model = load_model()
label_mapping = load_label_mapping()
started_at = perf_counter()
result = model.predict(
source=image,
imgsz=640,
conf=confidence,
iou=iou,
verbose=False,
)[0]
total_ms = (perf_counter() - started_at) * 1000
detections = []
if result.boxes is not None:
for class_id, score, xyxy in zip(
result.boxes.cls.tolist(),
result.boxes.conf.tolist(),
result.boxes.xyxy.tolist(),
):
code = str(result.names[int(class_id)])
detections.append(
{
"code": code,
"meaning": label_mapping.get(code, "Chưa có mô tả"),
"confidence": round(float(score) * 100, 2),
"box": [round(float(value), 1) for value in xyxy],
}
)
annotated_image = Image.fromarray(result.plot()[..., ::-1])
detail_rows = [
[
item["code"],
item["meaning"],
item["confidence"],
*item["box"],
]
for item in detections
]
counts = Counter(item["code"] for item in detections)
meanings = {item["code"]: item["meaning"] for item in detections}
summary_rows = [
[code, meanings[code], count]
for code, count in counts.most_common()
]
speed = result.speed or {}
metadata = {
"model": MODEL_DISPLAY_NAME,
"artifact": MODEL_FILENAME,
"image_size": {
"width": image.width,
"height": image.height,
},
"detection_count": len(detections),
"timing_ms": {
"preprocess": round(float(speed.get("preprocess", 0)), 2),
"inference": round(float(speed.get("inference", 0)), 2),
"postprocess": round(float(speed.get("postprocess", 0)), 2),
"total": round(total_ms, 2),
},
}
return annotated_image, summary_rows, detail_rows, metadata
with gr.Blocks(title="VTSR INT8 ONNX") as demo:
gr.Markdown(
"""
# 🚦 Nhận diện biển báo giao thông Việt Nam
Demo YOLOv8n ONNX INT8 nhận diện 56 loại biển báo giao thông Việt Nam.
Model trả về mã biển báo, ý nghĩa tiếng Việt, độ tin cậy và bounding
box.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
type="pil",
label="Ảnh đầu vào",
sources=["upload", "clipboard", "webcam"],
)
gr.Markdown("**Model:** YOLOv8n · ONNX INT8 · CPU")
confidence_slider = gr.Slider(
minimum=0.05,
maximum=0.95,
value=0.25,
step=0.05,
label="Ngưỡng tin cậy",
)
iou_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.7,
step=0.05,
label="Ngưỡng IoU",
)
detect_button = gr.Button("Nhận diện", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(
type="pil",
label="Kết quả nhận diện",
)
metadata_output = gr.JSON(label="Thông tin xử lý")
gr.Markdown("## Thống kê theo loại biển báo")
summary_output = gr.Dataframe(
headers=["Mã biển báo", "Ý nghĩa", "Số lượng"],
datatype=["str", "str", "number"],
interactive=False,
)
gr.Markdown("## Chi tiết detection")
detail_output = gr.Dataframe(
headers=[
"Mã biển báo",
"Ý nghĩa",
"Độ tin cậy (%)",
"x1",
"y1",
"x2",
"y2",
],
datatype=[
"str",
"str",
"number",
"number",
"number",
"number",
"number",
],
interactive=False,
)
detect_button.click(
fn=detect,
inputs=[
input_image,
confidence_slider,
iou_slider,
],
outputs=[
output_image,
summary_output,
detail_output,
metadata_output,
],
)
gr.Markdown(
"""
**Lưu ý:** Demo chỉ mang tính giáo dục và nghiên cứu. Kết quả không
thay thế việc diễn giải biển báo chính thức hoặc quyết định an toàn
khi tham gia giao thông.
Model: [liamxdev/vtsr](https://huggingface.co/liamxdev/vtsr)
"""
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=2).launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", "7860")),
show_error=True,
ssr_mode=False,
)