UnMelow commited on
Commit
4ce4fa4
Β·
verified Β·
1 Parent(s): 9609f9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +435 -119
app.py CHANGED
@@ -1,256 +1,539 @@
1
- import gradio as gr
2
- from transformers import AutoModel, AutoTokenizer
3
- import torch
4
- import spaces
5
  import os
6
  import sys
7
- import tempfile
8
- import shutil
9
- from PIL import Image, ImageDraw, ImageFont, ImageOps
10
- import fitz
11
  import re
 
 
12
  import warnings
13
- import numpy as np
14
  import base64
15
  from io import StringIO, BytesIO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- MODEL_NAME = 'deepseek-ai/DeepSeek-OCR'
 
 
 
 
 
 
 
 
18
 
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
20
- model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16, trust_remote_code=True, use_safetensors=True)
21
- model = model.eval().cuda()
 
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  MODEL_CONFIGS = {
24
  "Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True},
25
  "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
26
  "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
27
  "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
28
- "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}
29
  }
30
 
31
  TASK_PROMPTS = {
32
- "πŸ“‹ Markdown": {"prompt": "<image>\n<|grounding|>Convert the document to markdown.", "has_grounding": True},
33
- "πŸ“ Free OCR": {"prompt": "<image>\nFree OCR.", "has_grounding": False},
34
- "πŸ“ Locate": {"prompt": "<image>\nLocate <|ref|>text<|/ref|> in the image.", "has_grounding": True},
35
- "πŸ” Describe": {"prompt": "<image>\nDescribe this image in detail.", "has_grounding": False},
36
- "✏️ Custom": {"prompt": "", "has_grounding": False}
 
 
 
 
 
 
 
 
 
37
  }
38
 
39
- def extract_grounding_references(text):
40
- pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  return re.findall(pattern, text, re.DOTALL)
42
 
43
- def draw_bounding_boxes(image, refs, extract_images=False):
 
44
  img_w, img_h = image.size
45
  img_draw = image.copy()
46
  draw = ImageDraw.Draw(img_draw)
47
- overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
48
  draw2 = ImageDraw.Draw(overlay)
49
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 30)
50
  crops = []
51
-
52
  color_map = {}
53
  np.random.seed(42)
54
 
55
  for ref in refs:
56
  label = ref[1]
57
  if label not in color_map:
58
- color_map[label] = (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255))
 
 
 
 
59
 
60
  color = color_map[label]
61
- coords = eval(ref[2])
 
 
 
 
62
  color_a = color + (60,)
63
-
64
  for box in coords:
65
- x1, y1, x2, y2 = int(box[0]/999*img_w), int(box[1]/999*img_h), int(box[2]/999*img_w), int(box[3]/999*img_h)
66
-
67
- if extract_images and label == 'image':
 
 
 
 
 
68
  crops.append(image.crop((x1, y1, x2, y2)))
69
-
70
- width = 5 if label == 'title' else 3
71
  draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
72
  draw2.rectangle([x1, y1, x2, y2], fill=color_a)
73
-
74
  text_bbox = draw.textbbox((0, 0), label, font=font)
75
  tw, th = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
76
  ty = max(0, y1 - 20)
77
  draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], fill=color)
78
  draw.text((x1 + 2, ty + 2), label, font=font, fill=(255, 255, 255))
79
-
80
  img_draw.paste(overlay, (0, 0), overlay)
81
  return img_draw, crops
82
 
83
- def clean_output(text, include_images=False):
 
84
  if not text:
85
  return ""
86
- pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
87
  matches = re.findall(pattern, text, re.DOTALL)
88
  img_num = 0
89
-
90
  for match in matches:
91
- if '<|ref|>image<|/ref|>' in match[0]:
92
  if include_images:
93
- text = text.replace(match[0], f'\n\n**[Figure {img_num + 1}]**\n\n', 1)
94
  img_num += 1
95
  else:
96
- text = text.replace(match[0], '', 1)
97
  else:
98
- text = re.sub(rf'(?m)^[^\n]*{re.escape(match[0])}[^\n]*\n?', '', text)
99
-
100
  return text.strip()
101
 
102
- def embed_images(markdown, crops):
 
103
  if not crops:
104
  return markdown
105
  for i, img in enumerate(crops):
106
  buf = BytesIO()
107
  img.save(buf, format="PNG")
108
  b64 = base64.b64encode(buf.getvalue()).decode()
109
- markdown = markdown.replace(f'**[Figure {i + 1}]**', f'\n\n![Figure {i + 1}](data:image/png;base64,{b64})\n\n', 1)
 
 
 
 
110
  return markdown
111
 
112
- @spaces.GPU(duration=60)
113
- def process_image(image, mode, task, custom_prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  if image is None:
115
- return " Error Upload image", "", "", None, []
 
116
  if task in ["✏️ Custom", "πŸ“ Locate"] and not custom_prompt.strip():
117
- return "Enter prompt", "", "", None, []
118
-
119
- if image.mode in ('RGBA', 'LA', 'P'):
120
- image = image.convert('RGB')
121
  image = ImageOps.exif_transpose(image)
122
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  config = MODEL_CONFIGS[mode]
124
-
125
  if task == "✏️ Custom":
126
  prompt = f"<image>\n{custom_prompt.strip()}"
127
- has_grounding = '<|grounding|>' in custom_prompt
128
  elif task == "πŸ“ Locate":
129
  prompt = f"<image>\nLocate <|ref|>{custom_prompt.strip()}<|/ref|> in the image."
130
  has_grounding = True
131
  else:
132
  prompt = TASK_PROMPTS[task]["prompt"]
133
  has_grounding = TASK_PROMPTS[task]["has_grounding"]
134
-
135
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
136
- image.save(tmp.name, 'JPEG', quality=95)
137
  tmp.close()
138
  out_dir = tempfile.mkdtemp()
139
-
140
  stdout = sys.stdout
141
  sys.stdout = StringIO()
142
-
143
- model.infer(tokenizer=tokenizer, prompt=prompt, image_file=tmp.name, output_path=out_dir,
144
- base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"])
145
-
146
- result = '\n'.join([l for l in sys.stdout.getvalue().split('\n')
147
- if not any(s in l for s in ['image:', 'other:', 'PATCHES', '====', 'BASE:', '%|', 'torch.Size'])]).strip()
148
- sys.stdout = stdout
149
-
150
- os.unlink(tmp.name)
151
- shutil.rmtree(out_dir, ignore_errors=True)
152
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  if not result:
154
  return "No text", "", "", None, []
155
-
156
- cleaned = clean_output(result, False)
157
- markdown = clean_output(result, True)
158
-
159
  img_out = None
160
  crops = []
161
-
162
- if has_grounding and '<|ref|>' in result:
163
  refs = extract_grounding_references(result)
164
  if refs:
165
- img_out, crops = draw_bounding_boxes(image, refs, True)
166
-
167
  markdown = embed_images(markdown, crops)
168
-
169
  return cleaned, markdown, result, img_out, crops
170
 
171
- @spaces.GPU(duration=60)
172
- def process_pdf(path, mode, task, custom_prompt, page_num):
 
173
  doc = fitz.open(path)
174
  total_pages = len(doc)
175
  if page_num < 1 or page_num > total_pages:
176
  doc.close()
177
  return f"Invalid page number. PDF has {total_pages} pages.", "", "", None, []
178
  page = doc.load_page(page_num - 1)
179
- pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False)
180
  img = Image.open(BytesIO(pix.tobytes("png")))
181
  doc.close()
182
-
183
  return process_image(img, mode, task, custom_prompt)
184
 
185
- def process_file(path, mode, task, custom_prompt, page_num):
 
186
  if not path:
187
- return "Error Upload file", "", "", None, []
188
- if path.lower().endswith('.pdf'):
189
  return process_pdf(path, mode, task, custom_prompt, page_num)
190
- else:
191
- return process_image(Image.open(path), mode, task, custom_prompt)
192
 
193
- def toggle_prompt(task):
194
  if task == "✏️ Custom":
195
  return gr.update(visible=True, label="Custom Prompt", placeholder="Add <|grounding|> for boxes")
196
- elif task == "πŸ“ Locate":
197
  return gr.update(visible=True, label="Text to Locate", placeholder="Enter text")
198
  return gr.update(visible=False)
199
 
200
- def select_boxes(task):
 
201
  if task == "πŸ“ Locate":
202
  return gr.update(selected="tab_boxes")
203
  return gr.update()
204
 
205
- def get_pdf_page_count(file_path):
206
- if not file_path or not file_path.lower().endswith('.pdf'):
 
207
  return 1
208
  doc = fitz.open(file_path)
209
  count = len(doc)
210
  doc.close()
211
  return count
212
 
213
- def load_image(file_path, page_num=1):
 
214
  if not file_path:
215
  return None
216
- if file_path.lower().endswith('.pdf'):
217
  doc = fitz.open(file_path)
218
  page_idx = max(0, min(int(page_num) - 1, len(doc) - 1))
219
  page = doc.load_page(page_idx)
220
- pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False)
221
  img = Image.open(BytesIO(pix.tobytes("png")))
222
  doc.close()
223
  return img
224
- else:
225
- return Image.open(file_path)
226
 
227
- def update_page_selector(file_path):
228
  if not file_path:
229
  return gr.update(visible=False)
230
- if file_path.lower().endswith('.pdf'):
231
  page_count = get_pdf_page_count(file_path)
232
- return gr.update(visible=True, maximum=page_count, value=1, minimum=1,
233
- label=f"Select Page (1-{page_count})")
 
 
 
 
 
234
  return gr.update(visible=False)
235
 
236
- with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek-OCR") as demo:
237
- gr.Markdown("""
238
- # πŸš€ DeepSeek-OCR Demo
239
- **Convert documents to markdown, extract raw text, and locate specific content with bounding boxes. It takes 20~ sec for markdown and 3~ sec for locate task examples. Check the info at the bottom of the page for more information.**
240
-
241
- **Hope this tool was helpful! If so, a quick like ❀️ would mean a lot :)**
242
- """)
243
-
 
 
 
 
 
 
 
 
 
 
 
244
  with gr.Row():
245
  with gr.Column(scale=1):
246
  file_in = gr.File(label="Upload Image or PDF", file_types=["image", ".pdf"], type="filepath")
247
  input_img = gr.Image(label="Input Image", type="pil", height=300)
248
  page_selector = gr.Number(label="Select Page", value=1, minimum=1, step=1, visible=False)
 
249
  mode = gr.Dropdown(list(MODEL_CONFIGS.keys()), value="Gundam", label="Mode")
250
  task = gr.Dropdown(list(TASK_PROMPTS.keys()), value="πŸ“‹ Markdown", label="Task")
251
  prompt = gr.Textbox(label="Prompt", lines=2, visible=False)
 
252
  btn = gr.Button("Extract", variant="primary", size="lg")
253
-
254
  with gr.Column(scale=2):
255
  with gr.Tabs() as tabs:
256
  with gr.Tab("Text", id="tab_text"):
@@ -263,25 +546,58 @@ with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek-OCR") as demo:
263
  gallery = gr.Gallery(show_label=False, columns=3, height=400)
264
  with gr.Tab("Raw Text", id="tab_raw"):
265
  raw_out = gr.Textbox(lines=20, show_copy_button=True, show_label=False)
266
-
267
-
268
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  file_in.change(load_image, [file_in, page_selector], [input_img])
270
  file_in.change(update_page_selector, [file_in], [page_selector])
271
  page_selector.change(load_image, [file_in, page_selector], [input_img])
 
 
272
  task.change(toggle_prompt, [task], [prompt])
273
  task.change(select_boxes, [task], [tabs])
274
-
275
  def run(image, file_path, mode, task, custom_prompt, page_num):
276
  if file_path:
277
  return process_file(file_path, mode, task, custom_prompt, int(page_num))
278
  if image is not None:
279
  return process_image(image, mode, task, custom_prompt)
280
- return "Error uploading file or image", "", "", None, []
281
 
282
- submit_event = btn.click(run, [input_img, file_in, mode, task, prompt, page_selector],
283
- [text_out, md_out, raw_out, img_out, gallery])
 
 
 
284
  submit_event.then(select_boxes, [task], [tabs])
285
 
286
  if __name__ == "__main__":
287
- demo.queue(max_size=20).launch()
 
 
 
 
 
1
  import os
2
  import sys
 
 
 
 
3
  import re
4
+ import shutil
5
+ import tempfile
6
  import warnings
 
7
  import base64
8
  from io import StringIO, BytesIO
9
+ from typing import List, Tuple
10
+
11
+ import gradio as gr
12
+ import torch
13
+ import numpy as np
14
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
15
+ import fitz # PyMuPDF
16
+
17
+ from transformers import (
18
+ AutoModel,
19
+ AutoTokenizer,
20
+ AutoProcessor,
21
+ VisionEncoderDecoderModel,
22
+ BlipProcessor,
23
+ BlipForConditionalGeneration,
24
+ )
25
+
26
+ # --- Optional HF Spaces GPU decorator (safe fallback for local runs) ---
27
+ try:
28
+ import spaces # type: ignore
29
+
30
+ gpu_decorator = spaces.GPU
31
+ except Exception:
32
+ def gpu_decorator(*args, **kwargs):
33
+ def wrap(fn):
34
+ return fn
35
+ return wrap
36
+
37
+
38
+ # =========================
39
+ # Device / dtype utilities
40
+ # =========================
41
+ def get_device() -> str:
42
+ return "cuda" if torch.cuda.is_available() else "cpu"
43
+
44
+
45
+ def get_cuda_dtype() -> torch.dtype:
46
+ # bf16 only on supported GPUs (Ampere+). Otherwise fp16.
47
+ try:
48
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
49
+ return torch.bfloat16
50
+ except Exception:
51
+ pass
52
+ return torch.float16
53
+
54
+
55
+ DEVICE = get_device()
56
+ CUDA_DTYPE = get_cuda_dtype() if DEVICE == "cuda" else torch.float32
57
+
58
+
59
+ # =========================
60
+ # Model names
61
+ # =========================
62
+ DEEPSEEK_OCR_NAME = os.getenv("DEEPSEEK_OCR_MODEL", "deepseek-ai/DeepSeek-OCR")
63
+ # Optional pin to a specific revision/commit to avoid auto-updating remote code.
64
+ DEEPSEEK_OCR_REVISION = os.getenv("DEEPSEEK_OCR_REVISION", None)
65
+
66
+ TROCR_NAME = os.getenv("TROCR_MODEL", "microsoft/trocr-base-printed")
67
+ BLIP_NAME = os.getenv("BLIP_MODEL", "Salesforce/blip-image-captioning-base")
68
+
69
 
70
+ # =========================
71
+ # Load DeepSeek-OCR safely
72
+ # =========================
73
+ def load_deepseek_ocr():
74
+ tokenizer = AutoTokenizer.from_pretrained(
75
+ DEEPSEEK_OCR_NAME,
76
+ trust_remote_code=True,
77
+ revision=DEEPSEEK_OCR_REVISION,
78
+ )
79
 
80
+ base_kwargs = dict(
81
+ trust_remote_code=True,
82
+ use_safetensors=True,
83
+ revision=DEEPSEEK_OCR_REVISION,
84
+ )
85
 
86
+ # IMPORTANT:
87
+ # - Do NOT force flash_attention_2 on CPU.
88
+ # - On CUDA: try flash_attention_2, but gracefully fallback if unavailable.
89
+ if DEVICE == "cuda":
90
+ # Try FlashAttention2 first
91
+ try:
92
+ model = AutoModel.from_pretrained(
93
+ DEEPSEEK_OCR_NAME,
94
+ torch_dtype=CUDA_DTYPE,
95
+ _attn_implementation="flash_attention_2",
96
+ **base_kwargs,
97
+ )
98
+ except Exception as e:
99
+ warnings.warn(
100
+ f"FlashAttention2 unavailable or failed ({e}). Falling back to SDPA/eager."
101
+ )
102
+ # Try SDPA
103
+ try:
104
+ model = AutoModel.from_pretrained(
105
+ DEEPSEEK_OCR_NAME,
106
+ torch_dtype=CUDA_DTYPE,
107
+ _attn_implementation="sdpa",
108
+ **base_kwargs,
109
+ )
110
+ except Exception:
111
+ # Final fallback
112
+ model = AutoModel.from_pretrained(
113
+ DEEPSEEK_OCR_NAME,
114
+ torch_dtype=CUDA_DTYPE,
115
+ _attn_implementation="eager",
116
+ **base_kwargs,
117
+ )
118
+
119
+ model = model.eval().to(DEVICE)
120
+
121
+ else:
122
+ # CPU path: no flash attention, use float32 for stability
123
+ model = AutoModel.from_pretrained(
124
+ DEEPSEEK_OCR_NAME,
125
+ torch_dtype=torch.float32,
126
+ _attn_implementation="eager",
127
+ **base_kwargs,
128
+ )
129
+ model = model.eval().to(DEVICE)
130
+
131
+ return tokenizer, model
132
+
133
+
134
+ tokenizer, deepseek_model = load_deepseek_ocr()
135
+
136
+
137
+ # =========================
138
+ # Load TrOCR and BLIP
139
+ # =========================
140
+ def load_trocr():
141
+ processor = AutoProcessor.from_pretrained(TROCR_NAME)
142
+ model = VisionEncoderDecoderModel.from_pretrained(TROCR_NAME).eval()
143
+ if DEVICE == "cuda":
144
+ model = model.to(DEVICE).to(dtype=CUDA_DTYPE)
145
+ else:
146
+ model = model.to(DEVICE)
147
+ return processor, model
148
+
149
+
150
+ def load_blip():
151
+ processor = BlipProcessor.from_pretrained(BLIP_NAME)
152
+ model = BlipForConditionalGeneration.from_pretrained(BLIP_NAME).eval()
153
+ if DEVICE == "cuda":
154
+ model = model.to(DEVICE).to(dtype=CUDA_DTYPE)
155
+ else:
156
+ model = model.to(DEVICE)
157
+ return processor, model
158
+
159
+
160
+ trocr_processor, trocr_model = load_trocr()
161
+ blip_processor, blip_model = load_blip()
162
+
163
+
164
+ # =========================
165
+ # App configs
166
+ # =========================
167
  MODEL_CONFIGS = {
168
  "Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True},
169
  "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
170
  "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
171
  "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
172
+ "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
173
  }
174
 
175
  TASK_PROMPTS = {
176
+ "πŸ“‹ Markdown": {
177
+ "prompt": "<image>\n<|grounding|>Convert the document to markdown.",
178
+ "has_grounding": True,
179
+ },
180
+ # NOTE: Free OCR Ρ‚Π΅ΠΏΠ΅Ρ€ΡŒ Π΄Π΅Π»Π°Π΅ΠΌ Ρ‡Π΅Ρ€Π΅Π· TrOCR (быстро, text-only)
181
+ "πŸ“ Free OCR": {"prompt": "", "has_grounding": False},
182
+ # Locate оставляСм Π½Π° DeepSeek (grounding)
183
+ "πŸ“ Locate": {
184
+ "prompt": "<image>\nLocate <|ref|>text<|/ref|> in the image.",
185
+ "has_grounding": True,
186
+ },
187
+ # Describe Ρ‚Π΅ΠΏΠ΅Ρ€ΡŒ Π΄Π΅Π»Π°Π΅ΠΌ Ρ‡Π΅Ρ€Π΅Π· BLIP
188
+ "πŸ” Describe": {"prompt": "", "has_grounding": False},
189
+ "✏️ Custom": {"prompt": "", "has_grounding": False},
190
  }
191
 
192
+
193
+ # =========================
194
+ # Helpers
195
+ # =========================
196
+ def safe_load_font(size: int = 30) -> ImageFont.FreeTypeFont:
197
+ candidates = [
198
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
199
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
200
+ ]
201
+ for p in candidates:
202
+ try:
203
+ if os.path.exists(p):
204
+ return ImageFont.truetype(p, size)
205
+ except Exception:
206
+ continue
207
+ return ImageFont.load_default()
208
+
209
+
210
+ def extract_grounding_references(text: str):
211
+ pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
212
  return re.findall(pattern, text, re.DOTALL)
213
 
214
+
215
+ def draw_bounding_boxes(image: Image.Image, refs, extract_images: bool = False):
216
  img_w, img_h = image.size
217
  img_draw = image.copy()
218
  draw = ImageDraw.Draw(img_draw)
219
+ overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
220
  draw2 = ImageDraw.Draw(overlay)
221
+ font = safe_load_font(30)
222
  crops = []
223
+
224
  color_map = {}
225
  np.random.seed(42)
226
 
227
  for ref in refs:
228
  label = ref[1]
229
  if label not in color_map:
230
+ color_map[label] = (
231
+ int(np.random.randint(50, 255)),
232
+ int(np.random.randint(50, 255)),
233
+ int(np.random.randint(50, 255)),
234
+ )
235
 
236
  color = color_map[label]
237
+ try:
238
+ coords = eval(ref[2])
239
+ except Exception:
240
+ continue
241
+
242
  color_a = color + (60,)
243
+
244
  for box in coords:
245
+ x1, y1, x2, y2 = (
246
+ int(box[0] / 999 * img_w),
247
+ int(box[1] / 999 * img_h),
248
+ int(box[2] / 999 * img_w),
249
+ int(box[3] / 999 * img_h),
250
+ )
251
+
252
+ if extract_images and label == "image":
253
  crops.append(image.crop((x1, y1, x2, y2)))
254
+
255
+ width = 5 if label == "title" else 3
256
  draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
257
  draw2.rectangle([x1, y1, x2, y2], fill=color_a)
258
+
259
  text_bbox = draw.textbbox((0, 0), label, font=font)
260
  tw, th = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
261
  ty = max(0, y1 - 20)
262
  draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], fill=color)
263
  draw.text((x1 + 2, ty + 2), label, font=font, fill=(255, 255, 255))
264
+
265
  img_draw.paste(overlay, (0, 0), overlay)
266
  return img_draw, crops
267
 
268
+
269
+ def clean_output(text: str, include_images: bool = False) -> str:
270
  if not text:
271
  return ""
272
+ pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
273
  matches = re.findall(pattern, text, re.DOTALL)
274
  img_num = 0
275
+
276
  for match in matches:
277
+ if "<|ref|>image<|/ref|>" in match[0]:
278
  if include_images:
279
+ text = text.replace(match[0], f"\n\n**[Figure {img_num + 1}]**\n\n", 1)
280
  img_num += 1
281
  else:
282
+ text = text.replace(match[0], "", 1)
283
  else:
284
+ text = re.sub(rf"(?m)^[^\n]*{re.escape(match[0])}[^\n]*\n?", "", text)
285
+
286
  return text.strip()
287
 
288
+
289
+ def embed_images(markdown: str, crops: List[Image.Image]) -> str:
290
  if not crops:
291
  return markdown
292
  for i, img in enumerate(crops):
293
  buf = BytesIO()
294
  img.save(buf, format="PNG")
295
  b64 = base64.b64encode(buf.getvalue()).decode()
296
+ markdown = markdown.replace(
297
+ f"**[Figure {i + 1}]**",
298
+ f"\n\n![Figure {i + 1}](data:image/png;base64,{b64})\n\n",
299
+ 1,
300
+ )
301
  return markdown
302
 
303
+
304
+ def trocr_ocr(image: Image.Image) -> str:
305
+ if image.mode != "RGB":
306
+ image = image.convert("RGB")
307
+ pixel_values = trocr_processor(images=image, return_tensors="pt").pixel_values.to(DEVICE)
308
+ with torch.no_grad():
309
+ # Keep generation modest (faster)
310
+ generated_ids = trocr_model.generate(pixel_values, max_new_tokens=256)
311
+ text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
312
+ return text.strip()
313
+
314
+
315
+ def blip_describe(image: Image.Image) -> str:
316
+ if image.mode != "RGB":
317
+ image = image.convert("RGB")
318
+ inputs = blip_processor(images=image, return_tensors="pt").to(DEVICE)
319
+ with torch.no_grad():
320
+ out = blip_model.generate(**inputs, max_new_tokens=80)
321
+ caption = blip_processor.decode(out[0], skip_special_tokens=True)
322
+ return caption.strip()
323
+
324
+
325
+ # =========================
326
+ # Core processing
327
+ # =========================
328
+ @gpu_decorator(duration=60)
329
+ def process_image(image: Image.Image, mode: str, task: str, custom_prompt: str):
330
  if image is None:
331
+ return "Error: upload image", "", "", None, []
332
+
333
  if task in ["✏️ Custom", "πŸ“ Locate"] and not custom_prompt.strip():
334
+ return "Error: enter prompt", "", "", None, []
335
+
336
+ if image.mode in ("RGBA", "LA", "P"):
337
+ image = image.convert("RGB")
338
  image = ImageOps.exif_transpose(image)
339
+
340
+ # --- Route tasks to the best backend ---
341
+ if task == "πŸ“ Free OCR":
342
+ text = trocr_ocr(image)
343
+ if not text:
344
+ return "No text", "", "", None, []
345
+ md = "```text\n" + text + "\n```"
346
+ return text, md, text, None, []
347
+
348
+ if task == "πŸ” Describe":
349
+ desc = blip_describe(image)
350
+ if not desc:
351
+ return "No description", "", "", None, []
352
+ md = f"**Description:** {desc}"
353
+ return desc, md, desc, None, []
354
+
355
+ # --- DeepSeek-OCR for Markdown / Locate / Custom ---
356
  config = MODEL_CONFIGS[mode]
357
+
358
  if task == "✏️ Custom":
359
  prompt = f"<image>\n{custom_prompt.strip()}"
360
+ has_grounding = "<|grounding|>" in custom_prompt
361
  elif task == "πŸ“ Locate":
362
  prompt = f"<image>\nLocate <|ref|>{custom_prompt.strip()}<|/ref|> in the image."
363
  has_grounding = True
364
  else:
365
  prompt = TASK_PROMPTS[task]["prompt"]
366
  has_grounding = TASK_PROMPTS[task]["has_grounding"]
367
+
368
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
369
+ image.save(tmp.name, "JPEG", quality=95)
370
  tmp.close()
371
  out_dir = tempfile.mkdtemp()
372
+
373
  stdout = sys.stdout
374
  sys.stdout = StringIO()
