ocr / app.py
vithacocf's picture
Update app.py
76a5fff verified
raw
history blame
4.14 kB
# Code anh Thang
# import gradio as gr
# from transformers import AutoProcessor, AutoModelForVision2Seq
# from PIL import Image
# import torch
# device = "cuda" if torch.cuda.is_available() else "cpu"
# torch.cuda.empty_cache()
# model_id = "prithivMLmods/Camel-Doc-OCR-062825"
# processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# model = AutoModelForVision2Seq.from_pretrained(
# model_id,
# torch_dtype=torch.float16 if device == "cuda" else torch.float32,
# trust_remote_code=True
# ).to(device)
# def predict(image, prompt=None):
# image = image.convert("RGB")
# # Cực kỳ quan trọng: text="" bắt buộc phải có
# inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
# # In debug để kiểm tra input_ids
# print(">>> input_ids shape:", inputs.input_ids.shape)
# generated_ids = model.generate(
# **inputs,
# max_new_tokens=512,
# do_sample=False,
# use_cache=False, # ✅ Thêm dòng này để fix lỗi cache_position
# eos_token_id=processor.tokenizer.eos_token_id,
# pad_token_id=processor.tokenizer.pad_token_id
# )
# result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# return result
# demo = gr.Interface(
# fn=predict,
# inputs=[
# gr.Image(type="pil", label="Tải ảnh tài liệu lên"),
# gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn")
# ],
# outputs="text",
# title="Camel-Doc OCR - Trích xuất văn bản từ ảnh"
# )
# if __name__ == "__main__":
# demo.launch()
# Code fix
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image, UnidentifiedImageError
import torch
import os
# Cấu hình thiết bị
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
# Load mô hình
model_id = "prithivMLmods/Camel-Doc-OCR-062825"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
trust_remote_code=True
).to(device)
# Hỗ trợ định dạng ảnh
def is_supported_image(image):
return isinstance(image, Image.Image)
# Chuyển PNG sang JPG
def convert_png_to_jpg(image):
converted = Image.new("RGB", image.size, (255, 255, 255))
converted.paste(image)
return converted
# Hàm chính
def predict(image, prompt=None):
# Kiểm tra ảnh hợp lệ
if not is_supported_image(image):
return "Không hỗ trợ định dạng file này. Vui lòng tải ảnh đúng."
# Prompt rỗng
if prompt is None or prompt.strip() == "":
return "Vui lòng nhập prompt để trích xuất dữ liệu từ ảnh."
try:
# Nếu ảnh là PNG có alpha, convert sang RGB
if image.mode == "RGBA" or image.mode == "LA":
image = convert_png_to_jpg(image)
image = image.convert("RGB")
except UnidentifiedImageError:
return "Không thể đọc ảnh. Vui lòng kiểm tra lại định dạng hoặc ảnh bị lỗi."
except Exception as e:
return f"Lỗi khi xử lý ảnh: {str(e)}"
# Inference
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
use_cache=False, # fix cache_position
eos_token_id=processor.tokenizer.eos_token_id,
pad_token_id=processor.tokenizer.pad_token_id
)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return result
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Tải ảnh tài liệu lên"),
gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn")
],
outputs="text",
title="Camel-Doc OCR - Trích xuất văn bản từ ảnh"
)
if __name__ == "__main__":
demo.launch()