Agent_OCR / app.py
Shadow0704's picture
Upload 5 files
b85866b verified
# app.py
import os
import time
from typing import Tuple
import gradio as gr
from PIL import Image
import torch
from model import OCRModel
from preprocess import crop_by_region, to_tensor_one_tile # dùng hàm sẵn có của bạn
MODEL_ID = "5CD-AI/Vintern-1B-v3_5"
# CPU free-tier -> allow_flash_attn=False; GPU A10G có thể bật True
ocr_model = OCRModel(model_id=MODEL_ID, allow_flash_attn=False)
DEFAULT_PROMPT = "Chỉ trả về đúng nội dung văn bản nhìn thấy trong ảnh (không thêm giải thích)."
REGIONS = ["full", "head", "body", "foot"]
PRESETS = ["fast", "quality"]
def ensure_model_loaded():
if not ocr_model.is_loaded:
ocr_model.load()
def run_ocr(
image: Image.Image,
region: str,
preset: str,
prompt: str,
max_new_tokens: int
):
if image is None:
return "⚠️ Chưa chọn ảnh."
ensure_model_loaded()
# 1) Cắt vùng theo tham số (giống logic Flask cũ của bạn)
pil = crop_by_region(image, region=region, head_ratio=0.28, foot_ratio=0.22)
# 2) Đưa về tensor (1 tile / 448)
px = to_tensor_one_tile(pil, input_size=448)
# 3) Đồng bộ device & dtype với model (QUAN TRỌNG để tránh lỗi float/half)
model_dtype = next(ocr_model.model.parameters()).dtype
px = px.to(device=ocr_model.device, dtype=model_dtype)
# 4) Tham số sinh text
if preset == "fast":
gen = dict(max_new_tokens=min(512, max_new_tokens),
do_sample=False, num_beams=1, repetition_penalty=1.05)
else:
gen = dict(max_new_tokens=max_new_tokens,
do_sample=False, num_beams=1, repetition_penalty=1.10)
question = f"<image>\n{(prompt or DEFAULT_PROMPT).strip()}\n"
t0 = time.time()
text = ocr_model.chat(px, question, **gen)
dt = time.time() - t0
return f"{text}\n\n— elapsed: {dt:.2f}s | device: {ocr_model.device_str}"
with gr.Blocks(title="OCR Demo (Gradio)") as demo:
gr.Markdown(
"# OCR Demo (Gradio)\n"
"Upload ảnh giấy tờ → chọn **vùng** → bấm **Extract**.\n"
f"Model: `{MODEL_ID}`"
)
with gr.Row():
with gr.Column(scale=1):
inp_img = gr.Image(type="pil", label="Ảnh", sources=["upload", "clipboard"])
region = gr.Radio(REGIONS, value="full", label="Vùng cắt")
preset = gr.Radio(PRESETS, value="fast", label="Chế độ")
with gr.Column(scale=1):
prompt = gr.Textbox(value=DEFAULT_PROMPT, label="Prompt", lines=3)
max_tokens = gr.Slider(16, 512, value=128, step=8, label="max_new_tokens")
btn = gr.Button("Extract nội dung", variant="primary")
out = gr.Textbox(label="Kết quả OCR", lines=18)
btn.click(run_ocr, [inp_img, region, preset, prompt, max_tokens], [out])
if __name__ == "__main__":
# Local: mở http://127.0.0.1:7860
# Trên Hugging Face: không cần chỉnh — Spaces sẽ tự bind PORT
demo.launch()