Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| from transformers import AutoModel, AutoTokenizer | |
| import torch | |
| import spaces | |
| import os | |
| import sys | |
| import tempfile | |
| import shutil | |
| import inspect | |
| from PIL import Image, ImageDraw, ImageFont, ImageOps | |
| import fitz | |
| import re | |
| import ast | |
| import numpy as np | |
| import base64 | |
| import html as html_lib | |
| import markdown as md_lib | |
| import latex2mathml.converter | |
| from collections import deque | |
| from io import StringIO, BytesIO | |
| HAS_IMAGE_EDITOR = hasattr(gr, "ImageEditor") | |
| HAS_PAINT = hasattr(gr, "Paint") | |
| HAS_BRUSH = hasattr(gr, "Brush") | |
| HAS_ERASER = hasattr(gr, "Eraser") | |
| HAS_REGION_WORKSPACE = HAS_PAINT or HAS_IMAGE_EDITOR | |
| # Model options — swap MODEL_NAME to reduce VRAM usage on GPUs with <= 8GB | |
| # | |
| # Full precision BF16 (~8GB VRAM) — original, highest accuracy | |
| MODEL_NAME = 'deepseek-ai/DeepSeek-OCR-2' | |
| # | |
| # FP8 dynamic quantization (~3.5GB VRAM) — ~50% VRAM reduction, 3750 downloads/mo | |
| # Requires Ampere GPU or newer (RTX 3070 is supported) | |
| # MODEL_NAME = 'richarddavison/DeepSeek-OCR-2-FP8' | |
| # | |
| # 8-bit quantization (~4GB VRAM) — same stack (torch 2.6, flash-attn 2.7.3, py3.12) | |
| # Explicitly supports dynamic resolution (0-6 patches), 140 downloads/mo | |
| # MODEL_NAME = 'mzbac/DeepSeek-OCR-2-8bit' | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| # flash_attention_2 requires a CUDA device at init time — not available on ZeroGPU at | |
| # module load. DeepseekOCR2 only supports 'flash_attention_2' and 'eager'; sdpa is not | |
| # implemented for this model class. Fall back to 'eager' when no GPU is present. | |
| # Locally with CUDA, flash_attention_2 is used for maximum throughput. | |
| _attn_impl = 'flash_attention_2' if torch.cuda.is_available() else 'eager' | |
| model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation=_attn_impl, torch_dtype=torch.bfloat16, trust_remote_code=True, use_safetensors=True).eval() | |
| # .cuda() is NOT called here — on ZeroGPU, GPU is only available inside @spaces.GPU | |
| # functions. Locally, model.cuda() is called inside process_image on first run. | |
| BASE_SIZE = 1024 | |
| IMAGE_SIZE = 768 | |
| CROP_MODE = True | |
| WORKSPACE_EDITOR_HEIGHT = 640 | |
| WORKSPACE_EDITOR_WIDTH_EST = 980 | |
| WORKSPACE_DEFAULT_SCALE = 89 | |
| GROUNDING_PATTERN = re.compile(r'<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>', re.DOTALL) | |
| INFER_DEBUG_FILTERS = ['PATCHES', '====', 'BASE:', 'directly resize', 'NO PATCHES', 'torch.Size', '%|'] | |
| EQUATION_ZOOM_PROMPT = "<image>\n<|grounding|>Locate each individual equation or math line." | |
| EQUATION_LINE_OCR_PROMPT = "<image>\nRead the math expression exactly as written. Return only the equation text." | |
| EQUATION_ZOOM_MAX_CANDIDATES = 6 | |
| EQUATION_ZOOM_MIN_AREA = 0.05 | |
| EQUATION_ZOOM_MIN_DIM = 0.24 | |
| EQUATION_ZOOM_PADDING = 0.025 | |
| EQUATION_ZOOM_MAX_ASPECT = 12.0 | |
| EQUATION_DETAIL_MAX_BOXES = 24 | |
| EQUATION_DETAIL_IOU_DEDUPE = 0.7 | |
| EQUATION_LINE_IOU_DEDUPE = 0.55 | |
| EQUATION_LINE_MIN_AREA = 0.0008 | |
| EQUATION_LINE_MIN_W = 0.03 | |
| EQUATION_LINE_MIN_H = 0.01 | |
| EQUATION_LINE_MAX_ASPECT = 30.0 | |
| MATH_LABEL_HINTS = ("formula", "equation", "math") | |
| MATH_STRONG_MARKERS = ("\\(", "\\[", "\\frac", "\\sum", "\\int", "\\sqrt", "\\lim", "\\begin{") | |
| MATH_WEAK_MARKERS = ("^", "_", "=", "+", "\\cdot", "\\times") | |
| TASK_PROMPTS = { | |
| "📋 Markdown": {"prompt": "<image>\n<|grounding|>Convert the document to markdown.", "has_grounding": True}, | |
| "📝 Free OCR": {"prompt": "<image>\nFree OCR.", "has_grounding": False}, | |
| "📍 Locate": {"prompt": "<image>\nLocate <|ref|>text<|/ref|> in the image.", "has_grounding": True}, | |
| "🔍 Describe": {"prompt": "<image>\nDescribe this image in detail.", "has_grounding": False}, | |
| "✏️ Custom": {"prompt": "", "has_grounding": False} | |
| } | |
| def extract_grounding_references(text): | |
| refs = [] | |
| seen = set() | |
| for entry in _extract_grounding_entries(text): | |
| coord_text = repr(entry["coords"]) | |
| key = ( | |
| entry["label"].strip().lower(), | |
| tuple( | |
| (round(c[0], 1), round(c[1], 1), round(c[2], 1), round(c[3], 1)) | |
| for c in entry["coords"] | |
| ), | |
| ) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| raw = f'<|ref|>{entry["label"]}<|/ref|><|det|>{coord_text}<|/det|>' | |
| refs.append((raw, entry["label"], coord_text)) | |
| return refs | |
| def _parse_coord_payload(payload): | |
| if isinstance(payload, str): | |
| try: | |
| coords = ast.literal_eval(payload.strip()) | |
| except (SyntaxError, ValueError): | |
| return [] | |
| else: | |
| coords = payload | |
| if isinstance(coords, (tuple, list)) and coords and isinstance(coords[0], (int, float)): | |
| coords = [coords] | |
| if not isinstance(coords, list): | |
| return [] | |
| out = [] | |
| for c in coords: | |
| if not isinstance(c, (list, tuple)) or len(c) < 4: | |
| continue | |
| x1, y1, x2, y2 = [float(v) for v in c[:4]] | |
| x1, x2 = sorted((max(0.0, min(999.0, x1)), max(0.0, min(999.0, x2)))) | |
| y1, y2 = sorted((max(0.0, min(999.0, y1)), max(0.0, min(999.0, y2)))) | |
| if x2 <= x1 or y2 <= y1: | |
| continue | |
| out.append([x1, y1, x2, y2]) | |
| return out | |
| def _extract_grounding_entries(raw_text: str): | |
| if not raw_text: | |
| return [] | |
| entries = [] | |
| last_end = 0 | |
| for m in GROUNDING_PATTERN.finditer(raw_text): | |
| label = m.group(1).strip() or "text" | |
| coords = _parse_coord_payload(m.group(2)) | |
| if not coords: | |
| continue | |
| text_chunk = raw_text[last_end:m.start()].strip() | |
| entries.append({ | |
| "label": label, | |
| "coords": coords, | |
| "text": text_chunk, | |
| }) | |
| last_end = m.end() | |
| return entries | |
| def _math_marker_score(text_chunk: str) -> int: | |
| score = 0 | |
| for marker in MATH_STRONG_MARKERS: | |
| if marker in text_chunk: | |
| score += 3 | |
| for marker in MATH_WEAK_MARKERS: | |
| if marker in text_chunk: | |
| score += 1 | |
| return score | |
| def _box_iou(a, b): | |
| ax1, ay1, ax2, ay2 = a | |
| bx1, by1, bx2, by2 = b | |
| inter_x1 = max(ax1, bx1) | |
| inter_y1 = max(ay1, by1) | |
| inter_x2 = min(ax2, bx2) | |
| inter_y2 = min(ay2, by2) | |
| if inter_x2 <= inter_x1 or inter_y2 <= inter_y1: | |
| return 0.0 | |
| inter = (inter_x2 - inter_x1) * (inter_y2 - inter_y1) | |
| area_a = max(1e-9, (ax2 - ax1) * (ay2 - ay1)) | |
| area_b = max(1e-9, (bx2 - bx1) * (by2 - by1)) | |
| union = area_a + area_b - inter | |
| return inter / union if union > 0 else 0.0 | |
| def _dedupe_boxes(boxes, iou_threshold): | |
| kept = [] | |
| for box in sorted(boxes, key=lambda b: ((b[2] - b[0]) * (b[3] - b[1]))): | |
| if any(_box_iou(box, other) >= iou_threshold for other in kept): | |
| continue | |
| kept.append(box) | |
| return kept | |
| def _is_math_candidate(label: str, text_chunk: str, box): | |
| label_l = label.lower() | |
| box_w = (box[2] - box[0]) / 999.0 | |
| box_h = (box[3] - box[1]) / 999.0 | |
| area = box_w * box_h | |
| aspect = max(box_w / max(1e-9, box_h), box_h / max(1e-9, box_w)) | |
| has_math_label = any(hint in label_l for hint in MATH_LABEL_HINTS) | |
| has_math_text = _math_marker_score(text_chunk) >= 3 | |
| is_large = area >= EQUATION_ZOOM_MIN_AREA or box_w >= EQUATION_ZOOM_MIN_DIM or box_h >= EQUATION_ZOOM_MIN_DIM | |
| return (has_math_label or has_math_text) and is_large and aspect <= EQUATION_ZOOM_MAX_ASPECT | |
| def _map_crop_box_to_page(sub_box, crop_px, img_w, img_h): | |
| crop_x1, crop_y1, crop_x2, crop_y2 = crop_px | |
| crop_w = max(1, crop_x2 - crop_x1) | |
| crop_h = max(1, crop_y2 - crop_y1) | |
| page_x1 = ((crop_x1 + (sub_box[0] / 999.0) * crop_w) / img_w) * 999.0 | |
| page_y1 = ((crop_y1 + (sub_box[1] / 999.0) * crop_h) / img_h) * 999.0 | |
| page_x2 = ((crop_x1 + (sub_box[2] / 999.0) * crop_w) / img_w) * 999.0 | |
| page_y2 = ((crop_y1 + (sub_box[3] / 999.0) * crop_h) / img_h) * 999.0 | |
| return _parse_coord_payload([[page_x1, page_y1, page_x2, page_y2]])[0] | |
| def draw_bounding_boxes(image, refs, extract_images=False): | |
| img_w, img_h = image.size | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) | |
| draw2 = ImageDraw.Draw(overlay) | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 15) | |
| crops = [] | |
| color_map = {} | |
| np.random.seed(42) | |
| for ref in refs: | |
| label = ref[1] | |
| if label not in color_map: | |
| color_map[label] = (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255)) | |
| color = color_map[label] | |
| coords = _parse_coord_payload(ref[2]) | |
| color_a = color + (60,) | |
| for box in coords: | |
| x1, y1, x2, y2 = int(box[0]/999*img_w), int(box[1]/999*img_h), int(box[2]/999*img_w), int(box[3]/999*img_h) | |
| if extract_images and label == 'image': | |
| crops.append(image.crop((x1, y1, x2, y2))) | |
| width = 5 if label == 'title' else 3 | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=width) | |
| draw2.rectangle([x1, y1, x2, y2], fill=color_a) | |
| text_bbox = draw.textbbox((0, 0), label, font=font) | |
| tw, th = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] | |
| ty = max(0, y1 - 20) | |
| draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], fill=color) | |
| draw.text((x1 + 2, ty + 2), label, font=font, fill=(255, 255, 255)) | |
| img_draw.paste(overlay, (0, 0), overlay) | |
| return img_draw, crops | |
| def _extract_labeled_crops_from_refs(image, refs, max_items=24): | |
| img_w, img_h = image.size | |
| items = [] | |
| seen = set() | |
| for ref in refs: | |
| label = str(ref[1]) | |
| coords = _parse_coord_payload(ref[2]) | |
| for box in coords: | |
| x1 = int(box[0] / 999.0 * img_w) | |
| y1 = int(box[1] / 999.0 * img_h) | |
| x2 = int(box[2] / 999.0 * img_w) | |
| y2 = int(box[3] / 999.0 * img_h) | |
| if x2 - x1 < 8 or y2 - y1 < 8: | |
| continue | |
| key = (label.lower(), x1, y1, x2, y2) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| crop = image.crop((x1, y1, x2, y2)) | |
| caption = f"{label} ({crop.width}x{crop.height})" | |
| items.append((crop, caption)) | |
| if len(items) >= max_items: | |
| return items | |
| return items | |
| def clean_output(text, include_images=False): | |
| if not text: | |
| return "" | |
| pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| img_num = 0 | |
| for match in matches: | |
| if '<|ref|>image<|/ref|>' in match[0]: | |
| if include_images: | |
| text = text.replace(match[0], f'\n\n**[Figure {img_num + 1}]**\n\n', 1) | |
| img_num += 1 | |
| else: | |
| text = text.replace(match[0], '', 1) | |
| else: | |
| text = re.sub(rf'(?m)^[^\n]*{re.escape(match[0])}[^\n]*\n?', '', text) | |
| text = _strip_malformed_grounding(text) | |
| text = _dedupe_repeated_math_blocks(text) | |
| return text.strip() | |
| def _strip_malformed_grounding(text: str) -> str: | |
| """Remove incomplete grounding tags that can leak into OCR markdown/text.""" | |
| if not text: | |
| return "" | |
| line_patterns = [ | |
| r'(?m)^[^\n]*<\|ref\|>.*?<\|/ref\|><\|det\|>.*?(?:<\|/det\|>)?[^\n]*\n?', | |
| r'(?m)^[^\n]*<\|det\|>.*?(?:<\|/det\|>)?[^\n]*\n?', | |
| r'(?m)^[^\n]*<\|/?ref\|>[^\n]*\n?', | |
| ] | |
| for p in line_patterns: | |
| text = re.sub(p, '', text) | |
| text = re.sub(r'<\|/?ref\|>', '', text) | |
| text = re.sub(r'<\|/?det\|>', '', text) | |
| return text | |
| def _equation_text_key(text: str) -> str: | |
| if not text: | |
| return "" | |
| key = text.strip() | |
| key = re.sub(r'\\\[(.+?)\\\]', r'\1', key, flags=re.DOTALL) | |
| key = re.sub(r'\\\((.+?)\\\)', r'\1', key, flags=re.DOTALL) | |
| key = re.sub(r'\$\$(.+?)\$\$', r'\1', key, flags=re.DOTALL) | |
| key = re.sub(r'\^\{([A-Za-z0-9])\}', r'^\1', key) | |
| key = re.sub(r'_\{([A-Za-z0-9])\}', r'_\1', key) | |
| key = re.sub(r'\s+', '', key) | |
| return key.lower() | |
| def _dedupe_repeated_math_blocks(text: str) -> str: | |
| if not text: | |
| return "" | |
| pattern = re.compile(r'\\\[(.+?)\\\]|\\\((.+?)\\\)|\$\$(.+?)\$\$', re.DOTALL) | |
| seen = set() | |
| out = [] | |
| last = 0 | |
| removed_any = False | |
| for m in pattern.finditer(text): | |
| out.append(text[last:m.start()]) | |
| expr = m.group(1) or m.group(2) or m.group(3) or "" | |
| key = _equation_text_key(expr) | |
| if key and key in seen: | |
| removed_any = True | |
| else: | |
| if key: | |
| seen.add(key) | |
| out.append(m.group(0)) | |
| last = m.end() | |
| out.append(text[last:]) | |
| merged = ''.join(out) | |
| if removed_any: | |
| merged = re.sub(r'\n{3,}', '\n\n', merged) | |
| return merged | |
| PREVIEW_CSS = """ | |
| <style> | |
| .math-preview { | |
| padding: 1.5em; | |
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; | |
| font-size: 15px; | |
| line-height: 1.8; | |
| color: #1a1a1a; | |
| max-width: 100%; | |
| overflow-x: auto; | |
| } | |
| .math-display { | |
| text-align: center; | |
| overflow-x: auto; | |
| margin: 1em 0; | |
| padding: 0.5em 0; | |
| } | |
| math[display="block"] { display: block; overflow-x: auto; max-width: 100%; } | |
| .math-preview h1 { font-size: 1.8em; font-weight: 700; margin: 1em 0 0.4em; border-bottom: 2px solid #e0e0e0; padding-bottom: 0.3em; } | |
| .math-preview h2 { font-size: 1.4em; font-weight: 600; margin: 1em 0 0.4em; border-bottom: 1px solid #e0e0e0; padding-bottom: 0.2em; } | |
| .math-preview h3 { font-size: 1.15em; font-weight: 600; margin: 0.9em 0 0.3em; } | |
| .math-preview p { margin: 0.6em 0; } | |
| .math-preview ul, .math-preview ol { padding-left: 1.8em; margin: 0.5em 0; } | |
| .math-preview li { margin: 0.25em 0; } | |
| .math-preview table { border-collapse: collapse; width: 100%; margin: 1em 0; font-size: 0.95em; } | |
| .math-preview th, .math-preview td { border: 1px solid #ccc; padding: 0.45em 0.75em; text-align: left; } | |
| .math-preview th { background: #f2f2f2; font-weight: 600; } | |
| .math-preview tr:nth-child(even) { background: #fafafa; } | |
| .math-preview code { background: #f4f4f4; padding: 0.15em 0.4em; border-radius: 3px; font-family: 'Courier New', monospace; font-size: 0.88em; } | |
| .math-preview pre { background: #f4f4f4; padding: 1em; border-radius: 5px; overflow-x: auto; margin: 0.8em 0; } | |
| .math-preview pre code { background: none; padding: 0; } | |
| .math-preview blockquote { border-left: 4px solid #ccc; margin: 0.8em 0; padding: 0.4em 1em; color: #555; background: #fafafa; } | |
| .math-preview img { max-width: 100%; height: auto; display: block; margin: 0.8em 0; } | |
| .math-preview .ocr-gap, .mathjax-preview .ocr-gap { width: 100%; } | |
| .math-fallback { color: #888; font-style: italic; } | |
| </style> | |
| <script> | |
| (() => { | |
| if (window.__ocrMathJaxInit) return; | |
| window.__ocrMathJaxInit = true; | |
| if (!window.MathJax) { | |
| window.MathJax = { | |
| tex: { | |
| inlineMath: [['\\\\(', '\\\\)'], ['$', '$']], | |
| displayMath: [['\\\\[', '\\\\]'], ['$$', '$$']] | |
| }, | |
| options: { | |
| skipHtmlTags: ['script', 'noscript', 'style', 'textarea', 'pre', 'code'] | |
| } | |
| }; | |
| } | |
| const typeset = () => { | |
| if (window.MathJax?.typesetPromise) { | |
| const nodes = Array.from(document.querySelectorAll('.mathjax-preview, .spatial-preview')); | |
| if (nodes.length) window.MathJax.typesetPromise(nodes).catch(() => {}); | |
| } | |
| }; | |
| window.__typesetOcrMath = typeset; | |
| const ensureScript = () => { | |
| if (document.getElementById('mathjax-ocr-preview')) return; | |
| const script = document.createElement('script'); | |
| script.id = 'mathjax-ocr-preview'; | |
| script.async = true; | |
| script.src = 'https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js'; | |
| script.onload = () => setTimeout(typeset, 20); | |
| document.head.appendChild(script); | |
| }; | |
| ensureScript(); | |
| setTimeout(typeset, 100); | |
| const observer = new MutationObserver((mutations) => { | |
| for (const m of mutations) { | |
| for (const n of m.addedNodes) { | |
| if (n.nodeType !== 1) continue; | |
| if (n.matches?.('.mathjax-preview, .spatial-preview') || n.querySelector?.('.mathjax-preview, .spatial-preview')) { | |
| setTimeout(typeset, 30); | |
| return; | |
| } | |
| } | |
| } | |
| }); | |
| observer.observe(document.body, { childList: true, subtree: true }); | |
| })(); | |
| (() => { | |
| if (window.__ocrWorkspaceZoomInit) return; | |
| window.__ocrWorkspaceZoomInit = true; | |
| const stateByRoot = new WeakMap(); | |
| const targetZoomPct = 88; | |
| const nearTargetTolerancePct = 3; | |
| const tinyFitThresholdPct = 45; | |
| const getState = (root) => { | |
| let state = stateByRoot.get(root); | |
| if (!state) { | |
| state = { busy: false, applied: false, lastSeenZoom: null, lastAutoAt: 0 }; | |
| stateByRoot.set(root, state); | |
| } | |
| return state; | |
| }; | |
| const parseZoomPct = (root) => { | |
| const zoomNode = root.querySelector(".zoom-number span[role='button']"); | |
| if (!zoomNode) return null; | |
| const m = (zoomNode.textContent || "").match(/([0-9]+(?:\\.[0-9]+)?)\\s*%/); | |
| return m ? parseFloat(m[1]) : null; | |
| }; | |
| const getZoomInBtn = (root) => | |
| root.querySelector("button[aria-label='Zoom in'], button[title='Zoom in']"); | |
| const isWorkspaceRoot = (root) => | |
| !!root.querySelector(".pixi-target") && !!root.querySelector(".zoom-number"); | |
| const maybeAutoZoom = (root) => { | |
| if (!isWorkspaceRoot(root)) return; | |
| const state = getState(root); | |
| const now = Date.now(); | |
| const zoomPct = parseZoomPct(root); | |
| if (zoomPct == null) return; | |
| // A drop from high zoom to low zoom usually means a new image was loaded. | |
| if (state.lastSeenZoom != null && state.lastSeenZoom > 70 && zoomPct < 35) { | |
| state.applied = false; | |
| } | |
| state.lastSeenZoom = zoomPct; | |
| if (state.busy || state.applied) return; | |
| if (zoomPct > tinyFitThresholdPct) return; | |
| if (now - state.lastAutoAt < 1200) return; | |
| const zoomInBtn = getZoomInBtn(root); | |
| if (!zoomInBtn) return; | |
| state.busy = true; | |
| state.lastAutoAt = now; | |
| let steps = 0; | |
| const step = () => { | |
| const current = parseZoomPct(root); | |
| if ( | |
| current == null || | |
| current >= (targetZoomPct - nearTargetTolerancePct) || | |
| steps >= 20 | |
| ) { | |
| state.busy = false; | |
| state.applied = true; | |
| return; | |
| } | |
| zoomInBtn.click(); | |
| steps += 1; | |
| setTimeout(step, 80); | |
| }; | |
| setTimeout(step, 90); | |
| }; | |
| const attachRootObserver = (root) => { | |
| if (root.dataset.ocrZoomObserved === "1") return; | |
| root.dataset.ocrZoomObserved = "1"; | |
| const obs = new MutationObserver(() => maybeAutoZoom(root)); | |
| obs.observe(root, { childList: true, subtree: true, characterData: true }); | |
| setTimeout(() => maybeAutoZoom(root), 200); | |
| setTimeout(() => maybeAutoZoom(root), 800); | |
| }; | |
| const scan = () => { | |
| document.querySelectorAll("[data-testid='image']").forEach((root) => { | |
| if (isWorkspaceRoot(root)) attachRootObserver(root); | |
| }); | |
| }; | |
| scan(); | |
| const pageObs = new MutationObserver(scan); | |
| pageObs.observe(document.body, { childList: true, subtree: true }); | |
| })(); | |
| </script> | |
| """ | |
| def _inject_spatial_gap_placeholders(text: str): | |
| """Preserve runs of blank lines so OCR spacing is visible in preview.""" | |
| gaps: dict[str, int] = {} | |
| counter = [0] | |
| def repl(m): | |
| key = f'ZZOCRGAP{counter[0]}ZZ' | |
| counter[0] += 1 | |
| # Two newlines are a normal paragraph break; extras represent vertical spacing. | |
| gaps[key] = max(1, len(m.group(0)) - 2) | |
| return f'\n\n{key}\n\n' | |
| return re.sub(r'\n{3,}', repl, text), gaps | |
| def _restore_spatial_gap_placeholders(html: str, gaps: dict[str, int]) -> str: | |
| if not gaps: | |
| return html | |
| for key, extra_lines in gaps.items(): | |
| gap_em = min(10.0, 0.9 * extra_lines) | |
| block = f'<div class="ocr-gap" style="height:{gap_em:.2f}em"></div>' | |
| html = html.replace(f'<p>{key}</p>', block) | |
| html = html.replace(key, block) | |
| return html | |
| def _to_mathml(latex: str, display: bool) -> str: | |
| """Convert a LaTeX string to MathML. Falls back to a code block on error.""" | |
| # Fix OCR error: \frac{n/m} (single-argument fraction) → \frac{n}{m} | |
| latex = re.sub(r'\\frac\{(\d+)/(\d+)\}(?!\s*\{)', r'\\frac{\1}{\2}', latex) | |
| try: | |
| mathml = latex2mathml.converter.convert(latex) | |
| if display: | |
| mathml = re.sub(r'<math\b', '<math display="block"', mathml, count=1) | |
| return mathml | |
| except Exception: | |
| escaped = html_lib.escape(latex) | |
| if display: | |
| return f'<pre class="math-fallback"><code>{escaped}</code></pre>' | |
| return f'<code class="math-fallback">{escaped}</code>' | |
| def to_math_html(text: str) -> str: | |
| """Convert model markdown output to HTML with server-side MathML rendering. | |
| Uses a placeholder approach: math is extracted and replaced with unique | |
| tokens before the markdown pass, then swapped back afterwards. This avoids | |
| Python-Markdown mishandling multi-line <div> blocks that contain blank lines. | |
| """ | |
| if not text: | |
| return "" | |
| blocks: dict[str, str] = {} | |
| literals: dict[str, str] = {} | |
| counter = [0] | |
| def display_block(m): | |
| key = f'ZZDISPLAYMATH{counter[0]}ZZ' | |
| counter[0] += 1 | |
| expr = m.group(1).strip() | |
| blocks[key] = f'<div class="math-display">{_to_mathml(expr, display=True)}</div>' | |
| literals[key] = f'\\[{expr}\\]' | |
| return f'\n\n{key}\n\n' | |
| def inline_math(m): | |
| key = f'ZZINLINEMATH{counter[0]}ZZ' | |
| counter[0] += 1 | |
| expr = m.group(1).strip() | |
| blocks[key] = _to_mathml(expr, display=False) | |
| literals[key] = f'\\({expr}\\)' | |
| return key | |
| # Replace display math \[...\] with placeholder tokens | |
| text = re.sub(r'\\\[(.+?)\\\]', display_block, text, flags=re.DOTALL) | |
| # Remove orphaned \[ with no matching \] (truncated model output) | |
| text = re.sub(r'\\\[.*', '', text, flags=re.DOTALL) | |
| # Replace inline math \(...\) with placeholder tokens | |
| text = re.sub(r'\\\((.+?)\\\)', inline_math, text) | |
| text, gaps = _inject_spatial_gap_placeholders(text) | |
| # Run markdown on text that now contains only safe placeholder tokens | |
| html = md_lib.markdown(text, extensions=['tables', 'fenced_code', 'sane_lists', 'nl2br']) | |
| # Protect rendered code/pre blocks so placeholder swap never mutates literal code. | |
| protected_blocks: dict[str, str] = {} | |
| protected_counter = [0] | |
| def _protect_code_html(m): | |
| token = f'ZZCODEHTML{protected_counter[0]}ZZ' | |
| protected_counter[0] += 1 | |
| protected_blocks[token] = m.group(0) | |
| return token | |
| html = re.sub(r'<pre\b[^>]*>.*?</pre>', _protect_code_html, html, flags=re.DOTALL) | |
| html = re.sub(r'<code\b[^>]*>.*?</code>', _protect_code_html, html, flags=re.DOTALL) | |
| # Swap placeholders back for MathML/HTML (handle <p>KEY</p> wrapping too) | |
| for key, value in blocks.items(): | |
| html = html.replace(f'<p>{key}</p>', value) | |
| html = html.replace(key, value) | |
| # Restore protected literal code/pre blocks unchanged. | |
| for token, original in protected_blocks.items(): | |
| html = html.replace(token, original) | |
| # Placeholders left at this stage occur inside code/pre; keep them literal. | |
| for key, literal in literals.items(): | |
| html = html.replace(key, html_lib.escape(literal)) | |
| html = _restore_spatial_gap_placeholders(html, gaps) | |
| return f'<div class="math-preview">{html}</div>' | |
| def to_mathjax_html(text: str) -> str: | |
| """Render markdown to HTML and typeset math client-side with MathJax.""" | |
| if not text: | |
| return "" | |
| text, gaps = _inject_spatial_gap_placeholders(text) | |
| html = md_lib.markdown(text, extensions=['tables', 'fenced_code', 'sane_lists', 'nl2br']) | |
| html = _restore_spatial_gap_placeholders(html, gaps) | |
| return f'<div class="mathjax-preview">{html}</div>' | |
| def _grounding_blocks_from_raw(raw_text: str): | |
| blocks = [] | |
| for entry in _extract_grounding_entries(raw_text): | |
| label = entry["label"] | |
| text = entry["text"].strip() | |
| coords = entry["coords"] | |
| for idx, c in enumerate(coords): | |
| blocks.append({ | |
| "label": label, | |
| "text": text if idx == 0 else "", | |
| "x1": c[0], | |
| "y1": c[1], | |
| "x2": c[2], | |
| "y2": c[3], | |
| }) | |
| return blocks | |
| def to_spatial_html(raw_text: str, markdown_text: str) -> str: | |
| """Render OCR content using grounding boxes for spatially-positioned blocks.""" | |
| blocks = _grounding_blocks_from_raw(raw_text) | |
| if not blocks: | |
| return to_mathjax_html(markdown_text) | |
| used_text = 0 | |
| rendered = [] | |
| palette = { | |
| "title": "#8b5cf6", | |
| "text": "#2563eb", | |
| "image": "#059669", | |
| "table": "#d97706", | |
| "formula": "#dc2626", | |
| } | |
| for i, b in enumerate(sorted(blocks, key=lambda x: (x["y1"], x["x1"]))): | |
| label = b["label"] | |
| color = palette.get(label.lower(), "#4b5563") | |
| body = b["text"].strip() | |
| if body: | |
| used_text += len(body) | |
| body_text, gaps = _inject_spatial_gap_placeholders(body) | |
| body_html = md_lib.markdown(body_text, extensions=['tables', 'fenced_code', 'sane_lists', 'nl2br']) | |
| body_html = _restore_spatial_gap_placeholders(body_html, gaps) | |
| else: | |
| body_html = "" | |
| if not body_html: | |
| body_html = f"<p><em>{html_lib.escape(label)}</em></p>" | |
| left = b["x1"] / 999.0 * 100.0 | |
| top = b["y1"] / 999.0 * 100.0 | |
| width = max(1.0, (b["x2"] - b["x1"]) / 999.0 * 100.0) | |
| height = max(1.2, (b["y2"] - b["y1"]) / 999.0 * 100.0) | |
| rendered.append( | |
| f""" | |
| <article class="spatial-block" style="left:{left:.2f}%; top:{top:.2f}%; width:{width:.2f}%; min-height:{height:.2f}%; --block-color:{color};"> | |
| <header>{html_lib.escape(label)}</header> | |
| <section>{body_html}</section> | |
| </article> | |
| """ | |
| ) | |
| fallback = "" | |
| if markdown_text and used_text < max(120, int(len(markdown_text) * 0.4)): | |
| fallback_html = to_mathjax_html(markdown_text) | |
| fallback = f""" | |
| <details class="spatial-fallback"> | |
| <summary>Show full linear markdown rendering</summary> | |
| {fallback_html} | |
| </details> | |
| """ | |
| return f""" | |
| <style> | |
| .spatial-preview {{ | |
| padding: 1rem; | |
| }} | |
| .spatial-canvas {{ | |
| position: relative; | |
| width: 100%; | |
| min-height: 72vh; | |
| aspect-ratio: 1 / 1.35; | |
| background: linear-gradient(180deg, #fcfdff 0%, #f7f9fc 100%); | |
| border: 1px solid #d8dee9; | |
| border-radius: 8px; | |
| overflow: auto; | |
| }} | |
| .spatial-block {{ | |
| position: absolute; | |
| box-sizing: border-box; | |
| border: 1px solid var(--block-color); | |
| background: color-mix(in srgb, var(--block-color) 7%, white); | |
| border-radius: 6px; | |
| padding: 0.35rem 0.5rem; | |
| overflow: hidden; | |
| }} | |
| .spatial-block > header {{ | |
| font-size: 11px; | |
| font-weight: 700; | |
| letter-spacing: 0.03em; | |
| text-transform: uppercase; | |
| color: var(--block-color); | |
| margin-bottom: 0.25rem; | |
| }} | |
| .spatial-block > section {{ | |
| font-size: 13px; | |
| line-height: 1.35; | |
| }} | |
| .spatial-block p {{ margin: 0.2rem 0; }} | |
| .spatial-fallback {{ | |
| margin-top: 1rem; | |
| padding-top: 0.5rem; | |
| border-top: 1px solid #d8dee9; | |
| }} | |
| </style> | |
| <div class="spatial-preview mathjax-preview"> | |
| <div class="spatial-canvas"> | |
| {''.join(rendered)} | |
| </div> | |
| {fallback} | |
| </div> | |
| """ | |
| def embed_images(markdown, crops): | |
| if not crops: | |
| return markdown | |
| for i, img in enumerate(crops): | |
| buf = BytesIO() | |
| img.save(buf, format="PNG") | |
| b64 = base64.b64encode(buf.getvalue()).decode() | |
| markdown = markdown.replace(f'**[Figure {i + 1}]**', f'\n\n\n\n', 1) | |
| return markdown | |
| def _infer_with_prompt(image, prompt, crop_mode=None): | |
| if crop_mode is None: | |
| crop_mode = CROP_MODE | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') | |
| image.save(tmp.name, 'JPEG', quality=95) | |
| tmp.close() | |
| out_dir = tempfile.mkdtemp() | |
| stdout = sys.stdout | |
| capture = StringIO() | |
| sys.stdout = capture | |
| try: | |
| model.infer( | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| image_file=tmp.name, | |
| output_path=out_dir, | |
| base_size=BASE_SIZE, | |
| image_size=IMAGE_SIZE, | |
| crop_mode=crop_mode, | |
| save_results=False | |
| ) | |
| finally: | |
| sys.stdout = stdout | |
| os.unlink(tmp.name) | |
| shutil.rmtree(out_dir, ignore_errors=True) | |
| lines = [ | |
| l for l in capture.getvalue().split('\n') | |
| if l.strip() and not any(s in l for s in INFER_DEBUG_FILTERS) | |
| ] | |
| return '\n'.join(lines).strip() | |
| def _refine_equation_refs(image, raw_text): | |
| entries = _extract_grounding_entries(raw_text) | |
| if not entries: | |
| return [] | |
| img_w, img_h = image.size | |
| candidates = [] | |
| for entry in entries: | |
| for box in entry["coords"]: | |
| if _is_math_candidate(entry["label"], entry["text"], box): | |
| area = (box[2] - box[0]) * (box[3] - box[1]) | |
| candidates.append((area, entry, box)) | |
| if not candidates: | |
| return [] | |
| candidates.sort(key=lambda x: x[0], reverse=True) | |
| refined_refs = [] | |
| for _, entry, box in candidates[:EQUATION_ZOOM_MAX_CANDIDATES]: | |
| x1 = int(box[0] / 999.0 * img_w) | |
| y1 = int(box[1] / 999.0 * img_h) | |
| x2 = int(box[2] / 999.0 * img_w) | |
| y2 = int(box[3] / 999.0 * img_h) | |
| box_w = max(1, x2 - x1) | |
| box_h = max(1, y2 - y1) | |
| pad_x = max(8, int(box_w * EQUATION_ZOOM_PADDING)) | |
| pad_y = max(8, int(box_h * EQUATION_ZOOM_PADDING)) | |
| crop_x1 = max(0, x1 - pad_x) | |
| crop_y1 = max(0, y1 - pad_y) | |
| crop_x2 = min(img_w, x2 + pad_x) | |
| crop_y2 = min(img_h, y2 + pad_y) | |
| if crop_x2 - crop_x1 < 32 or crop_y2 - crop_y1 < 32: | |
| continue | |
| crop = image.crop((crop_x1, crop_y1, crop_x2, crop_y2)) | |
| sub_result = _infer_with_prompt(crop, EQUATION_ZOOM_PROMPT) | |
| sub_entries = _extract_grounding_entries(sub_result) | |
| if not sub_entries: | |
| continue | |
| mapped_boxes = [] | |
| for sub in sub_entries: | |
| sub_label = sub["label"].lower() | |
| sub_text = sub["text"] | |
| is_math_sub = any(hint in sub_label for hint in MATH_LABEL_HINTS) or _math_marker_score(sub_text) >= 3 | |
| if sub_label in ("image", "table") or not is_math_sub: | |
| continue | |
| for sub_box in sub["coords"]: | |
| mapped = _map_crop_box_to_page(sub_box, (crop_x1, crop_y1, crop_x2, crop_y2), img_w, img_h) | |
| w = (mapped[2] - mapped[0]) / 999.0 | |
| h = (mapped[3] - mapped[1]) / 999.0 | |
| if w * h < 0.0004: | |
| continue | |
| mapped_boxes.append(mapped) | |
| if not mapped_boxes: | |
| continue | |
| mapped_boxes = _dedupe_boxes(mapped_boxes, EQUATION_DETAIL_IOU_DEDUPE) | |
| mapped_boxes = sorted(mapped_boxes, key=lambda b: (b[1], b[0]))[:EQUATION_DETAIL_MAX_BOXES] | |
| if len(mapped_boxes) < 2: | |
| continue | |
| merged_text = repr(mapped_boxes) | |
| label = "equation_detail" | |
| raw = f'<|ref|>{label}<|/ref|><|det|>{merged_text}<|/det|>' | |
| refined_refs.append((raw, label, merged_text)) | |
| return refined_refs | |
| def _norm_box_to_pixels(box, img_w, img_h, pad_ratio=0.0): | |
| x1 = int(box[0] / 999.0 * img_w) | |
| y1 = int(box[1] / 999.0 * img_h) | |
| x2 = int(box[2] / 999.0 * img_w) | |
| y2 = int(box[3] / 999.0 * img_h) | |
| if pad_ratio > 0: | |
| pad_x = max(1, int((x2 - x1) * pad_ratio)) | |
| pad_y = max(1, int((y2 - y1) * pad_ratio)) | |
| x1 -= pad_x | |
| y1 -= pad_y | |
| x2 += pad_x | |
| y2 += pad_y | |
| x1 = max(0, min(img_w - 1, x1)) | |
| y1 = max(0, min(img_h - 1, y1)) | |
| x2 = max(x1 + 1, min(img_w, x2)) | |
| y2 = max(y1 + 1, min(img_h, y2)) | |
| return (x1, y1, x2, y2) | |
| def _detect_equation_line_boxes(image, infer_crop_mode=None): | |
| detect_raw = _infer_with_prompt(image, EQUATION_ZOOM_PROMPT, crop_mode=infer_crop_mode) | |
| entries = _extract_grounding_entries(detect_raw) | |
| if not entries: | |
| return [], detect_raw | |
| boxes = [] | |
| for entry in entries: | |
| label_l = entry["label"].lower() | |
| text_chunk = entry["text"] | |
| if label_l in ("image", "table"): | |
| continue | |
| for box in entry["coords"]: | |
| w = (box[2] - box[0]) / 999.0 | |
| h = (box[3] - box[1]) / 999.0 | |
| area = w * h | |
| aspect = max(w / max(1e-9, h), h / max(1e-9, w)) | |
| looks_math = any(hint in label_l for hint in MATH_LABEL_HINTS) or _math_marker_score(text_chunk) >= 2 | |
| if area < EQUATION_LINE_MIN_AREA or w < EQUATION_LINE_MIN_W or h < EQUATION_LINE_MIN_H: | |
| continue | |
| if aspect > EQUATION_LINE_MAX_ASPECT: | |
| continue | |
| if not looks_math and area < 0.004: | |
| continue | |
| boxes.append(box) | |
| boxes = _dedupe_boxes(boxes, EQUATION_LINE_IOU_DEDUPE) | |
| boxes = sorted(boxes, key=lambda b: (round(b[1], 3), b[0])) | |
| return boxes, detect_raw | |
| def _process_equation_lines_separately(image, infer_crop_mode=None): | |
| boxes, detect_raw = _detect_equation_line_boxes(image, infer_crop_mode=infer_crop_mode) | |
| if not boxes: | |
| return None | |
| img_w, img_h = image.size | |
| cleaned_parts = [] | |
| markdown_parts = [] | |
| raw_parts = [f"## Detection\n\n{detect_raw}".strip()] | |
| refs = [] | |
| crops = [] | |
| seen_line_keys = set() | |
| for i, box in enumerate(boxes, 1): | |
| x1, y1, x2, y2 = _norm_box_to_pixels(box, img_w, img_h, pad_ratio=0.01) | |
| crop = image.crop((x1, y1, x2, y2)) | |
| line_raw = _infer_with_prompt(crop, EQUATION_LINE_OCR_PROMPT, crop_mode=False) | |
| line_clean = clean_output(line_raw, False).strip() | |
| if not line_clean: | |
| continue | |
| line_key = _equation_text_key(line_clean) | |
| if line_key and line_key in seen_line_keys: | |
| continue | |
| if line_key: | |
| seen_line_keys.add(line_key) | |
| line_label = f"Eq {i}" | |
| line_markdown = line_clean | |
| if "$$" not in line_markdown and "\\[" not in line_markdown and "\\(" not in line_markdown: | |
| line_markdown = f"$$\n{line_markdown}\n$$" | |
| cleaned_parts.append(f"{line_label}: {line_clean}") | |
| markdown_parts.append(f"### {line_label}\n\n{line_markdown}") | |
| raw_parts.append(f"## {line_label}\n\n{line_raw}") | |
| coord_text = repr([box]) | |
| raw_ref = f'<|ref|>eq_line_{i}<|/ref|><|det|>{coord_text}<|/det|>' | |
| refs.append((raw_ref, line_label, coord_text)) | |
| crops.append((crop, line_label)) | |
| if not cleaned_parts: | |
| return None | |
| img_out, _ = draw_bounding_boxes(image, refs, extract_images=False) | |
| cleaned = "\n".join(cleaned_parts).strip() | |
| markdown = "\n\n".join(markdown_parts).strip() | |
| raw = "\n\n".join(raw_parts).strip() | |
| return cleaned, markdown, raw, img_out, crops | |
| def process_image(image, task, custom_prompt, enable_equation_zoom=True, infer_crop_mode=None, separate_equation_lines=False): | |
| model.cuda() # GPU is available here — works on ZeroGPU and locally | |
| if image is None: | |
| return "Error: Upload an image", "", "", None, [] | |
| if not separate_equation_lines and task in ["✏️ Custom", "📍 Locate"] and not custom_prompt.strip(): | |
| return "Please enter a prompt", "", "", None, [] | |
| if image.mode in ('RGBA', 'LA', 'P'): | |
| image = image.convert('RGB') | |
| image = ImageOps.exif_transpose(image) | |
| if separate_equation_lines: | |
| separate_result = _process_equation_lines_separately(image, infer_crop_mode=infer_crop_mode) | |
| if separate_result is not None: | |
| return separate_result | |
| msg = "No separate equation lines detected. Try Selected Region + freehand highlight around the equation steps." | |
| return msg, msg, msg, None, [] | |
| if task == "✏️ Custom": | |
| prompt = f"<image>\n{custom_prompt.strip()}" | |
| has_grounding = '<|grounding|>' in custom_prompt | |
| elif task == "📍 Locate": | |
| prompt = f"<image>\nLocate <|ref|>{custom_prompt.strip()}<|/ref|> in the image." | |
| has_grounding = True | |
| else: | |
| prompt = TASK_PROMPTS[task]["prompt"] | |
| has_grounding = TASK_PROMPTS[task]["has_grounding"] | |
| result = _infer_with_prompt(image, prompt, crop_mode=infer_crop_mode) | |
| if not result: | |
| return "No text detected", "", "", None, [] | |
| cleaned = clean_output(result, False) | |
| markdown = clean_output(result, True) | |
| img_out = None | |
| crops = [] | |
| figure_crops = [] | |
| result_for_layout = result | |
| if has_grounding and '<|ref|>' in result: | |
| refs = extract_grounding_references(result) | |
| if task == "📋 Markdown" and enable_equation_zoom: | |
| refs.extend(_refine_equation_refs(image, result)) | |
| if refs: | |
| img_out, figure_crops = draw_bounding_boxes(image, refs, True) | |
| crops = _extract_labeled_crops_from_refs(image, refs) | |
| synthetic = [r[0] for r in refs if r[1] == "equation_detail"] | |
| if synthetic: | |
| result_for_layout = result + "\n" + "\n".join(synthetic) | |
| markdown = embed_images(markdown, figure_crops) | |
| if not crops and figure_crops: | |
| crops = _label_gallery_items(figure_crops, prefix="Figure") | |
| return cleaned, markdown, result_for_layout, img_out, crops | |
| def process_pdf(path, task, custom_prompt, page_num, enable_equation_zoom=True, infer_crop_mode=None, separate_equation_lines=False): | |
| doc = fitz.open(path) | |
| total_pages = len(doc) | |
| if page_num < 1 or page_num > total_pages: | |
| doc.close() | |
| return f"Invalid page number. PDF has {total_pages} pages.", "", "", None, [] | |
| page = doc.load_page(page_num - 1) | |
| pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False) | |
| img = Image.open(BytesIO(pix.tobytes("png"))) | |
| doc.close() | |
| return process_image( | |
| img, | |
| task, | |
| custom_prompt, | |
| enable_equation_zoom=enable_equation_zoom, | |
| infer_crop_mode=infer_crop_mode, | |
| separate_equation_lines=separate_equation_lines, | |
| ) | |
| def process_file(path, task, custom_prompt, page_num, enable_equation_zoom=True, infer_crop_mode=None, separate_equation_lines=False): | |
| if not path: | |
| return "Error: Upload a file", "", "", None, [] | |
| if path.lower().endswith('.pdf'): | |
| return process_pdf( | |
| path, | |
| task, | |
| custom_prompt, | |
| page_num, | |
| enable_equation_zoom=enable_equation_zoom, | |
| infer_crop_mode=infer_crop_mode, | |
| separate_equation_lines=separate_equation_lines, | |
| ) | |
| else: | |
| return process_image( | |
| Image.open(path), | |
| task, | |
| custom_prompt, | |
| enable_equation_zoom=enable_equation_zoom, | |
| infer_crop_mode=infer_crop_mode, | |
| separate_equation_lines=separate_equation_lines, | |
| ) | |
| def _extract_editor_background(editor_value): | |
| if editor_value is None: | |
| return None | |
| if isinstance(editor_value, Image.Image): | |
| return editor_value | |
| if isinstance(editor_value, dict): | |
| background = editor_value.get("background") | |
| if isinstance(background, Image.Image): | |
| return background | |
| composite = editor_value.get("composite") | |
| if isinstance(composite, Image.Image): | |
| return composite | |
| return None | |
| def _to_rgba_image(obj): | |
| if isinstance(obj, dict): | |
| for k in ("image", "layer", "composite", "background", "mask"): | |
| if k in obj: | |
| return _to_rgba_image(obj[k]) | |
| return None | |
| if isinstance(obj, Image.Image): | |
| return obj.convert("RGBA") | |
| if isinstance(obj, np.ndarray): | |
| arr = obj | |
| if arr.ndim == 2: | |
| arr = np.stack([arr, arr, arr, np.full_like(arr, 255)], axis=-1) | |
| elif arr.ndim == 3 and arr.shape[2] == 3: | |
| alpha = np.full((arr.shape[0], arr.shape[1], 1), 255, dtype=arr.dtype) | |
| arr = np.concatenate([arr, alpha], axis=2) | |
| elif arr.ndim != 3 or arr.shape[2] != 4: | |
| return None | |
| return Image.fromarray(arr.astype(np.uint8), mode="RGBA") | |
| return None | |
| def _to_mask_array(obj): | |
| if obj is None: | |
| return None | |
| if isinstance(obj, dict): | |
| for k in ("mask", "image", "layer", "composite", "background"): | |
| if k in obj: | |
| arr = _to_mask_array(obj[k]) | |
| if arr is not None: | |
| return arr | |
| return None | |
| if isinstance(obj, Image.Image): | |
| arr = np.asarray(obj) | |
| elif isinstance(obj, np.ndarray): | |
| arr = obj | |
| else: | |
| return None | |
| if arr.ndim == 2: | |
| return arr > 0 | |
| if arr.ndim == 3: | |
| if arr.shape[2] >= 4: | |
| return arr[:, :, 3] > 0 | |
| return np.max(arr[:, :, :3], axis=2) > 0 | |
| return None | |
| def _locate_patch_bbox(base_image: Image.Image, patch_image: Image.Image): | |
| """Approximate patch location in base image using downscaled SSD search.""" | |
| if base_image is None or patch_image is None: | |
| return None | |
| base = np.asarray(base_image.convert("L"), dtype=np.float32) | |
| patch = np.asarray(patch_image.convert("L"), dtype=np.float32) | |
| bh, bw = base.shape[:2] | |
| ph, pw = patch.shape[:2] | |
| if ph <= 0 or pw <= 0 or ph > bh or pw > bw: | |
| return None | |
| max_dim = max(bh, bw) | |
| scale = min(1.0, 320.0 / max_dim) if max_dim > 0 else 1.0 | |
| if scale < 1.0: | |
| new_bw = max(1, int(round(bw * scale))) | |
| new_bh = max(1, int(round(bh * scale))) | |
| new_pw = max(1, int(round(pw * scale))) | |
| new_ph = max(1, int(round(ph * scale))) | |
| base_small = np.asarray(Image.fromarray(base.astype(np.uint8)).resize((new_bw, new_bh), Image.Resampling.BILINEAR), dtype=np.float32) | |
| patch_small = np.asarray(Image.fromarray(patch.astype(np.uint8)).resize((new_pw, new_ph), Image.Resampling.BILINEAR), dtype=np.float32) | |
| else: | |
| base_small = base | |
| patch_small = patch | |
| sbh, sbw = base_small.shape | |
| sph, spw = patch_small.shape | |
| if sph > sbh or spw > sbw: | |
| return None | |
| best_score = float("inf") | |
| best_x = 0 | |
| best_y = 0 | |
| for y in range(sbh - sph + 1): | |
| row = base_small[y:y + sph, :] | |
| windows = np.lib.stride_tricks.sliding_window_view(row, spw, axis=1) | |
| # windows: (sph, sbw-spw+1, spw) | |
| diff = windows - patch_small[:, None, :] | |
| scores = np.mean(diff * diff, axis=(0, 2)) | |
| x = int(np.argmin(scores)) | |
| score = float(scores[x]) | |
| if score < best_score: | |
| best_score = score | |
| best_x = x | |
| best_y = y | |
| if scale < 1.0: | |
| x1 = int(round(best_x / scale)) | |
| y1 = int(round(best_y / scale)) | |
| x2 = int(round((best_x + spw) / scale)) | |
| y2 = int(round((best_y + sph) / scale)) | |
| else: | |
| x1, y1, x2, y2 = best_x, best_y, best_x + spw, best_y + sph | |
| x1 = max(0, min(bw - 1, x1)) | |
| y1 = max(0, min(bh - 1, y1)) | |
| x2 = max(x1 + 1, min(bw, x2)) | |
| y2 = max(y1 + 1, min(bh, y2)) | |
| return (x1, y1, x2, y2) | |
| def _component_boxes(binary_mask, min_pixels=24): | |
| h, w = binary_mask.shape | |
| visited = np.zeros((h, w), dtype=bool) | |
| boxes = [] | |
| neighbors = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)] | |
| ys, xs = np.where(binary_mask) | |
| for sy, sx in zip(ys.tolist(), xs.tolist()): | |
| if visited[sy, sx]: | |
| continue | |
| q = deque([(sy, sx)]) | |
| visited[sy, sx] = True | |
| min_x = max_x = sx | |
| min_y = max_y = sy | |
| count = 0 | |
| while q: | |
| y, x = q.popleft() | |
| count += 1 | |
| if x < min_x: | |
| min_x = x | |
| if x > max_x: | |
| max_x = x | |
| if y < min_y: | |
| min_y = y | |
| if y > max_y: | |
| max_y = y | |
| for dy, dx in neighbors: | |
| ny, nx = y + dy, x + dx | |
| if ny < 0 or ny >= h or nx < 0 or nx >= w: | |
| continue | |
| if visited[ny, nx] or not binary_mask[ny, nx]: | |
| continue | |
| visited[ny, nx] = True | |
| q.append((ny, nx)) | |
| if count >= min_pixels: | |
| boxes.append((min_x, min_y, max_x + 1, max_y + 1, count)) | |
| return boxes | |
| def _extract_regions_from_mask(background, mask): | |
| components = _component_boxes(mask, min_pixels=24) | |
| if not components: | |
| return [] | |
| regions = [] | |
| for x1, y1, x2, y2, _ in components: | |
| pad_x = max(2, int((x2 - x1) * 0.02)) | |
| pad_y = max(2, int((y2 - y1) * 0.02)) | |
| px1 = max(0, x1 - pad_x) | |
| py1 = max(0, y1 - pad_y) | |
| px2 = min(background.width, x2 + pad_x) | |
| py2 = min(background.height, y2 + pad_y) | |
| if px2 <= px1 or py2 <= py1: | |
| continue | |
| crop = background.crop((px1, py1, px2, py2)).convert("RGB") | |
| regions.append((crop, (px1, py1, px2, py2))) | |
| regions.sort( | |
| key=lambda item: (item[1][2] - item[1][0]) * (item[1][3] - item[1][1]), | |
| reverse=True, | |
| ) | |
| return regions | |
| def _editor_background_and_mask(editor_value): | |
| if not isinstance(editor_value, dict): | |
| return None, None | |
| background = _to_rgba_image(editor_value.get("background")) | |
| if background is None: | |
| background = _to_rgba_image(editor_value.get("image")) | |
| composite = _to_rgba_image(editor_value.get("composite")) | |
| layers = editor_value.get("layers") or [] | |
| if background is None: | |
| if composite is None: | |
| return None, None | |
| background = composite | |
| mask = _to_mask_array(editor_value.get("mask")) | |
| if mask is not None: | |
| if mask.shape[:2] != (background.height, background.width): | |
| mask_img = Image.fromarray(mask.astype(np.uint8) * 255, mode="L") | |
| nearest = Image.Resampling.NEAREST if hasattr(Image, "Resampling") else Image.NEAREST | |
| mask = np.asarray(mask_img.resize((background.width, background.height), nearest)) > 0 | |
| return background, mask | |
| if not isinstance(layers, list) or not layers: | |
| return background, None | |
| alpha_acc = np.zeros((background.height, background.width), dtype=np.uint8) | |
| for layer in layers: | |
| layer_img = _to_rgba_image(layer) | |
| if layer_img is None: | |
| continue | |
| if layer_img.size != background.size: | |
| nearest = Image.Resampling.NEAREST if hasattr(Image, "Resampling") else Image.NEAREST | |
| layer_img = layer_img.resize(background.size, nearest) | |
| layer_alpha = np.asarray(layer_img, dtype=np.uint8)[:, :, 3] | |
| alpha_acc = np.maximum(alpha_acc, layer_alpha) | |
| return background, (alpha_acc > 0) | |
| def _extract_selected_regions(editor_value, base_size=None, base_image=None): | |
| if editor_value is None: | |
| return [] | |
| if isinstance(editor_value, Image.Image): | |
| if base_size and tuple(editor_value.size) == tuple(base_size): | |
| return [] | |
| bbox = _locate_patch_bbox(base_image, editor_value) if base_image is not None else None | |
| return [(editor_value, bbox)] | |
| if not isinstance(editor_value, dict): | |
| return [] | |
| background, mask = _editor_background_and_mask(editor_value) | |
| layers = editor_value.get("layers") or [] | |
| if background is None: | |
| return [] | |
| if not isinstance(layers, list) or not layers: | |
| # No annotation layers; treat as explicit crop only if size changed from base. | |
| if base_size and tuple(background.size) == tuple(base_size): | |
| return [] | |
| patch = background.convert("RGB") | |
| bbox = _locate_patch_bbox(base_image, patch) if base_image is not None else None | |
| return [(patch, bbox)] | |
| if mask is None: | |
| return [] | |
| return _extract_regions_from_mask(background, mask) | |
| def _extract_new_drawn_regions(editor_value, base_size=None, base_image=None, consumed_mask=None): | |
| # For crop mode / explicit cropped image, fall back to classic extraction. | |
| if isinstance(editor_value, Image.Image): | |
| regions = _extract_selected_regions(editor_value, base_size=base_size, base_image=base_image) | |
| return regions, consumed_mask | |
| if not isinstance(editor_value, dict): | |
| return [], consumed_mask | |
| background, mask = _editor_background_and_mask(editor_value) | |
| layers = editor_value.get("layers") or [] | |
| if background is None: | |
| return [], consumed_mask | |
| has_layer_data = isinstance(layers, list) and len(layers) > 0 | |
| has_draw_data = (mask is not None) or has_layer_data | |
| # If there are no draw layers/mask, treat as explicit crop mode. | |
| if not has_draw_data: | |
| regions = _extract_selected_regions(editor_value, base_size=base_size, base_image=base_image) | |
| return regions, consumed_mask | |
| if mask is None: | |
| return [], consumed_mask | |
| if consumed_mask is None or not isinstance(consumed_mask, np.ndarray) or consumed_mask.shape != mask.shape: | |
| delta_mask = mask | |
| else: | |
| delta_mask = np.logical_and(mask, np.logical_not(consumed_mask)) | |
| regions = _extract_regions_from_mask(background, delta_mask) | |
| return regions, mask | |
| def _extract_selected_region(editor_value, base_size=None, base_image=None): | |
| regions = _extract_selected_regions(editor_value, base_size=base_size, base_image=base_image) | |
| if not regions: | |
| return None, None | |
| return regions[0] | |
| def _bbox_overlap_ratio(a, b): | |
| ax1, ay1, ax2, ay2 = a | |
| bx1, by1, bx2, by2 = b | |
| ix1 = max(ax1, bx1) | |
| iy1 = max(ay1, by1) | |
| ix2 = min(ax2, bx2) | |
| iy2 = min(ay2, by2) | |
| if ix2 <= ix1 or iy2 <= iy1: | |
| return 0.0, 0.0 | |
| inter = float((ix2 - ix1) * (iy2 - iy1)) | |
| area_a = float(max(1, (ax2 - ax1) * (ay2 - ay1))) | |
| area_b = float(max(1, (bx2 - bx1) * (by2 - by1))) | |
| return inter / area_a, inter / area_b | |
| def _is_duplicate_bbox(candidate_bbox, existing_bbox): | |
| iou = _box_iou(candidate_bbox, existing_bbox) | |
| cover_cand, cover_exist = _bbox_overlap_ratio(candidate_bbox, existing_bbox) | |
| return iou >= 0.85 or cover_cand >= 0.92 or cover_exist >= 0.97 | |
| def _draw_selected_region_boxes(image, boxes): | |
| if image is None or not boxes: | |
| return None | |
| refs = [] | |
| w, h = image.size | |
| for i, b in enumerate(boxes, 1): | |
| x1, y1, x2, y2 = b | |
| nx1 = max(0.0, min(999.0, x1 / max(1, w) * 999.0)) | |
| ny1 = max(0.0, min(999.0, y1 / max(1, h) * 999.0)) | |
| nx2 = max(0.0, min(999.0, x2 / max(1, w) * 999.0)) | |
| ny2 = max(0.0, min(999.0, y2 / max(1, h) * 999.0)) | |
| label = f"Region {i}" | |
| coord_text = repr([[nx1, ny1, nx2, ny2]]) | |
| raw = f'<|ref|>region_{i}<|/ref|><|det|>{coord_text}<|/det|>' | |
| refs.append((raw, label, coord_text)) | |
| img_out, _ = draw_bounding_boxes(image, refs, extract_images=False) | |
| return img_out | |
| def _region_gallery_items(regions): | |
| items = [] | |
| for i, r in enumerate(regions, 1): | |
| img = r["image"] | |
| label = f"Region {i}" | |
| if isinstance(img, Image.Image): | |
| label = f"{label} ({img.width}x{img.height})" | |
| items.append((img, label)) | |
| return items | |
| def _label_gallery_items(items, prefix=None): | |
| labeled = [] | |
| for i, item in enumerate(items, 1): | |
| if isinstance(item, tuple) and len(item) >= 2: | |
| img, label = item[0], str(item[1]) | |
| else: | |
| img, label = item, f"Item {i}" | |
| if prefix: | |
| label = f"{prefix} - {label}" | |
| if isinstance(img, Image.Image): | |
| label = f"{label} ({img.width}x{img.height})" | |
| labeled.append((img, label)) | |
| return labeled | |
| def _reset_selected_regions(): | |
| return [], [], "No saved regions." | |
| def _reset_drawn_mask(): | |
| return None | |
| def add_selected_region(editor_value, base_size, base_image, selected_regions, consumed_mask): | |
| candidates, updated_mask = _extract_new_drawn_regions( | |
| editor_value, | |
| base_size=base_size, | |
| base_image=base_image, | |
| consumed_mask=consumed_mask, | |
| ) | |
| regions = list(selected_regions or []) | |
| if not candidates: | |
| msg = "No region detected. Use Crop or draw/highlight a region first." | |
| return regions, _region_gallery_items(regions), msg, updated_mask | |
| existing_boxes = [r.get("bbox") for r in regions if r.get("bbox") is not None] | |
| added = 0 | |
| for region_img, bbox in candidates: | |
| if bbox is not None and any(_is_duplicate_bbox(bbox, eb) for eb in existing_boxes): | |
| continue | |
| regions.append({"image": region_img, "bbox": bbox}) | |
| if bbox is not None: | |
| existing_boxes.append(bbox) | |
| added += 1 | |
| if added == 0: | |
| msg = "No new region added. Draw one region, click Add Region, then draw the next region." | |
| return regions, _region_gallery_items(regions), msg, updated_mask | |
| msg = f"Added {added} region(s). {len(regions)} total. Zoom/pan is preserved." | |
| return regions, _region_gallery_items(regions), msg, updated_mask | |
| def clear_selected_regions(): | |
| return _reset_selected_regions() | |
| def clear_regions_preserve_view(editor_value): | |
| regions, gallery_items, msg = _reset_selected_regions() | |
| _, mask = _editor_background_and_mask(editor_value) | |
| return regions, gallery_items, msg, mask | |
| def _compose_ui_outputs(cleaned, markdown, raw, img_out, gallery_items): | |
| text_display = re.sub( | |
| r'\\\[(.+?)\\\]', | |
| lambda m: f'\n$$\n{m.group(1).strip()}\n$$\n', | |
| cleaned, | |
| flags=re.DOTALL | |
| ) | |
| text_display = re.sub(r'\\\((.+?)\\\)', lambda m: f'${m.group(1).strip()}$', text_display) | |
| dl_tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.md', mode='w', encoding='utf-8') | |
| dl_tmp.write(cleaned) | |
| dl_tmp.close() | |
| markdown_html = to_math_html(markdown) | |
| return ( | |
| text_display, | |
| cleaned, | |
| markdown_html, | |
| raw, | |
| img_out, | |
| gallery_items, | |
| gr.DownloadButton(value=dl_tmp.name, visible=True), | |
| ) | |
| def toggle_prompt(task): | |
| if task == "✏️ Custom": | |
| return gr.update(visible=True, label="Custom Prompt", placeholder="Add <|grounding|> for bounding boxes") | |
| elif task == "📍 Locate": | |
| return gr.update(visible=True, label="Text to Locate", placeholder="Enter text to locate") | |
| return gr.update(visible=False) | |
| def select_boxes(task): | |
| if task == "📍 Locate": | |
| return gr.update(selected="tab_boxes") | |
| return gr.update() | |
| def toggle_scope_ui(scope): | |
| if scope == "Selected Region": | |
| hint = ( | |
| "**Selected Region mode:** Draw/highlight on the workspace, click **Add Region** " | |
| "for each target area, then click **Extract**." | |
| ) | |
| return ( | |
| gr.update(value=hint), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| ) | |
| hint = "**Entire Page mode:** No drawing needed. Click **Extract** to process the full page." | |
| return ( | |
| gr.update(value=hint), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| def select_post_extract_tab(task, scope): | |
| if scope == "Selected Region" or task == "📍 Locate": | |
| return gr.update(selected="tab_boxes") | |
| return gr.update(selected="tab_text") | |
| def get_pdf_page_count(file_path): | |
| if not file_path or not file_path.lower().endswith('.pdf'): | |
| return 1 | |
| doc = fitz.open(file_path) | |
| count = len(doc) | |
| doc.close() | |
| return count | |
| def load_image(file_path, page_num=1): | |
| if not file_path: | |
| return None | |
| if file_path.lower().endswith('.pdf'): | |
| doc = fitz.open(file_path) | |
| page_idx = max(0, min(int(page_num) - 1, len(doc) - 1)) | |
| page = doc.load_page(page_idx) | |
| pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False) | |
| img = Image.open(BytesIO(pix.tobytes("png"))) | |
| doc.close() | |
| return img | |
| else: | |
| return Image.open(file_path) | |
| def _scale_workspace_image(img, workspace_scale): | |
| if img is None: | |
| return None | |
| # Keep native pixels for workspace quality. Gradio's in-canvas zoom controls | |
| # visual scale; pre-resampling here causes blurry math when users zoom in. | |
| return img | |
| def _prepare_workspace_image(img, workspace_scale=WORKSPACE_DEFAULT_SCALE): | |
| if img is None: | |
| return None, None, None | |
| display_img = _scale_workspace_image(img, workspace_scale) | |
| return display_img, (int(display_img.width), int(display_img.height)), display_img | |
| def load_image_with_size(file_path, page_num=1, workspace_scale=WORKSPACE_DEFAULT_SCALE): | |
| img = load_image(file_path, page_num) | |
| return _prepare_workspace_image(img, workspace_scale) | |
| def load_example_into_workspace(example_value): | |
| if example_value is None: | |
| return None, None, None | |
| file_path = None | |
| if isinstance(example_value, os.PathLike): | |
| file_path = os.fspath(example_value) | |
| elif isinstance(example_value, str): | |
| file_path = example_value | |
| elif isinstance(example_value, dict): | |
| path_candidate = example_value.get("path") or example_value.get("name") | |
| if isinstance(path_candidate, os.PathLike): | |
| file_path = os.fspath(path_candidate) | |
| elif isinstance(path_candidate, str): | |
| file_path = path_candidate | |
| elif isinstance(example_value, (list, tuple)) and example_value: | |
| first = example_value[0] | |
| if isinstance(first, os.PathLike): | |
| file_path = os.fspath(first) | |
| elif isinstance(first, str): | |
| file_path = first | |
| if file_path: | |
| img = load_image(file_path, 1) | |
| return _prepare_workspace_image(img, WORKSPACE_DEFAULT_SCALE) | |
| if isinstance(example_value, Image.Image): | |
| img = example_value | |
| else: | |
| maybe_rgba = _to_rgba_image(example_value) | |
| if maybe_rgba is None: | |
| return None, None, None | |
| img = maybe_rgba.convert("RGB") | |
| return _prepare_workspace_image(img, WORKSPACE_DEFAULT_SCALE) | |
| def load_example_into_workspace_and_reset(example_value): | |
| display_img, base_size, base_img = load_example_into_workspace(example_value) | |
| return display_img, base_size, base_img, [], [], "No saved regions.", None | |
| def sync_workspace_state(editor_value, current_base_image): | |
| background = _extract_editor_background(editor_value) | |
| if isinstance(background, Image.Image): | |
| return (int(background.width), int(background.height)), background | |
| if isinstance(current_base_image, Image.Image): | |
| return (int(current_base_image.width), int(current_base_image.height)), current_base_image | |
| return None, None | |
| def update_page_selector(file_path): | |
| if not file_path: | |
| return gr.update(visible=False) | |
| if file_path.lower().endswith('.pdf'): | |
| page_count = get_pdf_page_count(file_path) | |
| return gr.update(visible=True, maximum=page_count, value=1, minimum=1, | |
| label=f"Select Page (1-{page_count})") | |
| return gr.update(visible=False) | |
| blocks_kwargs = {"title": "DeepSeek-OCR-2"} | |
| if hasattr(gr, "themes") and hasattr(gr.themes, "Soft"): | |
| try: | |
| blocks_kwargs["theme"] = gr.themes.Soft() | |
| except Exception: | |
| pass | |
| with gr.Blocks(**blocks_kwargs) as demo: | |
| gr.Markdown(""" | |
| # 🧮 DeepSeek-OCR-2 — Math Rendering Edition | |
| **Convert documents to markdown, extract text, parse figures, and locate specific content with bounding boxes.** | |
| **Model uses DeepEncoder v2 and achieves 91.09% on OmniDocBench (+3.73% over v1).** | |
| Built on the original [DeepSeek-OCR-2 Demo](https://huggingface.co/spaces/merterbak/DeepSeek-OCR-2) by **Mert Erbak** — thank you for the excellent foundation. | |
| This fork adds **math rendering** in the Markdown Preview tab so that equations from scanned papers and textbooks display as proper math notation. | |
| """) | |
| region_editor = None | |
| workspace_base_size = gr.State(None) | |
| workspace_base_image = gr.State(None) | |
| selected_regions_state = gr.State([]) | |
| drawn_mask_state = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| file_in = gr.File(label="Upload Image or PDF", file_types=["image", ".pdf"], type="filepath") | |
| with gr.Column(scale=1): | |
| page_selector = gr.Number(label="Select Page", value=1, minimum=1, step=1, visible=False) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| workspace_hint = gr.Markdown("**Entire Page mode:** No drawing needed. Click **Extract** to process the full page.") | |
| gr.Markdown("**Image Workspace (full page + region selection)**") | |
| if HAS_REGION_WORKSPACE: | |
| editor_kwargs = {} | |
| if HAS_BRUSH: | |
| try: | |
| highlight = ("#2563eb", 0.35) | |
| editor_kwargs["brush"] = gr.Brush( | |
| colors=[highlight], | |
| default_color=highlight, | |
| color_mode="fixed", | |
| default_size=22, | |
| ) | |
| except TypeError: | |
| try: | |
| editor_kwargs["brush"] = gr.Brush( | |
| colors=["rgba(37,99,235,0.35)"], | |
| default_color="rgba(37,99,235,0.35)", | |
| color_mode="fixed", | |
| default_size=22, | |
| ) | |
| except TypeError: | |
| editor_kwargs["brush"] = gr.Brush() | |
| if HAS_ERASER: | |
| try: | |
| editor_kwargs["eraser"] = gr.Eraser(default_size=26) | |
| except TypeError: | |
| editor_kwargs["eraser"] = gr.Eraser() | |
| if HAS_IMAGE_EDITOR: | |
| try: | |
| region_editor = gr.ImageEditor( | |
| label="Image Workspace", | |
| show_label=False, | |
| type="pil", | |
| height=WORKSPACE_EDITOR_HEIGHT, | |
| **editor_kwargs, | |
| ) | |
| except TypeError: | |
| try: | |
| region_editor = gr.ImageEditor( | |
| label="Image Workspace", | |
| show_label=False, | |
| height=WORKSPACE_EDITOR_HEIGHT, | |
| **editor_kwargs, | |
| ) | |
| except TypeError: | |
| region_editor = gr.ImageEditor( | |
| label="Image Workspace", | |
| show_label=False, | |
| height=WORKSPACE_EDITOR_HEIGHT, | |
| ) | |
| else: | |
| region_editor = gr.Paint( | |
| label="Image Workspace", | |
| show_label=False, | |
| type="pil", | |
| height=WORKSPACE_EDITOR_HEIGHT, | |
| **editor_kwargs, | |
| ) | |
| else: | |
| gr.Markdown("Region drawing requires a newer Gradio version with `Paint` or `ImageEditor` support.") | |
| region_editor = gr.State(None) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### OCR Workflow") | |
| task = gr.Dropdown(list(TASK_PROMPTS.keys()), value="📋 Markdown", label="Task") | |
| input_scope = gr.Radio(["Entire Page", "Selected Region"], value="Entire Page", label="Input Scope") | |
| selection_controls = gr.Row(visible=False) | |
| with selection_controls: | |
| add_region_btn = gr.Button("Add Region", variant="secondary") | |
| clear_regions_btn = gr.Button("Clear Regions") | |
| selection_status = gr.Textbox(label="Region Selection Status", value="No saved regions.", interactive=False, visible=False) | |
| selected_regions_gallery = gr.Gallery( | |
| label="Selected Regions", | |
| show_label=True, | |
| columns=2, | |
| height=190, | |
| visible=False, | |
| object_fit="contain", | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| equation_zoom = gr.Checkbox(label="Equation Zoom (multipass)", value=False) | |
| separate_eq_lines = gr.Checkbox(label="Detect Equation Lines Separately", value=False) | |
| prompt = gr.Textbox(label="Prompt", lines=2, visible=False) | |
| btn = gr.Button("Extract", variant="primary", size="lg") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Tabs() as tabs: | |
| with gr.Tab("Text", id="tab_text"): | |
| text_out = gr.Textbox(lines=20, show_label=False) | |
| with gr.Tab("LaTeX", id="tab_text_latex"): | |
| latex_out = gr.Textbox(lines=20, show_label=False) | |
| with gr.Tab("Preview", id="tab_markdown"): | |
| md_out = gr.HTML("") | |
| with gr.Tab("Boxes", id="tab_boxes"): | |
| img_out = gr.Image(type="pil", height=560, show_label=False) | |
| with gr.Tab("Crops", id="tab_crops"): | |
| gallery = gr.Gallery(show_label=False, columns=3, height=420, object_fit="contain") | |
| with gr.Tab("Raw", id="tab_raw"): | |
| raw_out = gr.Textbox(lines=20, show_label=False) | |
| download_btn = gr.DownloadButton("Download Markdown", visible=False, variant="secondary") | |
| gr.Markdown("### Examples") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_examples = [ | |
| "examples/2022-0922 Section 13 Notes.png", | |
| "examples/2022-0922 Section 14 Notes.png", | |
| "examples/2022-0922 Section 15 Notes.png", | |
| ] | |
| if HAS_REGION_WORKSPACE and region_editor is not None: | |
| image_examples_input = gr.Image( | |
| label="Example Loader", | |
| type="filepath", | |
| visible=False, | |
| show_label=False, | |
| ) | |
| gr.Examples( | |
| label="Image Examples (click thumbnail to load into workspace)", | |
| examples=image_examples, | |
| inputs=[image_examples_input], | |
| outputs=[region_editor, workspace_base_size, workspace_base_image, selected_regions_state, selected_regions_gallery, selection_status, drawn_mask_state], | |
| fn=load_example_into_workspace_and_reset, | |
| run_on_click=True, | |
| cache_examples=False, | |
| ) | |
| else: | |
| gr.Examples( | |
| label="Image Examples", | |
| examples=[[p] for p in image_examples], | |
| inputs=[file_in], | |
| cache_examples=False, | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Examples( | |
| label="PDF Examples", | |
| examples=[["examples/Gursoy Class Notes_ Accessibility Sandbox.pdf"]], | |
| inputs=[file_in], | |
| cache_examples=False, | |
| ) | |
| with gr.Accordion("ℹ️ Info", open=False): | |
| gr.Markdown(""" | |
| ### Configuration | |
| 1024 base + 768 patches with dynamic cropping (2-6 patches). 144 tokens per patch + 256 base tokens. | |
| ### Faculty Quick Workflow | |
| 1. Upload a page/image, then confirm **Task**. | |
| 2. Choose **Input Scope**: | |
| - `Entire Page` for the full page. | |
| - `Selected Region` for a specific area. | |
| 2a. Workspace keeps native image resolution for clarity. For very tall pages, it auto-boosts from tiny fit view toward ~88% width-friendly zoom. | |
| 3. For `Selected Region`, use the **Image Workspace**: | |
| - Recommended: freehand selection (draw/highlight target); app uses an automatic bounding box around your marks. | |
| - Optional rectangle selection: use the **Crop** tool. | |
| - Freehand/highlight ink is semi-transparent so underlying content stays visible. | |
| - Current known behavior: after zooming in/out, freehand stroke display may appear fully on mouse release (selection is still captured correctly). | |
| - Optional multi-select: click **Add Region** after each selection. | |
| - **Add Region** snapshots only newly drawn pixels so zoom/pan stays in place while you continue selecting. | |
| Then click **Extract**. | |
| 4. Use **Clear Regions** to reset multi-select state. | |
| 5. Review **Cropped Images** and **Boxes**: both are labeled `Region 1`, `Region 2`, etc. | |
| 6. Use **Advanced Options** only when needed (Equation Zoom / line-by-line equation OCR). | |
| ### Tasks | |
| - **Markdown**: Convert document to structured markdown with layout detection (grounding ✅) | |
| - **Free OCR**: Read all visible text from the full page/image (no boxes, no targeting) | |
| - **Locate**: Find and highlight where specific text appears (grounding ✅) | |
| - **Describe**: General image description | |
| - **Custom**: Your own prompt | |
| - **Region selection**: Use **Input Scope=Selected Region**, draw/crop in the Image Workspace, then click **Extract** | |
| - **Input Scope**: `Entire Page` or `Selected Region` (Selected Region uses the workspace crop as main input) | |
| - **Equation Zoom (multipass)**: Optional nested equation refinement for Markdown. Off by default for speed/stability. | |
| - **Detect Equation Lines Separately**: Detects likely equation-line boxes and OCRs each line independently to reduce merged multi-step equations. | |
| ### Free OCR vs Locate (important) | |
| - **Free OCR does not take a selected region**. It runs OCR on the whole image/page. | |
| - If you want OCR for one area only, crop that area first, then run **Free OCR** on the cropped image. | |
| - If you want to keep the full page but highlight where text appears, use **Locate** and enter the text to search. | |
| - For advanced region workflows, use **Custom** with `<|grounding|>` in the prompt. | |
| ### Special Tokens | |
| - `<image>` - Placeholder where visual tokens are inserted | |
| - `<|grounding|>` - Enables layout detection with bounding boxes | |
| - `<|ref|>text<|/ref|>` - Reference text to locate in the image | |
| """) | |
| file_in.change(update_page_selector, [file_in], [page_selector]) | |
| task.change(toggle_prompt, [task], [prompt]) | |
| task.change(select_boxes, [task], [tabs]) | |
| input_scope.change(toggle_scope_ui, [input_scope], [workspace_hint, selection_controls, selection_status, selected_regions_gallery]) | |
| if HAS_REGION_WORKSPACE and region_editor is not None: | |
| file_in.change(load_image_with_size, [file_in, page_selector], [region_editor, workspace_base_size, workspace_base_image]) | |
| page_selector.change(load_image_with_size, [file_in, page_selector], [region_editor, workspace_base_size, workspace_base_image]) | |
| region_editor.change(sync_workspace_state, [region_editor, workspace_base_image], [workspace_base_size, workspace_base_image]) | |
| file_in.change(_reset_selected_regions, outputs=[selected_regions_state, selected_regions_gallery, selection_status]) | |
| page_selector.change(_reset_selected_regions, outputs=[selected_regions_state, selected_regions_gallery, selection_status]) | |
| file_in.change(_reset_drawn_mask, outputs=[drawn_mask_state]) | |
| page_selector.change(_reset_drawn_mask, outputs=[drawn_mask_state]) | |
| add_region_btn.click( | |
| add_selected_region, | |
| [region_editor, workspace_base_size, workspace_base_image, selected_regions_state, drawn_mask_state], | |
| [selected_regions_state, selected_regions_gallery, selection_status, drawn_mask_state], | |
| ) | |
| clear_regions_btn.click( | |
| clear_regions_preserve_view, | |
| inputs=[region_editor], | |
| outputs=[selected_regions_state, selected_regions_gallery, selection_status, drawn_mask_state], | |
| ) | |
| def run(file_path, task, custom_prompt, page_num, enable_equation_zoom, detect_eq_lines, scope, region_value, base_size, base_image, selected_regions): | |
| if scope == "Selected Region": | |
| regions = list(selected_regions or []) | |
| if not regions: | |
| selected_region, selected_bbox = _extract_selected_region(region_value, base_size=base_size, base_image=base_image) | |
| if selected_region is None: | |
| msg = "Select Input Scope=Selected Region, then crop or annotate a target area in the Image Workspace first." | |
| return (msg, "", "", "", None, [], gr.DownloadButton(visible=False)) | |
| regions = [{"image": selected_region, "bbox": selected_bbox}] | |
| cleaned_parts = [] | |
| markdown_parts = [] | |
| raw_parts = [] | |
| line_crops = [] | |
| for i, r in enumerate(regions, 1): | |
| cleaned_i, markdown_i, raw_i, _, crops_i = process_image( | |
| r["image"], | |
| task, | |
| custom_prompt, | |
| enable_equation_zoom=enable_equation_zoom, | |
| infer_crop_mode=False, | |
| separate_equation_lines=detect_eq_lines, | |
| ) | |
| if len(regions) > 1: | |
| cleaned_parts.append(f"## Region {i}\n\n{cleaned_i}") | |
| markdown_parts.append(f"## Region {i}\n\n{markdown_i}") | |
| raw_parts.append(f"## Region {i}\n\n{raw_i}") | |
| else: | |
| cleaned_parts.append(cleaned_i) | |
| markdown_parts.append(markdown_i) | |
| raw_parts.append(raw_i) | |
| if detect_eq_lines and crops_i: | |
| line_crops.extend(_label_gallery_items(crops_i, prefix=f"Region {i}" if len(regions) > 1 else None)) | |
| cleaned = "\n\n".join(cleaned_parts).strip() | |
| markdown = "\n\n".join(markdown_parts).strip() | |
| raw = "\n\n".join(raw_parts).strip() | |
| crops = line_crops if line_crops else _region_gallery_items(regions) | |
| full_img = base_image if isinstance(base_image, Image.Image) else _extract_editor_background(region_value) | |
| region_boxes = [r["bbox"] for r in regions if r.get("bbox") is not None] | |
| img_out = _draw_selected_region_boxes(full_img, region_boxes) | |
| elif file_path: | |
| cleaned, markdown, raw, img_out, crops = process_file( | |
| file_path, | |
| task, | |
| custom_prompt, | |
| int(page_num), | |
| enable_equation_zoom=enable_equation_zoom, | |
| separate_equation_lines=detect_eq_lines, | |
| ) | |
| elif (full_image := _extract_editor_background(region_value)) is not None: | |
| cleaned, markdown, raw, img_out, crops = process_image( | |
| full_image, | |
| task, | |
| custom_prompt, | |
| enable_equation_zoom=enable_equation_zoom, | |
| separate_equation_lines=detect_eq_lines, | |
| ) | |
| elif isinstance(base_image, Image.Image): | |
| # Example clicks can briefly race editor-value hydration on first load. | |
| cleaned, markdown, raw, img_out, crops = process_image( | |
| base_image, | |
| task, | |
| custom_prompt, | |
| enable_equation_zoom=enable_equation_zoom, | |
| separate_equation_lines=detect_eq_lines, | |
| ) | |
| else: | |
| msg = "Error: Upload a file or image" | |
| return (msg, "", "", "", None, [], gr.DownloadButton(visible=False)) | |
| return _compose_ui_outputs(cleaned, markdown, raw, img_out, crops) | |
| submit_event = btn.click( | |
| run, | |
| [file_in, task, prompt, page_selector, equation_zoom, separate_eq_lines, input_scope, region_editor, workspace_base_size, workspace_base_image, selected_regions_state], | |
| [text_out, latex_out, md_out, raw_out, img_out, gallery, download_btn] | |
| ) | |
| submit_event.then(select_post_extract_tab, [task, input_scope], [tabs]) | |
| if __name__ == "__main__": | |
| # server_name="0.0.0.0" is needed locally (WSL2 → Windows access) | |
| # On HuggingFace Spaces, SPACE_ID is set and Gradio handles binding automatically | |
| local = not os.environ.get("SPACE_ID") | |
| queued = demo.queue(max_size=20) | |
| launch_sig = inspect.signature(queued.launch) | |
| launch_kwargs = {} | |
| if "server_name" in launch_sig.parameters: | |
| launch_kwargs["server_name"] = "0.0.0.0" if local else None | |
| if "head" in launch_sig.parameters: | |
| launch_kwargs["head"] = PREVIEW_CSS | |
| if "ssr_mode" in launch_sig.parameters: | |
| launch_kwargs["ssr_mode"] = False # SSR breaks HF Spaces routing in Gradio 6 | |
| queued.launch(**launch_kwargs) | |