File size: 5,525 Bytes
5c43f61 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | """
Science-aware losses for Vortex model training.
Combines standard language modeling with auxiliary tasks.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple
class VortexLoss(nn.Module):
"""
Combined loss for Vortex model with science-aware components.
total_loss = (
lm_loss * 1.0
+ equation_loss * 0.3
+ domain_loss * 0.1
+ citation_loss * 0.1
+ numerical_loss * 0.2
)
"""
def __init__(self, config: Dict):
"""
Initialize loss.
Args:
config: Training config with loss_weights
"""
super().__init__()
self.loss_weights = config.get("loss_weights", {
"lm_loss": 1.0,
"equation_loss": 0.3,
"domain_loss": 0.1,
"citation_loss": 0.1,
"numerical_loss": 0.2,
})
def forward(
self,
logits: torch.Tensor,
labels: torch.Tensor,
equation_module: Optional[nn.Module] = None,
equation_mask: Optional[torch.Tensor] = None,
domain_logits: Optional[torch.Tensor] = None,
domain_labels: Optional[torch.Tensor] = None,
citation_module: Optional[nn.Module] = None,
citation_mask: Optional[torch.Tensor] = None,
citation_confidence: Optional[torch.Tensor] = None,
numerical_module: Optional[nn.Module] = None,
numerical_mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Compute total loss.
Args:
logits: (batch, seq_len, vocab_size)
labels: (batch, seq_len) with token IDs
equation_module: EquationModule for equation loss
equation_mask: (batch, seq_len) 1 if token in equation
domain_logits: (batch, num_domains)
domain_labels: (batch,)
citation_module: CitationModule for citation loss
citation_mask: (batch, seq_len)
citation_confidence: (batch, seq_len, 1)
numerical_module: NumericalReasoningModule
numerical_mask: (batch, seq_len)
Returns:
Dictionary with total loss and component losses
"""
losses = {}
# 1. Language modeling loss (next token prediction)
lm_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100, # ignore padding
)
losses["lm_loss"] = lm_loss
# 2. Equation detection loss
if equation_module is not None and equation_mask is not None:
# Need hidden states from equation module - would need to modify forward pass
# For now, placeholder
equation_loss = torch.tensor(0.0, device=logits.device)
losses["equation_loss"] = equation_loss
else:
losses["equation_loss"] = torch.tensor(0.0, device=logits.device)
# 3. Domain classification loss
if domain_logits is not None and domain_labels is not None:
domain_loss = F.cross_entropy(domain_logits, domain_labels)
losses["domain_loss"] = domain_loss
else:
losses["domain_loss"] = torch.tensor(0.0, device=logits.device)
# 4. Citation detection loss
if citation_module is not None and citation_mask is not None and citation_confidence is not None:
citation_loss = citation_module.compute_citation_loss(
# Would need hidden states - placeholder
torch.zeros_like(logits[:, :, :1]), # dummy
citation_mask,
citation_confidence,
)
losses["citation_loss"] = citation_loss
else:
losses["citation_loss"] = torch.tensor(0.0, device=logits.device)
# 5. Numerical reasoning loss
if numerical_module is not None and numerical_mask is not None:
numerical_loss = numerical_module.compute_numerical_loss(
torch.zeros_like(logits), # dummy hidden states
numerical_mask,
None, # target values
)
losses["numerical_loss"] = numerical_loss
else:
losses["numerical_loss"] = torch.tensor(0.0, device=logits.device)
# Weighted sum
total_loss = torch.tensor(0.0, device=logits.device)
for name, loss in losses.items():
weight = self.loss_weights.get(name, 1.0)
total_loss = total_loss + loss * weight
losses["total_loss"] = total_loss
return losses
def test_vortex_loss():
"""Test the loss function."""
config = {"loss_weights": {
"lm_loss": 1.0,
"equation_loss": 0.3,
"domain_loss": 0.1,
"citation_loss": 0.1,
"numerical_loss": 0.2,
}}
loss_fn = VortexLoss(config)
batch_size = 2
seq_len = 128
vocab_size = 1000
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
losses = loss_fn(logits, labels)
print("Losses:")
for name, value in losses.items():
print(f" {name}: {value.item():.4f}")
assert "total_loss" in losses
print("VortexLoss test passed!")
if __name__ == "__main__":
test_vortex_loss()
|