Sensei13k commited on
Commit
ecb5889
·
verified ·
1 Parent(s): 3db735f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -517
app.py CHANGED
@@ -1,517 +1,197 @@
1
- import base64
2
- import os
3
- import re
4
- import shutil
5
- import time
6
- import uuid
7
- from pathlib import Path
8
-
9
- import cv2
10
- import gradio as gr
11
- import numpy as np
12
- import spaces
13
- import torch
14
- from globe import description, title
15
- from PIL import Image
16
- from render import render_ocr_text
17
-
18
- from transformers import AutoModelForImageTextToText, AutoProcessor
19
- from transformers.image_utils import load_image
20
-
21
- model_name = "stepfun-ai/GOT-OCR-2.0-hf"
22
-
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
-
25
- processor = AutoProcessor.from_pretrained(model_name)
26
- model = AutoModelForImageTextToText.from_pretrained(
27
- model_name, low_cpu_mem_usage=True, device_map=device
28
- )
29
- model = model.eval().to(device)
30
-
31
- UPLOAD_FOLDER = "./uploads"
32
- RESULTS_FOLDER = "./results"
33
- stop_str = "<|im_end|>"
34
- for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
35
- if not os.path.exists(folder):
36
- os.makedirs(folder)
37
-
38
- input_index = 0
39
-
40
-
41
- @spaces.GPU()
42
- def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None):
43
- if image is None:
44
- return "Error: No image provided", None, None
45
-
46
- unique_id = str(uuid.uuid4())
47
- image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
48
- result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
49
- try:
50
- if not isinstance(image, (tuple, list)):
51
- image = [image]
52
- else:
53
- image = [img[0] for img in image]
54
- for i, img in enumerate(image):
55
- if isinstance(img, dict):
56
- composite_image = img.get("composite")
57
- if composite_image is not None:
58
- if isinstance(composite_image, np.ndarray):
59
- cv2.imwrite(
60
- image_path, cv2.cvtColor(composite_image, cv2.COLOR_RGB2BGR)
61
- )
62
- elif isinstance(composite_image, Image.Image):
63
- composite_image.save(image_path)
64
- else:
65
- return (
66
- "Error: Unsupported image format from ImageEditor",
67
- None,
68
- None,
69
- )
70
- else:
71
- return (
72
- "Error: No composite image found in ImageEditor output",
73
- None,
74
- None,
75
- )
76
- elif isinstance(img, np.ndarray):
77
- cv2.imwrite(image_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
78
- elif isinstance(img, str):
79
- shutil.copy(img, image_path)
80
- else:
81
- return "Error: Unsupported image format", None, None
82
-
83
- image[i] = load_image(image_path)
84
-
85
- if task == "Plain Text OCR":
86
- inputs = processor(image, return_tensors="pt").to("cuda")
87
- generate_ids = model.generate(
88
- **inputs,
89
- do_sample=False,
90
- tokenizer=processor.tokenizer,
91
- stop_strings=stop_str,
92
- max_new_tokens=4096,
93
- )
94
- res = processor.decode(
95
- generate_ids[0, inputs["input_ids"].shape[1] :],
96
- skip_special_tokens=True,
97
- )
98
- return res, None, unique_id
99
- else:
100
- if task == "Format Text OCR":
101
- inputs = processor(image, return_tensors="pt", format=True).to("cuda")
102
- generate_ids = model.generate(
103
- **inputs,
104
- do_sample=False,
105
- tokenizer=processor.tokenizer,
106
- stop_strings=stop_str,
107
- max_new_tokens=4096,
108
- )
109
- res = processor.decode(
110
- generate_ids[0, inputs["input_ids"].shape[1] :],
111
- skip_special_tokens=True,
112
- )
113
- ocr_type = "format"
114
- elif task == "Fine-grained OCR (Box)":
115
- inputs = processor(image, return_tensors="pt", box=ocr_box).to("cuda")
116
- generate_ids = model.generate(
117
- **inputs,
118
- do_sample=False,
119
- tokenizer=processor.tokenizer,
120
- stop_strings=stop_str,
121
- max_new_tokens=4096,
122
- )
123
- res = processor.decode(
124
- generate_ids[0, inputs["input_ids"].shape[1] :],
125
- skip_special_tokens=True,
126
- )
127
- elif task == "Fine-grained OCR (Color)":
128
- inputs = processor(image, return_tensors="pt", color=ocr_color).to(
129
- "cuda"
130
- )
131
- generate_ids = model.generate(
132
- **inputs,
133
- do_sample=False,
134
- tokenizer=processor.tokenizer,
135
- stop_strings=stop_str,
136
- max_new_tokens=4096,
137
- )
138
- res = processor.decode(
139
- generate_ids[0, inputs["input_ids"].shape[1] :],
140
- skip_special_tokens=True,
141
- )
142
- elif task == "Multi-crop OCR":
143
- inputs = processor(
144
- image,
145
- return_tensors="pt",
146
- format=True,
147
- crop_to_patches=True,
148
- max_patches=5,
149
- ).to("cuda")
150
- generate_ids = model.generate(
151
- **inputs,
152
- do_sample=False,
153
- tokenizer=processor.tokenizer,
154
- stop_strings=stop_str,
155
- max_new_tokens=4096,
156
- )
157
- res = processor.decode(
158
- generate_ids[0, inputs["input_ids"].shape[1] :],
159
- skip_special_tokens=True,
160
- )
161
- ocr_type = "format"
162
- elif task == "Multi-page OCR":
163
- inputs = processor(
164
- image, return_tensors="pt", multi_page=True, format=True
165
- ).to("cuda")
166
- generate_ids = model.generate(
167
- **inputs,
168
- do_sample=False,
169
- tokenizer=processor.tokenizer,
170
- stop_strings=stop_str,
171
- max_new_tokens=4096,
172
- )
173
- res = processor.decode(
174
- generate_ids[0, inputs["input_ids"].shape[1] :],
175
- skip_special_tokens=True,
176
- )
177
- ocr_type = "format"
178
-
179
- render_ocr_text(res, result_path, format_text=ocr_type == "format")
180
- if os.path.exists(result_path):
181
- with open(result_path, "r") as f:
182
- html_content = f.read()
183
- return res, html_content, unique_id
184
- else:
185
- return res, None, unique_id
186
- except Exception as e:
187
- return f"Error: {str(e)}", None, None
188
- finally:
189
- if os.path.exists(image_path):
190
- os.remove(image_path)
191
-
192
-
193
- def update_image_input(task):
194
- if task == "Fine-grained OCR (Color)":
195
- return (
196
- gr.update(visible=False),
197
- gr.update(visible=True),
198
- gr.update(visible=True),
199
- gr.update(visible=False),
200
- gr.update(visible=False),
201
- )
202
- elif task == "Multi-page OCR":
203
- return (
204
- gr.update(visible=False),
205
- gr.update(visible=False),
206
- gr.update(visible=False),
207
- gr.update(visible=True),
208
- gr.update(visible=True),
209
- )
210
- else:
211
- return (
212
- gr.update(visible=True),
213
- gr.update(visible=False),
214
- gr.update(visible=False),
215
- gr.update(visible=False),
216
- gr.update(visible=False),
217
- )
218
-
219
-
220
- def update_inputs(task):
221
- if task in [
222
- "Plain Text OCR",
223
- "Format Text OCR",
224
- "Multi-crop OCR",
225
- ]:
226
- return [
227
- gr.update(visible=False),
228
- gr.update(visible=False),
229
- gr.update(visible=False),
230
- gr.update(visible=True),
231
- gr.update(visible=False),
232
- gr.update(visible=True),
233
- gr.update(visible=False),
234
- gr.update(visible=False),
235
- gr.update(visible=False),
236
- ]
237
- elif task == "Fine-grained OCR (Box)":
238
- return [
239
- gr.update(visible=True, choices=["ocr", "format"]),
240
- gr.update(visible=True),
241
- gr.update(visible=False),
242
- gr.update(visible=True),
243
- gr.update(visible=False),
244
- gr.update(visible=True),
245
- gr.update(visible=False),
246
- gr.update(visible=False),
247
- gr.update(visible=False),
248
- ]
249
- elif task == "Fine-grained OCR (Color)":
250
- return [
251
- gr.update(visible=True, choices=["ocr", "format"]),
252
- gr.update(visible=False),
253
- gr.update(visible=True, choices=["red", "green", "blue"]),
254
- gr.update(visible=False),
255
- gr.update(visible=True),
256
- gr.update(visible=False),
257
- gr.update(visible=True),
258
- gr.update(visible=False),
259
- gr.update(visible=False),
260
- ]
261
- elif task == "Multi-page OCR":
262
- return [
263
- gr.update(visible=False),
264
- gr.update(visible=False),
265
- gr.update(visible=False),
266
- gr.update(visible=False),
267
- gr.update(visible=False),
268
- gr.update(visible=False),
269
- gr.update(visible=False),
270
- gr.update(visible=True),
271
- gr.update(visible=True),
272
- ]
273
-
274
-
275
- def parse_latex_output(res):
276
- # Split the input, preserving newlines and empty lines
277
- lines = re.split(r"(\$\$.*?\$\$)", res, flags=re.DOTALL)
278
- parsed_lines = []
279
- in_latex = False
280
- latex_buffer = []
281
-
282
- for line in lines:
283
- if line == "\n":
284
- if in_latex:
285
- latex_buffer.append(line)
286
- else:
287
- parsed_lines.append(line)
288
- continue
289
-
290
- line = line.strip()
291
-
292
- latex_patterns = [r"\{", r"\}", r"\[", r"\]", r"\\", r"\$", r"_", r"^", r'"']
293
- contains_latex = any(re.search(pattern, line) for pattern in latex_patterns)
294
-
295
- if contains_latex:
296
- if not in_latex:
297
- in_latex = True
298
- latex_buffer = ["$$"]
299
- latex_buffer.append(line)
300
- else:
301
- if in_latex:
302
- latex_buffer.append("$$")
303
- parsed_lines.extend(latex_buffer)
304
- in_latex = False
305
- latex_buffer = []
306
- parsed_lines.append(line)
307
-
308
- if in_latex:
309
- latex_buffer.append("$$")
310
- parsed_lines.extend(latex_buffer)
311
-
312
- return "$$\\$$\n".join(parsed_lines)
313
-
314
-
315
- def ocr_demo(image, task, ocr_type, ocr_box, ocr_color):
316
- res, html_content, unique_id = process_image(
317
- image, task, ocr_type, ocr_box, ocr_color
318
- )
319
-
320
- if isinstance(res, str) and res.startswith("Error:"):
321
- return res, None
322
-
323
- res = res.replace("\\title", "\\title ")
324
- formatted_res = res
325
- # formatted_res = parse_latex_output(res)
326
-
327
- if html_content:
328
- encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
329
- iframe_src = f"data:text/html;base64,{encoded_html}"
330
- iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
331
- download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result_{unique_id}.html">Download Full Result</a>'
332
- return formatted_res, f"{download_link}<br>{iframe}"
333
- return formatted_res, None
334
-
335
-
336
- def cleanup_old_files():
337
- current_time = time.time()
338
- for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
339
- for file_path in Path(folder).glob("*"):
340
- if current_time - file_path.stat().st_mtime > 3600: # 1 hour
341
- file_path.unlink()
342
-
343
-
344
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
345
- gr.Markdown(title)
346
- gr.Markdown(description)
347
-
348
- with gr.Row():
349
- with gr.Column(scale=1):
350
- with gr.Group():
351
- image_input = gr.Image(type="filepath", label="Input Image")
352
- gallery_input = gr.Gallery(
353
- type="filepath", label="Input images", visible=False
354
- )
355
- image_editor = gr.ImageEditor(
356
- label="Image Editor", type="pil", visible=False
357
- )
358
- task_dropdown = gr.Dropdown(
359
- choices=[
360
- "Plain Text OCR",
361
- "Format Text OCR",
362
- "Fine-grained OCR (Box)",
363
- "Fine-grained OCR (Color)",
364
- "Multi-crop OCR",
365
- "Multi-page OCR",
366
- ],
367
- label="Select Task",
368
- value="Plain Text OCR",
369
- )
370
- ocr_type_dropdown = gr.Dropdown(
371
- choices=["ocr", "format"], label="OCR Type", visible=False
372
- )
373
- ocr_box_input = gr.Textbox(
374
- label="OCR Box (x1,y1,x2,y2)",
375
- placeholder="[100,100,200,200]",
376
- visible=False,
377
- )
378
- ocr_color_dropdown = gr.Dropdown(
379
- choices=["red", "green", "blue"], label="OCR Color", visible=False
380
- )
381
- # with gr.Row():
382
- # max_new_tokens_slider = gr.Slider(50, 500, step=10, value=150, label="Max New Tokens")
383
- # no_repeat_ngram_size_slider = gr.Slider(1, 10, step=1, value=2, label="No Repeat N-gram Size")
384
-
385
- submit_button = gr.Button("Process", variant="primary")
386
- editor_submit_button = gr.Button("Process Edited Image", visible=False, variant="primary")
387
- gallery_submit_button = gr.Button(
388
- "Process Multiple Images", visible=False, variant="primary"
389
- )
390
-
391
- with gr.Column(scale=1):
392
- with gr.Group():
393
- output_markdown = gr.Textbox(label="Text output")
394
- output_html = gr.HTML(label="HTML output")
395
-
396
- input_types = [
397
- image_input,
398
- image_editor,
399
- gallery_input,
400
- ]
401
-
402
- task_dropdown.change(
403
- update_inputs,
404
- inputs=[task_dropdown],
405
- outputs=[
406
- ocr_type_dropdown,
407
- ocr_box_input,
408
- ocr_color_dropdown,
409
- image_input,
410
- image_editor,
411
- submit_button,
412
- editor_submit_button,
413
- gallery_input,
414
- gallery_submit_button,
415
- ],
416
- )
417
-
418
- task_dropdown.change(
419
- update_image_input,
420
- inputs=[task_dropdown],
421
- outputs=[
422
- image_input,
423
- image_editor,
424
- editor_submit_button,
425
- gallery_input,
426
- gallery_submit_button,
427
- ],
428
- )
429
-
430
- submit_button.click(
431
- ocr_demo,
432
- inputs=[
433
- image_input,
434
- task_dropdown,
435
- ocr_type_dropdown,
436
- ocr_box_input,
437
- ocr_color_dropdown,
438
- ],
439
- outputs=[output_markdown, output_html],
440
- )
441
- editor_submit_button.click(
442
- ocr_demo,
443
- inputs=[
444
- image_editor,
445
- task_dropdown,
446
- ocr_type_dropdown,
447
- ocr_box_input,
448
- ocr_color_dropdown,
449
- ],
450
- outputs=[output_markdown, output_html],
451
- )
452
- gallery_submit_button.click(
453
- ocr_demo,
454
- inputs=[
455
- gallery_input,
456
- task_dropdown,
457
- ocr_type_dropdown,
458
- ocr_box_input,
459
- ocr_color_dropdown,
460
- ],
461
- outputs=[output_markdown, output_html],
462
- )
463
- example = gr.Examples(
464
- examples=[
465
- [
466
- "./sheet_music.png",
467
- "Format Text OCR",
468
- "format",
469
- None,
470
- None,
471
- ],
472
- [
473
- "./latex.png",
474
- "Format Text OCR",
475
- "format",
476
- None,
477
- None,
478
- ],
479
- ],
480
- inputs=[
481
- image_input,
482
- task_dropdown,
483
- ocr_type_dropdown,
484
- ocr_box_input,
485
- ocr_color_dropdown,
486
- ],
487
- outputs=[output_markdown, output_html],
488
- )
489
- example_finegrained = gr.Examples(
490
- examples=[
491
- [
492
- "./multi_box.png",
493
- "Fine-grained OCR (Color)",
494
- "ocr",
495
- None,
496
- "red",
497
- ]
498
- ],
499
- inputs=[
500
- image_editor,
501
- task_dropdown,
502
- ocr_type_dropdown,
503
- ocr_box_input,
504
- ocr_color_dropdown,
505
- ],
506
- outputs=[output_markdown, output_html],
507
- label="Fine-grained example",
508
- )
509
-
510
- gr.Markdown(
511
- "Space based on [Tonic's GOT-OCR](https://huggingface.co/spaces/Tonic/GOT-OCR)"
512
- )
513
-
514
-
515
- if __name__ == "__main__":
516
- cleanup_old_files()
517
- demo.launch()
 
1
+ import os
2
+ import uuid
3
+ import shutil
4
+
5
+ import cv2
6
+ import gradio as gr
7
+ import numpy as np
8
+ import spaces
9
+ import torch
10
+ from PIL import Image
11
+
12
+ from transformers import AutoModelForImageTextToText, AutoProcessor
13
+ from transformers.image_utils import load_image
14
+
15
+ model_name = "stepfun-ai/GOT-OCR-2.0-hf"
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ processor = AutoProcessor.from_pretrained(model_name)
19
+ model = AutoModelForImageTextToText.from_pretrained(
20
+ model_name, low_cpu_mem_usage=True, device_map=device
21
+ )
22
+ model = model.eval().to(device)
23
+
24
+ UPLOAD_FOLDER = "./uploads"
25
+ stop_str = "<|im_end|>"
26
+
27
+ if not os.path.exists(UPLOAD_FOLDER):
28
+ os.makedirs(UPLOAD_FOLDER)
29
+
30
+
31
+ @spaces.GPU()
32
+ def process_ocr(image):
33
+ if image is None:
34
+ return "⚠️ Please upload an image first"
35
+
36
+ unique_id = str(uuid.uuid4())
37
+ image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
38
+
39
+ try:
40
+ # Handle different image formats
41
+ if isinstance(image, np.ndarray):
42
+ cv2.imwrite(image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
43
+ elif isinstance(image, str):
44
+ shutil.copy(image, image_path)
45
+ else:
46
+ return "⚠️ Unsupported image format"
47
+
48
+ image = load_image(image_path)
49
+
50
+ # Process with OCR
51
+ inputs = processor([image], return_tensors="pt").to(device)
52
+ generate_ids = model.generate(
53
+ **inputs,
54
+ do_sample=False,
55
+ tokenizer=processor.tokenizer,
56
+ stop_strings=stop_str,
57
+ max_new_tokens=4096,
58
+ )
59
+
60
+ result = processor.decode(
61
+ generate_ids[0, inputs["input_ids"].shape[1]:],
62
+ skip_special_tokens=True,
63
+ )
64
+
65
+ return result
66
+
67
+ except Exception as e:
68
+ return f"❌ Error: {str(e)}"
69
+ finally:
70
+ if os.path.exists(image_path):
71
+ os.remove(image_path)
72
+
73
+
74
+ # Custom CSS for modern, minimal design
75
+ custom_css = """
76
+ #header {
77
+ text-align: center;
78
+ padding: 2rem 0;
79
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
80
+ color: white;
81
+ border-radius: 12px;
82
+ margin-bottom: 2rem;
83
+ }
84
+
85
+ #header h1 {
86
+ margin: 0;
87
+ font-size: 2.5rem;
88
+ font-weight: 700;
89
+ letter-spacing: -0.5px;
90
+ }
91
+
92
+ #header p {
93
+ margin: 0.5rem 0 0 0;
94
+ font-size: 1.1rem;
95
+ opacity: 0.95;
96
+ }
97
+
98
+ .main-container {
99
+ max-width: 1200px;
100
+ margin: 0 auto;
101
+ }
102
+
103
+ #image_input {
104
+ border: 2px dashed #667eea !important;
105
+ border-radius: 12px !important;
106
+ transition: all 0.3s ease;
107
+ }
108
+
109
+ #image_input:hover {
110
+ border-color: #764ba2 !important;
111
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.15);
112
+ }
113
+
114
+ #process_btn {
115
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
116
+ border: none !important;
117
+ font-size: 1.1rem !important;
118
+ font-weight: 600 !important;
119
+ padding: 0.75rem 2rem !important;
120
+ border-radius: 8px !important;
121
+ transition: all 0.3s ease !important;
122
+ }
123
+
124
+ #process_btn:hover {
125
+ transform: translateY(-2px);
126
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.3) !important;
127
+ }
128
+
129
+ #output_text {
130
+ border-radius: 12px !important;
131
+ font-family: 'Monaco', 'Courier New', monospace !important;
132
+ font-size: 0.95rem !important;
133
+ line-height: 1.6 !important;
134
+ }
135
+
136
+ .input-section, .output-section {
137
+ background: white;
138
+ padding: 1.5rem;
139
+ border-radius: 12px;
140
+ box-shadow: 0 2px 8px rgba(0,0,0,0.05);
141
+ }
142
+
143
+ footer {
144
+ text-align: center;
145
+ padding: 2rem 0;
146
+ color: #666;
147
+ font-size: 0.9rem;
148
+ }
149
+ """
150
+
151
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
152
+ gr.HTML("""
153
+ <div id="header">
154
+ <h1>✨ GOT-OCR 2.0</h1>
155
+ <p>Extract text from images with AI-powered OCR</p>
156
+ </div>
157
+ """)
158
+
159
+ with gr.Row(elem_classes="main-container"):
160
+ with gr.Column(scale=1, elem_classes="input-section"):
161
+ image_input = gr.Image(
162
+ type="filepath",
163
+ label="📸 Upload Image",
164
+ elem_id="image_input",
165
+ height=400
166
+ )
167
+ process_btn = gr.Button(
168
+ "🚀 Extract Text",
169
+ elem_id="process_btn",
170
+ size="lg"
171
+ )
172
+
173
+ with gr.Column(scale=1, elem_classes="output-section"):
174
+ output_text = gr.Textbox(
175
+ label="📝 Extracted Text",
176
+ elem_id="output_text",
177
+ lines=20,
178
+ placeholder="Your extracted text will appear here...",
179
+ show_copy_button=True
180
+ )
181
+
182
+ gr.HTML("""
183
+ <footer>
184
+ <p>Powered by GOT-OCR-2.0-hf | Built with Gradio</p>
185
+ </footer>
186
+ """)
187
+
188
+ # Connect the button to the processing function
189
+ process_btn.click(
190
+ fn=process_ocr,
191
+ inputs=[image_input],
192
+ outputs=[output_text]
193
+ )
194
+
195
+
196
+ if __name__ == "__main__":
197
+ demo.launch()