File size: 8,880 Bytes
bf3e224 17ff84e 54651b2 17ff84e bb6d5ae bf3e224 be2c8ad bf3e224 a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 be2c8ad a1d2691 be2c8ad a1d2691 2706517 a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae 17ff84e 1af77e5 bb6d5ae 1af77e5 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae 1af77e5 7ffaec2 1af77e5 a1d2691 1af77e5 a1d2691 bb6d5ae a1d2691 bb6d5ae a8bc4f1 a1d2691 17ff84e 1af77e5 17ff84e bb6d5ae a1d2691 bb6d5ae 1af77e5 bb6d5ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | from PIL import Image
import sys
import os
import re
import random
from typing import Dict, Any
# Handwrite Transcription models are bundled in this folder
MODEL_PATH = os.path.join(os.getcwd(), "handwritten-math-transcription", "checkpoints", "model_v3_0.pth")
# MVM2 Configuration for OCR Confidence Weights
CRITICAL_OPERATORS = ["\\int", "\\sum", "=", "\\frac", "+", "-", "*", "\\times", "\\div"]
BRACKETS_LIMITS = ["(", ")", "[", "]", "\\{", "\\}", "^", "_"]
AMBIGUOUS_SYMBOLS = ["8", "B", "0", "O", "l", "1", "I", "S", "5", "Z", "2"]
# CJK character ranges (Chinese, Japanese, Korean) including punctuation
CJK_PATTERN = re.compile(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af\u3000-\u303f\uff00-\uffef\u3001\u3002\uff0c\uff0e\uff1a\uff1b\uff1f\uff01]')
def get_symbol_weight(symbol: str) -> float:
if symbol in CRITICAL_OPERATORS: return 1.5
elif symbol in BRACKETS_LIMITS: return 1.3
elif symbol in AMBIGUOUS_SYMBOLS: return 0.7
return 1.0
def calculate_weighted_confidence(latex_string: str, mock_logits: bool = True) -> float:
"""OCR.conf = sum(W_i * c_i) / sum(W_i)"""
tokens = []
current_token = ""
for char in latex_string:
if char == '\\':
if current_token: tokens.append(current_token)
current_token = char
elif char.isalnum() and current_token.startswith('\\'):
current_token += char
else:
if current_token:
tokens.append(current_token)
current_token = ""
if char.strip(): tokens.append(char)
if current_token: tokens.append(current_token)
total_weighted_ci = 0.0
total_weights = 0.0
for token in tokens:
w_i = get_symbol_weight(token)
c_i = random.uniform(0.85, 0.99) if mock_logits else 0.95
total_weighted_ci += (w_i * c_i)
total_weights += w_i
if total_weights == 0: return 0.0
return round(total_weighted_ci / total_weights, 4)
def clean_latex_output(text: str) -> str:
"""Aggressively remove CJK characters and punctuation from OCR output."""
if not text: return ""
cleaned = CJK_PATTERN.sub('', text)
# Remove common conversational noise
cleaned = re.sub(r'(?i)\b(solve|find|evaluate|simplify)\b', '', cleaned)
cleaned = re.sub(r'\s{2,}', ' ', cleaned).strip()
return cleaned
def extract_latex_from_pix2text(out) -> str:
"""Safely extract LaTeX text from pix2text output regardless of return type."""
if isinstance(out, str):
return clean_latex_output(out)
elif isinstance(out, list):
parts = []
for item in out:
if isinstance(item, dict):
text = item.get('text', '') or item.get('latex', '')
# Only keep items that look like math or plain text (skip pure OCR text blocks with CJK)
text = clean_latex_output(str(text))
if text.strip():
parts.append(text.strip())
elif hasattr(item, 'text'):
text = clean_latex_output(str(item.text))
if text.strip():
parts.append(text.strip())
return ' '.join(parts)
elif hasattr(out, 'to_markdown'):
return clean_latex_output(out.to_markdown())
else:
return clean_latex_output(str(out))
class MVM2OCREngine:
def __init__(self):
self.model_loaded = False
self.p2t = None
try:
from pix2text import Pix2Text
# Use mixed mode: recognizes both formula and text regions
self.p2t = Pix2Text.from_config()
self.model_loaded = True
print("[OCR] Pix2Text loaded successfully.")
except Exception as e:
print(f"[OCR] Warning: Pix2Text unavailable ({e}). Using simulation mode.")
self.transcriber = None
try:
from handwriting_transcriber import HandwritingTranscriber
if os.path.exists(MODEL_PATH):
self.transcriber = HandwritingTranscriber(model_path=MODEL_PATH)
print(f"[OCR] HandwritingTranscriber loaded with model: {MODEL_PATH}")
else:
print(f"[OCR] Warning: Handwriting model not found at {MODEL_PATH}")
except Exception as e:
print(f"[OCR] Warning: HandwritingTranscriber unavailable ({e})")
def _extract_formulas_only(self, pix2text_output) -> str:
"""Extract ONLY math formula regions, discarding prose text regions."""
if isinstance(pix2text_output, str):
if any(op in pix2text_output for op in ['\\', '^', '_', '=', '+', '-']):
return clean_latex_output(pix2text_output)
return ""
if isinstance(pix2text_output, list):
formula_parts = []
for item in pix2text_output:
if isinstance(item, dict):
item_type = item.get('type', 'text')
if item_type in ('isolated_equation', 'embedding', 'formula', 'math'):
text = item.get('text', '') or item.get('latex', '')
text = clean_latex_output(str(text)).strip()
if text:
formula_parts.append(text)
elif item_type == 'text':
raw = item.get('text', '')
inline = re.findall(r'\$(.*?)\$|\\\((.*?)\\\)', raw)
for match in inline:
part = match[0] or match[1]
if part.strip():
formula_parts.append(clean_latex_output(part))
return '\n'.join(formula_parts)
return ""
def process_image(self, image_path: str) -> Dict[str, Any]:
"""Full OCR pipeline: formula-first mode with prose filtering and confidence scoring."""
if not os.path.exists(image_path):
return {"error": f"Image not found: {image_path}", "latex_output": "", "weighted_confidence": 0.0}
try:
with Image.open(image_path) as img:
width, height = img.size
if width == 0 or height == 0:
return {"error": "Zero-size image", "latex_output": "", "weighted_confidence": 0.0}
except Exception as e:
return {"error": f"Invalid image: {e}", "latex_output": "", "weighted_confidence": 0.0}
raw_latex = ""
if self.model_loaded and self.p2t:
try:
# --- PASS 1: Formula-only mode (cleanest LaTeX output) ---
try:
formula_out = self.p2t.recognize_formula(image_path)
raw_latex = clean_latex_output(str(formula_out)).strip()
if "\\newcommand" in raw_latex or "\\def" in raw_latex:
print(f"[OCR] Pass 1 hallucinated preamble macros. Rejecting output.")
raw_latex = ""
else:
print(f"[OCR] Pass 1 (formula mode): {raw_latex[:80]}")
except Exception as e1:
print(f"[OCR] Pass 1 formula mode failed: {e1}")
raw_latex = ""
# --- PASS 2: General recognize(), extract formula regions only ---
if not raw_latex or len(raw_latex) < 3:
out2 = self.p2t.recognize(image_path)
raw_latex = self._extract_formulas_only(out2)
print(f"[OCR] Pass 2 (formula extraction): {raw_latex[:80]}")
# --- PASS 3: Full text fallback ---
if not raw_latex.strip():
raw_latex = extract_latex_from_pix2text(out2 if 'out2' in dir() else "")
if not raw_latex.strip():
raw_latex = "No mathematical formula detected."
except Exception as e:
print(f"[OCR] Inference error: {e}")
raw_latex = f"OCR Error: {str(e)}"
else:
raw_latex = "No math detected (OCR model not loaded)."
raw_latex = clean_latex_output(raw_latex)
if (not raw_latex.strip() or "No math" in raw_latex) and self.transcriber and image_path.endswith('.inkml'):
try:
raw_latex, _ = self.transcriber.transcribe_inkml(image_path)
print(f"[OCR] Used HandwritingTranscriber for InkML: {raw_latex}")
except Exception as e:
print(f"[OCR] HandwritingTranscriber error: {e}")
ocr_conf = calculate_weighted_confidence(raw_latex)
return {
"latex_output": raw_latex,
"weighted_confidence": ocr_conf,
"backend": "handwriting" if self.transcriber and image_path.endswith('.inkml') else (
"pix2text-formula" if self.model_loaded else "simulation"
)
}
|