ricklon commited on
Commit
f76cb58
·
1 Parent(s): 474fd39

Refine equation grounding with zoom-in pass and per-box spatial blocks

Browse files
Files changed (2) hide show
  1. app.py +241 -72
  2. tests/test_spatial_blocks.py +56 -0
app.py CHANGED
@@ -44,6 +44,19 @@ model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation=_attn_impl, t
44
  BASE_SIZE = 1024
45
  IMAGE_SIZE = 768
46
  CROP_MODE = True
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  TASK_PROMPTS = {
49
  "📋 Markdown": {"prompt": "<image>\n<|grounding|>Convert the document to markdown.", "has_grounding": True},
@@ -54,8 +67,112 @@ TASK_PROMPTS = {
54
  }
55
 
56
  def extract_grounding_references(text):
57
- pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
58
- return re.findall(pattern, text, re.DOTALL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def draw_bounding_boxes(image, refs, extract_images=False):
61
  img_w, img_h = image.size
@@ -75,7 +192,7 @@ def draw_bounding_boxes(image, refs, extract_images=False):
75
  color_map[label] = (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255))
76
 
77
  color = color_map[label]
78
- coords = eval(ref[2])
79
  color_a = color + (60,)
80
 
81
  for box in coords:
@@ -326,48 +443,20 @@ def to_mathjax_html(text: str) -> str:
326
  return f'<div class="mathjax-preview">{html}</div>'
327
 
328
  def _grounding_blocks_from_raw(raw_text: str):
329
- if not raw_text:
330
- return []
331
-
332
- pattern = re.compile(r'<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>', re.DOTALL)
333
  blocks = []
334
- last_end = 0
335
-
336
- for m in pattern.finditer(raw_text):
337
- label = m.group(1).strip() or "text"
338
- coord_text = m.group(2).strip()
339
- text_chunk = raw_text[last_end:m.start()].strip()
340
- last_end = m.end()
341
-
342
- try:
343
- coords = ast.literal_eval(coord_text)
344
- except (SyntaxError, ValueError):
345
- continue
346
-
347
- if isinstance(coords, (tuple, list)) and coords and isinstance(coords[0], (int, float)):
348
- coords = [coords]
349
- if not isinstance(coords, list):
350
- continue
351
-
352
- boxes = [c for c in coords if isinstance(c, (list, tuple)) and len(c) >= 4]
353
- if not boxes:
354
- continue
355
-
356
- x1 = max(0.0, min(float(c[0]) for c in boxes))
357
- y1 = max(0.0, min(float(c[1]) for c in boxes))
358
- x2 = min(999.0, max(float(c[2]) for c in boxes))
359
- y2 = min(999.0, max(float(c[3]) for c in boxes))
360
- if x2 <= x1 or y2 <= y1:
361
- continue
362
-
363
- blocks.append({
364
- "label": label,
365
- "text": text_chunk,
366
- "x1": x1,
367
- "y1": y1,
368
- "x2": x2,
369
- "y2": y2,
370
- })
371
 
372
  return blocks
373
 
@@ -487,6 +576,106 @@ def embed_images(markdown, crops):
487
  markdown = markdown.replace(f'**[Figure {i + 1}]**', f'\n\n![Figure {i + 1}](data:image/png;base64,{b64})\n\n', 1)
488
  return markdown
489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  @spaces.GPU(duration=90)
491
  def process_image(image, task, custom_prompt):
492
  model.cuda() # GPU is available here — works on ZeroGPU and locally
@@ -508,33 +697,7 @@ def process_image(image, task, custom_prompt):
508
  else:
509
  prompt = TASK_PROMPTS[task]["prompt"]
510
  has_grounding = TASK_PROMPTS[task]["has_grounding"]
511
-
512
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
513
- image.save(tmp.name, 'JPEG', quality=95)
514
- tmp.close()
515
- out_dir = tempfile.mkdtemp()
516
-
517
- stdout = sys.stdout
518
- sys.stdout = StringIO()
519
-
520
- model.infer(
521
- tokenizer=tokenizer,
522
- prompt=prompt,
523
- image_file=tmp.name,
524
- output_path=out_dir,
525
- base_size=BASE_SIZE,
526
- image_size=IMAGE_SIZE,
527
- crop_mode=CROP_MODE,
528
- save_results=False
529
- )
530
-
531
- debug_filters = ['PATCHES', '====', 'BASE:', 'directly resize', 'NO PATCHES', 'torch.Size', '%|']
532
- result = '\n'.join([l for l in sys.stdout.getvalue().split('\n')
533
- if l.strip() and not any(s in l for s in debug_filters)]).strip()
534
- sys.stdout = stdout
535
-
536
- os.unlink(tmp.name)
537
- shutil.rmtree(out_dir, ignore_errors=True)
538
 
539
  if not result:
540
  return "No text detected", "", "", None, []
@@ -544,15 +707,21 @@ def process_image(image, task, custom_prompt):
544
 
545
  img_out = None
546
  crops = []
 
547
 
548
  if has_grounding and '<|ref|>' in result:
549
  refs = extract_grounding_references(result)
 
 
550
  if refs:
551
  img_out, crops = draw_bounding_boxes(image, refs, True)
 
 
 
552
 
553
  markdown = embed_images(markdown, crops)
554
 
555
- return cleaned, markdown, result, img_out, crops
556
 
557
  @spaces.GPU(duration=90)
558
  def process_pdf(path, task, custom_prompt, page_num):
 
44
  BASE_SIZE = 1024
45
  IMAGE_SIZE = 768
46
  CROP_MODE = True
47
+ GROUNDING_PATTERN = re.compile(r'<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>', re.DOTALL)
48
+ INFER_DEBUG_FILTERS = ['PATCHES', '====', 'BASE:', 'directly resize', 'NO PATCHES', 'torch.Size', '%|']
49
+ EQUATION_ZOOM_PROMPT = "<image>\n<|grounding|>Locate each individual equation or math line."
50
+ EQUATION_ZOOM_MAX_CANDIDATES = 6
51
+ EQUATION_ZOOM_MIN_AREA = 0.05
52
+ EQUATION_ZOOM_MIN_DIM = 0.24
53
+ EQUATION_ZOOM_PADDING = 0.025
54
+ EQUATION_ZOOM_MAX_ASPECT = 12.0
55
+ EQUATION_DETAIL_MAX_BOXES = 24
56
+ EQUATION_DETAIL_IOU_DEDUPE = 0.7
57
+ MATH_LABEL_HINTS = ("formula", "equation", "math")
58
+ MATH_STRONG_MARKERS = ("\\(", "\\[", "\\frac", "\\sum", "\\int", "\\sqrt", "\\lim", "\\begin{")
59
+ MATH_WEAK_MARKERS = ("^", "_", "=", "+", "\\cdot", "\\times")
60
 
61
  TASK_PROMPTS = {
62
  "📋 Markdown": {"prompt": "<image>\n<|grounding|>Convert the document to markdown.", "has_grounding": True},
 
67
  }
68
 
69
  def extract_grounding_references(text):
70
+ refs = []
71
+ for entry in _extract_grounding_entries(text):
72
+ coord_text = repr(entry["coords"])
73
+ raw = f'<|ref|>{entry["label"]}<|/ref|><|det|>{coord_text}<|/det|>'
74
+ refs.append((raw, entry["label"], coord_text))
75
+ return refs
76
+
77
+ def _parse_coord_payload(payload):
78
+ if isinstance(payload, str):
79
+ try:
80
+ coords = ast.literal_eval(payload.strip())
81
+ except (SyntaxError, ValueError):
82
+ return []
83
+ else:
84
+ coords = payload
85
+
86
+ if isinstance(coords, (tuple, list)) and coords and isinstance(coords[0], (int, float)):
87
+ coords = [coords]
88
+ if not isinstance(coords, list):
89
+ return []
90
+
91
+ out = []
92
+ for c in coords:
93
+ if not isinstance(c, (list, tuple)) or len(c) < 4:
94
+ continue
95
+ x1, y1, x2, y2 = [float(v) for v in c[:4]]
96
+ x1, x2 = sorted((max(0.0, min(999.0, x1)), max(0.0, min(999.0, x2))))
97
+ y1, y2 = sorted((max(0.0, min(999.0, y1)), max(0.0, min(999.0, y2))))
98
+ if x2 <= x1 or y2 <= y1:
99
+ continue
100
+ out.append([x1, y1, x2, y2])
101
+ return out
102
+
103
+ def _extract_grounding_entries(raw_text: str):
104
+ if not raw_text:
105
+ return []
106
+
107
+ entries = []
108
+ last_end = 0
109
+ for m in GROUNDING_PATTERN.finditer(raw_text):
110
+ label = m.group(1).strip() or "text"
111
+ coords = _parse_coord_payload(m.group(2))
112
+ if not coords:
113
+ continue
114
+ text_chunk = raw_text[last_end:m.start()].strip()
115
+ entries.append({
116
+ "label": label,
117
+ "coords": coords,
118
+ "text": text_chunk,
119
+ })
120
+ last_end = m.end()
121
+ return entries
122
+
123
+ def _math_marker_score(text_chunk: str) -> int:
124
+ score = 0
125
+ for marker in MATH_STRONG_MARKERS:
126
+ if marker in text_chunk:
127
+ score += 3
128
+ for marker in MATH_WEAK_MARKERS:
129
+ if marker in text_chunk:
130
+ score += 1
131
+ return score
132
+
133
+ def _box_iou(a, b):
134
+ ax1, ay1, ax2, ay2 = a
135
+ bx1, by1, bx2, by2 = b
136
+ inter_x1 = max(ax1, bx1)
137
+ inter_y1 = max(ay1, by1)
138
+ inter_x2 = min(ax2, bx2)
139
+ inter_y2 = min(ay2, by2)
140
+ if inter_x2 <= inter_x1 or inter_y2 <= inter_y1:
141
+ return 0.0
142
+ inter = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
143
+ area_a = max(1e-9, (ax2 - ax1) * (ay2 - ay1))
144
+ area_b = max(1e-9, (bx2 - bx1) * (by2 - by1))
145
+ union = area_a + area_b - inter
146
+ return inter / union if union > 0 else 0.0
147
+
148
+ def _dedupe_boxes(boxes, iou_threshold):
149
+ kept = []
150
+ for box in sorted(boxes, key=lambda b: ((b[2] - b[0]) * (b[3] - b[1]))):
151
+ if any(_box_iou(box, other) >= iou_threshold for other in kept):
152
+ continue
153
+ kept.append(box)
154
+ return kept
155
+
156
+ def _is_math_candidate(label: str, text_chunk: str, box):
157
+ label_l = label.lower()
158
+ box_w = (box[2] - box[0]) / 999.0
159
+ box_h = (box[3] - box[1]) / 999.0
160
+ area = box_w * box_h
161
+ aspect = max(box_w / max(1e-9, box_h), box_h / max(1e-9, box_w))
162
+ has_math_label = any(hint in label_l for hint in MATH_LABEL_HINTS)
163
+ has_math_text = _math_marker_score(text_chunk) >= 3
164
+ is_large = area >= EQUATION_ZOOM_MIN_AREA or box_w >= EQUATION_ZOOM_MIN_DIM or box_h >= EQUATION_ZOOM_MIN_DIM
165
+ return (has_math_label or has_math_text) and is_large and aspect <= EQUATION_ZOOM_MAX_ASPECT
166
+
167
+ def _map_crop_box_to_page(sub_box, crop_px, img_w, img_h):
168
+ crop_x1, crop_y1, crop_x2, crop_y2 = crop_px
169
+ crop_w = max(1, crop_x2 - crop_x1)
170
+ crop_h = max(1, crop_y2 - crop_y1)
171
+ page_x1 = ((crop_x1 + (sub_box[0] / 999.0) * crop_w) / img_w) * 999.0
172
+ page_y1 = ((crop_y1 + (sub_box[1] / 999.0) * crop_h) / img_h) * 999.0
173
+ page_x2 = ((crop_x1 + (sub_box[2] / 999.0) * crop_w) / img_w) * 999.0
174
+ page_y2 = ((crop_y1 + (sub_box[3] / 999.0) * crop_h) / img_h) * 999.0
175
+ return _parse_coord_payload([[page_x1, page_y1, page_x2, page_y2]])[0]
176
 
177
  def draw_bounding_boxes(image, refs, extract_images=False):
178
  img_w, img_h = image.size
 
192
  color_map[label] = (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255))
