| | """
|
| | NumericalReasoningModule: Handles scientific numerical reasoning.
|
| | Digit-level number encoding, scientific notation, unit awareness.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | import re
|
| | from typing import Optional, Tuple, List
|
| |
|
| |
|
| | class NumericalReasoningModule(nn.Module):
|
| | """
|
| | Handles scientific numerical reasoning.
|
| | - Digit-level number encoding (each digit gets position-aware embedding)
|
| | - Scientific notation understanding (6.02 × 10²³)
|
| | - Unit awareness (meters, joules, moles, kelvin)
|
| | - Order of magnitude reasoning
|
| | - Significant figures tracking
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | d_model: int,
|
| | max_digits: int = 20,
|
| | num_units: int = 256,
|
| | ):
|
| | """
|
| | Initialize NumericalReasoningModule.
|
| |
|
| | Args:
|
| | d_model: Model dimension
|
| | max_digits: Maximum number of digits to encode
|
| | num_units: Number of unit types to embed
|
| | """
|
| | super().__init__()
|
| | self.d_model = d_model
|
| | self.max_digits = max_digits
|
| |
|
| |
|
| | self.digit_embed = nn.Embedding(10, 64)
|
| |
|
| |
|
| | self.position_embed = nn.Embedding(max_digits, 64)
|
| |
|
| |
|
| | self.number_proj = nn.Linear(128, d_model)
|
| |
|
| |
|
| | self.unit_embed = nn.Embedding(num_units, d_model)
|
| |
|
| |
|
| | self.sci_notation = nn.Linear(d_model * 2, d_model)
|
| |
|
| |
|
| | self.magnitude_embed = nn.Embedding(21, d_model)
|
| |
|
| |
|
| | self._initialize_weights()
|
| |
|
| | def _initialize_weights(self):
|
| | """Initialize weights."""
|
| | for module in [self.digit_embed, self.position_embed, self.number_proj,
|
| | self.unit_embed, self.sci_notation, self.magnitude_embed]:
|
| | if hasattr(module, 'weight'):
|
| | nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| | if hasattr(module, 'bias') and module.bias is not None:
|
| | nn.init.zeros_(module.bias)
|
| |
|
| | def encode_number(
|
| | self,
|
| | number_str: str,
|
| | device: torch.device,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Encode a number string using digit-level encoding.
|
| |
|
| | Args:
|
| | number_str: String representation of number (e.g., "123.45e-6")
|
| | device: Torch device
|
| |
|
| | Returns:
|
| | Number embedding (d_model,)
|
| | """
|
| |
|
| | digits = [int(d) for d in re.findall(r'\d', number_str)]
|
| | if not digits:
|
| | digits = [0]
|
| |
|
| |
|
| | if len(digits) > self.max_digits:
|
| | digits = digits[:self.max_digits]
|
| | else:
|
| | digits = digits + [0] * (self.max_digits - len(digits))
|
| |
|
| | digits_tensor = torch.tensor(digits, device=device)
|
| | positions = torch.arange(self.max_digits, device=device)
|
| |
|
| |
|
| | digit_emb = self.digit_embed(digits_tensor)
|
| | pos_emb = self.position_embed(positions)
|
| |
|
| |
|
| | combined = torch.cat([digit_emb, pos_emb], dim=-1)
|
| | number_emb = self.number_proj(combined)
|
| |
|
| |
|
| | return number_emb.mean(dim=0)
|
| |
|
| | def detect_numbers(
|
| | self,
|
| | text: str,
|
| | ) -> List[Tuple[str, int, int, Optional[str]]]:
|
| | """
|
| | Detect numbers in text with optional units and scientific notation.
|
| |
|
| | Returns:
|
| | List of (number_str, start_char, end_char, unit_str)
|
| | """
|
| |
|
| |
|
| | pattern = r'(\d+(?:\.\d+)?(?:[eE][+-]?\d+)?(?:×10\^?[+-]?\d+)?)(?:\s*([a-zA-Z°%]+))?'
|
| |
|
| | matches = []
|
| | for match in re.finditer(pattern, text):
|
| | number_str = match.group(1)
|
| | unit_str = match.group(2) if match.group(2) else None
|
| | matches.append((number_str, match.start(), match.end(), unit_str))
|
| |
|
| | return matches
|
| |
|
| | def forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | text: Optional[List[str]] = None,
|
| | number_positions: Optional[List[List[Tuple[int, int, str]]]] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Forward pass through numerical reasoning module.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| | text: Optional original text strings
|
| | number_positions: Optional list of (start_token, end_token, number_str) per batch
|
| |
|
| | Returns:
|
| | Numerical-enhanced representation (batch, seq_len, d_model)
|
| | """
|
| | batch, seq_len, d_model = x.shape
|
| | device = x.device
|
| |
|
| |
|
| | if number_positions is None and text is not None:
|
| | number_positions = []
|
| | for b in range(batch):
|
| | numbers = self.detect_numbers(text[b])
|
| |
|
| | token_nums = []
|
| | for num_str, start_char, end_char, unit_str in numbers:
|
| | start_tok = max(0, start_char // 4)
|
| | end_tok = min(seq_len, end_char // 4 + 1)
|
| | token_nums.append((start_tok, end_tok, num_str, unit_str))
|
| | number_positions.append(token_nums)
|
| |
|
| |
|
| | output = x.clone()
|
| |
|
| | if number_positions:
|
| | for b in range(batch):
|
| | nums_b = number_positions[b] if b < len(number_positions) else []
|
| |
|
| | for start_tok, end_tok, num_str, unit_str in nums_b:
|
| | if end_tok <= start_tok or start_tok >= seq_len:
|
| | continue
|
| |
|
| |
|
| | start_tok = min(start_tok, seq_len - 1)
|
| | end_tok = min(end_tok, seq_len)
|
| |
|
| |
|
| | number_emb = self.encode_number(num_str, device)
|
| |
|
| |
|
| | if unit_str:
|
| |
|
| | unit_id = hash(unit_str) % self.unit_embed.num_embeddings
|
| | unit_emb = self.unit_embed(torch.tensor(unit_id, device=device))
|
| | number_emb = number_emb + unit_emb
|
| |
|
| |
|
| | if 'e' in num_str.lower() or '×10' in num_str:
|
| |
|
| | exp_match = re.search(r'[eE]([+-]?\d+)|×10\^?([+-]?\d+)', num_str)
|
| | if exp_match:
|
| | exp = int(exp_match.group(1) or exp_match.group(2))
|
| | exp = max(-10, min(10, exp))
|
| | magnitude_emb = self.magnitude_embed(torch.tensor(exp + 10, device=device))
|
| | number_emb = number_emb + magnitude_emb
|
| |
|
| |
|
| | output[b, start_tok, :] += number_emb
|
| |
|
| | return output
|
| |
|
| | def compute_numerical_loss(
|
| | self,
|
| | x: torch.Tensor,
|
| | number_mask: torch.Tensor,
|
| | target_values: torch.Tensor,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Compute auxiliary loss for numerical reasoning.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| | number_mask: Mask for number tokens (batch, seq_len)
|
| | target_values: Target numeric values (batch, seq_len) or None
|
| |
|
| | Returns:
|
| | MSE loss for value prediction (simplified)
|
| | """
|
| |
|
| |
|
| | return 0.0
|
| |
|
| |
|
| | def test_numerical_module():
|
| | """Test NumericalReasoningModule."""
|
| | d_model = 512
|
| | batch_size = 2
|
| | seq_len = 128
|
| |
|
| | module = NumericalReasoningModule(d_model)
|
| |
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| | text = [
|
| | "The speed of light is 2.998×10^8 m/s and Planck's constant is 6.626×10^-34 J·s.",
|
| | "Calculate: 123.45 + 67.89 = ? The answer is 191.34."
|
| | ]
|
| |
|
| | output = module(x, text=text)
|
| | print(f"Input shape: {x.shape}")
|
| | print(f"Output shape: {output.shape}")
|
| | assert output.shape == x.shape
|
| |
|
| | print("NumericalReasoningModule test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_numerical_module()
|
| |
|