| | """
|
| | EquationModule: Specialized processing for mathematical equations and LaTeX.
|
| | Detects equation spans, applies equation-specific attention, and learns
|
| | structural representations of mathematical expressions.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | import re
|
| | from typing import Optional, Tuple, List
|
| |
|
| |
|
| | class EquationModule(nn.Module):
|
| | """
|
| | Specialized processing for mathematical equations and LaTeX.
|
| | - Detects equation spans in input (between $ $ or \[ \] delimiters)
|
| | - Applies equation-specific attention patterns within equation spans
|
| | - Learns structural representations of mathematical expressions
|
| | - Tree-aware: understands operator precedence and nesting
|
| | """
|
| |
|
| | def __init__(self, d_model: int, num_heads: int = 8):
|
| | """
|
| | Initialize EquationModule.
|
| |
|
| | Args:
|
| | d_model: Model dimension
|
| | num_heads: Number of heads for equation-specific attention
|
| | """
|
| | super().__init__()
|
| | self.d_model = d_model
|
| |
|
| |
|
| | self.span_detector = nn.Linear(d_model, 1)
|
| |
|
| |
|
| | encoder_layer = nn.TransformerEncoderLayer(
|
| | d_model=d_model,
|
| | nhead=num_heads,
|
| | dim_feedforward=d_model * 4,
|
| | activation=F.silu,
|
| | batch_first=True,
|
| | dropout=0.1,
|
| | )
|
| | self.equation_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
| |
|
| |
|
| | self.merge = nn.Linear(d_model * 2, d_model)
|
| |
|
| |
|
| | self.depth_embedding = nn.Embedding(10, d_model)
|
| |
|
| |
|
| | self._initialize_weights()
|
| |
|
| | def _initialize_weights(self):
|
| | """Initialize weights."""
|
| | for module in [self.span_detector, self.merge, self.depth_embedding]:
|
| | if hasattr(module, 'weight'):
|
| | nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| | if hasattr(module, 'bias') and module.bias is not None:
|
| | nn.init.zeros_(module.bias)
|
| |
|
| | def detect_equation_spans(
|
| | self,
|
| | text: str,
|
| | token_ids: Optional[torch.Tensor] = None,
|
| | ) -> List[Tuple[int, int]]:
|
| | """
|
| | Detect equation spans in text using delimiters.
|
| | Supports: $...$, $$...$$, \[...\], \(...\)
|
| |
|
| | Args:
|
| | text: Input text string
|
| | token_ids: Optional token IDs for alignment
|
| |
|
| | Returns:
|
| | List of (start_char, end_char) spans
|
| | """
|
| | spans = []
|
| |
|
| |
|
| | for match in re.finditer(r'\$(.+?)\$', text, re.DOTALL):
|
| | spans.append((match.start(), match.end()))
|
| |
|
| |
|
| | for match in re.finditer(r'\$\$(.+?)\$\$', text, re.DOTALL):
|
| | spans.append((match.start(), match.end()))
|
| |
|
| |
|
| | for match in re.finditer(r'\\\[(.+?)\\\]', text, re.DOTALL):
|
| | spans.append((match.start(), match.end()))
|
| |
|
| |
|
| | for match in re.finditer(r'\\\((.+?)\\\)', text, re.DOTALL):
|
| | spans.append((match.start(), match.end()))
|
| |
|
| | return spans
|
| |
|
| | def forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | text: Optional[List[str]] = None,
|
| | token_spans: Optional[List[List[Tuple[int, int]]]] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Forward pass through the equation module.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| | text: Optional original text strings (for delimiter-based detection)
|
| | token_spans: Optional pre-computed token-level equation spans
|
| | Each element: list of (start_token, end_token) for that batch item
|
| |
|
| | Returns:
|
| | Equation-enhanced representation (batch, seq_len, d_model)
|
| | """
|
| | batch, seq_len, d_model = x.shape
|
| |
|
| |
|
| | if token_spans is None and text is not None:
|
| |
|
| | token_spans = []
|
| | for b in range(batch):
|
| | char_spans = self.detect_equation_spans(text[b])
|
| |
|
| |
|
| | token_spans_b = []
|
| | for start_char, end_char in char_spans:
|
| |
|
| | start_token = max(0, start_char // 4)
|
| | end_token = min(seq_len, end_char // 4 + 1)
|
| | token_spans_b.append((start_token, end_token))
|
| | token_spans.append(token_spans_b)
|
| | elif token_spans is None:
|
| |
|
| | token_spans = self._learned_span_detection(x)
|
| |
|
| |
|
| | output = x.clone()
|
| |
|
| | for b in range(batch):
|
| | spans_b = token_spans[b] if b < len(token_spans) else []
|
| |
|
| | for start_tok, end_tok in spans_b:
|
| | if end_tok <= start_tok:
|
| | continue
|
| |
|
| |
|
| | eq_segment = x[b:b+1, start_tok:end_tok, :]
|
| |
|
| |
|
| | eq_encoded = self.equation_encoder(eq_segment)
|
| |
|
| |
|
| | merged = torch.cat([eq_segment, eq_encoded], dim=-1)
|
| | merged = self.merge(merged)
|
| |
|
| |
|
| | output[b:b+1, start_tok:end_tok, :] = merged
|
| |
|
| | return output
|
| |
|
| | def _learned_span_detection(
|
| | self,
|
| | x: torch.Tensor,
|
| | ) -> List[List[Tuple[int, int]]]:
|
| | """
|
| | Use learned detector to find equation spans when delimiters missing.
|
| | Simple thresholding on span_detector output.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| |
|
| | Returns:
|
| | List of token spans per batch item
|
| | """
|
| | batch, seq_len, _ = x.shape
|
| |
|
| |
|
| | eq_probs = torch.sigmoid(self.span_detector(x))
|
| | eq_probs = eq_probs.squeeze(-1)
|
| |
|
| |
|
| | threshold = 0.5
|
| | spans = []
|
| |
|
| | for b in range(batch):
|
| | probs = eq_probs[b]
|
| | is_equation = (probs > threshold).cpu().numpy()
|
| |
|
| |
|
| | span_list = []
|
| | in_span = False
|
| | start = 0
|
| |
|
| | for t in range(seq_len):
|
| | if is_equation[t] and not in_span:
|
| | start = t
|
| | in_span = True
|
| | elif not is_equation[t] and in_span:
|
| | span_list.append((start, t))
|
| | in_span = False
|
| |
|
| | if in_span:
|
| | span_list.append((start, seq_len))
|
| |
|
| | spans.append(span_list)
|
| |
|
| | return spans
|
| |
|
| | def compute_equation_loss(
|
| | self,
|
| | x: torch.Tensor,
|
| | equation_mask: torch.Tensor,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Compute auxiliary loss for equation detection training.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| | equation_mask: Ground truth equation mask (batch, seq_len), 1 if token is in equation
|
| |
|
| | Returns:
|
| | Binary cross-entropy loss for equation detection
|
| | """
|
| | logits = self.span_detector(x).squeeze(-1)
|
| | loss = F.binary_cross_entropy_with_logits(
|
| | logits,
|
| | equation_mask.float(),
|
| | )
|
| | return loss
|
| |
|
| |
|
| | def test_equation_module():
|
| | """Test EquationModule."""
|
| | d_model = 512
|
| | batch_size = 2
|
| | seq_len = 128
|
| |
|
| | module = EquationModule(d_model)
|
| |
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| | text = [
|
| | "The energy is $E = mc^2$ and momentum is $p = mv$.",
|
| | "Equation: \[ F = ma \] and also $a^2 + b^2 = c^2$."
|
| | ]
|
| |
|
| | output = module(x, text=text)
|
| | print(f"Input shape: {x.shape}")
|
| | print(f"Output shape: {output.shape}")
|
| | assert output.shape == x.shape
|
| |
|
| |
|
| | equation_mask = torch.zeros(batch_size, seq_len)
|
| | equation_mask[0, 10:15] = 1.0
|
| | equation_mask[1, 5:12] = 1.0
|
| | loss = module.compute_equation_loss(x, equation_mask)
|
| | print(f"Equation loss: {loss.item():.4f}")
|
| |
|
| | print("EquationModule test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_equation_module()
|
| |
|