prithivMLmods commited on
Commit
95d9832
·
verified ·
1 Parent(s): dd959ff

update app

Browse files
Files changed (1) hide show
  1. app.py +264 -0
app.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, AutoModelForImageTextToText
6
+ from gradio.themes import Soft
7
+ from gradio.themes.utils import colors, fonts, sizes
8
+ from typing import Iterable
9
+
10
+ colors.orange_red = colors.Color(
11
+ name="orange_red",
12
+ c50="#FFF0E5",
13
+ c100="#FFE0CC",
14
+ c200="#FFC299",
15
+ c300="#FFA366",
16
+ c400="#FF8533",
17
+ c500="#FF4500",
18
+ c600="#E63E00",
19
+ c700="#CC3700",
20
+ c800="#B33000",
21
+ c900="#992900",
22
+ c950="#802200",
23
+ )
24
+
25
+ class OrangeRedTheme(Soft):
26
+ def __init__(
27
+ self,
28
+ *,
29
+ primary_hue: colors.Color | str = colors.gray,
30
+ secondary_hue: colors.Color | str = colors.orange_red,
31
+ neutral_hue: colors.Color | str = colors.slate,
32
+ text_size: sizes.Size | str = sizes.text_lg,
33
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
34
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
35
+ ),
36
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
37
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
38
+ ),
39
+ ):
40
+ super().__init__(
41
+ primary_hue=primary_hue,
42
+ secondary_hue=secondary_hue,
43
+ neutral_hue=neutral_hue,
44
+ text_size=text_size,
45
+ font=font,
46
+ font_mono=font_mono,
47
+ )
48
+ super().set(
49
+ background_fill_primary="*primary_50",
50
+ background_fill_primary_dark="*primary_900",
51
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
52
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
53
+ button_primary_text_color="white",
54
+ button_primary_text_color_hover="white",
55
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
56
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
57
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
58
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
59
+ slider_color="*secondary_500",
60
+ block_title_text_weight="600",
61
+ block_border_width="0px",
62
+ block_shadow="*shadow_drop_lg",
63
+ button_large_padding="12px 24px",
64
+ color_accent_soft="*primary_100",
65
+ )
66
+
67
+ orange_red_theme = OrangeRedTheme()
68
+
69
+ MODEL_PATH = "zai-org/GLM-OCR"
70
+
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+
73
+ print(f"Loading {MODEL_PATH} on {device}...")
74
+
75
+ try:
76
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
77
+ model = AutoModelForImageTextToText.from_pretrained(
78
+ pretrained_model_name_or_path=MODEL_PATH,
79
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
80
+ device_map="auto",
81
+ trust_remote_code=True,
82
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager"
83
+ )
84
+ except Exception as e:
85
+ print(f"Error loading model: {e}")
86
+ # Fallback for CPU/No-Flash-Attn environments if necessary
87
+ model = AutoModelForImageTextToText.from_pretrained(
88
+ pretrained_model_name_or_path=MODEL_PATH,
89
+ torch_dtype="auto",
90
+ device_map="auto",
91
+ trust_remote_code=True
92
+ )
93
+
94
+ class GlmOcr(gr.HTML):
95
+ """
96
+ Custom Header Component for the minimalistic UI.
97
+ """
98
+ def __init__(self):
99
+ content = """
100
+ <div style="text-align: center; margin-bottom: 2rem; padding: 2rem 1rem;">
101
+ <h1 style="font-size: 3rem; font-weight: 800; margin: 0;
102
+ background: linear-gradient(90deg, #FF4500, #E63E00);
103
+ -webkit-background-clip: text; -webkit-text-fill-color: transparent;">
104
+ GLM-OCR
105
+ </h1>
106
+ <p style="font-size: 1.2rem; margin-top: 0.5rem; opacity: 0.8; font-weight: 300;">
107
+ High-precision Document, Formula, and Table Recognition
108
+ </p>
109
+ <div style="display: flex; justify-content: center; gap: 10px; margin-top: 15px;">
110
+ <span style="background: rgba(255, 69, 0, 0.1); color: #E63E00; padding: 4px 12px; border-radius: 20px; font-size: 0.9rem; font-weight: 600;">Text</span>
111
+ <span style="background: rgba(255, 69, 0, 0.1); color: #E63E00; padding: 4px 12px; border-radius: 20px; font-size: 0.9rem; font-weight: 600;">LaTeX Formulas</span>
112
+ <span style="background: rgba(255, 69, 0, 0.1); color: #E63E00; padding: 4px 12px; border-radius: 20px; font-size: 0.9rem; font-weight: 600;">Tables</span>
113
+ </div>
114
+ </div>
115
+ """
116
+ super().__init__(value=content)
117
+
118
+ TASK_MAPPING = {
119
+ "Text Parsing": "Text Recognition:",
120
+ "Formula/LaTeX": "Formula Recognition:",
121
+ "Table Extraction": "Table Recognition:"
122
+ }
123
+
124
+ @spaces.GPU
125
+ def run_ocr(image, task_key):
126
+ if image is None:
127
+ return None, "Please upload an image."
128
+
129
+ prompt_text = TASK_MAPPING.get(task_key, "Text Recognition:")
130
+
131
+ # Prepare messages
132
+ messages = [
133
+ {
134
+ "role": "user",
135
+ "content": [
136
+ {
137
+ "type": "image",
138
+ "image": image, # Passing PIL image directly
139
+ },
140
+ {
141
+ "type": "text",
142
+ "text": prompt_text
143
+ }
144
+ ],
145
+ }
146
+ ]
147
+
148
+ # Process inputs
149
+ # Note: apply_chat_template with return_tensors="pt" handles image processing if the processor is multimodal aware
150
+ inputs = processor.apply_chat_template(
151
+ messages,
152
+ tokenize=True,
153
+ add_generation_prompt=True,
154
+ return_dict=True,
155
+ return_tensors="pt"
156
+ ).to(model.device)
157
+
158
+ # Remove token_type_ids if present (common issue with some models)
159
+ inputs.pop("token_type_ids", None)
160
+
161
+ # Generate
162
+ with torch.no_grad():
163
+ generated_ids = model.generate(
164
+ **inputs,
165
+ max_new_tokens=8192,
166
+ do_sample=False, # Deterministic for OCR
167
+ temperature=0.01
168
+ )
169
+
170
+ # Decode
171
+ # We skip the input prompt tokens to get only the new text
172
+ output_text = processor.decode(
173
+ generated_ids[0][inputs["input_ids"].shape[1]:],
174
+ skip_special_tokens=True
175
+ )
176
+
177
+ return output_text, output_text
178
+
179
+ css = """
180
+ .gradio-container {
181
+ max-width: 1200px !important;
182
+ margin: 0 auto;
183
+ }
184
+ .image-container {
185
+ border-radius: 12px;
186
+ overflow: hidden;
187
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
188
+ }
189
+ """
190
+
191
+ with gr.Blocks(title="GLM-OCR") as demo:
192
+
193
+ # Custom Header
194
+ GlmOcr()
195
+
196
+ with gr.Row():
197
+ # Left Column: Inputs
198
+ with gr.Column(scale=1):
199
+ with gr.Group():
200
+ image_input = gr.Image(
201
+ type="pil",
202
+ label="Document Image",
203
+ elem_classes="image-container",
204
+ height=400
205
+ )
206
+
207
+ with gr.Row():
208
+ task_select = gr.Dropdown(
209
+ choices=list(TASK_MAPPING.keys()),
210
+ value="Text Parsing",
211
+ label="Extraction Mode",
212
+ interactive=True,
213
+ scale=2
214
+ )
215
+ submit_btn = gr.Button(
216
+ "Process",
217
+ variant="primary",
218
+ scale=1,
219
+ size="lg"
220
+ )
221
+
222
+ with gr.Accordion("Tips", open=True):
223
+ gr.Markdown("""
224
+ - **Text Parsing**: Extracts all text and layout structure.
225
+ - **Formula/LaTeX**: Optimized for scientific papers and math.
226
+ - **Table Extraction**: Converts tables directly to Markdown/Structure.
227
+ """)
228
+
229
+ # Right Column: Outputs
230
+ with gr.Column(scale=1):
231
+ with gr.Tabs():
232
+ with gr.Tab("Rendered Output"):
233
+ md_output = gr.Markdown(
234
+ label="Result",
235
+ value="_Output will appear here..._",
236
+ latex_delimiters=[
237
+ {"left": "$$", "right": "$$", "display": True},
238
+ {"left": "$", "right": "$", "display": False},
239
+ {"left": "\\(", "right": "\\)", "display": False},
240
+ {"left": "\\[", "right": "\\]", "display": True}
241
+ ]
242
+ )
243
+ with gr.Tab("Raw Source"):
244
+ raw_output = gr.Textbox(
245
+ label="Raw Text/LaTeX",
246
+ lines=20,
247
+ show_copy_button=True,
248
+ interactive=False
249
+ )
250
+
251
+ # Event Wiring
252
+ submit_btn.click(
253
+ fn=run_ocr,
254
+ inputs=[image_input, task_select],
255
+ outputs=[md_output, raw_output]
256
+ )
257
+
258
+ if __name__ == "__main__":
259
+ demo.queue().launch(
260
+ theme=orange_red_theme,
261
+ css=css,
262
+ ssr_mode=False,
263
+ show_error=True
264
+ )