| """
|
| Science-aware losses for Vortex model training.
|
| Combines standard language modeling with auxiliary tasks.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Dict, Optional, Tuple
|
|
|
|
|
| class VortexLoss(nn.Module):
|
| """
|
| Combined loss for Vortex model with science-aware components.
|
| total_loss = (
|
| lm_loss * 1.0
|
| + equation_loss * 0.3
|
| + domain_loss * 0.1
|
| + citation_loss * 0.1
|
| + numerical_loss * 0.2
|
| )
|
| """
|
|
|
| def __init__(self, config: Dict):
|
| """
|
| Initialize loss.
|
|
|
| Args:
|
| config: Training config with loss_weights
|
| """
|
| super().__init__()
|
| self.loss_weights = config.get("loss_weights", {
|
| "lm_loss": 1.0,
|
| "equation_loss": 0.3,
|
| "domain_loss": 0.1,
|
| "citation_loss": 0.1,
|
| "numerical_loss": 0.2,
|
| })
|
|
|
| def forward(
|
| self,
|
| logits: torch.Tensor,
|
| labels: torch.Tensor,
|
| equation_module: Optional[nn.Module] = None,
|
| equation_mask: Optional[torch.Tensor] = None,
|
| domain_logits: Optional[torch.Tensor] = None,
|
| domain_labels: Optional[torch.Tensor] = None,
|
| citation_module: Optional[nn.Module] = None,
|
| citation_mask: Optional[torch.Tensor] = None,
|
| citation_confidence: Optional[torch.Tensor] = None,
|
| numerical_module: Optional[nn.Module] = None,
|
| numerical_mask: Optional[torch.Tensor] = None,
|
| ) -> Dict[str, torch.Tensor]:
|
| """
|
| Compute total loss.
|
|
|
| Args:
|
| logits: (batch, seq_len, vocab_size)
|
| labels: (batch, seq_len) with token IDs
|
| equation_module: EquationModule for equation loss
|
| equation_mask: (batch, seq_len) 1 if token in equation
|
| domain_logits: (batch, num_domains)
|
| domain_labels: (batch,)
|
| citation_module: CitationModule for citation loss
|
| citation_mask: (batch, seq_len)
|
| citation_confidence: (batch, seq_len, 1)
|
| numerical_module: NumericalReasoningModule
|
| numerical_mask: (batch, seq_len)
|
|
|
| Returns:
|
| Dictionary with total loss and component losses
|
| """
|
| losses = {}
|
|
|
|
|
| lm_loss = F.cross_entropy(
|
| logits.view(-1, logits.size(-1)),
|
| labels.view(-1),
|
| ignore_index=-100,
|
| )
|
| losses["lm_loss"] = lm_loss
|
|
|
|
|
| if equation_module is not None and equation_mask is not None:
|
|
|
|
|
| equation_loss = torch.tensor(0.0, device=logits.device)
|
| losses["equation_loss"] = equation_loss
|
| else:
|
| losses["equation_loss"] = torch.tensor(0.0, device=logits.device)
|
|
|
|
|
| if domain_logits is not None and domain_labels is not None:
|
| domain_loss = F.cross_entropy(domain_logits, domain_labels)
|
| losses["domain_loss"] = domain_loss
|
| else:
|
| losses["domain_loss"] = torch.tensor(0.0, device=logits.device)
|
|
|
|
|
| if citation_module is not None and citation_mask is not None and citation_confidence is not None:
|
| citation_loss = citation_module.compute_citation_loss(
|
|
|
| torch.zeros_like(logits[:, :, :1]),
|
| citation_mask,
|
| citation_confidence,
|
| )
|
| losses["citation_loss"] = citation_loss
|
| else:
|
| losses["citation_loss"] = torch.tensor(0.0, device=logits.device)
|
|
|
|
|
| if numerical_module is not None and numerical_mask is not None:
|
| numerical_loss = numerical_module.compute_numerical_loss(
|
| torch.zeros_like(logits),
|
| numerical_mask,
|
| None,
|
| )
|
| losses["numerical_loss"] = numerical_loss
|
| else:
|
| losses["numerical_loss"] = torch.tensor(0.0, device=logits.device)
|
|
|
|
|
| total_loss = torch.tensor(0.0, device=logits.device)
|
| for name, loss in losses.items():
|
| weight = self.loss_weights.get(name, 1.0)
|
| total_loss = total_loss + loss * weight
|
|
|
| losses["total_loss"] = total_loss
|
|
|
| return losses
|
|
|
|
|
| def test_vortex_loss():
|
| """Test the loss function."""
|
| config = {"loss_weights": {
|
| "lm_loss": 1.0,
|
| "equation_loss": 0.3,
|
| "domain_loss": 0.1,
|
| "citation_loss": 0.1,
|
| "numerical_loss": 0.2,
|
| }}
|
|
|
| loss_fn = VortexLoss(config)
|
|
|
| batch_size = 2
|
| seq_len = 128
|
| vocab_size = 1000
|
|
|
| logits = torch.randn(batch_size, seq_len, vocab_size)
|
| labels = torch.randint(0, vocab_size, (batch_size, seq_len))
|
|
|
| losses = loss_fn(logits, labels)
|
| print("Losses:")
|
| for name, value in losses.items():
|
| print(f" {name}: {value.item():.4f}")
|
|
|
| assert "total_loss" in losses
|
| print("VortexLoss test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_vortex_loss()
|
|
|