|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
|
from PIL import Image, UnidentifiedImageError |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
model_id = "prithivMLmods/Camel-Doc-OCR-062825" |
|
|
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
|
|
model = AutoModelForVision2Seq.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
|
trust_remote_code=True |
|
|
).to(device) |
|
|
|
|
|
|
|
|
def is_supported_image(image): |
|
|
return isinstance(image, Image.Image) |
|
|
|
|
|
|
|
|
def convert_png_to_jpg(image): |
|
|
converted = Image.new("RGB", image.size, (255, 255, 255)) |
|
|
converted.paste(image) |
|
|
return converted |
|
|
|
|
|
|
|
|
def predict(image, prompt=None): |
|
|
|
|
|
if not is_supported_image(image): |
|
|
return "Không hỗ trợ định dạng file này. Vui lòng tải ảnh đúng." |
|
|
|
|
|
|
|
|
if prompt is None or prompt.strip() == "": |
|
|
return "Vui lòng nhập prompt để trích xuất dữ liệu từ ảnh." |
|
|
|
|
|
try: |
|
|
|
|
|
if image.mode == "RGBA" or image.mode == "LA": |
|
|
image = convert_png_to_jpg(image) |
|
|
|
|
|
image = image.convert("RGB") |
|
|
|
|
|
except UnidentifiedImageError: |
|
|
return "Không thể đọc ảnh. Vui lòng kiểm tra lại định dạng hoặc ảnh bị lỗi." |
|
|
except Exception as e: |
|
|
return f"Lỗi khi xử lý ảnh: {str(e)}" |
|
|
|
|
|
|
|
|
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device) |
|
|
|
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
do_sample=False, |
|
|
use_cache=False, |
|
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
|
pad_token_id=processor.tokenizer.pad_token_id |
|
|
) |
|
|
|
|
|
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
return result |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Tải ảnh tài liệu lên"), |
|
|
gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn") |
|
|
], |
|
|
outputs="text", |
|
|
title="Camel-Doc OCR - Trích xuất văn bản từ ảnh" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |