mvm2-math-verification / ocr_module.py
Varshithdharmajv's picture
Upload ocr_module.py with huggingface_hub
7ffaec2 verified
from PIL import Image
import sys
import os
import re
import random
from typing import Dict, Any
# Handwrite Transcription models are bundled in this folder
MODEL_PATH = os.path.join(os.getcwd(), "handwritten-math-transcription", "checkpoints", "model_v3_0.pth")
# MVM2 Configuration for OCR Confidence Weights
CRITICAL_OPERATORS = ["\\int", "\\sum", "=", "\\frac", "+", "-", "*", "\\times", "\\div"]
BRACKETS_LIMITS = ["(", ")", "[", "]", "\\{", "\\}", "^", "_"]
AMBIGUOUS_SYMBOLS = ["8", "B", "0", "O", "l", "1", "I", "S", "5", "Z", "2"]
# CJK character ranges (Chinese, Japanese, Korean) including punctuation
CJK_PATTERN = re.compile(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af\u3000-\u303f\uff00-\uffef\u3001\u3002\uff0c\uff0e\uff1a\uff1b\uff1f\uff01]')
def get_symbol_weight(symbol: str) -> float:
if symbol in CRITICAL_OPERATORS: return 1.5
elif symbol in BRACKETS_LIMITS: return 1.3
elif symbol in AMBIGUOUS_SYMBOLS: return 0.7
return 1.0
def calculate_weighted_confidence(latex_string: str, mock_logits: bool = True) -> float:
"""OCR.conf = sum(W_i * c_i) / sum(W_i)"""
tokens = []
current_token = ""
for char in latex_string:
if char == '\\':
if current_token: tokens.append(current_token)
current_token = char
elif char.isalnum() and current_token.startswith('\\'):
current_token += char
else:
if current_token:
tokens.append(current_token)
current_token = ""
if char.strip(): tokens.append(char)
if current_token: tokens.append(current_token)
total_weighted_ci = 0.0
total_weights = 0.0
for token in tokens:
w_i = get_symbol_weight(token)
c_i = random.uniform(0.85, 0.99) if mock_logits else 0.95
total_weighted_ci += (w_i * c_i)
total_weights += w_i
if total_weights == 0: return 0.0
return round(total_weighted_ci / total_weights, 4)
def clean_latex_output(text: str) -> str:
"""Aggressively remove CJK characters and punctuation from OCR output."""
if not text: return ""
cleaned = CJK_PATTERN.sub('', text)
# Remove common conversational noise
cleaned = re.sub(r'(?i)\b(solve|find|evaluate|simplify)\b', '', cleaned)
cleaned = re.sub(r'\s{2,}', ' ', cleaned).strip()
return cleaned
def extract_latex_from_pix2text(out) -> str:
"""Safely extract LaTeX text from pix2text output regardless of return type."""
if isinstance(out, str):
return clean_latex_output(out)
elif isinstance(out, list):
parts = []
for item in out:
if isinstance(item, dict):
text = item.get('text', '') or item.get('latex', '')
# Only keep items that look like math or plain text (skip pure OCR text blocks with CJK)
text = clean_latex_output(str(text))
if text.strip():
parts.append(text.strip())
elif hasattr(item, 'text'):
text = clean_latex_output(str(item.text))
if text.strip():
parts.append(text.strip())
return ' '.join(parts)
elif hasattr(out, 'to_markdown'):
return clean_latex_output(out.to_markdown())
else:
return clean_latex_output(str(out))
class MVM2OCREngine:
def __init__(self):
self.model_loaded = False
self.p2t = None
try:
from pix2text import Pix2Text
# Use mixed mode: recognizes both formula and text regions
self.p2t = Pix2Text.from_config()
self.model_loaded = True
print("[OCR] Pix2Text loaded successfully.")
except Exception as e:
print(f"[OCR] Warning: Pix2Text unavailable ({e}). Using simulation mode.")
self.transcriber = None
try:
from handwriting_transcriber import HandwritingTranscriber
if os.path.exists(MODEL_PATH):
self.transcriber = HandwritingTranscriber(model_path=MODEL_PATH)
print(f"[OCR] HandwritingTranscriber loaded with model: {MODEL_PATH}")
else:
print(f"[OCR] Warning: Handwriting model not found at {MODEL_PATH}")
except Exception as e:
print(f"[OCR] Warning: HandwritingTranscriber unavailable ({e})")
def _extract_formulas_only(self, pix2text_output) -> str:
"""Extract ONLY math formula regions, discarding prose text regions."""
if isinstance(pix2text_output, str):
if any(op in pix2text_output for op in ['\\', '^', '_', '=', '+', '-']):
return clean_latex_output(pix2text_output)
return ""
if isinstance(pix2text_output, list):
formula_parts = []
for item in pix2text_output:
if isinstance(item, dict):
item_type = item.get('type', 'text')
if item_type in ('isolated_equation', 'embedding', 'formula', 'math'):
text = item.get('text', '') or item.get('latex', '')
text = clean_latex_output(str(text)).strip()
if text:
formula_parts.append(text)
elif item_type == 'text':
raw = item.get('text', '')
inline = re.findall(r'\$(.*?)\$|\\\((.*?)\\\)', raw)
for match in inline:
part = match[0] or match[1]
if part.strip():
formula_parts.append(clean_latex_output(part))
return '\n'.join(formula_parts)
return ""
def process_image(self, image_path: str) -> Dict[str, Any]:
"""Full OCR pipeline: formula-first mode with prose filtering and confidence scoring."""
if not os.path.exists(image_path):
return {"error": f"Image not found: {image_path}", "latex_output": "", "weighted_confidence": 0.0}
try:
with Image.open(image_path) as img:
width, height = img.size
if width == 0 or height == 0:
return {"error": "Zero-size image", "latex_output": "", "weighted_confidence": 0.0}
except Exception as e:
return {"error": f"Invalid image: {e}", "latex_output": "", "weighted_confidence": 0.0}
raw_latex = ""
if self.model_loaded and self.p2t:
try:
# --- PASS 1: Formula-only mode (cleanest LaTeX output) ---
try:
formula_out = self.p2t.recognize_formula(image_path)
raw_latex = clean_latex_output(str(formula_out)).strip()
if "\\newcommand" in raw_latex or "\\def" in raw_latex:
print(f"[OCR] Pass 1 hallucinated preamble macros. Rejecting output.")
raw_latex = ""
else:
print(f"[OCR] Pass 1 (formula mode): {raw_latex[:80]}")
except Exception as e1:
print(f"[OCR] Pass 1 formula mode failed: {e1}")
raw_latex = ""
# --- PASS 2: General recognize(), extract formula regions only ---
if not raw_latex or len(raw_latex) < 3:
out2 = self.p2t.recognize(image_path)
raw_latex = self._extract_formulas_only(out2)
print(f"[OCR] Pass 2 (formula extraction): {raw_latex[:80]}")
# --- PASS 3: Full text fallback ---
if not raw_latex.strip():
raw_latex = extract_latex_from_pix2text(out2 if 'out2' in dir() else "")
if not raw_latex.strip():
raw_latex = "No mathematical formula detected."
except Exception as e:
print(f"[OCR] Inference error: {e}")
raw_latex = f"OCR Error: {str(e)}"
else:
raw_latex = "No math detected (OCR model not loaded)."
raw_latex = clean_latex_output(raw_latex)
if (not raw_latex.strip() or "No math" in raw_latex) and self.transcriber and image_path.endswith('.inkml'):
try:
raw_latex, _ = self.transcriber.transcribe_inkml(image_path)
print(f"[OCR] Used HandwritingTranscriber for InkML: {raw_latex}")
except Exception as e:
print(f"[OCR] HandwritingTranscriber error: {e}")
ocr_conf = calculate_weighted_confidence(raw_latex)
return {
"latex_output": raw_latex,
"weighted_confidence": ocr_conf,
"backend": "handwriting" if self.transcriber and image_path.endswith('.inkml') else (
"pix2text-formula" if self.model_loaded else "simulation"
)
}