|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
|
import spaces |
|
|
|
|
|
|
|
|
MODEL_PATH = "PaddlePaddle/PaddleOCR-VL" |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
PROMPTS = { |
|
|
"OCR": "OCR:", |
|
|
"Table Recognition": "Table Recognition:", |
|
|
"Formula Recognition": "Formula Recognition:", |
|
|
"Chart Recognition": "Chart Recognition:", |
|
|
} |
|
|
|
|
|
|
|
|
print(f"Loading model on {DEVICE}...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_PATH, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16 |
|
|
).to(DEVICE).eval() |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) |
|
|
print("Model loaded successfully!") |
|
|
|
|
|
@spaces.GPU |
|
|
def process_image(image, task): |
|
|
""" |
|
|
Process an image with PaddleOCR-VL model. |
|
|
|
|
|
Args: |
|
|
image: PIL Image or path to image |
|
|
task: Task type (OCR, Table Recognition, etc.) |
|
|
|
|
|
Returns: |
|
|
str: Recognition result |
|
|
""" |
|
|
if image is None: |
|
|
return "Please upload an image first." |
|
|
|
|
|
|
|
|
if not isinstance(image, Image.Image): |
|
|
image = Image.open(image) |
|
|
|
|
|
image = image.convert("RGB") |
|
|
|
|
|
|
|
|
prompt = PROMPTS.get(task, PROMPTS["OCR"]) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": image}, |
|
|
{"type": "text", "text": prompt}, |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
inputs = processor.apply_chat_template( |
|
|
messages, |
|
|
tokenize=True, |
|
|
add_generation_prompt=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt" |
|
|
).to(DEVICE) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate(**inputs, max_new_tokens=1024) |
|
|
|
|
|
|
|
|
result = processor.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=process_image, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Upload Image"), |
|
|
gr.Radio( |
|
|
choices=list(PROMPTS.keys()), |
|
|
value="OCR", |
|
|
label="Task Type" |
|
|
) |
|
|
], |
|
|
outputs=gr.Textbox(label="Result", lines=10), |
|
|
title="PaddleOCR-VL: Multilingual Document Parsing", |
|
|
description="Upload an image and select a task. This model supports OCR in 109 languages, table recognition, formula recognition, and chart recognition.", |
|
|
examples=[ |
|
|
["example.png", "OCR"], |
|
|
] if False else None, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |