phamhieu2001 commited on
Commit
68ea9d7
·
verified ·
1 Parent(s): 1e22c1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +298 -145
app.py CHANGED
@@ -1,154 +1,307 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
- import random
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
 
 
 
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
  }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  ],
150
- outputs=[result, seed],
 
151
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer
3
+ import torch
4
+ import spaces
5
+ import os
6
+ import sys
7
+ import tempfile
8
+ import shutil
9
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
10
+ import fitz
11
+ import re
12
+ import warnings
13
  import numpy as np
14
+ import base64
15
+ from io import StringIO, BytesIO
16
 
17
+ MODEL_NAME = 'deepseek-ai/DeepSeek-OCR'
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
20
+ model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16, trust_remote_code=True, use_safetensors=True)
21
+ model = model.eval().cuda()
22
+
23
+ MODEL_CONFIGS = {
24
+ "Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True},
25
+ "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
26
+ "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
27
+ "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
28
+ "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}
29
+ }
30
 
31
+ TASK_PROMPTS = {
32
+ "📋 Markdown": {"prompt": "<image>\n<|grounding|>Convert the document to markdown.", "has_grounding": True},
33
+ "📝 Free OCR": {"prompt": "<image>\nFree OCR.", "has_grounding": False},
34
+ "📍 Locate": {"prompt": "<image>\nLocate <|ref|>text<|/ref|> in the image.", "has_grounding": True},
35
+ "🔍 Describe": {"prompt": "<image>\nDescribe this image in detail.", "has_grounding": False},
36
+ "✏️ Custom": {"prompt": "", "has_grounding": False}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  }
38
+
39
+ def extract_grounding_references(text):
40
+ pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
41
+ return re.findall(pattern, text, re.DOTALL)
42
+
43
+ def draw_bounding_boxes(image, refs, extract_images=False):
44
+ img_w, img_h = image.size
45
+ img_draw = image.copy()
46
+ draw = ImageDraw.Draw(img_draw)
47
+ overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
48
+ draw2 = ImageDraw.Draw(overlay)
49
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 30)
50
+ crops = []
51
+
52
+ color_map = {}
53
+ np.random.seed(42)
54
+
55
+ for ref in refs:
56
+ label = ref[1]
57
+ if label not in color_map:
58
+ color_map[label] = (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255))
59
+
60
+ color = color_map[label]
61
+ coords = eval(ref[2])
62
+ color_a = color + (60,)
63
+
64
+ for box in coords:
65
+ x1, y1, x2, y2 = int(box[0]/999*img_w), int(box[1]/999*img_h), int(box[2]/999*img_w), int(box[3]/999*img_h)
66
+
67
+ if extract_images and label == 'image':
68
+ crops.append(image.crop((x1, y1, x2, y2)))
69
+
70
+ width = 5 if label == 'title' else 3
71
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
72
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a)
73
+
74
+ text_bbox = draw.textbbox((0, 0), label, font=font)
75
+ tw, th = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
76
+ ty = max(0, y1 - 20)
77
+ draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], fill=color)
78
+ draw.text((x1 + 2, ty + 2), label, font=font, fill=(255, 255, 255))
79
+
80
+ img_draw.paste(overlay, (0, 0), overlay)
81
+ return img_draw, crops
82
+
83
+ def clean_output(text, include_images=False, remove_labels=False):
84
+ if not text:
85
+ return ""
86
+ pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
87
+ matches = re.findall(pattern, text, re.DOTALL)
88
+ img_num = 0
89
+
90
+ for match in matches:
91
+ if '<|ref|>image<|/ref|>' in match[0]:
92
+ if include_images:
93
+ text = text.replace(match[0], f'\n\n**[Figure {img_num + 1}]**\n\n', 1)
94
+ img_num += 1
95
+ else:
96
+ text = text.replace(match[0], '', 1)
97
+ else:
98
+ if remove_labels:
99
+ text = text.replace(match[0], '', 1)
100
+ else:
101
+ text = text.replace(match[0], match[1], 1)
102
+
103
+ return text.strip()
104
+
105
+ def embed_images(markdown, crops):
106
+ if not crops:
107
+ return markdown
108
+ for i, img in enumerate(crops):
109
+ buf = BytesIO()
110
+ img.save(buf, format="PNG")
111
+ b64 = base64.b64encode(buf.getvalue()).decode()
112
+ markdown = markdown.replace(f'**[Figure {i + 1}]**', f'\n\n![Figure {i + 1}](data:image/png;base64,{b64})\n\n', 1)
113
+ return markdown
114
+
115
+ @spaces.GPU(duration=60)
116
+ def process_image(image, mode, task, custom_prompt):
117
+ if image is None:
118
+ return " Error Upload image", "", "", None, []
119
+ if task in ["✏️ Custom", "📍 Locate"] and not custom_prompt.strip():
120
+ return "Enter prompt", "", "", None, []
121
+
122
+ if image.mode in ('RGBA', 'LA', 'P'):
123
+ image = image.convert('RGB')
124
+ image = ImageOps.exif_transpose(image)
125
+
126
+ config = MODEL_CONFIGS[mode]
127
+
128
+ if task == "✏️ Custom":
129
+ prompt = f"<image>\n{custom_prompt.strip()}"
130
+ has_grounding = '<|grounding|>' in custom_prompt
131
+ elif task == "📍 Locate":
132
+ prompt = f"<image>\nLocate <|ref|>{custom_prompt.strip()}<|/ref|> in the image."
133
+ has_grounding = True
134
+ else:
135
+ prompt = TASK_PROMPTS[task]["prompt"]
136
+ has_grounding = TASK_PROMPTS[task]["has_grounding"]
137
+
138
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
139
+ image.save(tmp.name, 'JPEG', quality=95)
140
+ tmp.close()
141
+ out_dir = tempfile.mkdtemp()
142
+
143
+ stdout = sys.stdout
144
+ sys.stdout = StringIO()
145
+
146
+ model.infer(tokenizer=tokenizer, prompt=prompt, image_file=tmp.name, output_path=out_dir,
147
+ base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"])
148
+
149
+ result = '\n'.join([l for l in sys.stdout.getvalue().split('\n')
150
+ if not any(s in l for s in ['image:', 'other:', 'PATCHES', '====', 'BASE:', '%|', 'torch.Size'])]).strip()
151
+ sys.stdout = stdout
152
+
153
+ os.unlink(tmp.name)
154
+ shutil.rmtree(out_dir, ignore_errors=True)
155
+
156
+ if not result:
157
+ return "No text", "", "", None, []
158
+
159
+ cleaned = clean_output(result, False, False)
160
+ markdown = clean_output(result, True, True)
161
+
162
+ img_out = None
163
+ crops = []
164
+
165
+ if has_grounding and '<|ref|>' in result:
166
+ refs = extract_grounding_references(result)
167
+ if refs:
168
+ img_out, crops = draw_bounding_boxes(image, refs, True)
169
+
170
+ markdown = embed_images(markdown, crops)
171
+
172
+ return cleaned, markdown, result, img_out, crops
173
+
174
+ @spaces.GPU(duration=60)
175
+ def process_pdf(path, mode, task, custom_prompt, page_num):
176
+ doc = fitz.open(path)
177
+ total_pages = len(doc)
178
+ if page_num < 1 or page_num > total_pages:
179
+ doc.close()
180
+ return f"Invalid page number. PDF has {total_pages} pages.", "", "", None, []
181
+ page = doc.load_page(page_num - 1)
182
+ pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False)
183
+ img = Image.open(BytesIO(pix.tobytes("png")))
184
+ doc.close()
185
+
186
+ return process_image(img, mode, task, custom_prompt)
187
+
188
+ def process_file(path, mode, task, custom_prompt, page_num):
189
+ if not path:
190
+ return "Error Upload file", "", "", None, []
191
+ if path.lower().endswith('.pdf'):
192
+ return process_pdf(path, mode, task, custom_prompt, page_num)
193
+ else:
194
+ return process_image(Image.open(path), mode, task, custom_prompt)
195
+
196
+ def toggle_prompt(task):
197
+ if task == "✏️ Custom":
198
+ return gr.update(visible=True, label="Custom Prompt", placeholder="Add <|grounding|> for boxes")
199
+ elif task == "📍 Locate":
200
+ return gr.update(visible=True, label="Text to Locate", placeholder="Enter text")
201
+ return gr.update(visible=False)
202
+
203
+ def get_pdf_page_count(file_path):
204
+ if not file_path or not file_path.lower().endswith('.pdf'):
205
+ return 1
206
+ doc = fitz.open(file_path)
207
+ count = len(doc)
208
+ doc.close()
209
+ return count
210
+
211
+ def load_image(file_path, page_num=1):
212
+ if not file_path:
213
+ return None
214
+ if file_path.lower().endswith('.pdf'):
215
+ doc = fitz.open(file_path)
216
+ page_idx = max(0, min(int(page_num) - 1, len(doc) - 1))
217
+ page = doc.load_page(page_idx)
218
+ pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False)
219
+ img = Image.open(BytesIO(pix.tobytes("png")))
220
+ doc.close()
221
+ return img
222
+ else:
223
+ return Image.open(file_path)
224
+
225
+ def update_page_selector(file_path):
226
+ if not file_path:
227
+ return gr.update(visible=False)
228
+ if file_path.lower().endswith('.pdf'):
229
+ page_count = get_pdf_page_count(file_path)
230
+ return gr.update(visible=True, maximum=page_count, value=1, minimum=1,
231
+ label=f"Select Page (1-{page_count})")
232
+ return gr.update(visible=False)
233
+
234
+ with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek-OCR") as demo:
235
+ gr.Markdown("""
236
+ # 🚀 DeepSeek-OCR Demo
237
+ **Convert documents to markdown, extract raw text, and locate specific content with bounding boxes. It takes 20~ sec for markdown and 3~ sec for locate task examples. Check the info at the bottom of the page for more information.**
238
+
239
+ **Hope this tool was helpful! If so, a quick like ❤️ would mean a lot :)**
240
+ """)
241
+
242
+ with gr.Row():
243
+ with gr.Column(scale=1):
244
+ file_in = gr.File(label="Upload Image or PDF", file_types=["image", ".pdf"], type="filepath")
245
+ input_img = gr.Image(label="Input Image", type="pil", height=300)
246
+ page_selector = gr.Number(label="Select Page", value=1, minimum=1, step=1, visible=False)
247
+ mode = gr.Dropdown(list(MODEL_CONFIGS.keys()), value="Gundam", label="Mode")
248
+ task = gr.Dropdown(list(TASK_PROMPTS.keys()), value="📋 Markdown", label="Task")
249
+ prompt = gr.Textbox(label="Prompt", lines=2, visible=False)
250
+ btn = gr.Button("Extract", variant="primary", size="lg")
251
+
252
+ with gr.Column(scale=2):
253
+ with gr.Tabs():
254
+ with gr.Tab("📝 Text"):
255
+ text_out = gr.Textbox(lines=20, show_copy_button=True, show_label=False)
256
+ with gr.Tab("🎨 Markdown"):
257
+ md_out = gr.Markdown("")
258
+ with gr.Tab("🖼️ Boxes"):
259
+ img_out = gr.Image(type="pil", height=500, show_label=False)
260
+ with gr.Tab("🖼️ Cropped Images"):
261
+ gallery = gr.Gallery(show_label=False, columns=3, height=400)
262
+ with gr.Tab("🔍 Raw"):
263
+ raw_out = gr.Textbox(lines=20, show_copy_button=True, show_label=False)
264
+
265
+ gr.Examples(
266
+ examples=[
267
+ ["examples/ocr.jpg", "Gundam", "📋 Markdown", ""],
268
+ ["examples/reachy-mini.jpg", "Gundam", "📍 Locate", "Robot"]
269
  ],
270
+ inputs=[input_img, mode, task, prompt],
271
+ cache_examples=False
272
  )
