Commit
ยท
e6c24fd
1
Parent(s):
89faab5
Refactor app.py: Clean up unused prompts and reorganize imports for clarity
Browse files
app.py
CHANGED
|
@@ -1,30 +1,17 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Dots.OCR Gradio Demo Application
|
| 4 |
-
|
| 5 |
-
A Gradio-based web interface for demonstrating the Dots.OCR model using Hugging Face transformers.
|
| 6 |
-
This application provides OCR and layout analysis capabilities for documents and images.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import os
|
| 10 |
import json
|
| 11 |
-
import traceback
|
| 12 |
import math
|
|
|
|
|
|
|
| 13 |
from io import BytesIO
|
| 14 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
| 15 |
import requests
|
| 16 |
-
|
| 17 |
-
# Set LOCAL_RANK for transformers
|
| 18 |
-
if "LOCAL_RANK" not in os.environ:
|
| 19 |
-
os.environ["LOCAL_RANK"] = "0"
|
| 20 |
-
|
| 21 |
import torch
|
| 22 |
-
import gradio as gr
|
| 23 |
from PIL import Image, ImageDraw, ImageFont
|
| 24 |
-
from transformers import AutoModelForCausalLM, AutoProcessor
|
| 25 |
from qwen_vl_utils import process_vision_info
|
| 26 |
-
import
|
| 27 |
-
|
| 28 |
|
| 29 |
# Constants
|
| 30 |
MIN_PIXELS = 3136
|
|
@@ -32,8 +19,7 @@ MAX_PIXELS = 11289600
|
|
| 32 |
IMAGE_FACTOR = 28
|
| 33 |
|
| 34 |
# Prompts
|
| 35 |
-
|
| 36 |
-
"prompt_layout_all_en": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
|
| 37 |
|
| 38 |
1. Bbox format: [x1, y1, x2, y2]
|
| 39 |
|
|
@@ -50,15 +36,7 @@ dict_promptmode_to_prompt = {
|
|
| 50 |
- All layout elements must be sorted according to human reading order.
|
| 51 |
|
| 52 |
5. Final Output: The entire output must be a single JSON object.
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
"prompt_layout_only_en": """Please output the layout information from this PDF image, including each layout's bbox and its category. The bbox should be in the format [x1, y1, x2, y2]. The layout categories for the PDF document include ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. Do not output the corresponding text. The layout result should be in JSON format.""",
|
| 56 |
-
|
| 57 |
-
"prompt_ocr": """Extract the text content from this image.""",
|
| 58 |
-
|
| 59 |
-
"prompt_grounding_ocr": """Extract text from the given bounding box on the image (format: [x1, y1, x2, y2]).\nBounding Box:\n""",
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
|
| 63 |
# Utility functions
|
| 64 |
def round_by_factor(number: int, factor: int) -> int:
|
|
@@ -263,15 +241,21 @@ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = '
|
|
| 263 |
|
| 264 |
# Initialize model and processor at script level
|
| 265 |
model_id = "rednote-hilab/dots.ocr"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
model = AutoModelForCausalLM.from_pretrained(
|
| 267 |
-
|
| 268 |
attn_implementation="flash_attention_2",
|
| 269 |
torch_dtype=torch.bfloat16,
|
| 270 |
device_map="auto",
|
| 271 |
trust_remote_code=True
|
| 272 |
)
|
| 273 |
processor = AutoProcessor.from_pretrained(
|
| 274 |
-
|
| 275 |
trust_remote_code=True
|
| 276 |
)
|
| 277 |
|
|
@@ -378,9 +362,6 @@ def process_image(
|
|
| 378 |
if min_pixels is not None or max_pixels is not None:
|
| 379 |
image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
|
| 380 |
|
| 381 |
-
# Get prompt
|
| 382 |
-
prompt = dict_promptmode_to_prompt[prompt_mode]
|
| 383 |
-
|
| 384 |
# Run inference
|
| 385 |
raw_output = inference(image, prompt)
|
| 386 |
|
|
@@ -640,15 +621,7 @@ def create_gradio_interface():
|
|
| 640 |
next_page_btn = gr.Button("Next โถ", size="sm")
|
| 641 |
|
| 642 |
gr.Markdown("### โ๏ธ Settings")
|
| 643 |
-
|
| 644 |
-
# Prompt mode selection
|
| 645 |
-
prompt_mode = gr.Dropdown(
|
| 646 |
-
choices=list(dict_promptmode_to_prompt.keys()),
|
| 647 |
-
value="prompt_layout_all_en",
|
| 648 |
-
label="Task Mode",
|
| 649 |
-
info="Choose the type of analysis to perform"
|
| 650 |
-
)
|
| 651 |
-
|
| 652 |
# Advanced settings
|
| 653 |
with gr.Accordion("Advanced Settings", open=False):
|
| 654 |
max_new_tokens = gr.Slider(
|
|
@@ -721,16 +694,6 @@ def create_gradio_interface():
|
|
| 721 |
value=None
|
| 722 |
)
|
| 723 |
|
| 724 |
-
# Prompt display
|
| 725 |
-
gr.Markdown("### ๐ฌ Current Prompt")
|
| 726 |
-
prompt_display = gr.Textbox(
|
| 727 |
-
value=dict_promptmode_to_prompt["prompt_layout_all_en"],
|
| 728 |
-
label="Prompt Text",
|
| 729 |
-
lines=8,
|
| 730 |
-
interactive=False,
|
| 731 |
-
info="This is the prompt that will be sent to the model"
|
| 732 |
-
)
|
| 733 |
-
|
| 734 |
# Event handlers
|
| 735 |
def load_model_on_startup():
|
| 736 |
"""Load model when the interface starts"""
|
|
@@ -839,8 +802,8 @@ def create_gradio_interface():
|
|
| 839 |
|
| 840 |
def update_prompt_display(mode):
|
| 841 |
"""Update the prompt display when mode changes"""
|
| 842 |
-
return
|
| 843 |
-
|
| 844 |
def handle_file_upload(file_path):
|
| 845 |
"""Handle file upload and show preview"""
|
| 846 |
if not file_path:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
|
|
|
| 2 |
import math
|
| 3 |
+
import os
|
| 4 |
+
import traceback
|
| 5 |
from io import BytesIO
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
from huggingface_hub import snapshot_download
|
| 8 |
+
import fitz # PyMuPDF
|
| 9 |
+
import gradio as gr
|
| 10 |
import requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
import torch
|
|
|
|
| 12 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
| 13 |
from qwen_vl_utils import process_vision_info
|
| 14 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
|
|
|
| 15 |
|
| 16 |
# Constants
|
| 17 |
MIN_PIXELS = 3136
|
|
|
|
| 19 |
IMAGE_FACTOR = 28
|
| 20 |
|
| 21 |
# Prompts
|
| 22 |
+
prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
|
|
|
|
| 23 |
|
| 24 |
1. Bbox format: [x1, y1, x2, y2]
|
| 25 |
|
|
|
|
| 36 |
- All layout elements must be sorted according to human reading order.
|
| 37 |
|
| 38 |
5. Final Output: The entire output must be a single JSON object.
|
| 39 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# Utility functions
|
| 42 |
def round_by_factor(number: int, factor: int) -> int:
|
|
|
|
| 241 |
|
| 242 |
# Initialize model and processor at script level
|
| 243 |
model_id = "rednote-hilab/dots.ocr"
|
| 244 |
+
model_path = "./models/dots-ocr-local"
|
| 245 |
+
snapshot_download(
|
| 246 |
+
repo_id=model_id,
|
| 247 |
+
local_dir=model_path,
|
| 248 |
+
local_dir_use_symlinks=False, # Recommended to set to False to avoid symlink issues
|
| 249 |
+
)
|
| 250 |
model = AutoModelForCausalLM.from_pretrained(
|
| 251 |
+
model_path,
|
| 252 |
attn_implementation="flash_attention_2",
|
| 253 |
torch_dtype=torch.bfloat16,
|
| 254 |
device_map="auto",
|
| 255 |
trust_remote_code=True
|
| 256 |
)
|
| 257 |
processor = AutoProcessor.from_pretrained(
|
| 258 |
+
model_path,
|
| 259 |
trust_remote_code=True
|
| 260 |
)
|
| 261 |
|
|
|
|
| 362 |
if min_pixels is not None or max_pixels is not None:
|
| 363 |
image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
|
| 364 |
|
|
|
|
|
|
|
|
|
|
| 365 |
# Run inference
|
| 366 |
raw_output = inference(image, prompt)
|
| 367 |
|
|
|
|
| 621 |
next_page_btn = gr.Button("Next โถ", size="sm")
|
| 622 |
|
| 623 |
gr.Markdown("### โ๏ธ Settings")
|
| 624 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
# Advanced settings
|
| 626 |
with gr.Accordion("Advanced Settings", open=False):
|
| 627 |
max_new_tokens = gr.Slider(
|
|
|
|
| 694 |
value=None
|
| 695 |
)
|
| 696 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 697 |
# Event handlers
|
| 698 |
def load_model_on_startup():
|
| 699 |
"""Load model when the interface starts"""
|
|
|
|
| 802 |
|
| 803 |
def update_prompt_display(mode):
|
| 804 |
"""Update the prompt display when mode changes"""
|
| 805 |
+
return prompt
|
| 806 |
+
|
| 807 |
def handle_file_upload(file_path):
|
| 808 |
"""Handle file upload and show preview"""
|
| 809 |
if not file_path:
|