Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import time | |
| from typing import Tuple | |
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from model import OCRModel | |
| from preprocess import crop_by_region, to_tensor_one_tile # dùng hàm sẵn có của bạn | |
| MODEL_ID = "5CD-AI/Vintern-1B-v3_5" | |
| # CPU free-tier -> allow_flash_attn=False; GPU A10G có thể bật True | |
| ocr_model = OCRModel(model_id=MODEL_ID, allow_flash_attn=False) | |
| DEFAULT_PROMPT = "Chỉ trả về đúng nội dung văn bản nhìn thấy trong ảnh (không thêm giải thích)." | |
| REGIONS = ["full", "head", "body", "foot"] | |
| PRESETS = ["fast", "quality"] | |
| def ensure_model_loaded(): | |
| if not ocr_model.is_loaded: | |
| ocr_model.load() | |
| def run_ocr( | |
| image: Image.Image, | |
| region: str, | |
| preset: str, | |
| prompt: str, | |
| max_new_tokens: int | |
| ): | |
| if image is None: | |
| return "⚠️ Chưa chọn ảnh." | |
| ensure_model_loaded() | |
| # 1) Cắt vùng theo tham số (giống logic Flask cũ của bạn) | |
| pil = crop_by_region(image, region=region, head_ratio=0.28, foot_ratio=0.22) | |
| # 2) Đưa về tensor (1 tile / 448) | |
| px = to_tensor_one_tile(pil, input_size=448) | |
| # 3) Đồng bộ device & dtype với model (QUAN TRỌNG để tránh lỗi float/half) | |
| model_dtype = next(ocr_model.model.parameters()).dtype | |
| px = px.to(device=ocr_model.device, dtype=model_dtype) | |
| # 4) Tham số sinh text | |
| if preset == "fast": | |
| gen = dict(max_new_tokens=min(512, max_new_tokens), | |
| do_sample=False, num_beams=1, repetition_penalty=1.05) | |
| else: | |
| gen = dict(max_new_tokens=max_new_tokens, | |
| do_sample=False, num_beams=1, repetition_penalty=1.10) | |
| question = f"<image>\n{(prompt or DEFAULT_PROMPT).strip()}\n" | |
| t0 = time.time() | |
| text = ocr_model.chat(px, question, **gen) | |
| dt = time.time() - t0 | |
| return f"{text}\n\n— elapsed: {dt:.2f}s | device: {ocr_model.device_str}" | |
| with gr.Blocks(title="OCR Demo (Gradio)") as demo: | |
| gr.Markdown( | |
| "# OCR Demo (Gradio)\n" | |
| "Upload ảnh giấy tờ → chọn **vùng** → bấm **Extract**.\n" | |
| f"Model: `{MODEL_ID}`" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inp_img = gr.Image(type="pil", label="Ảnh", sources=["upload", "clipboard"]) | |
| region = gr.Radio(REGIONS, value="full", label="Vùng cắt") | |
| preset = gr.Radio(PRESETS, value="fast", label="Chế độ") | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(value=DEFAULT_PROMPT, label="Prompt", lines=3) | |
| max_tokens = gr.Slider(16, 512, value=128, step=8, label="max_new_tokens") | |
| btn = gr.Button("Extract nội dung", variant="primary") | |
| out = gr.Textbox(label="Kết quả OCR", lines=18) | |
| btn.click(run_ocr, [inp_img, region, preset, prompt, max_tokens], [out]) | |
| if __name__ == "__main__": | |
| # Local: mở http://127.0.0.1:7860 | |
| # Trên Hugging Face: không cần chỉnh — Spaces sẽ tự bind PORT | |
| demo.launch() | |