375
+
376
+ try:
377
+ deepseek_model.infer(
378
+ tokenizer=tokenizer,
379
+ prompt=prompt,
380
+ image_file=tmp.name,
381
+ output_path=out_dir,
382
+ base_size=config["base_size"],
383
+ image_size=config["image_size"],
384
+ crop_mode=config["crop_mode"],
385
+ )
386
+
387
+ result = "\n".join(
388
+ [
389
+ l
390
+ for l in sys.stdout.getvalue().split("\n")
391
+ if not any(
392
+ s in l
393
+ for s in [
394
+ "image:",
395
+ "other:",
396
+ "PATCHES",
397
+ "====",
398
+ "BASE:",
399
+ "%|",
400
+ "torch.Size",
401
+ ]
402
+ )
403
+ ]
404
+ ).strip()
405
+
406
+ finally:
407
+ sys.stdout = stdout
408
+ try:
409
+ os.unlink(tmp.name)
410
+ except Exception:
411
+ pass
412
+ shutil.rmtree(out_dir, ignore_errors=True)
413
+
414
  if not result:
415
  return "No text", "", "", None, []
416
+
417
+ cleaned = clean_output(result, include_images=False)
418
+ markdown = clean_output(result, include_images=True)
419
+
420
  img_out = None
421
  crops = []
422
+
423
+ if has_grounding and "<|ref|>" in result:
424
  refs = extract_grounding_references(result)
425
  if refs:
426
+ img_out, crops = draw_bounding_boxes(image, refs, extract_images=True)
427
+
428
  markdown = embed_images(markdown, crops)
429
+
430
  return cleaned, markdown, result, img_out, crops
431
 
432
+
433
+ @gpu_decorator(duration=60)
434
+ def process_pdf(path: str, mode: str, task: str, custom_prompt: str, page_num: int):
435
  doc = fitz.open(path)
436
  total_pages = len(doc)
437
  if page_num < 1 or page_num > total_pages:
438
  doc.close()
439
  return f"Invalid page number. PDF has {total_pages} pages.", "", "", None, []
440
  page = doc.load_page(page_num - 1)
441
+ pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72), alpha=False)
442
  img = Image.open(BytesIO(pix.tobytes("png")))
443
  doc.close()
 
444
  return process_image(img, mode, task, custom_prompt)
445
 
446
+
447
+ def process_file(path: str, mode: str, task: str, custom_prompt: str, page_num: int):
448
  if not path:
449
+ return "Error: upload file", "", "", None, []
450
+ if path.lower().endswith(".pdf"):
451
  return process_pdf(path, mode, task, custom_prompt, page_num)
452
+ return process_image(Image.open(path), mode, task, custom_prompt)
453
+
454
 
455
+ def toggle_prompt(task: str):
456
  if task == "✏️ Custom":
457
  return gr.update(visible=True, label="Custom Prompt", placeholder="Add <|grounding|> for boxes")
458
+ if task == "πŸ“ Locate":
459
  return gr.update(visible=True, label="Text to Locate", placeholder="Enter text")
460
  return gr.update(visible=False)
461
 
462
+
463
+ def select_boxes(task: str):
464
  if task == "πŸ“ Locate":
465
  return gr.update(selected="tab_boxes")
466
  return gr.update()
467
 
468
+
469
+ def get_pdf_page_count(file_path: str) -> int:
470
+ if not file_path or not file_path.lower().endswith(".pdf"):
471
  return 1
472
  doc = fitz.open(file_path)
473
  count = len(doc)
474
  doc.close()
475
  return count
476
 
477
+
478
+ def load_image(file_path: str, page_num: int = 1):
479
  if not file_path:
480
  return None
481
+ if file_path.lower().endswith(".pdf"):
482
  doc = fitz.open(file_path)
483
  page_idx = max(0, min(int(page_num) - 1, len(doc) - 1))
484
  page = doc.load_page(page_idx)
485
+ pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72), alpha=False)
486
  img = Image.open(BytesIO(pix.tobytes("png")))
487
  doc.close()
488
  return img
489
+ return Image.open(file_path)
490
+
491
 
492
+ def update_page_selector(file_path: str):
493
  if not file_path:
494
  return gr.update(visible=False)
495
+ if file_path.lower().endswith(".pdf"):
496
  page_count = get_pdf_page_count(file_path)
497
+ return gr.update(
498
+ visible=True,
499
+ maximum=page_count,
500
+ value=1,
501
+ minimum=1,
502
+ label=f"Select Page (1-{page_count})",
503
+ )
504
  return gr.update(visible=False)
505
 
