Vortex-7b-V1 / training /losses.py
Zandy-Wandy's picture
Upload Vortex model
bf64b03 verified
"""
Science-aware losses for Vortex model training.
Combines standard language modeling with auxiliary tasks.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple
class VortexLoss(nn.Module):
"""
Combined loss for Vortex model with science-aware components.
total_loss = (
lm_loss * 1.0
+ equation_loss * 0.3
+ domain_loss * 0.1
+ citation_loss * 0.1
+ numerical_loss * 0.2
)
"""
def __init__(self, config: Dict):
"""
Initialize loss.
Args:
config: Training config with loss_weights
"""
super().__init__()
self.loss_weights = config.get("loss_weights", {
"lm_loss": 1.0,
"equation_loss": 0.3,
"domain_loss": 0.1,
"citation_loss": 0.1,
"numerical_loss": 0.2,
})
def forward(
self,
logits: torch.Tensor,
labels: torch.Tensor,
equation_module: Optional[nn.Module] = None,
equation_mask: Optional[torch.Tensor] = None,
domain_logits: Optional[torch.Tensor] = None,
domain_labels: Optional[torch.Tensor] = None,
citation_module: Optional[nn.Module] = None,
citation_mask: Optional[torch.Tensor] = None,
citation_confidence: Optional[torch.Tensor] = None,
numerical_module: Optional[nn.Module] = None,
numerical_mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Compute total loss.
Args:
logits: (batch, seq_len, vocab_size)
labels: (batch, seq_len) with token IDs
equation_module: EquationModule for equation loss
equation_mask: (batch, seq_len) 1 if token in equation
domain_logits: (batch, num_domains)
domain_labels: (batch,)
citation_module: CitationModule for citation loss
citation_mask: (batch, seq_len)
citation_confidence: (batch, seq_len, 1)
numerical_module: NumericalReasoningModule
numerical_mask: (batch, seq_len)
Returns:
Dictionary with total loss and component losses
"""
losses = {}
# 1. Language modeling loss (next token prediction)
lm_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100, # ignore padding
)
losses["lm_loss"] = lm_loss
# 2. Equation detection loss
if equation_module is not None and equation_mask is not None:
# Need hidden states from equation module - would need to modify forward pass
# For now, placeholder
equation_loss = torch.tensor(0.0, device=logits.device)
losses["equation_loss"] = equation_loss
else:
losses["equation_loss"] = torch.tensor(0.0, device=logits.device)
# 3. Domain classification loss
if domain_logits is not None and domain_labels is not None:
domain_loss = F.cross_entropy(domain_logits, domain_labels)
losses["domain_loss"] = domain_loss
else:
losses["domain_loss"] = torch.tensor(0.0, device=logits.device)
# 4. Citation detection loss
if citation_module is not None and citation_mask is not None and citation_confidence is not None:
citation_loss = citation_module.compute_citation_loss(
# Would need hidden states - placeholder
torch.zeros_like(logits[:, :, :1]), # dummy
citation_mask,
citation_confidence,
)
losses["citation_loss"] = citation_loss
else:
losses["citation_loss"] = torch.tensor(0.0, device=logits.device)
# 5. Numerical reasoning loss
if numerical_module is not None and numerical_mask is not None:
numerical_loss = numerical_module.compute_numerical_loss(
torch.zeros_like(logits), # dummy hidden states
numerical_mask,
None, # target values
)
losses["numerical_loss"] = numerical_loss
else:
losses["numerical_loss"] = torch.tensor(0.0, device=logits.device)
# Weighted sum
total_loss = torch.tensor(0.0, device=logits.device)
for name, loss in losses.items():
weight = self.loss_weights.get(name, 1.0)
total_loss = total_loss + loss * weight
losses["total_loss"] = total_loss
return losses
def test_vortex_loss():
"""Test the loss function."""
config = {"loss_weights": {
"lm_loss": 1.0,
"equation_loss": 0.3,
"domain_loss": 0.1,
"citation_loss": 0.1,
"numerical_loss": 0.2,
}}
loss_fn = VortexLoss(config)
batch_size = 2
seq_len = 128
vocab_size = 1000
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
losses = loss_fn(logits, labels)
print("Losses:")
for name, value in losses.items():
print(f" {name}: {value.item():.4f}")
assert "total_loss" in losses
print("VortexLoss test passed!")
if __name__ == "__main__":
test_vortex_loss()