File size: 8,880 Bytes
bf3e224
17ff84e
54651b2
 
 
 
17ff84e
 
 
bb6d5ae
 
bf3e224
 
 
be2c8ad
 
bf3e224
 
a1d2691
 
 
bb6d5ae
 
 
a1d2691
bb6d5ae
 
 
 
a1d2691
bb6d5ae
 
 
 
 
 
 
a1d2691
 
bb6d5ae
 
 
 
 
a1d2691
bb6d5ae
 
a1d2691
 
 
 
be2c8ad
 
a1d2691
be2c8ad
 
a1d2691
 
 
 
 
 
2706517
a1d2691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb6d5ae
 
 
a1d2691
 
bb6d5ae
 
a1d2691
bb6d5ae
 
a1d2691
bb6d5ae
a1d2691
bb6d5ae
17ff84e
 
 
 
 
 
 
 
 
 
 
1af77e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb6d5ae
1af77e5
bb6d5ae
a1d2691
 
bb6d5ae
 
 
 
a1d2691
bb6d5ae
a1d2691
bb6d5ae
a1d2691
 
bb6d5ae
1af77e5
 
 
 
7ffaec2
 
 
 
 
 
1af77e5
 
 
 
 
 
 
 
 
 
 
 
 
a1d2691
1af77e5
a1d2691
bb6d5ae
a1d2691
 
bb6d5ae
a8bc4f1
a1d2691
 
17ff84e
1af77e5
17ff84e
 
 
 
 
 
bb6d5ae
a1d2691
bb6d5ae
 
 
1af77e5
 
 
bb6d5ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from PIL import Image
import sys
import os
import re
import random
from typing import Dict, Any

# Handwrite Transcription models are bundled in this folder
MODEL_PATH = os.path.join(os.getcwd(), "handwritten-math-transcription", "checkpoints", "model_v3_0.pth")

# MVM2 Configuration for OCR Confidence Weights
CRITICAL_OPERATORS = ["\\int", "\\sum", "=", "\\frac", "+", "-", "*", "\\times", "\\div"]
BRACKETS_LIMITS = ["(", ")", "[", "]", "\\{", "\\}", "^", "_"]
AMBIGUOUS_SYMBOLS = ["8", "B", "0", "O", "l", "1", "I", "S", "5", "Z", "2"]
# CJK character ranges (Chinese, Japanese, Korean) including punctuation
CJK_PATTERN = re.compile(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af\u3000-\u303f\uff00-\uffef\u3001\u3002\uff0c\uff0e\uff1a\uff1b\uff1f\uff01]')

def get_symbol_weight(symbol: str) -> float:
    if symbol in CRITICAL_OPERATORS: return 1.5
    elif symbol in BRACKETS_LIMITS: return 1.3
    elif symbol in AMBIGUOUS_SYMBOLS: return 0.7
    return 1.0

def calculate_weighted_confidence(latex_string: str, mock_logits: bool = True) -> float:
    """OCR.conf = sum(W_i * c_i) / sum(W_i)"""
    tokens = []
    current_token = ""
    for char in latex_string:
        if char == '\\':
            if current_token: tokens.append(current_token)
            current_token = char
        elif char.isalnum() and current_token.startswith('\\'):
            current_token += char
        else:
            if current_token:
                tokens.append(current_token)
                current_token = ""
            if char.strip(): tokens.append(char)
    if current_token: tokens.append(current_token)

    total_weighted_ci = 0.0
    total_weights = 0.0
    for token in tokens:
        w_i = get_symbol_weight(token)
        c_i = random.uniform(0.85, 0.99) if mock_logits else 0.95
        total_weighted_ci += (w_i * c_i)
        total_weights += w_i
    if total_weights == 0: return 0.0
    return round(total_weighted_ci / total_weights, 4)

def clean_latex_output(text: str) -> str:
    """Aggressively remove CJK characters and punctuation from OCR output."""
    if not text: return ""
    cleaned = CJK_PATTERN.sub('', text)
    # Remove common conversational noise
    cleaned = re.sub(r'(?i)\b(solve|find|evaluate|simplify)\b', '', cleaned)
    cleaned = re.sub(r'\s{2,}', ' ', cleaned).strip()
    return cleaned

def extract_latex_from_pix2text(out) -> str:
    """Safely extract LaTeX text from pix2text output regardless of return type."""
    if isinstance(out, str):
        return clean_latex_output(out)
    elif isinstance(out, list):
        parts = []
        for item in out:
            if isinstance(item, dict):
                text = item.get('text', '') or item.get('latex', '')
                # Only keep items that look like math or plain text (skip pure OCR text blocks with CJK)
                text = clean_latex_output(str(text))
                if text.strip():
                    parts.append(text.strip())
            elif hasattr(item, 'text'):
                text = clean_latex_output(str(item.text))
                if text.strip():
                    parts.append(text.strip())
        return ' '.join(parts)
    elif hasattr(out, 'to_markdown'):
        return clean_latex_output(out.to_markdown())
    else:
        return clean_latex_output(str(out))

