| | import gradio as gr |
| | import torch |
| | from PIL import Image |
| | from transformers import AutoModelForCausalLM, AutoProcessor |
| | import spaces |
| |
|
| | |
| | MODEL_PATH = "PaddlePaddle/PaddleOCR-VL" |
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | PROMPTS = { |
| | "OCR": "OCR:", |
| | "Table Recognition": "Table Recognition:", |
| | "Formula Recognition": "Formula Recognition:", |
| | "Chart Recognition": "Chart Recognition:", |
| | } |
| |
|
| | |
| | print(f"Loading model on {DEVICE}...") |
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_PATH, |
| | trust_remote_code=True, |
| | torch_dtype=torch.bfloat16 |
| | ).to(DEVICE).eval() |
| |
|
| | processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) |
| | print("Model loaded successfully!") |
| |
|
| | @spaces.GPU |
| | def process_image(image, task): |
| | """ |
| | Process an image with PaddleOCR-VL model. |
| | |
| | Args: |
| | image: PIL Image or path to image |
| | task: Task type (OCR, Table Recognition, etc.) |
| | |
| | Returns: |
| | str: Recognition result |
| | """ |
| | if image is None: |
| | return "Please upload an image first." |
| | |
| | |
| | if not isinstance(image, Image.Image): |
| | image = Image.open(image) |
| | |
| | image = image.convert("RGB") |
| | |
| | |
| | prompt = PROMPTS.get(task, PROMPTS["OCR"]) |
| | |
| | |
| | messages = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image", "image": image}, |
| | {"type": "text", "text": prompt}, |
| | ] |
| | } |
| | ] |
| | |
| | |
| | inputs = processor.apply_chat_template( |
| | messages, |
| | tokenize=True, |
| | add_generation_prompt=True, |
| | return_dict=True, |
| | return_tensors="pt" |
| | ).to(DEVICE) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = model.generate(**inputs, max_new_tokens=1024) |
| | |
| | |
| | result = processor.batch_decode(outputs, skip_special_tokens=True)[0] |
| | |
| | return result |
| |
|
| | |
| | demo = gr.Interface( |
| | fn=process_image, |
| | inputs=[ |
| | gr.Image(type="pil", label="Upload Image"), |
| | gr.Radio( |
| | choices=list(PROMPTS.keys()), |
| | value="OCR", |
| | label="Task Type" |
| | ) |
| | ], |
| | outputs=gr.Textbox(label="Result", lines=10), |
| | title="PaddleOCR-VL: Multilingual Document Parsing", |
| | description="Upload an image and select a task. This model supports OCR in 109 languages, table recognition, formula recognition, and chart recognition.", |
| | examples=[ |
| | ["example.png", "OCR"], |
| | ] if False else None, |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |