""" Science-aware losses for Vortex model training. Combines standard language modeling with auxiliary tasks. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, Optional, Tuple class VortexLoss(nn.Module): """ Combined loss for Vortex model with science-aware components. total_loss = ( lm_loss * 1.0 + equation_loss * 0.3 + domain_loss * 0.1 + citation_loss * 0.1 + numerical_loss * 0.2 ) """ def __init__(self, config: Dict): """ Initialize loss. Args: config: Training config with loss_weights """ super().__init__() self.loss_weights = config.get("loss_weights", { "lm_loss": 1.0, "equation_loss": 0.3, "domain_loss": 0.1, "citation_loss": 0.1, "numerical_loss": 0.2, }) def forward( self, logits: torch.Tensor, labels: torch.Tensor, equation_module: Optional[nn.Module] = None, equation_mask: Optional[torch.Tensor] = None, domain_logits: Optional[torch.Tensor] = None, domain_labels: Optional[torch.Tensor] = None, citation_module: Optional[nn.Module] = None, citation_mask: Optional[torch.Tensor] = None, citation_confidence: Optional[torch.Tensor] = None, numerical_module: Optional[nn.Module] = None, numerical_mask: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Compute total loss. Args: logits: (batch, seq_len, vocab_size) labels: (batch, seq_len) with token IDs equation_module: EquationModule for equation loss equation_mask: (batch, seq_len) 1 if token in equation domain_logits: (batch, num_domains) domain_labels: (batch,) citation_module: CitationModule for citation loss citation_mask: (batch, seq_len) citation_confidence: (batch, seq_len, 1) numerical_module: NumericalReasoningModule numerical_mask: (batch, seq_len) Returns: Dictionary with total loss and component losses """ losses = {} # 1. Language modeling loss (next token prediction) lm_loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100, # ignore padding ) losses["lm_loss"] = lm_loss # 2. Equation detection loss if equation_module is not None and equation_mask is not None: # Need hidden states from equation module - would need to modify forward pass # For now, placeholder equation_loss = torch.tensor(0.0, device=logits.device) losses["equation_loss"] = equation_loss else: losses["equation_loss"] = torch.tensor(0.0, device=logits.device) # 3. Domain classification loss if domain_logits is not None and domain_labels is not None: domain_loss = F.cross_entropy(domain_logits, domain_labels) losses["domain_loss"] = domain_loss else: losses["domain_loss"] = torch.tensor(0.0, device=logits.device) # 4. Citation detection loss if citation_module is not None and citation_mask is not None and citation_confidence is not None: citation_loss = citation_module.compute_citation_loss( # Would need hidden states - placeholder torch.zeros_like(logits[:, :, :1]), # dummy citation_mask, citation_confidence, ) losses["citation_loss"] = citation_loss else: losses["citation_loss"] = torch.tensor(0.0, device=logits.device) # 5. Numerical reasoning loss if numerical_module is not None and numerical_mask is not None: numerical_loss = numerical_module.compute_numerical_loss( torch.zeros_like(logits), # dummy hidden states numerical_mask, None, # target values ) losses["numerical_loss"] = numerical_loss else: losses["numerical_loss"] = torch.tensor(0.0, device=logits.device) # Weighted sum total_loss = torch.tensor(0.0, device=logits.device) for name, loss in losses.items(): weight = self.loss_weights.get(name, 1.0) total_loss = total_loss + loss * weight losses["total_loss"] = total_loss return losses def test_vortex_loss(): """Test the loss function.""" config = {"loss_weights": { "lm_loss": 1.0, "equation_loss": 0.3, "domain_loss": 0.1, "citation_loss": 0.1, "numerical_loss": 0.2, }} loss_fn = VortexLoss(config) batch_size = 2 seq_len = 128 vocab_size = 1000 logits = torch.randn(batch_size, seq_len, vocab_size) labels = torch.randint(0, vocab_size, (batch_size, seq_len)) losses = loss_fn(logits, labels) print("Losses:") for name, value in losses.items(): print(f" {name}: {value.item():.4f}") assert "total_loss" in losses print("VortexLoss test passed!") if __name__ == "__main__": test_vortex_loss()