ocr / app.py
upprize's picture
.
700ddbf
import gradio as gr
import torch
import json
import spaces
import os
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.processing_utils import ProcessorMixin
from qwen_vl_utils import process_vision_info
from huggingface_hub import login
# Model configuration
MODEL_PATH = "rednote-hilab/dots.ocr"
# Optional authentication (required if the repository is gated)
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
print("Authenticating with Hugging Face token...")
login(token=HF_TOKEN, add_to_git_credential=False)
# Model and processor will be loaded on GPU when decorated function is called
model = None
processor = None
def load_model():
"""Load model and processor on GPU"""
global model, processor
if model is None:
print(f"Loading model weights from {MODEL_PATH}...")
# Try to use FlashAttention2 if available, otherwise use default attention
try:
import flash_attn
attn_implementation = "flash_attention_2"
print("Using FlashAttention2 for faster inference")
except ImportError:
attn_implementation = "eager"
print("FlashAttention2 not available, using default attention")
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
token=HF_TOKEN,
attn_implementation=attn_implementation
)
print("Model loaded successfully.")
print(f"Loading processor from {MODEL_PATH}...")
# Patch check_argument_for_proper_class to allow None for video_processor
_original_check = ProcessorMixin.check_argument_for_proper_class
def _patched_check(self, attribute_name, value):
if attribute_name == "video_processor" and value is None:
return # Skip validation for None video_processor
return _original_check(self, attribute_name, value)
ProcessorMixin.check_argument_for_proper_class = _patched_check
try:
processor = AutoProcessor.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
token=HF_TOKEN
)
print("Processor loaded successfully.")
finally:
# Restore original validation method
ProcessorMixin.check_argument_for_proper_class = _original_check
return model, processor
# Predefined prompts
PROMPTS = {
"Full Layout + OCR (English)": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules:
- Picture: For the 'Picture' category, the text field should be omitted.
- Formula: Format its text as LaTeX.
- Table: Format its text as HTML.
- All Others (Text, Title, etc.): Format their text as Markdown.
4. Constraints:
- The output text must be the original text from the image, with no translation.
- All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object.""",
"OCR Only": """Please extract all text from the image in reading order. Format the output as plain text, preserving the original structure as much as possible.""",
"Layout Detection Only": """Please detect all layout elements in the image and output their bounding boxes and categories. Format: [{"bbox": [x1, y1, x2, y2], "category": "category_name"}]""",
"Custom": ""
}
@spaces.GPU(duration=120)
def process_image(image, prompt_type, custom_prompt):
"""Process image with OCR model"""
try:
# Load model and processor on GPU
current_model, current_processor = load_model()
# Determine which prompt to use
if prompt_type == "Custom" and custom_prompt.strip():
prompt = custom_prompt
else:
prompt = PROMPTS[prompt_type]
# Prepare messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
# Prepare inputs
text = current_processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = current_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Generate output
with torch.no_grad():
generated_ids = current_model.generate(
**inputs,
max_new_tokens=24000,
temperature=0.1,
top_p=0.9,
)
# Decode output
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = current_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
# Try to format as JSON if possible
try:
parsed_json = json.loads(output_text)
output_text = json.dumps(parsed_json, ensure_ascii=False, indent=2)
except:
pass # Keep as plain text if not valid JSON
return output_text
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="dots.ocr - Multilingual Document OCR") as demo:
gr.Markdown("""
# ๐Ÿ” dots.ocr - Multilingual Document Layout Parsing
Upload a document image and get OCR results with layout detection.
This space uses the [dots.ocr](https://github.com/rednote-hilab/dots.ocr) model.
**Features:**
- Multilingual support
- Layout detection (tables, formulas, text, etc.)
- Reading order preservation
- Formula extraction (LaTeX format)
- Table extraction (HTML format)
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
label="Upload Document Image",
height=400
)
prompt_type = gr.Dropdown(
choices=list(PROMPTS.keys()),
value="Full Layout + OCR (English)",
label="Prompt Type",
info="Select the type of processing you want"
)
custom_prompt = gr.Textbox(
label="Custom Prompt (used when 'Custom' is selected)",
placeholder="Enter your custom prompt here...",
lines=5,
visible=False
)
submit_btn = gr.Button("Process Document", variant="primary", size="lg")
with gr.Column():
output_text = gr.Textbox(
label="OCR Result",
lines=25,
show_copy_button=True
)
# Show/hide custom prompt based on selection
def toggle_custom_prompt(choice):
return gr.update(visible=(choice == "Custom"))
prompt_type.change(
fn=toggle_custom_prompt,
inputs=[prompt_type],
outputs=[custom_prompt]
)
submit_btn.click(
fn=process_image,
inputs=[image_input, prompt_type, custom_prompt],
outputs=[output_text]
)
# Examples
gr.Markdown("## ๐Ÿ“ Examples")
gr.Examples(
examples=[
["examples/example1.jpg", "Full Layout + OCR (English)", ""],
["examples/example2.jpg", "OCR Only", ""],
],
inputs=[image_input, prompt_type, custom_prompt],
outputs=[output_text],
fn=process_image,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()