""" EquationModule: Specialized processing for mathematical equations and LaTeX. Detects equation spans, applies equation-specific attention, and learns structural representations of mathematical expressions. """ import torch import torch.nn as nn import torch.nn.functional as F import re from typing import Optional, Tuple, List class EquationModule(nn.Module): """ Specialized processing for mathematical equations and LaTeX. - Detects equation spans in input (between $ $ or \[ \] delimiters) - Applies equation-specific attention patterns within equation spans - Learns structural representations of mathematical expressions - Tree-aware: understands operator precedence and nesting """ def __init__(self, d_model: int, num_heads: int = 8): """ Initialize EquationModule. Args: d_model: Model dimension num_heads: Number of heads for equation-specific attention """ super().__init__() self.d_model = d_model # Equation span detector (lightweight linear classifier) self.span_detector = nn.Linear(d_model, 1) # Equation-specific transformer (shallow, 2 layers) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=num_heads, dim_feedforward=d_model * 4, activation=F.silu, batch_first=True, dropout=0.1, ) self.equation_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2) # Merge equation representations back into main stream self.merge = nn.Linear(d_model * 2, d_model) # LaTeX structure awareness (simple positional encoding for tree depth) self.depth_embedding = nn.Embedding(10, d_model) # Max depth 10 # Initialize weights self._initialize_weights() def _initialize_weights(self): """Initialize weights.""" for module in [self.span_detector, self.merge, self.depth_embedding]: if hasattr(module, 'weight'): nn.init.normal_(module.weight, mean=0.0, std=0.02) if hasattr(module, 'bias') and module.bias is not None: nn.init.zeros_(module.bias) def detect_equation_spans( self, text: str, token_ids: Optional[torch.Tensor] = None, ) -> List[Tuple[int, int]]: """ Detect equation spans in text using delimiters. Supports: $...$, $$...$$, \[...\], \(...\) Args: text: Input text string token_ids: Optional token IDs for alignment Returns: List of (start_char, end_char) spans """ spans = [] # Pattern 1: $...$ (inline math) for match in re.finditer(r'\$(.+?)\$', text, re.DOTALL): spans.append((match.start(), match.end())) # Pattern 2: $$...$$ (display math) for match in re.finditer(r'\$\$(.+?)\$\$', text, re.DOTALL): spans.append((match.start(), match.end())) # Pattern 3: \[...\] (LaTeX display math) for match in re.finditer(r'\\\[(.+?)\\\]', text, re.DOTALL): spans.append((match.start(), match.end())) # Pattern 4: \(...\) (LaTeX inline math) for match in re.finditer(r'\\\((.+?)\\\)', text, re.DOTALL): spans.append((match.start(), match.end())) return spans def forward( self, x: torch.Tensor, text: Optional[List[str]] = None, token_spans: Optional[List[List[Tuple[int, int]]]] = None, ) -> torch.Tensor: """ Forward pass through the equation module. Args: x: Input tensor (batch, seq_len, d_model) text: Optional original text strings (for delimiter-based detection) token_spans: Optional pre-computed token-level equation spans Each element: list of (start_token, end_token) for that batch item Returns: Equation-enhanced representation (batch, seq_len, d_model) """ batch, seq_len, d_model = x.shape # Detect equation spans if token_spans is None and text is not None: # Use delimiter-based detection (requires text) token_spans = [] for b in range(batch): char_spans = self.detect_equation_spans(text[b]) # Convert char spans to token spans (simplified - assumes 1 char ≈ 1 token) # In practice, would need proper tokenization alignment token_spans_b = [] for start_char, end_char in char_spans: # Rough approximation: divide by average chars per token (~4) start_token = max(0, start_char // 4) end_token = min(seq_len, end_char // 4 + 1) token_spans_b.append((start_token, end_token)) token_spans.append(token_spans_b) elif token_spans is None: # Fallback: use learned detector token_spans = self._learned_span_detection(x) # Process each batch item output = x.clone() for b in range(batch): spans_b = token_spans[b] if b < len(token_spans) else [] for start_tok, end_tok in spans_b: if end_tok <= start_tok: continue # Extract equation segment eq_segment = x[b:b+1, start_tok:end_tok, :] # (1, seg_len, d_model) # Apply equation-specific transformer eq_encoded = self.equation_encoder(eq_segment) # Merge with original merged = torch.cat([eq_segment, eq_encoded], dim=-1) merged = self.merge(merged) # Place back in output output[b:b+1, start_tok:end_tok, :] = merged return output def _learned_span_detection( self, x: torch.Tensor, ) -> List[List[Tuple[int, int]]]: """ Use learned detector to find equation spans when delimiters missing. Simple thresholding on span_detector output. Args: x: Input tensor (batch, seq_len, d_model) Returns: List of token spans per batch item """ batch, seq_len, _ = x.shape # Compute equation probability per token eq_probs = torch.sigmoid(self.span_detector(x)) # (batch, seq_len, 1) eq_probs = eq_probs.squeeze(-1) # (batch, seq_len) # Threshold threshold = 0.5 spans = [] for b in range(batch): probs = eq_probs[b] is_equation = (probs > threshold).cpu().numpy() # Find contiguous spans span_list = [] in_span = False start = 0 for t in range(seq_len): if is_equation[t] and not in_span: start = t in_span = True elif not is_equation[t] and in_span: span_list.append((start, t)) in_span = False if in_span: span_list.append((start, seq_len)) spans.append(span_list) return spans def compute_equation_loss( self, x: torch.Tensor, equation_mask: torch.Tensor, ) -> torch.Tensor: """ Compute auxiliary loss for equation detection training. Args: x: Input tensor (batch, seq_len, d_model) equation_mask: Ground truth equation mask (batch, seq_len), 1 if token is in equation Returns: Binary cross-entropy loss for equation detection """ logits = self.span_detector(x).squeeze(-1) # (batch, seq_len) loss = F.binary_cross_entropy_with_logits( logits, equation_mask.float(), ) return loss def test_equation_module(): """Test EquationModule.""" d_model = 512 batch_size = 2 seq_len = 128 module = EquationModule(d_model) x = torch.randn(batch_size, seq_len, d_model) text = [ "The energy is $E = mc^2$ and momentum is $p = mv$.", "Equation: \[ F = ma \] and also $a^2 + b^2 = c^2$." ] output = module(x, text=text) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") assert output.shape == x.shape # Test equation loss equation_mask = torch.zeros(batch_size, seq_len) equation_mask[0, 10:15] = 1.0 # Simulate equation span equation_mask[1, 5:12] = 1.0 loss = module.compute_equation_loss(x, equation_mask) print(f"Equation loss: {loss.item():.4f}") print("EquationModule test passed!") if __name__ == "__main__": test_equation_module()