import gradio as gr from transformers import AutoModel, AutoTokenizer import torch import spaces import os import sys import tempfile import shutil import inspect from PIL import Image, ImageDraw, ImageFont, ImageOps import fitz import re import ast import numpy as np import base64 import html as html_lib import markdown as md_lib import latex2mathml.converter from collections import deque from io import StringIO, BytesIO HAS_IMAGE_EDITOR = hasattr(gr, "ImageEditor") HAS_PAINT = hasattr(gr, "Paint") HAS_BRUSH = hasattr(gr, "Brush") HAS_ERASER = hasattr(gr, "Eraser") HAS_REGION_WORKSPACE = HAS_PAINT or HAS_IMAGE_EDITOR # Model options — swap MODEL_NAME to reduce VRAM usage on GPUs with <= 8GB # # Full precision BF16 (~8GB VRAM) — original, highest accuracy MODEL_NAME = 'deepseek-ai/DeepSeek-OCR-2' # # FP8 dynamic quantization (~3.5GB VRAM) — ~50% VRAM reduction, 3750 downloads/mo # Requires Ampere GPU or newer (RTX 3070 is supported) # MODEL_NAME = 'richarddavison/DeepSeek-OCR-2-FP8' # # 8-bit quantization (~4GB VRAM) — same stack (torch 2.6, flash-attn 2.7.3, py3.12) # Explicitly supports dynamic resolution (0-6 patches), 140 downloads/mo # MODEL_NAME = 'mzbac/DeepSeek-OCR-2-8bit' tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) # flash_attention_2 requires a CUDA device at init time — not available on ZeroGPU at # module load. DeepseekOCR2 only supports 'flash_attention_2' and 'eager'; sdpa is not # implemented for this model class. Fall back to 'eager' when no GPU is present. # Locally with CUDA, flash_attention_2 is used for maximum throughput. _attn_impl = 'flash_attention_2' if torch.cuda.is_available() else 'eager' model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation=_attn_impl, torch_dtype=torch.bfloat16, trust_remote_code=True, use_safetensors=True).eval() # .cuda() is NOT called here — on ZeroGPU, GPU is only available inside @spaces.GPU # functions. Locally, model.cuda() is called inside process_image on first run. BASE_SIZE = 1024 IMAGE_SIZE = 768 CROP_MODE = True WORKSPACE_EDITOR_HEIGHT = 640 WORKSPACE_EDITOR_WIDTH_EST = 980 WORKSPACE_DEFAULT_SCALE = 89 GROUNDING_PATTERN = re.compile(r'<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>', re.DOTALL) INFER_DEBUG_FILTERS = ['PATCHES', '====', 'BASE:', 'directly resize', 'NO PATCHES', 'torch.Size', '%|'] EQUATION_ZOOM_PROMPT = "\n<|grounding|>Locate each individual equation or math line." EQUATION_LINE_OCR_PROMPT = "\nRead the math expression exactly as written. Return only the equation text." EQUATION_ZOOM_MAX_CANDIDATES = 6 EQUATION_ZOOM_MIN_AREA = 0.05 EQUATION_ZOOM_MIN_DIM = 0.24 EQUATION_ZOOM_PADDING = 0.025 EQUATION_ZOOM_MAX_ASPECT = 12.0 EQUATION_DETAIL_MAX_BOXES = 24 EQUATION_DETAIL_IOU_DEDUPE = 0.7 EQUATION_LINE_IOU_DEDUPE = 0.55 EQUATION_LINE_MIN_AREA = 0.0008 EQUATION_LINE_MIN_W = 0.03 EQUATION_LINE_MIN_H = 0.01 EQUATION_LINE_MAX_ASPECT = 30.0 MATH_LABEL_HINTS = ("formula", "equation", "math") MATH_STRONG_MARKERS = ("\\(", "\\[", "\\frac", "\\sum", "\\int", "\\sqrt", "\\lim", "\\begin{") MATH_WEAK_MARKERS = ("^", "_", "=", "+", "\\cdot", "\\times") TASK_PROMPTS = { "📋 Markdown": {"prompt": "\n<|grounding|>Convert the document to markdown.", "has_grounding": True}, "📝 Free OCR": {"prompt": "\nFree OCR.", "has_grounding": False}, "📍 Locate": {"prompt": "\nLocate <|ref|>text<|/ref|> in the image.", "has_grounding": True}, "🔍 Describe": {"prompt": "\nDescribe this image in detail.", "has_grounding": False}, "✏️ Custom": {"prompt": "", "has_grounding": False} } def extract_grounding_references(text): refs = [] seen = set() for entry in _extract_grounding_entries(text): coord_text = repr(entry["coords"]) key = ( entry["label"].strip().lower(), tuple( (round(c[0], 1), round(c[1], 1), round(c[2], 1), round(c[3], 1)) for c in entry["coords"] ), ) if key in seen: continue seen.add(key) raw = f'<|ref|>{entry["label"]}<|/ref|><|det|>{coord_text}<|/det|>' refs.append((raw, entry["label"], coord_text)) return refs def _parse_coord_payload(payload): if isinstance(payload, str): try: coords = ast.literal_eval(payload.strip()) except (SyntaxError, ValueError): return [] else: coords = payload if isinstance(coords, (tuple, list)) and coords and isinstance(coords[0], (int, float)): coords = [coords] if not isinstance(coords, list): return [] out = [] for c in coords: if not isinstance(c, (list, tuple)) or len(c) < 4: continue x1, y1, x2, y2 = [float(v) for v in c[:4]] x1, x2 = sorted((max(0.0, min(999.0, x1)), max(0.0, min(999.0, x2)))) y1, y2 = sorted((max(0.0, min(999.0, y1)), max(0.0, min(999.0, y2)))) if x2 <= x1 or y2 <= y1: continue out.append([x1, y1, x2, y2]) return out def _extract_grounding_entries(raw_text: str): if not raw_text: return [] entries = [] last_end = 0 for m in GROUNDING_PATTERN.finditer(raw_text): label = m.group(1).strip() or "text" coords = _parse_coord_payload(m.group(2)) if not coords: continue text_chunk = raw_text[last_end:m.start()].strip() entries.append({ "label": label, "coords": coords, "text": text_chunk, }) last_end = m.end() return entries def _math_marker_score(text_chunk: str) -> int: score = 0 for marker in MATH_STRONG_MARKERS: if marker in text_chunk: score += 3 for marker in MATH_WEAK_MARKERS: if marker in text_chunk: score += 1 return score def _box_iou(a, b): ax1, ay1, ax2, ay2 = a bx1, by1, bx2, by2 = b inter_x1 = max(ax1, bx1) inter_y1 = max(ay1, by1) inter_x2 = min(ax2, bx2) inter_y2 = min(ay2, by2) if inter_x2 <= inter_x1 or inter_y2 <= inter_y1: return 0.0 inter = (inter_x2 - inter_x1) * (inter_y2 - inter_y1) area_a = max(1e-9, (ax2 - ax1) * (ay2 - ay1)) area_b = max(1e-9, (bx2 - bx1) * (by2 - by1)) union = area_a + area_b - inter return inter / union if union > 0 else 0.0 def _dedupe_boxes(boxes, iou_threshold): kept = [] for box in sorted(boxes, key=lambda b: ((b[2] - b[0]) * (b[3] - b[1]))): if any(_box_iou(box, other) >= iou_threshold for other in kept): continue kept.append(box) return kept def _is_math_candidate(label: str, text_chunk: str, box): label_l = label.lower() box_w = (box[2] - box[0]) / 999.0 box_h = (box[3] - box[1]) / 999.0 area = box_w * box_h aspect = max(box_w / max(1e-9, box_h), box_h / max(1e-9, box_w)) has_math_label = any(hint in label_l for hint in MATH_LABEL_HINTS) has_math_text = _math_marker_score(text_chunk) >= 3 is_large = area >= EQUATION_ZOOM_MIN_AREA or box_w >= EQUATION_ZOOM_MIN_DIM or box_h >= EQUATION_ZOOM_MIN_DIM return (has_math_label or has_math_text) and is_large and aspect <= EQUATION_ZOOM_MAX_ASPECT def _map_crop_box_to_page(sub_box, crop_px, img_w, img_h): crop_x1, crop_y1, crop_x2, crop_y2 = crop_px crop_w = max(1, crop_x2 - crop_x1) crop_h = max(1, crop_y2 - crop_y1) page_x1 = ((crop_x1 + (sub_box[0] / 999.0) * crop_w) / img_w) * 999.0 page_y1 = ((crop_y1 + (sub_box[1] / 999.0) * crop_h) / img_h) * 999.0 page_x2 = ((crop_x1 + (sub_box[2] / 999.0) * crop_w) / img_w) * 999.0 page_y2 = ((crop_y1 + (sub_box[3] / 999.0) * crop_h) / img_h) * 999.0 return _parse_coord_payload([[page_x1, page_y1, page_x2, page_y2]])[0] def draw_bounding_boxes(image, refs, extract_images=False): img_w, img_h = image.size img_draw = image.copy() draw = ImageDraw.Draw(img_draw) overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) draw2 = ImageDraw.Draw(overlay) font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 15) crops = [] color_map = {} np.random.seed(42) for ref in refs: label = ref[1] if label not in color_map: color_map[label] = (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255)) color = color_map[label] coords = _parse_coord_payload(ref[2]) color_a = color + (60,) for box in coords: x1, y1, x2, y2 = int(box[0]/999*img_w), int(box[1]/999*img_h), int(box[2]/999*img_w), int(box[3]/999*img_h) if extract_images and label == 'image': crops.append(image.crop((x1, y1, x2, y2))) width = 5 if label == 'title' else 3 draw.rectangle([x1, y1, x2, y2], outline=color, width=width) draw2.rectangle([x1, y1, x2, y2], fill=color_a) text_bbox = draw.textbbox((0, 0), label, font=font) tw, th = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] ty = max(0, y1 - 20) draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], fill=color) draw.text((x1 + 2, ty + 2), label, font=font, fill=(255, 255, 255)) img_draw.paste(overlay, (0, 0), overlay) return img_draw, crops def _extract_labeled_crops_from_refs(image, refs, max_items=24): img_w, img_h = image.size items = [] seen = set() for ref in refs: label = str(ref[1]) coords = _parse_coord_payload(ref[2]) for box in coords: x1 = int(box[0] / 999.0 * img_w) y1 = int(box[1] / 999.0 * img_h) x2 = int(box[2] / 999.0 * img_w) y2 = int(box[3] / 999.0 * img_h) if x2 - x1 < 8 or y2 - y1 < 8: continue key = (label.lower(), x1, y1, x2, y2) if key in seen: continue seen.add(key) crop = image.crop((x1, y1, x2, y2)) caption = f"{label} ({crop.width}x{crop.height})" items.append((crop, caption)) if len(items) >= max_items: return items return items def clean_output(text, include_images=False): if not text: return "" pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' matches = re.findall(pattern, text, re.DOTALL) img_num = 0 for match in matches: if '<|ref|>image<|/ref|>' in match[0]: if include_images: text = text.replace(match[0], f'\n\n**[Figure {img_num + 1}]**\n\n', 1) img_num += 1 else: text = text.replace(match[0], '', 1) else: text = re.sub(rf'(?m)^[^\n]*{re.escape(match[0])}[^\n]*\n?', '', text) text = _strip_malformed_grounding(text) text = _dedupe_repeated_math_blocks(text) return text.strip() def _strip_malformed_grounding(text: str) -> str: """Remove incomplete grounding tags that can leak into OCR markdown/text.""" if not text: return "" line_patterns = [ r'(?m)^[^\n]*<\|ref\|>.*?<\|/ref\|><\|det\|>.*?(?:<\|/det\|>)?[^\n]*\n?', r'(?m)^[^\n]*<\|det\|>.*?(?:<\|/det\|>)?[^\n]*\n?', r'(?m)^[^\n]*<\|/?ref\|>[^\n]*\n?', ] for p in line_patterns: text = re.sub(p, '', text) text = re.sub(r'<\|/?ref\|>', '', text) text = re.sub(r'<\|/?det\|>', '', text) return text def _equation_text_key(text: str) -> str: if not text: return "" key = text.strip() key = re.sub(r'\\\[(.+?)\\\]', r'\1', key, flags=re.DOTALL) key = re.sub(r'\\\((.+?)\\\)', r'\1', key, flags=re.DOTALL) key = re.sub(r'\$\$(.+?)\$\$', r'\1', key, flags=re.DOTALL) key = re.sub(r'\^\{([A-Za-z0-9])\}', r'^\1', key) key = re.sub(r'_\{([A-Za-z0-9])\}', r'_\1', key) key = re.sub(r'\s+', '', key) return key.lower() def _dedupe_repeated_math_blocks(text: str) -> str: if not text: return "" pattern = re.compile(r'\\\[(.+?)\\\]|\\\((.+?)\\\)|\$\$(.+?)\$\$', re.DOTALL) seen = set() out = [] last = 0 removed_any = False for m in pattern.finditer(text): out.append(text[last:m.start()]) expr = m.group(1) or m.group(2) or m.group(3) or "" key = _equation_text_key(expr) if key and key in seen: removed_any = True else: if key: seen.add(key) out.append(m.group(0)) last = m.end() out.append(text[last:]) merged = ''.join(out) if removed_any: merged = re.sub(r'\n{3,}', '\n\n', merged) return merged PREVIEW_CSS = """ """ def _inject_spatial_gap_placeholders(text: str): """Preserve runs of blank lines so OCR spacing is visible in preview.""" gaps: dict[str, int] = {} counter = [0] def repl(m): key = f'ZZOCRGAP{counter[0]}ZZ' counter[0] += 1 # Two newlines are a normal paragraph break; extras represent vertical spacing. gaps[key] = max(1, len(m.group(0)) - 2) return f'\n\n{key}\n\n' return re.sub(r'\n{3,}', repl, text), gaps def _restore_spatial_gap_placeholders(html: str, gaps: dict[str, int]) -> str: if not gaps: return html for key, extra_lines in gaps.items(): gap_em = min(10.0, 0.9 * extra_lines) block = f'
' html = html.replace(f'

{key}

', block) html = html.replace(key, block) return html def _to_mathml(latex: str, display: bool) -> str: """Convert a LaTeX string to MathML. Falls back to a code block on error.""" # Fix OCR error: \frac{n/m} (single-argument fraction) → \frac{n}{m} latex = re.sub(r'\\frac\{(\d+)/(\d+)\}(?!\s*\{)', r'\\frac{\1}{\2}', latex) try: mathml = latex2mathml.converter.convert(latex) if display: mathml = re.sub(r'{escaped}' return f'{escaped}' def to_math_html(text: str) -> str: """Convert model markdown output to HTML with server-side MathML rendering. Uses a placeholder approach: math is extracted and replaced with unique tokens before the markdown pass, then swapped back afterwards. This avoids Python-Markdown mishandling multi-line
blocks that contain blank lines. """ if not text: return "" blocks: dict[str, str] = {} literals: dict[str, str] = {} counter = [0] def display_block(m): key = f'ZZDISPLAYMATH{counter[0]}ZZ' counter[0] += 1 expr = m.group(1).strip() blocks[key] = f'
{_to_mathml(expr, display=True)}
' literals[key] = f'\\[{expr}\\]' return f'\n\n{key}\n\n' def inline_math(m): key = f'ZZINLINEMATH{counter[0]}ZZ' counter[0] += 1 expr = m.group(1).strip() blocks[key] = _to_mathml(expr, display=False) literals[key] = f'\\({expr}\\)' return key # Replace display math \[...\] with placeholder tokens text = re.sub(r'\\\[(.+?)\\\]', display_block, text, flags=re.DOTALL) # Remove orphaned \[ with no matching \] (truncated model output) text = re.sub(r'\\\[.*', '', text, flags=re.DOTALL) # Replace inline math \(...\) with placeholder tokens text = re.sub(r'\\\((.+?)\\\)', inline_math, text) text, gaps = _inject_spatial_gap_placeholders(text) # Run markdown on text that now contains only safe placeholder tokens html = md_lib.markdown(text, extensions=['tables', 'fenced_code', 'sane_lists', 'nl2br']) # Protect rendered code/pre blocks so placeholder swap never mutates literal code. protected_blocks: dict[str, str] = {} protected_counter = [0] def _protect_code_html(m): token = f'ZZCODEHTML{protected_counter[0]}ZZ' protected_counter[0] += 1 protected_blocks[token] = m.group(0) return token html = re.sub(r']*>.*?', _protect_code_html, html, flags=re.DOTALL) html = re.sub(r']*>.*?', _protect_code_html, html, flags=re.DOTALL) # Swap placeholders back for MathML/HTML (handle

KEY

wrapping too) for key, value in blocks.items(): html = html.replace(f'

{key}

', value) html = html.replace(key, value) # Restore protected literal code/pre blocks unchanged. for token, original in protected_blocks.items(): html = html.replace(token, original) # Placeholders left at this stage occur inside code/pre; keep them literal. for key, literal in literals.items(): html = html.replace(key, html_lib.escape(literal)) html = _restore_spatial_gap_placeholders(html, gaps) return f'
{html}
' def to_mathjax_html(text: str) -> str: """Render markdown to HTML and typeset math client-side with MathJax.""" if not text: return "" text, gaps = _inject_spatial_gap_placeholders(text) html = md_lib.markdown(text, extensions=['tables', 'fenced_code', 'sane_lists', 'nl2br']) html = _restore_spatial_gap_placeholders(html, gaps) return f'
{html}
' def _grounding_blocks_from_raw(raw_text: str): blocks = [] for entry in _extract_grounding_entries(raw_text): label = entry["label"] text = entry["text"].strip() coords = entry["coords"] for idx, c in enumerate(coords): blocks.append({ "label": label, "text": text if idx == 0 else "", "x1": c[0], "y1": c[1], "x2": c[2], "y2": c[3], }) return blocks def to_spatial_html(raw_text: str, markdown_text: str) -> str: """Render OCR content using grounding boxes for spatially-positioned blocks.""" blocks = _grounding_blocks_from_raw(raw_text) if not blocks: return to_mathjax_html(markdown_text) used_text = 0 rendered = [] palette = { "title": "#8b5cf6", "text": "#2563eb", "image": "#059669", "table": "#d97706", "formula": "#dc2626", } for i, b in enumerate(sorted(blocks, key=lambda x: (x["y1"], x["x1"]))): label = b["label"] color = palette.get(label.lower(), "#4b5563") body = b["text"].strip() if body: used_text += len(body) body_text, gaps = _inject_spatial_gap_placeholders(body) body_html = md_lib.markdown(body_text, extensions=['tables', 'fenced_code', 'sane_lists', 'nl2br']) body_html = _restore_spatial_gap_placeholders(body_html, gaps) else: body_html = "" if not body_html: body_html = f"

{html_lib.escape(label)}

" left = b["x1"] / 999.0 * 100.0 top = b["y1"] / 999.0 * 100.0 width = max(1.0, (b["x2"] - b["x1"]) / 999.0 * 100.0) height = max(1.2, (b["y2"] - b["y1"]) / 999.0 * 100.0) rendered.append( f"""
{html_lib.escape(label)}
{body_html}
""" ) fallback = "" if markdown_text and used_text < max(120, int(len(markdown_text) * 0.4)): fallback_html = to_mathjax_html(markdown_text) fallback = f"""
Show full linear markdown rendering {fallback_html}
""" return f"""
{''.join(rendered)}
{fallback}
""" def embed_images(markdown, crops): if not crops: return markdown for i, img in enumerate(crops): buf = BytesIO() img.save(buf, format="PNG") b64 = base64.b64encode(buf.getvalue()).decode() markdown = markdown.replace(f'**[Figure {i + 1}]**', f'\n\n![Figure {i + 1}](data:image/png;base64,{b64})\n\n', 1) return markdown def _infer_with_prompt(image, prompt, crop_mode=None): if crop_mode is None: crop_mode = CROP_MODE tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') image.save(tmp.name, 'JPEG', quality=95) tmp.close() out_dir = tempfile.mkdtemp() stdout = sys.stdout capture = StringIO() sys.stdout = capture try: model.infer( tokenizer=tokenizer, prompt=prompt, image_file=tmp.name, output_path=out_dir, base_size=BASE_SIZE, image_size=IMAGE_SIZE, crop_mode=crop_mode, save_results=False ) finally: sys.stdout = stdout os.unlink(tmp.name) shutil.rmtree(out_dir, ignore_errors=True) lines = [ l for l in capture.getvalue().split('\n') if l.strip() and not any(s in l for s in INFER_DEBUG_FILTERS) ] return '\n'.join(lines).strip() def _refine_equation_refs(image, raw_text): entries = _extract_grounding_entries(raw_text) if not entries: return [] img_w, img_h = image.size candidates = [] for entry in entries: for box in entry["coords"]: if _is_math_candidate(entry["label"], entry["text"], box): area = (box[2] - box[0]) * (box[3] - box[1]) candidates.append((area, entry, box)) if not candidates: return [] candidates.sort(key=lambda x: x[0], reverse=True) refined_refs = [] for _, entry, box in candidates[:EQUATION_ZOOM_MAX_CANDIDATES]: x1 = int(box[0] / 999.0 * img_w) y1 = int(box[1] / 999.0 * img_h) x2 = int(box[2] / 999.0 * img_w) y2 = int(box[3] / 999.0 * img_h) box_w = max(1, x2 - x1) box_h = max(1, y2 - y1) pad_x = max(8, int(box_w * EQUATION_ZOOM_PADDING)) pad_y = max(8, int(box_h * EQUATION_ZOOM_PADDING)) crop_x1 = max(0, x1 - pad_x) crop_y1 = max(0, y1 - pad_y) crop_x2 = min(img_w, x2 + pad_x) crop_y2 = min(img_h, y2 + pad_y) if crop_x2 - crop_x1 < 32 or crop_y2 - crop_y1 < 32: continue crop = image.crop((crop_x1, crop_y1, crop_x2, crop_y2)) sub_result = _infer_with_prompt(crop, EQUATION_ZOOM_PROMPT) sub_entries = _extract_grounding_entries(sub_result) if not sub_entries: continue mapped_boxes = [] for sub in sub_entries: sub_label = sub["label"].lower() sub_text = sub["text"] is_math_sub = any(hint in sub_label for hint in MATH_LABEL_HINTS) or _math_marker_score(sub_text) >= 3 if sub_label in ("image", "table") or not is_math_sub: continue for sub_box in sub["coords"]: mapped = _map_crop_box_to_page(sub_box, (crop_x1, crop_y1, crop_x2, crop_y2), img_w, img_h) w = (mapped[2] - mapped[0]) / 999.0 h = (mapped[3] - mapped[1]) / 999.0 if w * h < 0.0004: continue mapped_boxes.append(mapped) if not mapped_boxes: continue mapped_boxes = _dedupe_boxes(mapped_boxes, EQUATION_DETAIL_IOU_DEDUPE) mapped_boxes = sorted(mapped_boxes, key=lambda b: (b[1], b[0]))[:EQUATION_DETAIL_MAX_BOXES] if len(mapped_boxes) < 2: continue merged_text = repr(mapped_boxes) label = "equation_detail" raw = f'<|ref|>{label}<|/ref|><|det|>{merged_text}<|/det|>' refined_refs.append((raw, label, merged_text)) return refined_refs def _norm_box_to_pixels(box, img_w, img_h, pad_ratio=0.0): x1 = int(box[0] / 999.0 * img_w) y1 = int(box[1] / 999.0 * img_h) x2 = int(box[2] / 999.0 * img_w) y2 = int(box[3] / 999.0 * img_h) if pad_ratio > 0: pad_x = max(1, int((x2 - x1) * pad_ratio)) pad_y = max(1, int((y2 - y1) * pad_ratio)) x1 -= pad_x y1 -= pad_y x2 += pad_x y2 += pad_y x1 = max(0, min(img_w - 1, x1)) y1 = max(0, min(img_h - 1, y1)) x2 = max(x1 + 1, min(img_w, x2)) y2 = max(y1 + 1, min(img_h, y2)) return (x1, y1, x2, y2) def _detect_equation_line_boxes(image, infer_crop_mode=None): detect_raw = _infer_with_prompt(image, EQUATION_ZOOM_PROMPT, crop_mode=infer_crop_mode) entries = _extract_grounding_entries(detect_raw) if not entries: return [], detect_raw boxes = [] for entry in entries: label_l = entry["label"].lower() text_chunk = entry["text"] if label_l in ("image", "table"): continue for box in entry["coords"]: w = (box[2] - box[0]) / 999.0 h = (box[3] - box[1]) / 999.0 area = w * h aspect = max(w / max(1e-9, h), h / max(1e-9, w)) looks_math = any(hint in label_l for hint in MATH_LABEL_HINTS) or _math_marker_score(text_chunk) >= 2 if area < EQUATION_LINE_MIN_AREA or w < EQUATION_LINE_MIN_W or h < EQUATION_LINE_MIN_H: continue if aspect > EQUATION_LINE_MAX_ASPECT: continue if not looks_math and area < 0.004: continue boxes.append(box) boxes = _dedupe_boxes(boxes, EQUATION_LINE_IOU_DEDUPE) boxes = sorted(boxes, key=lambda b: (round(b[1], 3), b[0])) return boxes, detect_raw def _process_equation_lines_separately(image, infer_crop_mode=None): boxes, detect_raw = _detect_equation_line_boxes(image, infer_crop_mode=infer_crop_mode) if not boxes: return None img_w, img_h = image.size cleaned_parts = [] markdown_parts = [] raw_parts = [f"## Detection\n\n{detect_raw}".strip()] refs = [] crops = [] seen_line_keys = set() for i, box in enumerate(boxes, 1): x1, y1, x2, y2 = _norm_box_to_pixels(box, img_w, img_h, pad_ratio=0.01) crop = image.crop((x1, y1, x2, y2)) line_raw = _infer_with_prompt(crop, EQUATION_LINE_OCR_PROMPT, crop_mode=False) line_clean = clean_output(line_raw, False).strip() if not line_clean: continue line_key = _equation_text_key(line_clean) if line_key and line_key in seen_line_keys: continue if line_key: seen_line_keys.add(line_key) line_label = f"Eq {i}" line_markdown = line_clean if "$$" not in line_markdown and "\\[" not in line_markdown and "\\(" not in line_markdown: line_markdown = f"$$\n{line_markdown}\n$$" cleaned_parts.append(f"{line_label}: {line_clean}") markdown_parts.append(f"### {line_label}\n\n{line_markdown}") raw_parts.append(f"## {line_label}\n\n{line_raw}") coord_text = repr([box]) raw_ref = f'<|ref|>eq_line_{i}<|/ref|><|det|>{coord_text}<|/det|>' refs.append((raw_ref, line_label, coord_text)) crops.append((crop, line_label)) if not cleaned_parts: return None img_out, _ = draw_bounding_boxes(image, refs, extract_images=False) cleaned = "\n".join(cleaned_parts).strip() markdown = "\n\n".join(markdown_parts).strip() raw = "\n\n".join(raw_parts).strip() return cleaned, markdown, raw, img_out, crops @spaces.GPU(duration=90) def process_image(image, task, custom_prompt, enable_equation_zoom=True, infer_crop_mode=None, separate_equation_lines=False): model.cuda() # GPU is available here — works on ZeroGPU and locally if image is None: return "Error: Upload an image", "", "", None, [] if not separate_equation_lines and task in ["✏️ Custom", "📍 Locate"] and not custom_prompt.strip(): return "Please enter a prompt", "", "", None, [] if image.mode in ('RGBA', 'LA', 'P'): image = image.convert('RGB') image = ImageOps.exif_transpose(image) if separate_equation_lines: separate_result = _process_equation_lines_separately(image, infer_crop_mode=infer_crop_mode) if separate_result is not None: return separate_result msg = "No separate equation lines detected. Try Selected Region + freehand highlight around the equation steps." return msg, msg, msg, None, [] if task == "✏️ Custom": prompt = f"\n{custom_prompt.strip()}" has_grounding = '<|grounding|>' in custom_prompt elif task == "📍 Locate": prompt = f"\nLocate <|ref|>{custom_prompt.strip()}<|/ref|> in the image." has_grounding = True else: prompt = TASK_PROMPTS[task]["prompt"] has_grounding = TASK_PROMPTS[task]["has_grounding"] result = _infer_with_prompt(image, prompt, crop_mode=infer_crop_mode) if not result: return "No text detected", "", "", None, [] cleaned = clean_output(result, False) markdown = clean_output(result, True) img_out = None crops = [] figure_crops = [] result_for_layout = result if has_grounding and '<|ref|>' in result: refs = extract_grounding_references(result) if task == "📋 Markdown" and enable_equation_zoom: refs.extend(_refine_equation_refs(image, result)) if refs: img_out, figure_crops = draw_bounding_boxes(image, refs, True) crops = _extract_labeled_crops_from_refs(image, refs) synthetic = [r[0] for r in refs if r[1] == "equation_detail"] if synthetic: result_for_layout = result + "\n" + "\n".join(synthetic) markdown = embed_images(markdown, figure_crops) if not crops and figure_crops: crops = _label_gallery_items(figure_crops, prefix="Figure") return cleaned, markdown, result_for_layout, img_out, crops @spaces.GPU(duration=90) def process_pdf(path, task, custom_prompt, page_num, enable_equation_zoom=True, infer_crop_mode=None, separate_equation_lines=False): doc = fitz.open(path) total_pages = len(doc) if page_num < 1 or page_num > total_pages: doc.close() return f"Invalid page number. PDF has {total_pages} pages.", "", "", None, [] page = doc.load_page(page_num - 1) pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False) img = Image.open(BytesIO(pix.tobytes("png"))) doc.close() return process_image( img, task, custom_prompt, enable_equation_zoom=enable_equation_zoom, infer_crop_mode=infer_crop_mode, separate_equation_lines=separate_equation_lines, ) def process_file(path, task, custom_prompt, page_num, enable_equation_zoom=True, infer_crop_mode=None, separate_equation_lines=False): if not path: return "Error: Upload a file", "", "", None, [] if path.lower().endswith('.pdf'): return process_pdf( path, task, custom_prompt, page_num, enable_equation_zoom=enable_equation_zoom, infer_crop_mode=infer_crop_mode, separate_equation_lines=separate_equation_lines, ) else: return process_image( Image.open(path), task, custom_prompt, enable_equation_zoom=enable_equation_zoom, infer_crop_mode=infer_crop_mode, separate_equation_lines=separate_equation_lines, ) def _extract_editor_background(editor_value): if editor_value is None: return None if isinstance(editor_value, Image.Image): return editor_value if isinstance(editor_value, dict): background = editor_value.get("background") if isinstance(background, Image.Image): return background composite = editor_value.get("composite") if isinstance(composite, Image.Image): return composite return None def _to_rgba_image(obj): if isinstance(obj, dict): for k in ("image", "layer", "composite", "background", "mask"): if k in obj: return _to_rgba_image(obj[k]) return None if isinstance(obj, Image.Image): return obj.convert("RGBA") if isinstance(obj, np.ndarray): arr = obj if arr.ndim == 2: arr = np.stack([arr, arr, arr, np.full_like(arr, 255)], axis=-1) elif arr.ndim == 3 and arr.shape[2] == 3: alpha = np.full((arr.shape[0], arr.shape[1], 1), 255, dtype=arr.dtype) arr = np.concatenate([arr, alpha], axis=2) elif arr.ndim != 3 or arr.shape[2] != 4: return None return Image.fromarray(arr.astype(np.uint8), mode="RGBA") return None def _to_mask_array(obj): if obj is None: return None if isinstance(obj, dict): for k in ("mask", "image", "layer", "composite", "background"): if k in obj: arr = _to_mask_array(obj[k]) if arr is not None: return arr return None if isinstance(obj, Image.Image): arr = np.asarray(obj) elif isinstance(obj, np.ndarray): arr = obj else: return None if arr.ndim == 2: return arr > 0 if arr.ndim == 3: if arr.shape[2] >= 4: return arr[:, :, 3] > 0 return np.max(arr[:, :, :3], axis=2) > 0 return None def _locate_patch_bbox(base_image: Image.Image, patch_image: Image.Image): """Approximate patch location in base image using downscaled SSD search.""" if base_image is None or patch_image is None: return None base = np.asarray(base_image.convert("L"), dtype=np.float32) patch = np.asarray(patch_image.convert("L"), dtype=np.float32) bh, bw = base.shape[:2] ph, pw = patch.shape[:2] if ph <= 0 or pw <= 0 or ph > bh or pw > bw: return None max_dim = max(bh, bw) scale = min(1.0, 320.0 / max_dim) if max_dim > 0 else 1.0 if scale < 1.0: new_bw = max(1, int(round(bw * scale))) new_bh = max(1, int(round(bh * scale))) new_pw = max(1, int(round(pw * scale))) new_ph = max(1, int(round(ph * scale))) base_small = np.asarray(Image.fromarray(base.astype(np.uint8)).resize((new_bw, new_bh), Image.Resampling.BILINEAR), dtype=np.float32) patch_small = np.asarray(Image.fromarray(patch.astype(np.uint8)).resize((new_pw, new_ph), Image.Resampling.BILINEAR), dtype=np.float32) else: base_small = base patch_small = patch sbh, sbw = base_small.shape sph, spw = patch_small.shape if sph > sbh or spw > sbw: return None best_score = float("inf") best_x = 0 best_y = 0 for y in range(sbh - sph + 1): row = base_small[y:y + sph, :] windows = np.lib.stride_tricks.sliding_window_view(row, spw, axis=1) # windows: (sph, sbw-spw+1, spw) diff = windows - patch_small[:, None, :] scores = np.mean(diff * diff, axis=(0, 2)) x = int(np.argmin(scores)) score = float(scores[x]) if score < best_score: best_score = score best_x = x best_y = y if scale < 1.0: x1 = int(round(best_x / scale)) y1 = int(round(best_y / scale)) x2 = int(round((best_x + spw) / scale)) y2 = int(round((best_y + sph) / scale)) else: x1, y1, x2, y2 = best_x, best_y, best_x + spw, best_y + sph x1 = max(0, min(bw - 1, x1)) y1 = max(0, min(bh - 1, y1)) x2 = max(x1 + 1, min(bw, x2)) y2 = max(y1 + 1, min(bh, y2)) return (x1, y1, x2, y2) def _component_boxes(binary_mask, min_pixels=24): h, w = binary_mask.shape visited = np.zeros((h, w), dtype=bool) boxes = [] neighbors = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)] ys, xs = np.where(binary_mask) for sy, sx in zip(ys.tolist(), xs.tolist()): if visited[sy, sx]: continue q = deque([(sy, sx)]) visited[sy, sx] = True min_x = max_x = sx min_y = max_y = sy count = 0 while q: y, x = q.popleft() count += 1 if x < min_x: min_x = x if x > max_x: max_x = x if y < min_y: min_y = y if y > max_y: max_y = y for dy, dx in neighbors: ny, nx = y + dy, x + dx if ny < 0 or ny >= h or nx < 0 or nx >= w: continue if visited[ny, nx] or not binary_mask[ny, nx]: continue visited[ny, nx] = True q.append((ny, nx)) if count >= min_pixels: boxes.append((min_x, min_y, max_x + 1, max_y + 1, count)) return boxes def _extract_regions_from_mask(background, mask): components = _component_boxes(mask, min_pixels=24) if not components: return [] regions = [] for x1, y1, x2, y2, _ in components: pad_x = max(2, int((x2 - x1) * 0.02)) pad_y = max(2, int((y2 - y1) * 0.02)) px1 = max(0, x1 - pad_x) py1 = max(0, y1 - pad_y) px2 = min(background.width, x2 + pad_x) py2 = min(background.height, y2 + pad_y) if px2 <= px1 or py2 <= py1: continue crop = background.crop((px1, py1, px2, py2)).convert("RGB") regions.append((crop, (px1, py1, px2, py2))) regions.sort( key=lambda item: (item[1][2] - item[1][0]) * (item[1][3] - item[1][1]), reverse=True, ) return regions def _editor_background_and_mask(editor_value): if not isinstance(editor_value, dict): return None, None background = _to_rgba_image(editor_value.get("background")) if background is None: background = _to_rgba_image(editor_value.get("image")) composite = _to_rgba_image(editor_value.get("composite")) layers = editor_value.get("layers") or [] if background is None: if composite is None: return None, None background = composite mask = _to_mask_array(editor_value.get("mask")) if mask is not None: if mask.shape[:2] != (background.height, background.width): mask_img = Image.fromarray(mask.astype(np.uint8) * 255, mode="L") nearest = Image.Resampling.NEAREST if hasattr(Image, "Resampling") else Image.NEAREST mask = np.asarray(mask_img.resize((background.width, background.height), nearest)) > 0 return background, mask if not isinstance(layers, list) or not layers: return background, None alpha_acc = np.zeros((background.height, background.width), dtype=np.uint8) for layer in layers: layer_img = _to_rgba_image(layer) if layer_img is None: continue if layer_img.size != background.size: nearest = Image.Resampling.NEAREST if hasattr(Image, "Resampling") else Image.NEAREST layer_img = layer_img.resize(background.size, nearest) layer_alpha = np.asarray(layer_img, dtype=np.uint8)[:, :, 3] alpha_acc = np.maximum(alpha_acc, layer_alpha) return background, (alpha_acc > 0) def _extract_selected_regions(editor_value, base_size=None, base_image=None): if editor_value is None: return [] if isinstance(editor_value, Image.Image): if base_size and tuple(editor_value.size) == tuple(base_size): return [] bbox = _locate_patch_bbox(base_image, editor_value) if base_image is not None else None return [(editor_value, bbox)] if not isinstance(editor_value, dict): return [] background, mask = _editor_background_and_mask(editor_value) layers = editor_value.get("layers") or [] if background is None: return [] if not isinstance(layers, list) or not layers: # No annotation layers; treat as explicit crop only if size changed from base. if base_size and tuple(background.size) == tuple(base_size): return [] patch = background.convert("RGB") bbox = _locate_patch_bbox(base_image, patch) if base_image is not None else None return [(patch, bbox)] if mask is None: return [] return _extract_regions_from_mask(background, mask) def _extract_new_drawn_regions(editor_value, base_size=None, base_image=None, consumed_mask=None): # For crop mode / explicit cropped image, fall back to classic extraction. if isinstance(editor_value, Image.Image): regions = _extract_selected_regions(editor_value, base_size=base_size, base_image=base_image) return regions, consumed_mask if not isinstance(editor_value, dict): return [], consumed_mask background, mask = _editor_background_and_mask(editor_value) layers = editor_value.get("layers") or [] if background is None: return [], consumed_mask has_layer_data = isinstance(layers, list) and len(layers) > 0 has_draw_data = (mask is not None) or has_layer_data # If there are no draw layers/mask, treat as explicit crop mode. if not has_draw_data: regions = _extract_selected_regions(editor_value, base_size=base_size, base_image=base_image) return regions, consumed_mask if mask is None: return [], consumed_mask if consumed_mask is None or not isinstance(consumed_mask, np.ndarray) or consumed_mask.shape != mask.shape: delta_mask = mask else: delta_mask = np.logical_and(mask, np.logical_not(consumed_mask)) regions = _extract_regions_from_mask(background, delta_mask) return regions, mask def _extract_selected_region(editor_value, base_size=None, base_image=None): regions = _extract_selected_regions(editor_value, base_size=base_size, base_image=base_image) if not regions: return None, None return regions[0] def _bbox_overlap_ratio(a, b): ax1, ay1, ax2, ay2 = a bx1, by1, bx2, by2 = b ix1 = max(ax1, bx1) iy1 = max(ay1, by1) ix2 = min(ax2, bx2) iy2 = min(ay2, by2) if ix2 <= ix1 or iy2 <= iy1: return 0.0, 0.0 inter = float((ix2 - ix1) * (iy2 - iy1)) area_a = float(max(1, (ax2 - ax1) * (ay2 - ay1))) area_b = float(max(1, (bx2 - bx1) * (by2 - by1))) return inter / area_a, inter / area_b def _is_duplicate_bbox(candidate_bbox, existing_bbox): iou = _box_iou(candidate_bbox, existing_bbox) cover_cand, cover_exist = _bbox_overlap_ratio(candidate_bbox, existing_bbox) return iou >= 0.85 or cover_cand >= 0.92 or cover_exist >= 0.97 def _draw_selected_region_boxes(image, boxes): if image is None or not boxes: return None refs = [] w, h = image.size for i, b in enumerate(boxes, 1): x1, y1, x2, y2 = b nx1 = max(0.0, min(999.0, x1 / max(1, w) * 999.0)) ny1 = max(0.0, min(999.0, y1 / max(1, h) * 999.0)) nx2 = max(0.0, min(999.0, x2 / max(1, w) * 999.0)) ny2 = max(0.0, min(999.0, y2 / max(1, h) * 999.0)) label = f"Region {i}" coord_text = repr([[nx1, ny1, nx2, ny2]]) raw = f'<|ref|>region_{i}<|/ref|><|det|>{coord_text}<|/det|>' refs.append((raw, label, coord_text)) img_out, _ = draw_bounding_boxes(image, refs, extract_images=False) return img_out def _region_gallery_items(regions): items = [] for i, r in enumerate(regions, 1): img = r["image"] label = f"Region {i}" if isinstance(img, Image.Image): label = f"{label} ({img.width}x{img.height})" items.append((img, label)) return items def _label_gallery_items(items, prefix=None): labeled = [] for i, item in enumerate(items, 1): if isinstance(item, tuple) and len(item) >= 2: img, label = item[0], str(item[1]) else: img, label = item, f"Item {i}" if prefix: label = f"{prefix} - {label}" if isinstance(img, Image.Image): label = f"{label} ({img.width}x{img.height})" labeled.append((img, label)) return labeled def _reset_selected_regions(): return [], [], "No saved regions." def _reset_drawn_mask(): return None def add_selected_region(editor_value, base_size, base_image, selected_regions, consumed_mask): candidates, updated_mask = _extract_new_drawn_regions( editor_value, base_size=base_size, base_image=base_image, consumed_mask=consumed_mask, ) regions = list(selected_regions or []) if not candidates: msg = "No region detected. Use Crop or draw/highlight a region first." return regions, _region_gallery_items(regions), msg, updated_mask existing_boxes = [r.get("bbox") for r in regions if r.get("bbox") is not None] added = 0 for region_img, bbox in candidates: if bbox is not None and any(_is_duplicate_bbox(bbox, eb) for eb in existing_boxes): continue regions.append({"image": region_img, "bbox": bbox}) if bbox is not None: existing_boxes.append(bbox) added += 1 if added == 0: msg = "No new region added. Draw one region, click Add Region, then draw the next region." return regions, _region_gallery_items(regions), msg, updated_mask msg = f"Added {added} region(s). {len(regions)} total. Zoom/pan is preserved." return regions, _region_gallery_items(regions), msg, updated_mask def clear_selected_regions(): return _reset_selected_regions() def clear_regions_preserve_view(editor_value): regions, gallery_items, msg = _reset_selected_regions() _, mask = _editor_background_and_mask(editor_value) return regions, gallery_items, msg, mask def _compose_ui_outputs(cleaned, markdown, raw, img_out, gallery_items): text_display = re.sub( r'\\\[(.+?)\\\]', lambda m: f'\n$$\n{m.group(1).strip()}\n$$\n', cleaned, flags=re.DOTALL ) text_display = re.sub(r'\\\((.+?)\\\)', lambda m: f'${m.group(1).strip()}$', text_display) dl_tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.md', mode='w', encoding='utf-8') dl_tmp.write(cleaned) dl_tmp.close() markdown_html = to_math_html(markdown) return ( text_display, cleaned, markdown_html, raw, img_out, gallery_items, gr.DownloadButton(value=dl_tmp.name, visible=True), ) def toggle_prompt(task): if task == "✏️ Custom": return gr.update(visible=True, label="Custom Prompt", placeholder="Add <|grounding|> for bounding boxes") elif task == "📍 Locate": return gr.update(visible=True, label="Text to Locate", placeholder="Enter text to locate") return gr.update(visible=False) def select_boxes(task): if task == "📍 Locate": return gr.update(selected="tab_boxes") return gr.update() def toggle_scope_ui(scope): if scope == "Selected Region": hint = ( "**Selected Region mode:** Draw/highlight on the workspace, click **Add Region** " "for each target area, then click **Extract**." ) return ( gr.update(value=hint), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), ) hint = "**Entire Page mode:** No drawing needed. Click **Extract** to process the full page." return ( gr.update(value=hint), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), ) def select_post_extract_tab(task, scope): if scope == "Selected Region" or task == "📍 Locate": return gr.update(selected="tab_boxes") return gr.update(selected="tab_text") def get_pdf_page_count(file_path): if not file_path or not file_path.lower().endswith('.pdf'): return 1 doc = fitz.open(file_path) count = len(doc) doc.close() return count def load_image(file_path, page_num=1): if not file_path: return None if file_path.lower().endswith('.pdf'): doc = fitz.open(file_path) page_idx = max(0, min(int(page_num) - 1, len(doc) - 1)) page = doc.load_page(page_idx) pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False) img = Image.open(BytesIO(pix.tobytes("png"))) doc.close() return img else: return Image.open(file_path) def _scale_workspace_image(img, workspace_scale): if img is None: return None # Keep native pixels for workspace quality. Gradio's in-canvas zoom controls # visual scale; pre-resampling here causes blurry math when users zoom in. return img def _prepare_workspace_image(img, workspace_scale=WORKSPACE_DEFAULT_SCALE): if img is None: return None, None, None display_img = _scale_workspace_image(img, workspace_scale) return display_img, (int(display_img.width), int(display_img.height)), display_img def load_image_with_size(file_path, page_num=1, workspace_scale=WORKSPACE_DEFAULT_SCALE): img = load_image(file_path, page_num) return _prepare_workspace_image(img, workspace_scale) def load_example_into_workspace(example_value): if example_value is None: return None, None, None file_path = None if isinstance(example_value, os.PathLike): file_path = os.fspath(example_value) elif isinstance(example_value, str): file_path = example_value elif isinstance(example_value, dict): path_candidate = example_value.get("path") or example_value.get("name") if isinstance(path_candidate, os.PathLike): file_path = os.fspath(path_candidate) elif isinstance(path_candidate, str): file_path = path_candidate elif isinstance(example_value, (list, tuple)) and example_value: first = example_value[0] if isinstance(first, os.PathLike): file_path = os.fspath(first) elif isinstance(first, str): file_path = first if file_path: img = load_image(file_path, 1) return _prepare_workspace_image(img, WORKSPACE_DEFAULT_SCALE) if isinstance(example_value, Image.Image): img = example_value else: maybe_rgba = _to_rgba_image(example_value) if maybe_rgba is None: return None, None, None img = maybe_rgba.convert("RGB") return _prepare_workspace_image(img, WORKSPACE_DEFAULT_SCALE) def load_example_into_workspace_and_reset(example_value): display_img, base_size, base_img = load_example_into_workspace(example_value) return display_img, base_size, base_img, [], [], "No saved regions.", None def sync_workspace_state(editor_value, current_base_image): background = _extract_editor_background(editor_value) if isinstance(background, Image.Image): return (int(background.width), int(background.height)), background if isinstance(current_base_image, Image.Image): return (int(current_base_image.width), int(current_base_image.height)), current_base_image return None, None def update_page_selector(file_path): if not file_path: return gr.update(visible=False) if file_path.lower().endswith('.pdf'): page_count = get_pdf_page_count(file_path) return gr.update(visible=True, maximum=page_count, value=1, minimum=1, label=f"Select Page (1-{page_count})") return gr.update(visible=False) blocks_kwargs = {"title": "DeepSeek-OCR-2"} if hasattr(gr, "themes") and hasattr(gr.themes, "Soft"): try: blocks_kwargs["theme"] = gr.themes.Soft() except Exception: pass with gr.Blocks(**blocks_kwargs) as demo: gr.Markdown(""" # 🧮 DeepSeek-OCR-2 — Math Rendering Edition **Convert documents to markdown, extract text, parse figures, and locate specific content with bounding boxes.** **Model uses DeepEncoder v2 and achieves 91.09% on OmniDocBench (+3.73% over v1).** Built on the original [DeepSeek-OCR-2 Demo](https://huggingface.co/spaces/merterbak/DeepSeek-OCR-2) by **Mert Erbak** — thank you for the excellent foundation. This fork adds **math rendering** in the Markdown Preview tab so that equations from scanned papers and textbooks display as proper math notation. """) region_editor = None workspace_base_size = gr.State(None) workspace_base_image = gr.State(None) selected_regions_state = gr.State([]) drawn_mask_state = gr.State(None) with gr.Row(): with gr.Column(scale=3): file_in = gr.File(label="Upload Image or PDF", file_types=["image", ".pdf"], type="filepath") with gr.Column(scale=1): page_selector = gr.Number(label="Select Page", value=1, minimum=1, step=1, visible=False) with gr.Row(): with gr.Column(scale=3): workspace_hint = gr.Markdown("**Entire Page mode:** No drawing needed. Click **Extract** to process the full page.") gr.Markdown("**Image Workspace (full page + region selection)**") if HAS_REGION_WORKSPACE: editor_kwargs = {} if HAS_BRUSH: try: highlight = ("#2563eb", 0.35) editor_kwargs["brush"] = gr.Brush( colors=[highlight], default_color=highlight, color_mode="fixed", default_size=22, ) except TypeError: try: editor_kwargs["brush"] = gr.Brush( colors=["rgba(37,99,235,0.35)"], default_color="rgba(37,99,235,0.35)", color_mode="fixed", default_size=22, ) except TypeError: editor_kwargs["brush"] = gr.Brush() if HAS_ERASER: try: editor_kwargs["eraser"] = gr.Eraser(default_size=26) except TypeError: editor_kwargs["eraser"] = gr.Eraser() if HAS_IMAGE_EDITOR: try: region_editor = gr.ImageEditor( label="Image Workspace", show_label=False, type="pil", height=WORKSPACE_EDITOR_HEIGHT, **editor_kwargs, ) except TypeError: try: region_editor = gr.ImageEditor( label="Image Workspace", show_label=False, height=WORKSPACE_EDITOR_HEIGHT, **editor_kwargs, ) except TypeError: region_editor = gr.ImageEditor( label="Image Workspace", show_label=False, height=WORKSPACE_EDITOR_HEIGHT, ) else: region_editor = gr.Paint( label="Image Workspace", show_label=False, type="pil", height=WORKSPACE_EDITOR_HEIGHT, **editor_kwargs, ) else: gr.Markdown("Region drawing requires a newer Gradio version with `Paint` or `ImageEditor` support.") region_editor = gr.State(None) with gr.Column(scale=1): gr.Markdown("### OCR Workflow") task = gr.Dropdown(list(TASK_PROMPTS.keys()), value="📋 Markdown", label="Task") input_scope = gr.Radio(["Entire Page", "Selected Region"], value="Entire Page", label="Input Scope") selection_controls = gr.Row(visible=False) with selection_controls: add_region_btn = gr.Button("Add Region", variant="secondary") clear_regions_btn = gr.Button("Clear Regions") selection_status = gr.Textbox(label="Region Selection Status", value="No saved regions.", interactive=False, visible=False) selected_regions_gallery = gr.Gallery( label="Selected Regions", show_label=True, columns=2, height=190, visible=False, object_fit="contain", ) with gr.Accordion("Advanced Options", open=False): equation_zoom = gr.Checkbox(label="Equation Zoom (multipass)", value=False) separate_eq_lines = gr.Checkbox(label="Detect Equation Lines Separately", value=False) prompt = gr.Textbox(label="Prompt", lines=2, visible=False) btn = gr.Button("Extract", variant="primary", size="lg") with gr.Row(): with gr.Column(scale=1): with gr.Tabs() as tabs: with gr.Tab("Text", id="tab_text"): text_out = gr.Textbox(lines=20, show_label=False) with gr.Tab("LaTeX", id="tab_text_latex"): latex_out = gr.Textbox(lines=20, show_label=False) with gr.Tab("Preview", id="tab_markdown"): md_out = gr.HTML("") with gr.Tab("Boxes", id="tab_boxes"): img_out = gr.Image(type="pil", height=560, show_label=False) with gr.Tab("Crops", id="tab_crops"): gallery = gr.Gallery(show_label=False, columns=3, height=420, object_fit="contain") with gr.Tab("Raw", id="tab_raw"): raw_out = gr.Textbox(lines=20, show_label=False) download_btn = gr.DownloadButton("Download Markdown", visible=False, variant="secondary") gr.Markdown("### Examples") with gr.Row(): with gr.Column(scale=2): image_examples = [ "examples/2022-0922 Section 13 Notes.png", "examples/2022-0922 Section 14 Notes.png", "examples/2022-0922 Section 15 Notes.png", ] if HAS_REGION_WORKSPACE and region_editor is not None: image_examples_input = gr.Image( label="Example Loader", type="filepath", visible=False, show_label=False, ) gr.Examples( label="Image Examples (click thumbnail to load into workspace)", examples=image_examples, inputs=[image_examples_input], outputs=[region_editor, workspace_base_size, workspace_base_image, selected_regions_state, selected_regions_gallery, selection_status, drawn_mask_state], fn=load_example_into_workspace_and_reset, run_on_click=True, cache_examples=False, ) else: gr.Examples( label="Image Examples", examples=[[p] for p in image_examples], inputs=[file_in], cache_examples=False, ) with gr.Column(scale=1): gr.Examples( label="PDF Examples", examples=[["examples/Gursoy Class Notes_ Accessibility Sandbox.pdf"]], inputs=[file_in], cache_examples=False, ) with gr.Accordion("ℹ️ Info", open=False): gr.Markdown(""" ### Configuration 1024 base + 768 patches with dynamic cropping (2-6 patches). 144 tokens per patch + 256 base tokens. ### Faculty Quick Workflow 1. Upload a page/image, then confirm **Task**. 2. Choose **Input Scope**: - `Entire Page` for the full page. - `Selected Region` for a specific area. 2a. Workspace keeps native image resolution for clarity. For very tall pages, it auto-boosts from tiny fit view toward ~88% width-friendly zoom. 3. For `Selected Region`, use the **Image Workspace**: - Recommended: freehand selection (draw/highlight target); app uses an automatic bounding box around your marks. - Optional rectangle selection: use the **Crop** tool. - Freehand/highlight ink is semi-transparent so underlying content stays visible. - Current known behavior: after zooming in/out, freehand stroke display may appear fully on mouse release (selection is still captured correctly). - Optional multi-select: click **Add Region** after each selection. - **Add Region** snapshots only newly drawn pixels so zoom/pan stays in place while you continue selecting. Then click **Extract**. 4. Use **Clear Regions** to reset multi-select state. 5. Review **Cropped Images** and **Boxes**: both are labeled `Region 1`, `Region 2`, etc. 6. Use **Advanced Options** only when needed (Equation Zoom / line-by-line equation OCR). ### Tasks - **Markdown**: Convert document to structured markdown with layout detection (grounding ✅) - **Free OCR**: Read all visible text from the full page/image (no boxes, no targeting) - **Locate**: Find and highlight where specific text appears (grounding ✅) - **Describe**: General image description - **Custom**: Your own prompt - **Region selection**: Use **Input Scope=Selected Region**, draw/crop in the Image Workspace, then click **Extract** - **Input Scope**: `Entire Page` or `Selected Region` (Selected Region uses the workspace crop as main input) - **Equation Zoom (multipass)**: Optional nested equation refinement for Markdown. Off by default for speed/stability. - **Detect Equation Lines Separately**: Detects likely equation-line boxes and OCRs each line independently to reduce merged multi-step equations. ### Free OCR vs Locate (important) - **Free OCR does not take a selected region**. It runs OCR on the whole image/page. - If you want OCR for one area only, crop that area first, then run **Free OCR** on the cropped image. - If you want to keep the full page but highlight where text appears, use **Locate** and enter the text to search. - For advanced region workflows, use **Custom** with `<|grounding|>` in the prompt. ### Special Tokens - `` - Placeholder where visual tokens are inserted - `<|grounding|>` - Enables layout detection with bounding boxes - `<|ref|>text<|/ref|>` - Reference text to locate in the image """) file_in.change(update_page_selector, [file_in], [page_selector]) task.change(toggle_prompt, [task], [prompt]) task.change(select_boxes, [task], [tabs]) input_scope.change(toggle_scope_ui, [input_scope], [workspace_hint, selection_controls, selection_status, selected_regions_gallery]) if HAS_REGION_WORKSPACE and region_editor is not None: file_in.change(load_image_with_size, [file_in, page_selector], [region_editor, workspace_base_size, workspace_base_image]) page_selector.change(load_image_with_size, [file_in, page_selector], [region_editor, workspace_base_size, workspace_base_image]) region_editor.change(sync_workspace_state, [region_editor, workspace_base_image], [workspace_base_size, workspace_base_image]) file_in.change(_reset_selected_regions, outputs=[selected_regions_state, selected_regions_gallery, selection_status]) page_selector.change(_reset_selected_regions, outputs=[selected_regions_state, selected_regions_gallery, selection_status]) file_in.change(_reset_drawn_mask, outputs=[drawn_mask_state]) page_selector.change(_reset_drawn_mask, outputs=[drawn_mask_state]) add_region_btn.click( add_selected_region, [region_editor, workspace_base_size, workspace_base_image, selected_regions_state, drawn_mask_state], [selected_regions_state, selected_regions_gallery, selection_status, drawn_mask_state], ) clear_regions_btn.click( clear_regions_preserve_view, inputs=[region_editor], outputs=[selected_regions_state, selected_regions_gallery, selection_status, drawn_mask_state], ) def run(file_path, task, custom_prompt, page_num, enable_equation_zoom, detect_eq_lines, scope, region_value, base_size, base_image, selected_regions): if scope == "Selected Region": regions = list(selected_regions or []) if not regions: selected_region, selected_bbox = _extract_selected_region(region_value, base_size=base_size, base_image=base_image) if selected_region is None: msg = "Select Input Scope=Selected Region, then crop or annotate a target area in the Image Workspace first." return (msg, "", "", "", None, [], gr.DownloadButton(visible=False)) regions = [{"image": selected_region, "bbox": selected_bbox}] cleaned_parts = [] markdown_parts = [] raw_parts = [] line_crops = [] for i, r in enumerate(regions, 1): cleaned_i, markdown_i, raw_i, _, crops_i = process_image( r["image"], task, custom_prompt, enable_equation_zoom=enable_equation_zoom, infer_crop_mode=False, separate_equation_lines=detect_eq_lines, ) if len(regions) > 1: cleaned_parts.append(f"## Region {i}\n\n{cleaned_i}") markdown_parts.append(f"## Region {i}\n\n{markdown_i}") raw_parts.append(f"## Region {i}\n\n{raw_i}") else: cleaned_parts.append(cleaned_i) markdown_parts.append(markdown_i) raw_parts.append(raw_i) if detect_eq_lines and crops_i: line_crops.extend(_label_gallery_items(crops_i, prefix=f"Region {i}" if len(regions) > 1 else None)) cleaned = "\n\n".join(cleaned_parts).strip() markdown = "\n\n".join(markdown_parts).strip() raw = "\n\n".join(raw_parts).strip() crops = line_crops if line_crops else _region_gallery_items(regions) full_img = base_image if isinstance(base_image, Image.Image) else _extract_editor_background(region_value) region_boxes = [r["bbox"] for r in regions if r.get("bbox") is not None] img_out = _draw_selected_region_boxes(full_img, region_boxes) elif file_path: cleaned, markdown, raw, img_out, crops = process_file( file_path, task, custom_prompt, int(page_num), enable_equation_zoom=enable_equation_zoom, separate_equation_lines=detect_eq_lines, ) elif (full_image := _extract_editor_background(region_value)) is not None: cleaned, markdown, raw, img_out, crops = process_image( full_image, task, custom_prompt, enable_equation_zoom=enable_equation_zoom, separate_equation_lines=detect_eq_lines, ) elif isinstance(base_image, Image.Image): # Example clicks can briefly race editor-value hydration on first load. cleaned, markdown, raw, img_out, crops = process_image( base_image, task, custom_prompt, enable_equation_zoom=enable_equation_zoom, separate_equation_lines=detect_eq_lines, ) else: msg = "Error: Upload a file or image" return (msg, "", "", "", None, [], gr.DownloadButton(visible=False)) return _compose_ui_outputs(cleaned, markdown, raw, img_out, crops) submit_event = btn.click( run, [file_in, task, prompt, page_selector, equation_zoom, separate_eq_lines, input_scope, region_editor, workspace_base_size, workspace_base_image, selected_regions_state], [text_out, latex_out, md_out, raw_out, img_out, gallery, download_btn] ) submit_event.then(select_post_extract_tab, [task, input_scope], [tabs]) if __name__ == "__main__": # server_name="0.0.0.0" is needed locally (WSL2 → Windows access) # On HuggingFace Spaces, SPACE_ID is set and Gradio handles binding automatically local = not os.environ.get("SPACE_ID") queued = demo.queue(max_size=20) launch_sig = inspect.signature(queued.launch) launch_kwargs = {} if "server_name" in launch_sig.parameters: launch_kwargs["server_name"] = "0.0.0.0" if local else None if "head" in launch_sig.parameters: launch_kwargs["head"] = PREVIEW_CSS if "ssr_mode" in launch_sig.parameters: launch_kwargs["ssr_mode"] = False # SSR breaks HF Spaces routing in Gradio 6 queued.launch(**launch_kwargs)