273
+
274
+ with gr.Accordion("ℹ️ Info", open=False):
275
+ gr.Markdown("""
276
+ ### Modes
277
+ - **Gundam**: 1024 base + 640 tiles with cropping - Best balance
278
+ - **Tiny**: 512×512, no crop - Fastest
279
+ - **Small**: 640×640, no crop - Quick
280
+ - **Base**: 1024×1024, no crop - Standard
281
+ - **Large**: 1280×1280, no crop - Highest quality
282
+
283
+ ### Tasks
284
+ - **Markdown**: Convert document to structured markdown (grounding ✅)
285
+ - **Free OCR**: Simple text extraction
286
+ - **Locate**: Find specific things in image (grounding ✅)
287
+ - **Describe**: General image description
288
+ - **Custom**: Your own prompt (add `<|grounding|>` for boxes)
289
+ """)
290
+
291
+ file_in.change(load_image, [file_in, page_selector], [input_img])
292
+ file_in.change(update_page_selector, [file_in], [page_selector])
293
+ page_selector.change(load_image, [file_in, page_selector], [input_img])
294
+ task.change(toggle_prompt, [task], [prompt])
295
+
296
+ def run(image, file_path, mode, task, custom_prompt, page_num):
297
+ if file_path:
298
+ return process_file(file_path, mode, task, custom_prompt, int(page_num))
299
+ if image is not None:
300
+ return process_image(image, mode, task, custom_prompt)
301
+ return "Error uploading file or image", "", "", None, []
302
+
303
+ btn.click(run, [input_img, file_in, mode, task, prompt, page_selector],
304
+ [text_out, md_out, raw_out, img_out, gallery])
305
 
306
  if __name__ == "__main__":
307
+ demo.queue(max_size=20).launch()