File size: 8,556 Bytes
7876f9c
 
21ac63b
 
4faf4e5
21ac63b
 
8a06d9f
21ac63b
4faf4e5
21ac63b
 
037bbf2
21ac63b
037bbf2
 
4faf4e5
 
037bbf2
4faf4e5
8a06d9f
21ac63b
8a06d9f
21ac63b
 
8a06d9f
 
21ac63b
037bbf2
700ddbf
 
 
 
 
 
 
 
 
 
21ac63b
 
700ddbf
21ac63b
4faf4e5
700ddbf
 
21ac63b
037bbf2
6767963
 
8a06d9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6767963
21ac63b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7876f9c
21ac63b
 
 
 
8a06d9f
 
21ac63b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a06d9f
21ac63b
 
 
 
 
8a06d9f
21ac63b
 
 
 
 
 
 
 
 
 
 
 
 
14d9fc6
21ac63b
 
 
 
 
 
 
 
 
8a06d9f
21ac63b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7876f9c
21ac63b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7876f9c
21ac63b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import gradio as gr
import torch
import json
import spaces
import os
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.processing_utils import ProcessorMixin
from qwen_vl_utils import process_vision_info
from huggingface_hub import login

# Model configuration
MODEL_PATH = "rednote-hilab/dots.ocr"

# Optional authentication (required if the repository is gated)
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
    print("Authenticating with Hugging Face token...")
    login(token=HF_TOKEN, add_to_git_credential=False)

# Model and processor will be loaded on GPU when decorated function is called
model = None
processor = None

def load_model():
    """Load model and processor on GPU"""
    global model, processor
    if model is None:
        print(f"Loading model weights from {MODEL_PATH}...")
        
        # Try to use FlashAttention2 if available, otherwise use default attention
        try:
            import flash_attn
            attn_implementation = "flash_attention_2"
            print("Using FlashAttention2 for faster inference")
        except ImportError:
            attn_implementation = "eager"
            print("FlashAttention2 not available, using default attention")
        
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_PATH,
            dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
            token=HF_TOKEN,
            attn_implementation=attn_implementation
        )
        print("Model loaded successfully.")

        print(f"Loading processor from {MODEL_PATH}...")
        
        # Patch check_argument_for_proper_class to allow None for video_processor
        _original_check = ProcessorMixin.check_argument_for_proper_class
        
        def _patched_check(self, attribute_name, value):
            if attribute_name == "video_processor" and value is None:
                return  # Skip validation for None video_processor
            return _original_check(self, attribute_name, value)
        
        ProcessorMixin.check_argument_for_proper_class = _patched_check
        
        try:
            processor = AutoProcessor.from_pretrained(
                MODEL_PATH,
                trust_remote_code=True,
                token=HF_TOKEN
            )
            print("Processor loaded successfully.")
        finally:
            # Restore original validation method
            ProcessorMixin.check_argument_for_proper_class = _original_check
    
    return model, processor

# Predefined prompts
PROMPTS = {
    "Full Layout + OCR (English)": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.

1. Bbox format: [x1, y1, x2, y2]

2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].

3. Text Extraction & Formatting Rules:
    - Picture: For the 'Picture' category, the text field should be omitted.
    - Formula: Format its text as LaTeX.
    - Table: Format its text as HTML.
    - All Others (Text, Title, etc.): Format their text as Markdown.

4. Constraints:
    - The output text must be the original text from the image, with no translation.
    - All layout elements must be sorted according to human reading order.

5. Final Output: The entire output must be a single JSON object.""",

    "OCR Only": """Please extract all text from the image in reading order. Format the output as plain text, preserving the original structure as much as possible.""",
    
    "Layout Detection Only": """Please detect all layout elements in the image and output their bounding boxes and categories. Format: [{"bbox": [x1, y1, x2, y2], "category": "category_name"}]""",
    
    "Custom": ""
}

@spaces.GPU(duration=120)
def process_image(image, prompt_type, custom_prompt):
    """Process image with OCR model"""
    try:
        # Load model and processor on GPU
        current_model, current_processor = load_model()
        
        # Determine which prompt to use
        if prompt_type == "Custom" and custom_prompt.strip():
            prompt = custom_prompt
        else:
            prompt = PROMPTS[prompt_type]
        
        # Prepare messages
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": prompt}
                ]
            }
        ]
        
        # Prepare inputs
        text = current_processor.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = current_processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        
        inputs = inputs.to("cuda")
        
        # Generate output
        with torch.no_grad():
            generated_ids = current_model.generate(
                **inputs, 
                max_new_tokens=24000,
                temperature=0.1,
                top_p=0.9,
            )
        
        # Decode output
        generated_ids_trimmed = [
            out_ids[len(in_ids):] 
            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = current_processor.batch_decode(
            generated_ids_trimmed, 
            skip_special_tokens=True, 
            clean_up_tokenization_spaces=False
        )[0]
        
        # Try to format as JSON if possible
        try:
            parsed_json = json.loads(output_text)
            output_text = json.dumps(parsed_json, ensure_ascii=False, indent=2)
        except:
            pass  # Keep as plain text if not valid JSON
        
        return output_text
    
    except Exception as e:
        return f"Error: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="dots.ocr - Multilingual Document OCR") as demo:
    gr.Markdown("""
    # ๐Ÿ” dots.ocr - Multilingual Document Layout Parsing
    
    Upload a document image and get OCR results with layout detection.
    This space uses the [dots.ocr](https://github.com/rednote-hilab/dots.ocr) model.
    
    **Features:**
    - Multilingual support
    - Layout detection (tables, formulas, text, etc.)
    - Reading order preservation
    - Formula extraction (LaTeX format)
    - Table extraction (HTML format)
    """)
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(
                type="pil", 
                label="Upload Document Image",
                height=400
            )
            
            prompt_type = gr.Dropdown(
                choices=list(PROMPTS.keys()),
                value="Full Layout + OCR (English)",
                label="Prompt Type",
                info="Select the type of processing you want"
            )
            
            custom_prompt = gr.Textbox(
                label="Custom Prompt (used when 'Custom' is selected)",
                placeholder="Enter your custom prompt here...",
                lines=5,
                visible=False
            )
            
            submit_btn = gr.Button("Process Document", variant="primary", size="lg")
        
        with gr.Column():
            output_text = gr.Textbox(
                label="OCR Result", 
                lines=25,
                show_copy_button=True
            )
    
    # Show/hide custom prompt based on selection
    def toggle_custom_prompt(choice):
        return gr.update(visible=(choice == "Custom"))
    
    prompt_type.change(
        fn=toggle_custom_prompt,
        inputs=[prompt_type],
        outputs=[custom_prompt]
    )
    
    submit_btn.click(
        fn=process_image,
        inputs=[image_input, prompt_type, custom_prompt],
        outputs=[output_text]
    )
    
    # Examples
    gr.Markdown("## ๐Ÿ“ Examples")
    gr.Examples(
        examples=[
            ["examples/example1.jpg", "Full Layout + OCR (English)", ""],
            ["examples/example2.jpg", "OCR Only", ""],
        ],
        inputs=[image_input, prompt_type, custom_prompt],
        outputs=[output_text],
        fn=process_image,
        cache_examples=False,
    )

if __name__ == "__main__":
    demo.launch()