|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
from PIL import Image, UnidentifiedImageError |
|
|
from transformers import AutoProcessor, BitsAndBytesConfig, TextIteratorStreamer |
|
|
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration |
|
|
import torch |
|
|
from threading import Thread |
|
|
import time |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
model_id = "prithivMLmods/Camel-Doc-OCR-062825" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
model_id, |
|
|
|
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
).eval() |
|
|
|
|
|
def convert_png_to_jpg(image): |
|
|
if image.mode in ["RGBA", "LA"]: |
|
|
converted = Image.new("RGB", image.size, (255, 255, 255)) |
|
|
converted.paste(image, mask=image.split()[-1]) |
|
|
return converted |
|
|
return image.convert("RGB") |
|
|
|
|
|
|
|
|
def predict(image, prompt=""): |
|
|
if image is None: |
|
|
return "=Vui lòng tải lên ảnh hợp lệ." |
|
|
|
|
|
try: |
|
|
image = convert_png_to_jpg(image) |
|
|
prompt = prompt.strip() if prompt else "Please describe the document." |
|
|
|
|
|
|
|
|
messages = [{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": image}, |
|
|
{"type": "text", "text": prompt} |
|
|
] |
|
|
}] |
|
|
text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
inputs = processor( |
|
|
text=[text_prompt], |
|
|
images=[image], |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
).to(model.device) |
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True, skip_prompt=True) |
|
|
generation_kwargs = { |
|
|
**inputs, |
|
|
"streamer": streamer, |
|
|
"max_new_tokens": 512, |
|
|
"do_sample": False, |
|
|
"use_cache": True |
|
|
} |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
buffer = "" |
|
|
for new_text in streamer: |
|
|
buffer += new_text |
|
|
time.sleep(0.01) |
|
|
|
|
|
return buffer |
|
|
|
|
|
except UnidentifiedImageError: |
|
|
return "Không thể đọc ảnh. Ảnh có thể bị hỏng hoặc sai định dạng." |
|
|
except Exception as e: |
|
|
return f"Lỗi khi xử lý ảnh: {str(e)}" |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Tải ảnh tài liệu lên"), |
|
|
gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn") |
|
|
], |
|
|
outputs="text", |
|
|
title="Camel-Doc OCR - Trích xuất văn bản từ ảnh" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |