Vortex-13b-V1 / models /science_modules /equation_module.py
Zandy-Wandy's picture
Upload Vortex model
5c43f61 verified
"""
EquationModule: Specialized processing for mathematical equations and LaTeX.
Detects equation spans, applies equation-specific attention, and learns
structural representations of mathematical expressions.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
from typing import Optional, Tuple, List
class EquationModule(nn.Module):
"""
Specialized processing for mathematical equations and LaTeX.
- Detects equation spans in input (between $ $ or \[ \] delimiters)
- Applies equation-specific attention patterns within equation spans
- Learns structural representations of mathematical expressions
- Tree-aware: understands operator precedence and nesting
"""
def __init__(self, d_model: int, num_heads: int = 8):
"""
Initialize EquationModule.
Args:
d_model: Model dimension
num_heads: Number of heads for equation-specific attention
"""
super().__init__()
self.d_model = d_model
# Equation span detector (lightweight linear classifier)
self.span_detector = nn.Linear(d_model, 1)
# Equation-specific transformer (shallow, 2 layers)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=d_model * 4,
activation=F.silu,
batch_first=True,
dropout=0.1,
)
self.equation_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
# Merge equation representations back into main stream
self.merge = nn.Linear(d_model * 2, d_model)
# LaTeX structure awareness (simple positional encoding for tree depth)
self.depth_embedding = nn.Embedding(10, d_model) # Max depth 10
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize weights."""
for module in [self.span_detector, self.merge, self.depth_embedding]:
if hasattr(module, 'weight'):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.zeros_(module.bias)
def detect_equation_spans(
self,
text: str,
token_ids: Optional[torch.Tensor] = None,
) -> List[Tuple[int, int]]:
"""
Detect equation spans in text using delimiters.
Supports: $...$, $$...$$, \[...\], \(...\)
Args:
text: Input text string
token_ids: Optional token IDs for alignment
Returns:
List of (start_char, end_char) spans
"""
spans = []
# Pattern 1: $...$ (inline math)
for match in re.finditer(r'\$(.+?)\$', text, re.DOTALL):
spans.append((match.start(), match.end()))
# Pattern 2: $$...$$ (display math)
for match in re.finditer(r'\$\$(.+?)\$\$', text, re.DOTALL):
spans.append((match.start(), match.end()))
# Pattern 3: \[...\] (LaTeX display math)
for match in re.finditer(r'\\\[(.+?)\\\]', text, re.DOTALL):
spans.append((match.start(), match.end()))
# Pattern 4: \(...\) (LaTeX inline math)
for match in re.finditer(r'\\\((.+?)\\\)', text, re.DOTALL):
spans.append((match.start(), match.end()))
return spans
def forward(
self,
x: torch.Tensor,
text: Optional[List[str]] = None,
token_spans: Optional[List[List[Tuple[int, int]]]] = None,
) -> torch.Tensor:
"""
Forward pass through the equation module.
Args:
x: Input tensor (batch, seq_len, d_model)
text: Optional original text strings (for delimiter-based detection)
token_spans: Optional pre-computed token-level equation spans
Each element: list of (start_token, end_token) for that batch item
Returns:
Equation-enhanced representation (batch, seq_len, d_model)
"""
batch, seq_len, d_model = x.shape
# Detect equation spans
if token_spans is None and text is not None:
# Use delimiter-based detection (requires text)
token_spans = []
for b in range(batch):
char_spans = self.detect_equation_spans(text[b])
# Convert char spans to token spans (simplified - assumes 1 char ≈ 1 token)
# In practice, would need proper tokenization alignment
token_spans_b = []
for start_char, end_char in char_spans:
# Rough approximation: divide by average chars per token (~4)
start_token = max(0, start_char // 4)
end_token = min(seq_len, end_char // 4 + 1)
token_spans_b.append((start_token, end_token))
token_spans.append(token_spans_b)
elif token_spans is None:
# Fallback: use learned detector
token_spans = self._learned_span_detection(x)
# Process each batch item
output = x.clone()
for b in range(batch):
spans_b = token_spans[b] if b < len(token_spans) else []
for start_tok, end_tok in spans_b:
if end_tok <= start_tok:
continue
# Extract equation segment
eq_segment = x[b:b+1, start_tok:end_tok, :] # (1, seg_len, d_model)
# Apply equation-specific transformer
eq_encoded = self.equation_encoder(eq_segment)
# Merge with original
merged = torch.cat([eq_segment, eq_encoded], dim=-1)
merged = self.merge(merged)
# Place back in output
output[b:b+1, start_tok:end_tok, :] = merged
return output
def _learned_span_detection(
self,
x: torch.Tensor,
) -> List[List[Tuple[int, int]]]:
"""
Use learned detector to find equation spans when delimiters missing.
Simple thresholding on span_detector output.
Args:
x: Input tensor (batch, seq_len, d_model)
Returns:
List of token spans per batch item
"""
batch, seq_len, _ = x.shape
# Compute equation probability per token
eq_probs = torch.sigmoid(self.span_detector(x)) # (batch, seq_len, 1)
eq_probs = eq_probs.squeeze(-1) # (batch, seq_len)
# Threshold
threshold = 0.5
spans = []
for b in range(batch):
probs = eq_probs[b]
is_equation = (probs > threshold).cpu().numpy()
# Find contiguous spans
span_list = []
in_span = False
start = 0
for t in range(seq_len):
if is_equation[t] and not in_span:
start = t
in_span = True
elif not is_equation[t] and in_span:
span_list.append((start, t))
in_span = False
if in_span:
span_list.append((start, seq_len))
spans.append(span_list)
return spans
def compute_equation_loss(
self,
x: torch.Tensor,
equation_mask: torch.Tensor,
) -> torch.Tensor:
"""
Compute auxiliary loss for equation detection training.
Args:
x: Input tensor (batch, seq_len, d_model)
equation_mask: Ground truth equation mask (batch, seq_len), 1 if token is in equation
Returns:
Binary cross-entropy loss for equation detection
"""
logits = self.span_detector(x).squeeze(-1) # (batch, seq_len)
loss = F.binary_cross_entropy_with_logits(
logits,
equation_mask.float(),
)
return loss
def test_equation_module():
"""Test EquationModule."""
d_model = 512
batch_size = 2
seq_len = 128
module = EquationModule(d_model)
x = torch.randn(batch_size, seq_len, d_model)
text = [
"The energy is $E = mc^2$ and momentum is $p = mv$.",
"Equation: \[ F = ma \] and also $a^2 + b^2 = c^2$."
]
output = module(x, text=text)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
assert output.shape == x.shape
# Test equation loss
equation_mask = torch.zeros(batch_size, seq_len)
equation_mask[0, 10:15] = 1.0 # Simulate equation span
equation_mask[1, 5:12] = 1.0
loss = module.compute_equation_loss(x, equation_mask)
print(f"Equation loss: {loss.item():.4f}")
print("EquationModule test passed!")
if __name__ == "__main__":
test_equation_module()