File size: 9,075 Bytes
bf64b03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""

EquationModule: Specialized processing for mathematical equations and LaTeX.

Detects equation spans, applies equation-specific attention, and learns

structural representations of mathematical expressions.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import re
from typing import Optional, Tuple, List


class EquationModule(nn.Module):
    """

    Specialized processing for mathematical equations and LaTeX.

    - Detects equation spans in input (between $ $ or \[ \] delimiters)

    - Applies equation-specific attention patterns within equation spans

    - Learns structural representations of mathematical expressions

    - Tree-aware: understands operator precedence and nesting

    """

    def __init__(self, d_model: int, num_heads: int = 8):
        """

        Initialize EquationModule.



        Args:

            d_model: Model dimension

            num_heads: Number of heads for equation-specific attention

        """
        super().__init__()
        self.d_model = d_model

        # Equation span detector (lightweight linear classifier)
        self.span_detector = nn.Linear(d_model, 1)

        # Equation-specific transformer (shallow, 2 layers)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_model * 4,
            activation=F.silu,
            batch_first=True,
            dropout=0.1,
        )
        self.equation_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)

        # Merge equation representations back into main stream
        self.merge = nn.Linear(d_model * 2, d_model)

        # LaTeX structure awareness (simple positional encoding for tree depth)
        self.depth_embedding = nn.Embedding(10, d_model)  # Max depth 10

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize weights."""
        for module in [self.span_detector, self.merge, self.depth_embedding]:
            if hasattr(module, 'weight'):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if hasattr(module, 'bias') and module.bias is not None:
                nn.init.zeros_(module.bias)

    def detect_equation_spans(

        self,

        text: str,

        token_ids: Optional[torch.Tensor] = None,

    ) -> List[Tuple[int, int]]:
        """

        Detect equation spans in text using delimiters.

        Supports: $...$, $$...$$, \[...\], \(...\)



        Args:

            text: Input text string

            token_ids: Optional token IDs for alignment



        Returns:

            List of (start_char, end_char) spans

        """
        spans = []

        # Pattern 1: $...$ (inline math)
        for match in re.finditer(r'\$(.+?)\$', text, re.DOTALL):
            spans.append((match.start(), match.end()))

        # Pattern 2: $$...$$ (display math)
        for match in re.finditer(r'\$\$(.+?)\$\$', text, re.DOTALL):
            spans.append((match.start(), match.end()))

        # Pattern 3: \[...\] (LaTeX display math)
        for match in re.finditer(r'\\\[(.+?)\\\]', text, re.DOTALL):
            spans.append((match.start(), match.end()))

        # Pattern 4: \(...\) (LaTeX inline math)
        for match in re.finditer(r'\\\((.+?)\\\)', text, re.DOTALL):
            spans.append((match.start(), match.end()))

        return spans

    def forward(

        self,

        x: torch.Tensor,

        text: Optional[List[str]] = None,

        token_spans: Optional[List[List[Tuple[int, int]]]] = None,

    ) -> torch.Tensor:
        """

        Forward pass through the equation module.



        Args:

            x: Input tensor (batch, seq_len, d_model)

            text: Optional original text strings (for delimiter-based detection)

            token_spans: Optional pre-computed token-level equation spans

                         Each element: list of (start_token, end_token) for that batch item



        Returns:

            Equation-enhanced representation (batch, seq_len, d_model)

        """
        batch, seq_len, d_model = x.shape

        # Detect equation spans
        if token_spans is None and text is not None:
            # Use delimiter-based detection (requires text)
            token_spans = []
            for b in range(batch):
                char_spans = self.detect_equation_spans(text[b])
                # Convert char spans to token spans (simplified - assumes 1 char ≈ 1 token)
                # In practice, would need proper tokenization alignment
                token_spans_b = []
                for start_char, end_char in char_spans:
                    # Rough approximation: divide by average chars per token (~4)
                    start_token = max(0, start_char // 4)
                    end_token = min(seq_len, end_char // 4 + 1)
                    token_spans_b.append((start_token, end_token))
                token_spans.append(token_spans_b)
        elif token_spans is None:
            # Fallback: use learned detector
            token_spans = self._learned_span_detection(x)

        # Process each batch item
        output = x.clone()

        for b in range(batch):
            spans_b = token_spans[b] if b < len(token_spans) else []

            for start_tok, end_tok in spans_b:
                if end_tok <= start_tok:
                    continue

                # Extract equation segment
                eq_segment = x[b:b+1, start_tok:end_tok, :]  # (1, seg_len, d_model)

                # Apply equation-specific transformer
                eq_encoded = self.equation_encoder(eq_segment)

                # Merge with original
                merged = torch.cat([eq_segment, eq_encoded], dim=-1)
                merged = self.merge(merged)

                # Place back in output
                output[b:b+1, start_tok:end_tok, :] = merged

        return output

    def _learned_span_detection(

        self,

        x: torch.Tensor,

    ) -> List[List[Tuple[int, int]]]:
        """

        Use learned detector to find equation spans when delimiters missing.

        Simple thresholding on span_detector output.



        Args:

            x: Input tensor (batch, seq_len, d_model)



        Returns:

            List of token spans per batch item

        """
        batch, seq_len, _ = x.shape

        # Compute equation probability per token
        eq_probs = torch.sigmoid(self.span_detector(x))  # (batch, seq_len, 1)
        eq_probs = eq_probs.squeeze(-1)  # (batch, seq_len)

        # Threshold
        threshold = 0.5
        spans = []

        for b in range(batch):
            probs = eq_probs[b]
            is_equation = (probs > threshold).cpu().numpy()

            # Find contiguous spans
            span_list = []
            in_span = False
            start = 0

            for t in range(seq_len):
                if is_equation[t] and not in_span:
                    start = t
                    in_span = True
                elif not is_equation[t] and in_span:
                    span_list.append((start, t))
                    in_span = False

            if in_span:
                span_list.append((start, seq_len))

            spans.append(span_list)

        return spans

    def compute_equation_loss(

        self,

        x: torch.Tensor,

        equation_mask: torch.Tensor,

    ) -> torch.Tensor:
        """

        Compute auxiliary loss for equation detection training.



        Args:

            x: Input tensor (batch, seq_len, d_model)

            equation_mask: Ground truth equation mask (batch, seq_len), 1 if token is in equation



        Returns:

            Binary cross-entropy loss for equation detection

        """
        logits = self.span_detector(x).squeeze(-1)  # (batch, seq_len)
        loss = F.binary_cross_entropy_with_logits(
            logits,
            equation_mask.float(),
        )
        return loss


def test_equation_module():
    """Test EquationModule."""
    d_model = 512
    batch_size = 2
    seq_len = 128

    module = EquationModule(d_model)

    x = torch.randn(batch_size, seq_len, d_model)
    text = [
        "The energy is $E = mc^2$ and momentum is $p = mv$.",
        "Equation: \[ F = ma \] and also $a^2 + b^2 = c^2$."
    ]

    output = module(x, text=text)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    assert output.shape == x.shape

    # Test equation loss
    equation_mask = torch.zeros(batch_size, seq_len)
    equation_mask[0, 10:15] = 1.0  # Simulate equation span
    equation_mask[1, 5:12] = 1.0
    loss = module.compute_equation_loss(x, equation_mask)
    print(f"Equation loss: {loss.item():.4f}")

    print("EquationModule test passed!")


if __name__ == "__main__":
    test_equation_module()