File size: 4,473 Bytes
1813932
 
4fef2dd
 
 
 
a4853dc
1813932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4853dc
1813932
a4853dc
 
 
 
 
1813932
a4853dc
 
 
 
1813932
a4853dc
 
 
 
 
1813932
 
 
 
a4853dc
1813932
a4853dc
1813932
a4853dc
1813932
a4853dc
 
 
1813932
a4853dc
 
 
1813932
 
 
 
a4853dc
4fef2dd
a4853dc
 
1813932
a4853dc
 
1813932
a4853dc
 
 
 
 
 
 
1813932
a4853dc
 
 
 
1813932
a4853dc
 
 
 
 
 
 
1813932
a4853dc
 
 
1813932
 
a4853dc
1813932
 
 
 
 
 
a4853dc
4fef2dd
1813932
 
 
a4853dc
1813932
 
a4853dc
dcf3b1f
1813932
 
 
 
 
 
 
 
 
 
 
 
a4853dc
 
1813932
 
 
4fef2dd
 
a4853dc
1813932
 
 
 
 
a4853dc
4fef2dd
a4853dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# app.py
import os, sys, json, tempfile
import gradio as gr
import cv2
import numpy as np
from PIL import Image

# ── Auto-download detection weights từ HuggingFace Hub ──────
CHECKPOINT = "best.pt"
HF_REPO    = "phamha/drawing-model-weights"  # ← sửa thành username của bạn

def ensure_weights():
    if not os.path.exists(CHECKPOINT):
        print("[INFO] Downloading model weights...")
        from huggingface_hub import hf_hub_download
        hf_hub_download(
            repo_id=HF_REPO,
            filename="best.pt",
            local_dir=".",
            local_dir_use_symlinks=False,
        )
        print("[INFO] Weights ready.")

ensure_weights()

sys.path.insert(0, ".")
from src.inference import run_pipeline


# ── Gradio handler ───────────────────────────────────────────
def process(image: Image.Image):
    if image is None:
        return None, "{}", "Chưa có ảnh."

    tmp_dir  = tempfile.mkdtemp()
    tmp_path = os.path.join(tmp_dir, "input.jpg")
    image.save(tmp_path, quality=95)

    try:
        result, vis_path = run_pipeline(
            image_path  = tmp_path,
            output_dir  = tmp_dir,
            checkpoint  = CHECKPOINT,
            conf_thresh = 0.3,
        )
    except Exception as e:
        import traceback
        return None, "{}", f"Lỗi:\n{traceback.format_exc()}"

    # Ảnh visualize
    vis_bgr = cv2.imread(vis_path)
    vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB)

    # JSON sạch
    clean_objs = []
    for obj in result["objects"]:
        clean_objs.append({
            "id":          obj["id"],
            "class":       obj["class"],
            "confidence":  obj["confidence"],
            "bbox":        obj["bbox"],
            "ocr_content": obj["ocr_content"],
        })
    json_str = json.dumps(
        {"image": result["image"], "objects": clean_objs},
        ensure_ascii=False, indent=2,
    )

    # OCR panel
    ocr_parts = []
    for obj in result["objects"]:
        content = obj.get("ocr_content")
        if not content:
            continue
        if isinstance(content, dict):
            content = content.get("text", "")
        if not str(content).strip():
            continue
        sep = "─" * 46
        ocr_parts.append(
            f"{sep}\n"
            f"[{obj['class']} #{obj['id']}]  conf={obj['confidence']}\n"
            f"{sep}\n{content}"
        )
    ocr_text = "\n\n".join(ocr_parts) or "Không phát hiện Note / Table."

    return vis_rgb, json_str, ocr_text


# ── UI ───────────────────────────────────────────────────────
with gr.Blocks(title="Engineering Drawing Analyzer", theme=gr.themes.Soft()) as demo:

    gr.Markdown("""
    # 🔧 Engineering Drawing Analyzer
    Tự động phát hiện và trích xuất văn bản từ bản vẽ kỹ thuật (tiếng Việt & tiếng Anh).

    | Class | Màu | Mô tả |
    |-------|-----|-------|
    | 🟢 PartDrawing | Xanh lá | Vùng bản vẽ chi tiết |
    | 🟠 Note | Cam | Ghi chú, chú thích |
    | 🔴 Table | Đỏ | Bảng dữ liệu kỹ thuật |
    """)

    with gr.Row():
        with gr.Column(scale=1):
            inp = gr.Image(type="pil", label="📁 Upload bản vẽ kỹ thuật")
            btn = gr.Button("🔍 Detect & OCR", variant="primary", size="lg")

        with gr.Column(scale=1):
            out_img = gr.Image(label="✅ Kết quả detection")

    with gr.Row():
        with gr.Column(scale=1):
            out_json = gr.Code(
                language="json",
                label="📋 JSON output",
                lines=25,
            )
        with gr.Column(scale=1):
            out_ocr = gr.Textbox(
                label="📝 OCR content (Note & Table)",
                lines=25,
                max_lines=60,
            )

    btn.click(
        fn      = process,
        inputs  = [inp],
        outputs = [out_img, out_json, out_ocr],
    )

    gr.Markdown("""
    ---
    **Detection:** RT-DETR-L · mAP50 = 0.942  
    **OCR:** TrOCR (microsoft/trocr-large-handwritten) + EasyOCR fallback  
    **Hỗ trợ:** Tiếng Việt · Tiếng Anh · Chữ viết tay · Chữ in
    """)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)