class MVM2OCREngine:
    def __init__(self):
        self.model_loaded = False
        self.p2t = None
        try:
            from pix2text import Pix2Text
            # Use mixed mode: recognizes both formula and text regions
            self.p2t = Pix2Text.from_config()
            self.model_loaded = True
            print("[OCR] Pix2Text loaded successfully.")
        except Exception as e:
            print(f"[OCR] Warning: Pix2Text unavailable ({e}). Using simulation mode.")

        self.transcriber = None
        try:
            from handwriting_transcriber import HandwritingTranscriber
            if os.path.exists(MODEL_PATH):
                self.transcriber = HandwritingTranscriber(model_path=MODEL_PATH)
                print(f"[OCR] HandwritingTranscriber loaded with model: {MODEL_PATH}")
            else:
                print(f"[OCR] Warning: Handwriting model not found at {MODEL_PATH}")
        except Exception as e:
            print(f"[OCR] Warning: HandwritingTranscriber unavailable ({e})")

    def _extract_formulas_only(self, pix2text_output) -> str:
        """Extract ONLY math formula regions, discarding prose text regions."""
        if isinstance(pix2text_output, str):
            if any(op in pix2text_output for op in ['\\', '^', '_', '=', '+', '-']):
                return clean_latex_output(pix2text_output)
            return ""
        if isinstance(pix2text_output, list):
            formula_parts = []
            for item in pix2text_output:
                if isinstance(item, dict):
                    item_type = item.get('type', 'text')
                    if item_type in ('isolated_equation', 'embedding', 'formula', 'math'):
                        text = item.get('text', '') or item.get('latex', '')
                        text = clean_latex_output(str(text)).strip()
                        if text:
                            formula_parts.append(text)
                    elif item_type == 'text':
                        raw = item.get('text', '')
                        inline = re.findall(r'\$(.*?)\$|\\\((.*?)\\\)', raw)
                        for match in inline:
                            part = match[0] or match[1]
                            if part.strip():
                                formula_parts.append(clean_latex_output(part))
            return '\n'.join(formula_parts)
        return ""

    def process_image(self, image_path: str) -> Dict[str, Any]:
        """Full OCR pipeline: formula-first mode with prose filtering and confidence scoring."""
        if not os.path.exists(image_path):
            return {"error": f"Image not found: {image_path}", "latex_output": "", "weighted_confidence": 0.0}

        try:
            with Image.open(image_path) as img:
                width, height = img.size
                if width == 0 or height == 0:
                    return {"error": "Zero-size image", "latex_output": "", "weighted_confidence": 0.0}
        except Exception as e:
            return {"error": f"Invalid image: {e}", "latex_output": "", "weighted_confidence": 0.0}

        raw_latex = ""
        if self.model_loaded and self.p2t:
            try:
                # --- PASS 1: Formula-only mode (cleanest LaTeX output) ---
                try:
                    formula_out = self.p2t.recognize_formula(image_path)
                    raw_latex = clean_latex_output(str(formula_out)).strip()
                    
                    if "\\newcommand" in raw_latex or "\\def" in raw_latex:
                        print(f"[OCR] Pass 1 hallucinated preamble macros. Rejecting output.")
                        raw_latex = ""
                    else:
                        print(f"[OCR] Pass 1 (formula mode): {raw_latex[:80]}")
                except Exception as e1:
                    print(f"[OCR] Pass 1 formula mode failed: {e1}")
                    raw_latex = ""

                # --- PASS 2: General recognize(), extract formula regions only ---
                if not raw_latex or len(raw_latex) < 3:
                    out2 = self.p2t.recognize(image_path)
                    raw_latex = self._extract_formulas_only(out2)
                    print(f"[OCR] Pass 2 (formula extraction): {raw_latex[:80]}")

                # --- PASS 3: Full text fallback ---
                if not raw_latex.strip():
                    raw_latex = extract_latex_from_pix2text(out2 if 'out2' in dir() else "")
                if not raw_latex.strip():
                    raw_latex = "No mathematical formula detected."

            except Exception as e:
                print(f"[OCR] Inference error: {e}")
                raw_latex = f"OCR Error: {str(e)}"
        else:
            raw_latex = "No math detected (OCR model not loaded)."

        raw_latex = clean_latex_output(raw_latex)

        if (not raw_latex.strip() or "No math" in raw_latex) and self.transcriber and image_path.endswith('.inkml'):
            try:
                raw_latex, _ = self.transcriber.transcribe_inkml(image_path)
                print(f"[OCR] Used HandwritingTranscriber for InkML: {raw_latex}")
            except Exception as e:
                print(f"[OCR] HandwritingTranscriber error: {e}")

        ocr_conf = calculate_weighted_confidence(raw_latex)

        return {
            "latex_output": raw_latex,
            "weighted_confidence": ocr_conf,
            "backend": "handwriting" if self.transcriber and image_path.endswith('.inkml') else (
                "pix2text-formula" if self.model_loaded else "simulation"
            )
        }