ocr / app.py
vithacocf's picture
Update app.py
0fb18ff verified
raw
history blame
4.16 kB
# Code anh Thang
# import gradio as gr
# from transformers import AutoProcessor, AutoModelForVision2Seq
# from PIL import Image
# import torch
# device = "cuda" if torch.cuda.is_available() else "cpu"
# torch.cuda.empty_cache()
# model_id = "prithivMLmods/Camel-Doc-OCR-062825"
# processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# model = AutoModelForVision2Seq.from_pretrained(
# model_id,
# torch_dtype=torch.float16 if device == "cuda" else torch.float32,
# trust_remote_code=True
# ).to(device)
# def predict(image, prompt=None):
# image = image.convert("RGB")
# # Cực kỳ quan trọng: text="" bắt buộc phải có
# inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
# # In debug để kiểm tra input_ids
# print(">>> input_ids shape:", inputs.input_ids.shape)
# generated_ids = model.generate(
# **inputs,
# max_new_tokens=512,
# do_sample=False,
# use_cache=False, # ✅ Thêm dòng này để fix lỗi cache_position
# eos_token_id=processor.tokenizer.eos_token_id,
# pad_token_id=processor.tokenizer.pad_token_id
# )
# result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# return result
# demo = gr.Interface(
# fn=predict,
# inputs=[
# gr.Image(type="pil", label="Tải ảnh tài liệu lên"),
# gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn")
# ],
# outputs="text",
# title="Camel-Doc OCR - Trích xuất văn bản từ ảnh"
# )
# if __name__ == "__main__":
# demo.launch()
# Code fix
import gradio as gr
# from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image, UnidentifiedImageError
# import torch
import os
# # Cấu hình thiết bị
# device = "cuda" if torch.cuda.is_available() else "cpu"
# torch.cuda.empty_cache()
# # Load mô hình
# model_id = "prithivMLmods/Camel-Doc-OCR-062825"
# processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# model = AutoModelForVision2Seq.from_pretrained(
# model_id,
# torch_dtype=torch.float16 if device == "cuda" else torch.float32,
# trust_remote_code=True
# ).to(device)
# Hỗ trợ định dạng ảnh
def is_supported_image(image):
return isinstance(image, Image.Image)
# Chuyển PNG sang JPG
def convert_png_to_jpg(image):
converted = Image.new("RGB", image.size, (255, 255, 255))
converted.paste(image)
return converted
# Hàm chính
def predict(image_path, prompt=None):
if not isinstance(image_path, str) or not os.path.exists(image_path):
return "=Không tìm thấy ảnh. Vui lòng thử lại sau khi upload thành công."
if prompt is None or prompt.strip() == "":
return "=Vui lòng nhập prompt để trích xuất dữ liệu."
try:
image = Image.open(image_path).convert("RGB")
if image.mode in ["RGBA", "LA"]:
new_img = Image.new("RGB", image.size, (255, 255, 255))
new_img.paste(image)
image = new_img
except UnidentifiedImageError:
return "=Không thể đọc ảnh. Ảnh có thể bị hỏng hoặc sai định dạng."
except Exception as e:
return f"=Lỗi khi xử lý ảnh: {str(e)}"
# inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
# generated_ids = model.generate(
# **inputs,
# max_new_tokens=512,
# do_sample=False,
# use_cache=False,
# eos_token_id=processor.tokenizer.eos_token_id,
# pad_token_id=processor.tokenizer.pad_token_id
# )
# result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
result = "aaa"
return result
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Tải ảnh tài liệu lên"),
gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn")
],
outputs="text",
title="Camel-Doc OCR - Trích xuất văn bản từ ảnh"
)
if __name__ == "__main__":
demo.launch()