TouchGrass-7b / training /losses.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Loss functions for TouchGrass fine-tuning.
Includes standard LM loss and music-specific auxiliary losses.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple
class TouchGrassLoss(nn.Module):
"""
Combined loss for TouchGrass fine-tuning.
Components:
- LM loss (standard cross-entropy)
- EQ loss (frustration detection auxiliary)
- Music module losses (tab validation, theory accuracy, etc.)
"""
def __init__(self, config: Dict):
"""
Initialize loss.
Args:
config: Training config with loss_weights
"""
super().__init__()
self.loss_weights = config.get("loss_weights", {
"lm_loss": 1.0,
"eq_loss": 0.1,
"music_module_loss": 0.05,
})
def forward(
self,
logits: torch.Tensor,
labels: torch.Tensor,
eq_outputs: Optional[Dict[str, torch.Tensor]] = None,
eq_labels: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
music_module_outputs: Optional[Dict[str, torch.Tensor]] = None,
music_labels: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""
Compute total loss.
Args:
logits: Model logits [batch, seq_len, vocab_size]
labels: Target labels [batch, seq_len]
eq_outputs: EQ adapter outputs (frustration_score, emotion_logits, etc.)
eq_labels: (emotion_labels, frustration_labels)
music_module_outputs: Outputs from music modules
music_labels: Ground truth for music tasks
Returns:
Dictionary with total_loss and component losses
"""
losses = {}
# 1. Language modeling loss (always computed)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
lm_loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
losses["lm_loss"] = lm_loss
# 2. EQ loss (if available)
if eq_outputs is not None and eq_labels is not None:
emotion_labels, frustration_labels = eq_labels
eq_loss = self._compute_eq_loss(eq_outputs, emotion_labels, frustration_labels)
losses["eq_loss"] = eq_loss
else:
eq_loss = 0.0
losses["eq_loss"] = torch.tensor(0.0, device=logits.device)
# 3. Music module losses (if available)
if music_module_outputs is not None and music_labels is not None:
music_loss = self._compute_music_module_loss(music_module_outputs, music_labels)
losses["music_module_loss"] = music_loss
else:
music_loss = 0.0
losses["music_module_loss"] = torch.tensor(0.0, device=logits.device)
# Total loss
total_loss = (
self.loss_weights["lm_loss"] * lm_loss +
self.loss_weights["eq_loss"] * eq_loss +
self.loss_weights["music_module_loss"] * music_loss
)
losses["total_loss"] = total_loss
return losses
def _compute_eq_loss(
self,
eq_outputs: Dict[str, torch.Tensor],
emotion_labels: torch.Tensor,
frustration_labels: torch.Tensor,
) -> torch.Tensor:
"""
Compute EQ auxiliary loss.
Args:
eq_outputs: Dictionary with emotion_logits, frustration_score
emotion_labels: Ground truth emotion classes [batch]
frustration_labels: Ground truth frustration (0/1) [batch]
Returns:
EQ loss
"""
# Emotion classification loss
emotion_logits = eq_outputs["emotion_logits"]
emotion_loss = F.cross_entropy(emotion_logits, emotion_labels)
# Frustration detection loss (binary)
frustration_score = eq_outputs["frustration_score"].squeeze()
frustration_loss = F.binary_cross_entropy(frustration_score, frustration_labels.float())
return emotion_loss + frustration_loss
def _compute_music_module_loss(
self,
music_outputs: Dict[str, torch.Tensor],
music_labels: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""
Compute music module auxiliary losses.
Args:
music_outputs: Dictionary with outputs from various music modules
music_labels: Ground truth labels for music tasks
Returns:
Music module loss
"""
total_loss = 0.0
count = 0
# Tab validation loss (if present)
if "tab_validity" in music_outputs and "tab_valid" in music_labels:
tab_loss = F.binary_cross_entropy(
music_outputs["tab_validity"].squeeze(),
music_labels["tab_valid"].float(),
)
total_loss += tab_loss
count += 1
# Difficulty classification loss
if "difficulty_logits" in music_outputs and "difficulty" in music_labels:
diff_loss = F.cross_entropy(
music_outputs["difficulty_logits"],
music_labels["difficulty"],
)
total_loss += diff_loss
count += 1
# Chord quality prediction
if "chord_quality_logits" in music_outputs and "chord_quality" in music_labels:
chord_loss = F.cross_entropy(
music_outputs["chord_quality_logits"],
music_labels["chord_quality"],
)
total_loss += chord_loss
count += 1
# Scale degree prediction
if "scale_degree_logits" in music_outputs and "scale_degree" in music_labels:
scale_loss = F.cross_entropy(
music_outputs["scale_degree_logits"],
music_labels["scale_degree"],
)
total_loss += scale_loss
count += 1
if count > 0:
total_loss = total_loss / count
return total_loss
def compute_lora_gradient_norm(model: nn.Module) -> float:
"""
Compute L2 norm of gradients for LoRA parameters.
Useful for monitoring training stability.
"""
total_norm = 0.0
for p in model.parameters():
if p.requires_grad and p.grad is not None:
param_norm = p.grad.detach().data.norm(2)
total_norm += param_norm.item() ** 2
return total_norm ** 0.5
def get_parameter_groups(model: nn.Module, weight_decay: float = 0.1) -> List[Dict]:
"""
Get parameter groups for optimizer (LoRA-specific).
Apply weight decay only to LoRA weights, not biases/LayerNorm.
"""
# Separate parameters
no_decay = ["bias", "layer_norm", "layernorm", "ln"]
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if any(nd in name.lower() for nd in no_decay):
no_decay_params.append(param)
else:
decay_params.append(param)
return [
{"params": decay_params, "weight_decay": weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
def test_losses():
"""Test the loss functions."""
import torch
# Create loss
config = {
"loss_weights": {
"lm_loss": 1.0,
"eq_loss": 0.1,
"music_module_loss": 0.05,
}
}
loss_fn = TouchGrassLoss(config)
# Dummy inputs
batch_size = 2
seq_len = 10
vocab_size = 32000
logits = torch.randn(batch_size, seq_len - 1, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
# EQ outputs
eq_outputs = {
"emotion_logits": torch.randn(batch_size, 4),
"frustration_score": torch.rand(batch_size, 1),
}
emotion_labels = torch.randint(0, 4, (batch_size,))
frustration_labels = torch.randint(0, 2, (batch_size,))
# Compute loss
losses = loss_fn.forward(
logits=logits,
labels=labels,
eq_outputs=eq_outputs,
eq_labels=(emotion_labels, frustration_labels),
)
print("Loss components:")
for key, value in losses.items():
print(f" {key}: {value.item():.4f}")
# Test gradient norm
model = torch.nn.Linear(10, 10)
model.weight.grad = torch.randn_like(model.weight)
grad_norm = compute_lora_gradient_norm(model)
print(f"\nGradient norm: {grad_norm:.4f}")
print("\nLoss functions test complete!")
if __name__ == "__main__":
test_losses()