""" NumericalReasoningModule: Handles scientific numerical reasoning. Digit-level number encoding, scientific notation, unit awareness. """ import torch import torch.nn as nn import torch.nn.functional as F import re from typing import Optional, Tuple, List class NumericalReasoningModule(nn.Module): """ Handles scientific numerical reasoning. - Digit-level number encoding (each digit gets position-aware embedding) - Scientific notation understanding (6.02 × 10²³) - Unit awareness (meters, joules, moles, kelvin) - Order of magnitude reasoning - Significant figures tracking """ def __init__( self, d_model: int, max_digits: int = 20, num_units: int = 256, ): """ Initialize NumericalReasoningModule. Args: d_model: Model dimension max_digits: Maximum number of digits to encode num_units: Number of unit types to embed """ super().__init__() self.d_model = d_model self.max_digits = max_digits # Digit embeddings (0-9) self.digit_embed = nn.Embedding(10, 64) # Position embeddings (ones, tens, hundreds...) self.position_embed = nn.Embedding(max_digits, 64) # Project digit+position to model dimension self.number_proj = nn.Linear(128, d_model) # Unit embedding (SI units + common scientific units) self.unit_embed = nn.Embedding(num_units, d_model) # Scientific notation handler self.sci_notation = nn.Linear(d_model * 2, d_model) # Magnitude embedding (powers of 10: -10 to +10) self.magnitude_embed = nn.Embedding(21, d_model) # -10 to +10 # Initialize weights self._initialize_weights() def _initialize_weights(self): """Initialize weights.""" for module in [self.digit_embed, self.position_embed, self.number_proj, self.unit_embed, self.sci_notation, self.magnitude_embed]: if hasattr(module, 'weight'): nn.init.normal_(module.weight, mean=0.0, std=0.02) if hasattr(module, 'bias') and module.bias is not None: nn.init.zeros_(module.bias) def encode_number( self, number_str: str, device: torch.device, ) -> torch.Tensor: """ Encode a number string using digit-level encoding. Args: number_str: String representation of number (e.g., "123.45e-6") device: Torch device Returns: Number embedding (d_model,) """ # Extract digits (ignore decimal point, sign, exponent) digits = [int(d) for d in re.findall(r'\d', number_str)] if not digits: digits = [0] # Pad/truncate to max_digits if len(digits) > self.max_digits: digits = digits[:self.max_digits] else: digits = digits + [0] * (self.max_digits - len(digits)) digits_tensor = torch.tensor(digits, device=device) # (max_digits,) positions = torch.arange(self.max_digits, device=device) # (max_digits,) # Embed digits and positions digit_emb = self.digit_embed(digits_tensor) # (max_digits, 64) pos_emb = self.position_embed(positions) # (max_digits, 64) # Concatenate and project combined = torch.cat([digit_emb, pos_emb], dim=-1) # (max_digits, 128) number_emb = self.number_proj(combined) # (max_digits, d_model) # Mean pool over positions return number_emb.mean(dim=0) # (d_model,) def detect_numbers( self, text: str, ) -> List[Tuple[str, int, int, Optional[str]]]: """ Detect numbers in text with optional units and scientific notation. Returns: List of (number_str, start_char, end_char, unit_str) """ # Pattern: number with optional decimal, exponent, and unit # Matches: 123, 123.45, 1.23e-4, 6.02×10²³, 100 m, 5.0 J/mol pattern = r'(\d+(?:\.\d+)?(?:[eE][+-]?\d+)?(?:×10\^?[+-]?\d+)?)(?:\s*([a-zA-Z°%]+))?' matches = [] for match in re.finditer(pattern, text): number_str = match.group(1) unit_str = match.group(2) if match.group(2) else None matches.append((number_str, match.start(), match.end(), unit_str)) return matches def forward( self, x: torch.Tensor, text: Optional[List[str]] = None, number_positions: Optional[List[List[Tuple[int, int, str]]]] = None, ) -> torch.Tensor: """ Forward pass through numerical reasoning module. Args: x: Input tensor (batch, seq_len, d_model) text: Optional original text strings number_positions: Optional list of (start_token, end_token, number_str) per batch Returns: Numerical-enhanced representation (batch, seq_len, d_model) """ batch, seq_len, d_model = x.shape device = x.device # Detect numbers if text provided if number_positions is None and text is not None: number_positions = [] for b in range(batch): numbers = self.detect_numbers(text[b]) # Convert char positions to token positions (approximate) token_nums = [] for num_str, start_char, end_char, unit_str in numbers: start_tok = max(0, start_char // 4) end_tok = min(seq_len, end_char // 4 + 1) token_nums.append((start_tok, end_tok, num_str, unit_str)) number_positions.append(token_nums) # Enhance number spans output = x.clone() if number_positions: for b in range(batch): nums_b = number_positions[b] if b < len(number_positions) else [] for start_tok, end_tok, num_str, unit_str in nums_b: if end_tok <= start_tok or start_tok >= seq_len: continue # Clamp to sequence bounds start_tok = min(start_tok, seq_len - 1) end_tok = min(end_tok, seq_len) # Encode the number number_emb = self.encode_number(num_str, device) # (d_model,) # Add unit embedding if present if unit_str: # Simple hash-based unit ID (in practice would have unit vocab) unit_id = hash(unit_str) % self.unit_embed.num_embeddings unit_emb = self.unit_embed(torch.tensor(unit_id, device=device)) number_emb = number_emb + unit_emb # Add magnitude embedding for scientific notation if 'e' in num_str.lower() or '×10' in num_str: # Extract exponent exp_match = re.search(r'[eE]([+-]?\d+)|×10\^?([+-]?\d+)', num_str) if exp_match: exp = int(exp_match.group(1) or exp_match.group(2)) exp = max(-10, min(10, exp)) # Clamp to embedding range magnitude_emb = self.magnitude_embed(torch.tensor(exp + 10, device=device)) number_emb = number_emb + magnitude_emb # Add to the first token of the number span output[b, start_tok, :] += number_emb return output def compute_numerical_loss( self, x: torch.Tensor, number_mask: torch.Tensor, target_values: torch.Tensor, ) -> torch.Tensor: """ Compute auxiliary loss for numerical reasoning. Args: x: Input tensor (batch, seq_len, d_model) number_mask: Mask for number tokens (batch, seq_len) target_values: Target numeric values (batch, seq_len) or None Returns: MSE loss for value prediction (simplified) """ # This is a simplified loss - in practice would have a value prediction head # For now, return a small regularization loss on number embeddings return 0.0 def test_numerical_module(): """Test NumericalReasoningModule.""" d_model = 512 batch_size = 2 seq_len = 128 module = NumericalReasoningModule(d_model) x = torch.randn(batch_size, seq_len, d_model) text = [ "The speed of light is 2.998×10^8 m/s and Planck's constant is 6.626×10^-34 J·s.", "Calculate: 123.45 + 67.89 = ? The answer is 191.34." ] output = module(x, text=text) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") assert output.shape == x.shape print("NumericalReasoningModule test passed!") if __name__ == "__main__": test_numerical_module()