Irfaniiioo commited on
Commit
c8b6d1e
Β·
verified Β·
1 Parent(s): d5b0f6f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +311 -0
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer
3
+ import torch
4
+ import spaces
5
+ import os
6
+ import sys
7
+ import tempfile
8
+ import shutil
9
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
10
+ import fitz
11
+ import re
12
+ import warnings
13
+ import numpy as np
14
+ import base64
15
+ from io import StringIO, BytesIO
16
+
17
+ MODEL_NAME = 'deepseek-ai/DeepSeek-OCR'
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
20
+ model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16, trust_remote_code=True, use_safetensors=True)
21
+ model = model.eval().cuda()
22
+
23
+ MODEL_CONFIGS = {
24
+ "Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True},
25
+ "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
26
+ "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
27
+ "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
28
+ "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}
29
+ }
30
+
31
+ TASK_PROMPTS = {
32
+ "πŸ“‹ Markdown": {"prompt": "<image>\n<|grounding|>Convert the document to markdown.", "has_grounding": True},
33
+ "πŸ“ Free OCR": {"prompt": "<image>\nFree OCR.", "has_grounding": False},
34
+ "πŸ“ Locate": {"prompt": "<image>\nLocate <|ref|>text<|/ref|> in the image.", "has_grounding": True},
35
+ "πŸ” Describe": {"prompt": "<image>\nDescribe this image in detail.", "has_grounding": False},
36
+ "✏️ Custom": {"prompt": "", "has_grounding": False}
37
+ }
38
+
39
+ def extract_grounding_references(text):
40
+ pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
41
+ return re.findall(pattern, text, re.DOTALL)
42
+
43
+ def draw_bounding_boxes(image, refs, extract_images=False):
44
+ img_w, img_h = image.size
45
+ img_draw = image.copy()
46
+ draw = ImageDraw.Draw(img_draw)
47
+ overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
48
+ draw2 = ImageDraw.Draw(overlay)
49
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 30)
50
+ crops = []
51
+
52
+ color_map = {}
53
+ np.random.seed(42)
54
+
55
+ for ref in refs:
56
+ label = ref[1]
57
+ if label not in color_map:
58
+ color_map[label] = (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255))
59
+
60
+ color = color_map[label]
61
+ coords = eval(ref[2])
62
+ color_a = color + (60,)
63
+
64
+ for box in coords:
65
+ x1, y1, x2, y2 = int(box[0]/999*img_w), int(box[1]/999*img_h), int(box[2]/999*img_w), int(box[3]/999*img_h)
66
+
67
+ if extract_images and label == 'image':
68
+ crops.append(image.crop((x1, y1, x2, y2)))
69
+
70
+ width = 5 if label == 'title' else 3
71
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
72
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a)
73
+
74
+ text_bbox = draw.textbbox((0, 0), label, font=font)
75
+ tw, th = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
76
+ ty = max(0, y1 - 20)
77
+ draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], fill=color)
78
+ draw.text((x1 + 2, ty + 2), label, font=font, fill=(255, 255, 255))
79
+
80
+ img_draw.paste(overlay, (0, 0), overlay)
81
+ return img_draw, crops
82
+
83
+ def clean_output(text, include_images=False):
84
+ if not text:
85
+ return ""
86
+ pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
87
+ matches = re.findall(pattern, text, re.DOTALL)
88
+ img_num = 0
89
+
90
+ for match in matches:
91
+ if '<|ref|>image<|/ref|>' in match[0]:
92
+ if include_images:
93
+ text = text.replace(match[0], f'\n\n**[Figure {img_num + 1}]**\n\n', 1)
94
+ img_num += 1
95
+ else:
96
+ text = text.replace(match[0], '', 1)
97
+ else:
98
+ text = re.sub(rf'(?m)^[^\n]*{re.escape(match[0])}[^\n]*\n?', '', text)
99
+
100
+ return text.strip()
101
+
102
+ def embed_images(markdown, crops):
103
+ if not crops:
104
+ return markdown
105
+ for i, img in enumerate(crops):
106
+ buf = BytesIO()
107
+ img.save(buf, format="PNG")
108
+ b64 = base64.b64encode(buf.getvalue()).decode()
109
+ markdown = markdown.replace(f'**[Figure {i + 1}]**', f'\n\n![Figure {i + 1}](data:image/png;base64,{b64})\n\n', 1)
110
+ return markdown
111
+
112
+ @spaces.GPU(duration=60)
113
+ def process_image(image, mode, task, custom_prompt):
114
+ if image is None:
115
+ return " Error Upload image", "", "", None, []
116
+ if task in ["✏️ Custom", "πŸ“ Locate"] and not custom_prompt.strip():
117
+ return "Enter prompt", "", "", None, []
118
+
119
+ if image.mode in ('RGBA', 'LA', 'P'):
120
+ image = image.convert('RGB')
121
+ image = ImageOps.exif_transpose(image)
122
+
123
+ config = MODEL_CONFIGS[mode]
124
+
125
+ if task == "✏️ Custom":
126
+ prompt = f"<image>\n{custom_prompt.strip()}"
127
+ has_grounding = '<|grounding|>' in custom_prompt
128
+ elif task == "πŸ“ Locate":
129
+ prompt = f"<image>\nLocate <|ref|>{custom_prompt.strip()}<|/ref|> in the image."
130
+ has_grounding = True
131
+ else:
132
+ prompt = TASK_PROMPTS[task]["prompt"]
133
+ has_grounding = TASK_PROMPTS[task]["has_grounding"]
134
+
135
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
136
+ image.save(tmp.name, 'JPEG', quality=95)
137
+ tmp.close()
138
+ out_dir = tempfile.mkdtemp()
139
+
140
+ stdout = sys.stdout
141
+ sys.stdout = StringIO()
142
+
143
+ model.infer(tokenizer=tokenizer, prompt=prompt, image_file=tmp.name, output_path=out_dir,
144
+ base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"])
145
+
146
+ result = '\n'.join([l for l in sys.stdout.getvalue().split('\n')
147
+ if not any(s in l for s in ['image:', 'other:', 'PATCHES', '====', 'BASE:', '%|', 'torch.Size'])]).strip()
148
+ sys.stdout = stdout
149
+
150
+ os.unlink(tmp.name)
151
+ shutil.rmtree(out_dir, ignore_errors=True)
152
+
153
+ if not result:
154
+ return "No text", "", "", None, []
155
+
156
+ cleaned = clean_output(result, False)
157
+ markdown = clean_output(result, True)
158
+
159
+ img_out = None
160
+ crops = []
161
+
162
+ if has_grounding and '<|ref|>' in result:
163
+ refs = extract_grounding_references(result)
164
+ if refs:
165
+ img_out, crops = draw_bounding_boxes(image, refs, True)
166
+
167
+ markdown = embed_images(markdown, crops)
168
+
169
+ return cleaned, markdown, result, img_out, crops
170
+
171
+ @spaces.GPU(duration=60)
172
+ def process_pdf(path, mode, task, custom_prompt, page_num):
173
+ doc = fitz.open(path)
174
+ total_pages = len(doc)
175
+ if page_num < 1 or page_num > total_pages:
176
+ doc.close()
177
+ return f"Invalid page number. PDF has {total_pages} pages.", "", "", None, []
178
+ page = doc.load_page(page_num - 1)
179
+ pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False)
180
+ img = Image.open(BytesIO(pix.tobytes("png")))
181
+ doc.close()
182
+
183
+ return process_image(img, mode, task, custom_prompt)
184
+
185
+ def process_file(path, mode, task, custom_prompt, page_num):
186
+ if not path:
187
+ return "Error Upload file", "", "", None, []
188
+ if path.lower().endswith('.pdf'):
189
+ return process_pdf(path, mode, task, custom_prompt, page_num)
190
+ else:
191
+ return process_image(Image.open(path), mode, task, custom_prompt)
192
+
193
+ def toggle_prompt(task):
194
+ if task == "✏️ Custom":
195
+ return gr.update(visible=True, label="Custom Prompt", placeholder="Add <|grounding|> for boxes")
196
+ elif task == "πŸ“ Locate":
197
+ return gr.update(visible=True, label="Text to Locate", placeholder="Enter text")
198
+ return gr.update(visible=False)
199
+
200
+ def select_boxes(task):
201
+ if task == "πŸ“ Locate":
202
+ return gr.update(selected="tab_boxes")
203
+ return gr.update()
204
+
205
+ def get_pdf_page_count(file_path):
206
+ if not file_path or not file_path.lower().endswith('.pdf'):
207
+ return 1
208
+ doc = fitz.open(file_path)
209
+ count = len(doc)
210
+ doc.close()
211
+ return count
212
+
213
+ def load_image(file_path, page_num=1):
214
+ if not file_path:
215
+ return None
216
+ if file_path.lower().endswith('.pdf'):
217
+ doc = fitz.open(file_path)
218
+ page_idx = max(0, min(int(page_num) - 1, len(doc) - 1))
219
+ page = doc.load_page(page_idx)
220
+ pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False)
221
+ img = Image.open(BytesIO(pix.tobytes("png")))
222
+ doc.close()
223
+ return img
224
+ else:
225
+ return Image.open(file_path)
226
+
227
+ def update_page_selector(file_path):
228
+ if not file_path:
229
+ return gr.update(visible=False)
230
+ if file_path.lower().endswith('.pdf'):
231
+ page_count = get_pdf_page_count(file_path)
232
+ return gr.update(visible=True, maximum=page_count, value=1, minimum=1,
233
+ label=f"Select Page (1-{page_count})")
234
+ return gr.update(visible=False)
235
+
236
+ with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek-OCR") as demo:
237
+ gr.Markdown("""
238
+ # πŸš€ DeepSeek-OCR Demo
239
+ **Convert documents to markdown, extract raw text, and locate specific content with bounding boxes. It takes 20~ sec for markdown and 3~ sec for locate task examples. Check the info at the bottom of the page for more information.**
240
+
241
+ **Hope this tool was helpful! If so, a quick like ❀️ would mean a lot :)**
242
+ """)
243
+
244
+ with gr.Row():
245
+ with gr.Column(scale=1):
246
+ file_in = gr.File(label="Upload Image or PDF", file_types=["image", ".pdf"], type="filepath")
247
+ input_img = gr.Image(label="Input Image", type="pil", height=300)
248
+ page_selector = gr.Number(label="Select Page", value=1, minimum=1, step=1, visible=False)
249
+ mode = gr.Dropdown(list(MODEL_CONFIGS.keys()), value="Gundam", label="Mode")
250
+ task = gr.Dropdown(list(TASK_PROMPTS.keys()), value="πŸ“‹ Markdown", label="Task")
251
+ prompt = gr.Textbox(label="Prompt", lines=2, visible=False)
252
+ btn = gr.Button("Extract", variant="primary", size="lg")
253
+
254
+ with gr.Column(scale=2):
255
+ with gr.Tabs() as tabs:
256
+ with gr.Tab("Text", id="tab_text"):
257
+ text_out = gr.Textbox(lines=20, show_copy_button=True, show_label=False)
258
+ with gr.Tab("Markdown Preview", id="tab_markdown"):
259
+ md_out = gr.Markdown("")
260
+ with gr.Tab("Boxes", id="tab_boxes"):
261
+ img_out = gr.Image(type="pil", height=500, show_label=False)
262
+ with gr.Tab("Cropped Images", id="tab_crops"):
263
+ gallery = gr.Gallery(show_label=False, columns=3, height=400)
264
+ with gr.Tab("Raw Text", id="tab_raw"):
265
+ raw_out = gr.Textbox(lines=20, show_copy_button=True, show_label=False)
266
+
267
+ gr.Examples(
268
+ examples=[
269
+ ["examples/ocr.jpg", "Gundam", "πŸ“‹ Markdown", ""],
270
+ ["examples/reachy-mini.jpg", "Gundam", "πŸ“ Locate", "Robot"]
271
+ ],
272
+ inputs=[input_img, mode, task, prompt],
273
+ cache_examples=False
274
+ )
275
+
276
+ with gr.Accordion("ℹ️ Info", open=False):
277
+ gr.Markdown("""
278
+ ### Modes
279
+ - **Gundam**: 1024 base + 640 tiles with cropping - Best balance
280
+ - **Tiny**: 512Γ—512, no crop - Fastest
281
+ - **Small**: 640Γ—640, no crop - Quick
282
+ - **Base**: 1024Γ—1024, no crop - Standard
283
+ - **Large**: 1280Γ—1280, no crop - Highest quality
284
+
285
+ ### Tasks
286
+ - **Markdown**: Convert document to structured markdown (grounding βœ…)
287
+ - **Free OCR**: Simple text extraction
288
+ - **Locate**: Find specific things in image (grounding βœ…)
289
+ - **Describe**: General image description
290
+ - **Custom**: Your own prompt (add `<|grounding|>` for boxes)
291
+ """)
292
+
293
+ file_in.change(load_image, [file_in, page_selector], [input_img])
294
+ file_in.change(update_page_selector, [file_in], [page_selector])
295
+ page_selector.change(load_image, [file_in, page_selector], [input_img])
296
+ task.change(toggle_prompt, [task], [prompt])
297
+ task.change(select_boxes, [task], [tabs])
298
+
299
+ def run(image, file_path, mode, task, custom_prompt, page_num):
300
+ if file_path:
301
+ return process_file(file_path, mode, task, custom_prompt, int(page_num))
302
+ if image is not None:
303
+ return process_image(image, mode, task, custom_prompt)
304
+ return "Error uploading file or image", "", "", None, []
305
+
306
+ submit_event = btn.click(run, [input_img, file_in, mode, task, prompt, page_selector],
307
+ [text_out, md_out, raw_out, img_out, gallery])
308
+ submit_event.then(select_boxes, [task], [tabs])
309
+
310
+ if __name__ == "__main__":
311
+ demo.queue(max_size=20).launch()