File size: 9,075 Bytes
5c43f61 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 | """
EquationModule: Specialized processing for mathematical equations and LaTeX.
Detects equation spans, applies equation-specific attention, and learns
structural representations of mathematical expressions.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
from typing import Optional, Tuple, List
class EquationModule(nn.Module):
"""
Specialized processing for mathematical equations and LaTeX.
- Detects equation spans in input (between $ $ or \[ \] delimiters)
- Applies equation-specific attention patterns within equation spans
- Learns structural representations of mathematical expressions
- Tree-aware: understands operator precedence and nesting
"""
def __init__(self, d_model: int, num_heads: int = 8):
"""
Initialize EquationModule.
Args:
d_model: Model dimension
num_heads: Number of heads for equation-specific attention
"""
super().__init__()
self.d_model = d_model
# Equation span detector (lightweight linear classifier)
self.span_detector = nn.Linear(d_model, 1)
# Equation-specific transformer (shallow, 2 layers)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=d_model * 4,
activation=F.silu,
batch_first=True,
dropout=0.1,
)
self.equation_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
# Merge equation representations back into main stream
self.merge = nn.Linear(d_model * 2, d_model)
# LaTeX structure awareness (simple positional encoding for tree depth)
self.depth_embedding = nn.Embedding(10, d_model) # Max depth 10
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize weights."""
for module in [self.span_detector, self.merge, self.depth_embedding]:
if hasattr(module, 'weight'):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.zeros_(module.bias)
def detect_equation_spans(
self,
text: str,
token_ids: Optional[torch.Tensor] = None,
) -> List[Tuple[int, int]]:
"""
Detect equation spans in text using delimiters.
Supports: $...$, $$...$$, \[...\], \(...\)
Args:
text: Input text string
token_ids: Optional token IDs for alignment
Returns:
List of (start_char, end_char) spans
"""
spans = []
# Pattern 1: $...$ (inline math)
for match in re.finditer(r'\$(.+?)\$', text, re.DOTALL):
spans.append((match.start(), match.end()))
# Pattern 2: $$...$$ (display math)
for match in re.finditer(r'\$\$(.+?)\$\$', text, re.DOTALL):
spans.append((match.start(), match.end()))
# Pattern 3: \[...\] (LaTeX display math)
for match in re.finditer(r'\\\[(.+?)\\\]', text, re.DOTALL):
spans.append((match.start(), match.end()))
# Pattern 4: \(...\) (LaTeX inline math)
for match in re.finditer(r'\\\((.+?)\\\)', text, re.DOTALL):
spans.append((match.start(), match.end()))
return spans
def forward(
self,
x: torch.Tensor,
text: Optional[List[str]] = None,
token_spans: Optional[List[List[Tuple[int, int]]]] = None,
) -> torch.Tensor:
"""
Forward pass through the equation module.
Args:
x: Input tensor (batch, seq_len, d_model)
text: Optional original text strings (for delimiter-based detection)
token_spans: Optional pre-computed token-level equation spans
Each element: list of (start_token, end_token) for that batch item
Returns:
Equation-enhanced representation (batch, seq_len, d_model)
"""
batch, seq_len, d_model = x.shape
# Detect equation spans
if token_spans is None and text is not None:
# Use delimiter-based detection (requires text)
token_spans = []
for b in range(batch):
char_spans = self.detect_equation_spans(text[b])
# Convert char spans to token spans (simplified - assumes 1 char ≈ 1 token)
# In practice, would need proper tokenization alignment
token_spans_b = []
for start_char, end_char in char_spans:
# Rough approximation: divide by average chars per token (~4)
start_token = max(0, start_char // 4)
end_token = min(seq_len, end_char // 4 + 1)
token_spans_b.append((start_token, end_token))
token_spans.append(token_spans_b)
elif token_spans is None:
# Fallback: use learned detector
token_spans = self._learned_span_detection(x)
# Process each batch item
output = x.clone()
for b in range(batch):
spans_b = token_spans[b] if b < len(token_spans) else []
for start_tok, end_tok in spans_b:
if end_tok <= start_tok:
continue
# Extract equation segment
eq_segment = x[b:b+1, start_tok:end_tok, :] # (1, seg_len, d_model)
# Apply equation-specific transformer
eq_encoded = self.equation_encoder(eq_segment)
# Merge with original
merged = torch.cat([eq_segment, eq_encoded], dim=-1)
merged = self.merge(merged)
# Place back in output
output[b:b+1, start_tok:end_tok, :] = merged
return output
def _learned_span_detection(
self,
x: torch.Tensor,
) -> List[List[Tuple[int, int]]]:
"""
Use learned detector to find equation spans when delimiters missing.
Simple thresholding on span_detector output.
Args:
x: Input tensor (batch, seq_len, d_model)
Returns:
List of token spans per batch item
"""
batch, seq_len, _ = x.shape
# Compute equation probability per token
eq_probs = torch.sigmoid(self.span_detector(x)) # (batch, seq_len, 1)
eq_probs = eq_probs.squeeze(-1) # (batch, seq_len)
# Threshold
threshold = 0.5
spans = []
for b in range(batch):
probs = eq_probs[b]
is_equation = (probs > threshold).cpu().numpy()
# Find contiguous spans
span_list = []
in_span = False
start = 0
for t in range(seq_len):
if is_equation[t] and not in_span:
start = t
in_span = True
elif not is_equation[t] and in_span:
span_list.append((start, t))
in_span = False
if in_span:
span_list.append((start, seq_len))
spans.append(span_list)
return spans
def compute_equation_loss(
self,
x: torch.Tensor,
equation_mask: torch.Tensor,
) -> torch.Tensor:
"""
Compute auxiliary loss for equation detection training.
Args:
x: Input tensor (batch, seq_len, d_model)
equation_mask: Ground truth equation mask (batch, seq_len), 1 if token is in equation
Returns:
Binary cross-entropy loss for equation detection
"""
logits = self.span_detector(x).squeeze(-1) # (batch, seq_len)
loss = F.binary_cross_entropy_with_logits(
logits,
equation_mask.float(),
)
return loss
def test_equation_module():
"""Test EquationModule."""
d_model = 512
batch_size = 2
seq_len = 128
module = EquationModule(d_model)
x = torch.randn(batch_size, seq_len, d_model)
text = [
"The energy is $E = mc^2$ and momentum is $p = mv$.",
"Equation: \[ F = ma \] and also $a^2 + b^2 = c^2$."
]
output = module(x, text=text)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
assert output.shape == x.shape
# Test equation loss
equation_mask = torch.zeros(batch_size, seq_len)
equation_mask[0, 10:15] = 1.0 # Simulate equation span
equation_mask[1, 5:12] = 1.0
loss = module.compute_equation_loss(x, equation_mask)
print(f"Equation loss: {loss.item():.4f}")
print("EquationModule test passed!")
if __name__ == "__main__":
test_equation_module()
|