from PIL import Image import sys import os import re import random from typing import Dict, Any # Handwrite Transcription models are bundled in this folder MODEL_PATH = os.path.join(os.getcwd(), "handwritten-math-transcription", "checkpoints", "model_v3_0.pth") # MVM2 Configuration for OCR Confidence Weights CRITICAL_OPERATORS = ["\\int", "\\sum", "=", "\\frac", "+", "-", "*", "\\times", "\\div"] BRACKETS_LIMITS = ["(", ")", "[", "]", "\\{", "\\}", "^", "_"] AMBIGUOUS_SYMBOLS = ["8", "B", "0", "O", "l", "1", "I", "S", "5", "Z", "2"] # CJK character ranges (Chinese, Japanese, Korean) including punctuation CJK_PATTERN = re.compile(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af\u3000-\u303f\uff00-\uffef\u3001\u3002\uff0c\uff0e\uff1a\uff1b\uff1f\uff01]') def get_symbol_weight(symbol: str) -> float: if symbol in CRITICAL_OPERATORS: return 1.5 elif symbol in BRACKETS_LIMITS: return 1.3 elif symbol in AMBIGUOUS_SYMBOLS: return 0.7 return 1.0 def calculate_weighted_confidence(latex_string: str, mock_logits: bool = True) -> float: """OCR.conf = sum(W_i * c_i) / sum(W_i)""" tokens = [] current_token = "" for char in latex_string: if char == '\\': if current_token: tokens.append(current_token) current_token = char elif char.isalnum() and current_token.startswith('\\'): current_token += char else: if current_token: tokens.append(current_token) current_token = "" if char.strip(): tokens.append(char) if current_token: tokens.append(current_token) total_weighted_ci = 0.0 total_weights = 0.0 for token in tokens: w_i = get_symbol_weight(token) c_i = random.uniform(0.85, 0.99) if mock_logits else 0.95 total_weighted_ci += (w_i * c_i) total_weights += w_i if total_weights == 0: return 0.0 return round(total_weighted_ci / total_weights, 4) def clean_latex_output(text: str) -> str: """Aggressively remove CJK characters and punctuation from OCR output.""" if not text: return "" cleaned = CJK_PATTERN.sub('', text) # Remove common conversational noise cleaned = re.sub(r'(?i)\b(solve|find|evaluate|simplify)\b', '', cleaned) cleaned = re.sub(r'\s{2,}', ' ', cleaned).strip() return cleaned def extract_latex_from_pix2text(out) -> str: """Safely extract LaTeX text from pix2text output regardless of return type.""" if isinstance(out, str): return clean_latex_output(out) elif isinstance(out, list): parts = [] for item in out: if isinstance(item, dict): text = item.get('text', '') or item.get('latex', '') # Only keep items that look like math or plain text (skip pure OCR text blocks with CJK) text = clean_latex_output(str(text)) if text.strip(): parts.append(text.strip()) elif hasattr(item, 'text'): text = clean_latex_output(str(item.text)) if text.strip(): parts.append(text.strip()) return ' '.join(parts) elif hasattr(out, 'to_markdown'): return clean_latex_output(out.to_markdown()) else: return clean_latex_output(str(out)) class MVM2OCREngine: def __init__(self): self.model_loaded = False self.p2t = None try: from pix2text import Pix2Text # Use mixed mode: recognizes both formula and text regions self.p2t = Pix2Text.from_config() self.model_loaded = True print("[OCR] Pix2Text loaded successfully.") except Exception as e: print(f"[OCR] Warning: Pix2Text unavailable ({e}). Using simulation mode.") self.transcriber = None try: from handwriting_transcriber import HandwritingTranscriber if os.path.exists(MODEL_PATH): self.transcriber = HandwritingTranscriber(model_path=MODEL_PATH) print(f"[OCR] HandwritingTranscriber loaded with model: {MODEL_PATH}") else: print(f"[OCR] Warning: Handwriting model not found at {MODEL_PATH}") except Exception as e: print(f"[OCR] Warning: HandwritingTranscriber unavailable ({e})") def _extract_formulas_only(self, pix2text_output) -> str: """Extract ONLY math formula regions, discarding prose text regions.""" if isinstance(pix2text_output, str): if any(op in pix2text_output for op in ['\\', '^', '_', '=', '+', '-']): return clean_latex_output(pix2text_output) return "" if isinstance(pix2text_output, list): formula_parts = [] for item in pix2text_output: if isinstance(item, dict): item_type = item.get('type', 'text') if item_type in ('isolated_equation', 'embedding', 'formula', 'math'): text = item.get('text', '') or item.get('latex', '') text = clean_latex_output(str(text)).strip() if text: formula_parts.append(text) elif item_type == 'text': raw = item.get('text', '') inline = re.findall(r'\$(.*?)\$|\\\((.*?)\\\)', raw) for match in inline: part = match[0] or match[1] if part.strip(): formula_parts.append(clean_latex_output(part)) return '\n'.join(formula_parts) return "" def process_image(self, image_path: str) -> Dict[str, Any]: """Full OCR pipeline: formula-first mode with prose filtering and confidence scoring.""" if not os.path.exists(image_path): return {"error": f"Image not found: {image_path}", "latex_output": "", "weighted_confidence": 0.0} try: with Image.open(image_path) as img: width, height = img.size if width == 0 or height == 0: return {"error": "Zero-size image", "latex_output": "", "weighted_confidence": 0.0} except Exception as e: return {"error": f"Invalid image: {e}", "latex_output": "", "weighted_confidence": 0.0} raw_latex = "" if self.model_loaded and self.p2t: try: # --- PASS 1: Formula-only mode (cleanest LaTeX output) --- try: formula_out = self.p2t.recognize_formula(image_path) raw_latex = clean_latex_output(str(formula_out)).strip() if "\\newcommand" in raw_latex or "\\def" in raw_latex: print(f"[OCR] Pass 1 hallucinated preamble macros. Rejecting output.") raw_latex = "" else: print(f"[OCR] Pass 1 (formula mode): {raw_latex[:80]}") except Exception as e1: print(f"[OCR] Pass 1 formula mode failed: {e1}") raw_latex = "" # --- PASS 2: General recognize(), extract formula regions only --- if not raw_latex or len(raw_latex) < 3: out2 = self.p2t.recognize(image_path) raw_latex = self._extract_formulas_only(out2) print(f"[OCR] Pass 2 (formula extraction): {raw_latex[:80]}") # --- PASS 3: Full text fallback --- if not raw_latex.strip(): raw_latex = extract_latex_from_pix2text(out2 if 'out2' in dir() else "") if not raw_latex.strip(): raw_latex = "No mathematical formula detected." except Exception as e: print(f"[OCR] Inference error: {e}") raw_latex = f"OCR Error: {str(e)}" else: raw_latex = "No math detected (OCR model not loaded)." raw_latex = clean_latex_output(raw_latex) if (not raw_latex.strip() or "No math" in raw_latex) and self.transcriber and image_path.endswith('.inkml'): try: raw_latex, _ = self.transcriber.transcribe_inkml(image_path) print(f"[OCR] Used HandwritingTranscriber for InkML: {raw_latex}") except Exception as e: print(f"[OCR] HandwritingTranscriber error: {e}") ocr_conf = calculate_weighted_confidence(raw_latex) return { "latex_output": raw_latex, "weighted_confidence": ocr_conf, "backend": "handwriting" if self.transcriber and image_path.endswith('.inkml') else ( "pix2text-formula" if self.model_loaded else "simulation" ) }