ocr / app.py
vithacocf's picture
Update app.py
506e1a2 verified
raw
history blame
4.83 kB
# Code anh Thang
# import gradio as gr
# from transformers import AutoProcessor, AutoModelForVision2Seq
# from PIL import Image
# import torch
# device = "cuda" if torch.cuda.is_available() else "cpu"
# torch.cuda.empty_cache()
# model_id = "prithivMLmods/Camel-Doc-OCR-062825"
# processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# model = AutoModelForVision2Seq.from_pretrained(
# model_id,
# torch_dtype=torch.float16 if device == "cuda" else torch.float32,
# trust_remote_code=True
# ).to(device)
# def predict(image, prompt=None):
# image = image.convert("RGB")
# # Cực kỳ quan trọng: text="" bắt buộc phải có
# inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
# # In debug để kiểm tra input_ids
# print(">>> input_ids shape:", inputs.input_ids.shape)
# generated_ids = model.generate(
# **inputs,
# max_new_tokens=512,
# do_sample=False,
# use_cache=False, # ✅ Thêm dòng này để fix lỗi cache_position
# eos_token_id=processor.tokenizer.eos_token_id,
# pad_token_id=processor.tokenizer.pad_token_id
# )
# result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# return result
# demo = gr.Interface(
# fn=predict,
# inputs=[
# gr.Image(type="pil", label="Tải ảnh tài liệu lên"),
# gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn")
# ],
# outputs="text",
# title="Camel-Doc OCR - Trích xuất văn bản từ ảnh"
# )
# if __name__ == "__main__":
# demo.launch()
# Code fix
import gradio as gr
from PIL import Image, UnidentifiedImageError
from transformers import AutoProcessor, BitsAndBytesConfig, TextIteratorStreamer
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
import torch
from threading import Thread
import time
# Cấu hình thiết bị
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
# Load mô hình Qwen2.5-VL với quantization 4-bit
model_id = "prithivMLmods/Camel-Doc-OCR-062825"
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.float16
# )
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
# quantization_config=bnb_config, Quantization
device_map="auto",
trust_remote_code=True
).eval()
def convert_png_to_jpg(image):
if image.mode in ["RGBA", "LA"]:
converted = Image.new("RGB", image.size, (255, 255, 255))
converted.paste(image, mask=image.split()[-1])
return converted
return image.convert("RGB")
# Hàm dự đoán
def predict(image, prompt=""):
if image is None:
return "=Vui lòng tải lên ảnh hợp lệ."
try:
image = convert_png_to_jpg(image)
prompt = prompt.strip() if prompt else "Please describe the document."
# Xây dựng prompt theo định dạng Qwen2.5-VL
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}]
text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text_prompt],
images=[image],
return_tensors="pt",
padding=True
).to(model.device)
# Dùng streamer để sinh kết quả mượt hơn
streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": 512,
"do_sample": False,
"use_cache": True
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
return buffer
except UnidentifiedImageError:
return "Không thể đọc ảnh. Ảnh có thể bị hỏng hoặc sai định dạng."
except Exception as e:
return f"Lỗi khi xử lý ảnh: {str(e)}"
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Tải ảnh tài liệu lên"),
gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn")
],
outputs="text",
title="Camel-Doc OCR - Trích xuất văn bản từ ảnh"
)
if __name__ == "__main__":
demo.launch()