506
+
507
+ # =========================
508
+ # UI
509
+ # =========================
510
+ with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek-OCR + TrOCR + BLIP") as demo:
511
+ gr.Markdown(
512
+ f"""
513
+ # DeepSeek-OCR Demo (with TrOCR + BLIP)
514
+
515
+ This app supports:
516
+ - **Markdown**: DeepSeek-OCR (structured markdown + optional grounding boxes)
517
+ - **Free OCR**: TrOCR (fast text-only OCR)
518
+ - **Locate**: DeepSeek-OCR (grounding boxes)
519
+ - **Describe**: BLIP (image captioning)
520
+
521
+ Runtime device: **{DEVICE}**
522
+ """
523
+ )
524
+
525
  with gr.Row():
526
  with gr.Column(scale=1):
527
  file_in = gr.File(label="Upload Image or PDF", file_types=["image", ".pdf"], type="filepath")
528
  input_img = gr.Image(label="Input Image", type="pil", height=300)
529
  page_selector = gr.Number(label="Select Page", value=1, minimum=1, step=1, visible=False)
530
+
531
  mode = gr.Dropdown(list(MODEL_CONFIGS.keys()), value="Gundam", label="Mode")
532
  task = gr.Dropdown(list(TASK_PROMPTS.keys()), value="πŸ“‹ Markdown", label="Task")
533
  prompt = gr.Textbox(label="Prompt", lines=2, visible=False)
534
+
535
  btn = gr.Button("Extract", variant="primary", size="lg")
536
+
537
  with gr.Column(scale=2):
538
  with gr.Tabs() as tabs:
539
  with gr.Tab("Text", id="tab_text"):
 
546
  gallery = gr.Gallery(show_label=False, columns=3, height=400)
547
  with gr.Tab("Raw Text", id="tab_raw"):
548
  raw_out = gr.Textbox(lines=20, show_copy_button=True, show_label=False)
549
+
550
+ # Better examples: populate File input (works for both image/pdf paths inside repo)
551
+ gr.Examples(
552
+ examples=[
553
+ ["examples/ocr.jpg", "Gundam", "πŸ“‹ Markdown", "", 1],
554
+ ["examples/reachy-mini.jpg", "Gundam", "πŸ“ Locate", "Robot", 1],
555
+ ],
556
+ inputs=[file_in, mode, task, prompt, page_selector],
557
+ cache_examples=False,
558
+ )
559
+
560
+ with gr.Accordion("ℹ️ Info", open=False):
561
+ gr.Markdown(
562
+ """
563
+ ### Modes
564
+ - **Gundam**: 1024 base + 640 tiles with cropping - Best balance
565
+ - **Tiny**: 512Γ—512, no crop - Fastest
566
+ - **Small**: 640Γ—640, no crop - Quick
567
+ - **Base**: 1024Γ—1024, no crop - Standard
568
+ - **Large**: 1280Γ—1280, no crop - Highest quality
569
+
570
+ ### Tasks
571
+ - **πŸ“‹ Markdown**: DeepSeek-OCR β†’ structured markdown (grounding βœ…)
572
+ - **πŸ“ Free OCR**: TrOCR β†’ fast text-only OCR
573
+ - **πŸ“ Locate**: DeepSeek-OCR β†’ bounding boxes (grounding βœ…)
574
+ - **πŸ” Describe**: BLIP β†’ short image description
575
+ - **✏️ Custom**: DeepSeek-OCR prompt (add `<|grounding|>` for boxes)
576
+ """
577
+ )
578
+
579
+ # File / PDF page handling
580
  file_in.change(load_image, [file_in, page_selector], [input_img])
581
  file_in.change(update_page_selector, [file_in], [page_selector])
582
  page_selector.change(load_image, [file_in, page_selector], [input_img])
583
+
584
+ # Prompt visibility and tab switch
585
  task.change(toggle_prompt, [task], [prompt])
586
  task.change(select_boxes, [task], [tabs])
587
+
588
  def run(image, file_path, mode, task, custom_prompt, page_num):
589
  if file_path:
590
  return process_file(file_path, mode, task, custom_prompt, int(page_num))
591
  if image is not None:
592
  return process_image(image, mode, task, custom_prompt)
593
+ return "Error: upload file or image", "", "", None, []
594
 
595
+ submit_event = btn.click(
596
+ run,
597
+ [input_img, file_in, mode, task, prompt, page_selector],
598
+ [text_out, md_out, raw_out, img_out, gallery],
599
+ )
600
  submit_event.then(select_boxes, [task], [tabs])
601
 
602
  if __name__ == "__main__":
603
+ demo.queue(max_size=20).launch()