193
 
194
  color = color_map[label]
195
+ coords = _parse_coord_payload(ref[2])
196
  color_a = color + (60,)
197
 
198
  for box in coords:
 
443
  return f'<div class="mathjax-preview">{html}</div>'
444
 
445
  def _grounding_blocks_from_raw(raw_text: str):
 
 
 
 
446
  blocks = []
447
+ for entry in _extract_grounding_entries(raw_text):
448
+ label = entry["label"]
449
+ text = entry["text"].strip()
450
+ coords = entry["coords"]
451
+ for idx, c in enumerate(coords):
452
+ blocks.append({
453
+ "label": label,
454
+ "text": text if idx == 0 else "",
455
+ "x1": c[0],
456
+ "y1": c[1],
457
+ "x2": c[2],
458
+ "y2": c[3],
459
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
  return blocks
462
 
 
576
  markdown = markdown.replace(f'**[Figure {i + 1}]**', f'\n\n![Figure {i + 1}](data:image/png;base64,{b64})\n\n', 1)
577
  return markdown
578
 
579
+ def _infer_with_prompt(image, prompt):
580
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
581
+ image.save(tmp.name, 'JPEG', quality=95)
582
+ tmp.close()
583
+ out_dir = tempfile.mkdtemp()
584
+
585
+ stdout = sys.stdout
586
+ capture = StringIO()
587
+ sys.stdout = capture
588
+ try:
589
+ model.infer(
590
+ tokenizer=tokenizer,
591
+ prompt=prompt,
592
+ image_file=tmp.name,
593
+ output_path=out_dir,
594
+ base_size=BASE_SIZE,
595
+ image_size=IMAGE_SIZE,
596
+ crop_mode=CROP_MODE,
597
+ save_results=False
598
+ )
599
+ finally:
600
+ sys.stdout = stdout
601
+ os.unlink(tmp.name)
602
+ shutil.rmtree(out_dir, ignore_errors=True)
603
+
604
+ lines = [
605
+ l for l in capture.getvalue().split('\n')
606
+ if l.strip() and not any(s in l for s in INFER_DEBUG_FILTERS)
607
+ ]
608
+ return '\n'.join(lines).strip()
609
+
610
+ def _refine_equation_refs(image, raw_text):
611
+ entries = _extract_grounding_entries(raw_text)
612
+ if not entries:
613
+ return []
614
+
615
+ img_w, img_h = image.size
616
+ candidates = []
617
+ for entry in entries:
618
+ for box in entry["coords"]:
619
+ if _is_math_candidate(entry["label"], entry["text"], box):
620
+ area = (box[2] - box[0]) * (box[3] - box[1])
621
+ candidates.append((area, entry, box))
622
+
623
+ if not candidates:
624
+ return []
625
+
626
+ candidates.sort(key=lambda x: x[0], reverse=True)
627
+ refined_refs = []
628
+ for _, entry, box in candidates[:EQUATION_ZOOM_MAX_CANDIDATES]:
629
+ x1 = int(box[0] / 999.0 * img_w)
630
+ y1 = int(box[1] / 999.0 * img_h)
631
+ x2 = int(box[2] / 999.0 * img_w)
632
+ y2 = int(box[3] / 999.0 * img_h)
633
+ box_w = max(1, x2 - x1)
634
+ box_h = max(1, y2 - y1)
635
+ pad_x = max(8, int(box_w * EQUATION_ZOOM_PADDING))
636
+ pad_y = max(8, int(box_h * EQUATION_ZOOM_PADDING))
637
+ crop_x1 = max(0, x1 - pad_x)
638
+ crop_y1 = max(0, y1 - pad_y)
639
+ crop_x2 = min(img_w, x2 + pad_x)
640
+ crop_y2 = min(img_h, y2 + pad_y)
641
+ if crop_x2 - crop_x1 < 32 or crop_y2 - crop_y1 < 32:
642
+ continue
643
+
644
+ crop = image.crop((crop_x1, crop_y1, crop_x2, crop_y2))
645
+ sub_result = _infer_with_prompt(crop, EQUATION_ZOOM_PROMPT)
646
+ sub_entries = _extract_grounding_entries(sub_result)
647
+ if not sub_entries:
648
+ continue
649
+
650
+ mapped_boxes = []
651
+ for sub in sub_entries:
652
+ sub_label = sub["label"].lower()
653
+ sub_text = sub["text"]
654
+ is_math_sub = any(hint in sub_label for hint in MATH_LABEL_HINTS) or _math_marker_score(sub_text) >= 3
655
+ if sub_label in ("image", "table") or not is_math_sub:
656
+ continue
657
+ for sub_box in sub["coords"]:
658
+ mapped = _map_crop_box_to_page(sub_box, (crop_x1, crop_y1, crop_x2, crop_y2), img_w, img_h)
659
+ w = (mapped[2] - mapped[0]) / 999.0
660
+ h = (mapped[3] - mapped[1]) / 999.0
661
+ if w * h < 0.0004:
662
+ continue
663
+ mapped_boxes.append(mapped)
664
+
665
+ if not mapped_boxes:
666
+ continue
667
+ mapped_boxes = _dedupe_boxes(mapped_boxes, EQUATION_DETAIL_IOU_DEDUPE)
668
+ mapped_boxes = sorted(mapped_boxes, key=lambda b: (b[1], b[0]))[:EQUATION_DETAIL_MAX_BOXES]
669
+ if len(mapped_boxes) < 2:
670
+ continue
671
+
672
+ merged_text = repr(mapped_boxes)
673
+ label = "equation_detail"
674
+ raw = f'<|ref|>{label}<|/ref|><|det|>{merged_text}<|/det|>'
675
+ refined_refs.append((raw, label, merged_text))
676
+
677
+ return refined_refs
678
+
679
  @spaces.GPU(duration=90)
680
  def process_image(image, task, custom_prompt):
681
  model.cuda() # GPU is available here — works on ZeroGPU and locally
 
697
  else:
698
  prompt = TASK_PROMPTS[task]["prompt"]
699
  has_grounding = TASK_PROMPTS[task]["has_grounding"]
700
+ result = _infer_with_prompt(image, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
 
702
  if not result:
703
  return "No text detected", "", "", None, []
 
707
 
708
  img_out = None
709
  crops = []
710
+ result_for_layout = result
711
 
712
  if has_grounding and '<|ref|>' in result:
713
  refs = extract_grounding_references(result)
714
+ if task == "📋 Markdown":
715
+ refs.extend(_refine_equation_refs(image, result))
716
  if refs:
717
  img_out, crops = draw_bounding_boxes(image, refs, True)
718
+ synthetic = [r[0] for r in refs if r[1] == "equation_detail"]
719
+ if synthetic:
720
+ result_for_layout = result + "\n" + "\n".join(synthetic)
721
 
722
  markdown = embed_images(markdown, crops)
723
 
724
+ return cleaned, markdown, result_for_layout, img_out, crops
725
 
726
  @spaces.GPU(duration=90)
727
  def process_pdf(path, task, custom_prompt, page_num):
tests/test_spatial_blocks.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import pathlib
3
+ import re
4
+ import unittest
5
+
6
+
7
+ def _load_grounding_blocks():
8
+ app_path = pathlib.Path(__file__).resolve().parents[1] / "app.py"
9
+ source = app_path.read_text(encoding="utf-8")
10
+ module = ast.parse(source, filename=str(app_path))
11
+
12
+ wanted = {
13
+ "_parse_coord_payload",
14
+ "_extract_grounding_entries",
15
+ "_grounding_blocks_from_raw",
16
+ }
17
+ fn_nodes = [n for n in module.body if isinstance(n, ast.FunctionDef) and n.name in wanted]
18
+ fn_nodes.sort(key=lambda n: n.lineno)
19
+
20
+ test_mod = ast.Module(body=fn_nodes, type_ignores=[])
21
+ code = compile(test_mod, filename=str(app_path), mode="exec")
22
+
23
+ scope = {
24
+ "ast": ast,
25
+ "re": re,
26
+ "GROUNDING_PATTERN": re.compile(r"<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>", re.DOTALL),
27
+ }
28
+ exec(code, scope)
29
+ return scope["_grounding_blocks_from_raw"]
30
+
31
+
32
+ class SpatialBlockTests(unittest.TestCase):
33
+ def test_multi_coord_refs_render_as_multiple_blocks(self):
34
+ grounding_blocks = _load_grounding_blocks()
35
+ raw = (
36
+ "Equation cluster\n"
37
+ "<|ref|>formula<|/ref|><|det|>[[100,100,600,220],[100,240,600,360]]<|/det|>\n"
38
+ "Trailing text\n"
39
+ "<|ref|>text<|/ref|><|det|>[[40,400,700,520]]<|/det|>"
40
+ )
41
+
42
+ blocks = grounding_blocks(raw)
43
+ self.assertEqual(3, len(blocks))
44
+
45
+ formula_blocks = [b for b in blocks if b["label"] == "formula"]
46
+ self.assertEqual(2, len(formula_blocks))
47
+ self.assertEqual(100.0, formula_blocks[0]["x1"])
48
+ self.assertEqual(600.0, formula_blocks[0]["x2"])
49
+ self.assertEqual(100.0, formula_blocks[1]["x1"])
50
+ self.assertEqual(600.0, formula_blocks[1]["x2"])
51
+ self.assertEqual("Equation cluster", formula_blocks[0]["text"])
52
+ self.assertEqual("", formula_blocks[1]["text"])
53
+
54
+
55
+ if __name__ == "__main__":
56
+ unittest.main()