File size: 3,096 Bytes
b85866b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# app.py
import os
import time
from typing import Tuple

import gradio as gr
from PIL import Image
import torch

from model import OCRModel
from preprocess import crop_by_region, to_tensor_one_tile  # dùng hàm sẵn có của bạn

MODEL_ID = "5CD-AI/Vintern-1B-v3_5"

# CPU free-tier -> allow_flash_attn=False; GPU A10G có thể bật True
ocr_model = OCRModel(model_id=MODEL_ID, allow_flash_attn=False)

DEFAULT_PROMPT = "Chỉ trả về đúng nội dung văn bản nhìn thấy trong ảnh (không thêm giải thích)."
REGIONS = ["full", "head", "body", "foot"]
PRESETS = ["fast", "quality"]

def ensure_model_loaded():
    if not ocr_model.is_loaded:
        ocr_model.load()

def run_ocr(

    image: Image.Image,

    region: str,

    preset: str,

    prompt: str,

    max_new_tokens: int

):
    if image is None:
        return "⚠️ Chưa chọn ảnh."

    ensure_model_loaded()

    # 1) Cắt vùng theo tham số (giống logic Flask cũ của bạn)
    pil = crop_by_region(image, region=region, head_ratio=0.28, foot_ratio=0.22)

    # 2) Đưa về tensor (1 tile / 448)
    px = to_tensor_one_tile(pil, input_size=448)

    # 3) Đồng bộ device & dtype với model (QUAN TRỌNG để tránh lỗi float/half)
    model_dtype = next(ocr_model.model.parameters()).dtype
    px = px.to(device=ocr_model.device, dtype=model_dtype)

    # 4) Tham số sinh text
    if preset == "fast":
        gen = dict(max_new_tokens=min(512, max_new_tokens),
                   do_sample=False, num_beams=1, repetition_penalty=1.05)
    else:
        gen = dict(max_new_tokens=max_new_tokens,
                   do_sample=False, num_beams=1, repetition_penalty=1.10)

    question = f"<image>\n{(prompt or DEFAULT_PROMPT).strip()}\n"

    t0 = time.time()
    text = ocr_model.chat(px, question, **gen)
    dt = time.time() - t0

    return f"{text}\n\n— elapsed: {dt:.2f}s | device: {ocr_model.device_str}"

with gr.Blocks(title="OCR Demo (Gradio)") as demo:
    gr.Markdown(
        "# OCR Demo (Gradio)\n"
        "Upload ảnh giấy tờ → chọn **vùng** → bấm **Extract**.\n"
        f"Model: `{MODEL_ID}`"
    )

    with gr.Row():
        with gr.Column(scale=1):
            inp_img = gr.Image(type="pil", label="Ảnh", sources=["upload", "clipboard"])
            region = gr.Radio(REGIONS, value="full", label="Vùng cắt")
            preset = gr.Radio(PRESETS, value="fast", label="Chế độ")
        with gr.Column(scale=1):
            prompt = gr.Textbox(value=DEFAULT_PROMPT, label="Prompt", lines=3)
            max_tokens = gr.Slider(16, 512, value=128, step=8, label="max_new_tokens")
            btn = gr.Button("Extract nội dung", variant="primary")
            out = gr.Textbox(label="Kết quả OCR", lines=18)

    btn.click(run_ocr, [inp_img, region, preset, prompt, max_tokens], [out])

if __name__ == "__main__":
    # Local: mở http://127.0.0.1:7860
    # Trên Hugging Face: không cần chỉnh — Spaces sẽ tự bind PORT
    demo.launch()