| """
|
| Loss functions for TouchGrass fine-tuning.
|
| Includes standard LM loss and music-specific auxiliary losses.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Dict, Optional, Tuple
|
|
|
|
|
| class TouchGrassLoss(nn.Module):
|
| """
|
| Combined loss for TouchGrass fine-tuning.
|
|
|
| Components:
|
| - LM loss (standard cross-entropy)
|
| - EQ loss (frustration detection auxiliary)
|
| - Music module losses (tab validation, theory accuracy, etc.)
|
| """
|
|
|
| def __init__(self, config: Dict):
|
| """
|
| Initialize loss.
|
|
|
| Args:
|
| config: Training config with loss_weights
|
| """
|
| super().__init__()
|
| self.loss_weights = config.get("loss_weights", {
|
| "lm_loss": 1.0,
|
| "eq_loss": 0.1,
|
| "music_module_loss": 0.05,
|
| })
|
|
|
| def forward(
|
| self,
|
| logits: torch.Tensor,
|
| labels: torch.Tensor,
|
| eq_outputs: Optional[Dict[str, torch.Tensor]] = None,
|
| eq_labels: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| music_module_outputs: Optional[Dict[str, torch.Tensor]] = None,
|
| music_labels: Optional[Dict[str, torch.Tensor]] = None,
|
| ) -> Dict[str, torch.Tensor]:
|
| """
|
| Compute total loss.
|
|
|
| Args:
|
| logits: Model logits [batch, seq_len, vocab_size]
|
| labels: Target labels [batch, seq_len]
|
| eq_outputs: EQ adapter outputs (frustration_score, emotion_logits, etc.)
|
| eq_labels: (emotion_labels, frustration_labels)
|
| music_module_outputs: Outputs from music modules
|
| music_labels: Ground truth for music tasks
|
|
|
| Returns:
|
| Dictionary with total_loss and component losses
|
| """
|
| losses = {}
|
|
|
|
|
| shift_logits = logits[..., :-1, :].contiguous()
|
| shift_labels = labels[..., 1:].contiguous()
|
| lm_loss = F.cross_entropy(
|
| shift_logits.view(-1, shift_logits.size(-1)),
|
| shift_labels.view(-1),
|
| ignore_index=-100,
|
| )
|
| losses["lm_loss"] = lm_loss
|
|
|
|
|
| if eq_outputs is not None and eq_labels is not None:
|
| emotion_labels, frustration_labels = eq_labels
|
| eq_loss = self._compute_eq_loss(eq_outputs, emotion_labels, frustration_labels)
|
| losses["eq_loss"] = eq_loss
|
| else:
|
| eq_loss = 0.0
|
| losses["eq_loss"] = torch.tensor(0.0, device=logits.device)
|
|
|
|
|
| if music_module_outputs is not None and music_labels is not None:
|
| music_loss = self._compute_music_module_loss(music_module_outputs, music_labels)
|
| losses["music_module_loss"] = music_loss
|
| else:
|
| music_loss = 0.0
|
| losses["music_module_loss"] = torch.tensor(0.0, device=logits.device)
|
|
|
|
|
| total_loss = (
|
| self.loss_weights["lm_loss"] * lm_loss +
|
| self.loss_weights["eq_loss"] * eq_loss +
|
| self.loss_weights["music_module_loss"] * music_loss
|
| )
|
| losses["total_loss"] = total_loss
|
|
|
| return losses
|
|
|
| def _compute_eq_loss(
|
| self,
|
| eq_outputs: Dict[str, torch.Tensor],
|
| emotion_labels: torch.Tensor,
|
| frustration_labels: torch.Tensor,
|
| ) -> torch.Tensor:
|
| """
|
| Compute EQ auxiliary loss.
|
|
|
| Args:
|
| eq_outputs: Dictionary with emotion_logits, frustration_score
|
| emotion_labels: Ground truth emotion classes [batch]
|
| frustration_labels: Ground truth frustration (0/1) [batch]
|
|
|
| Returns:
|
| EQ loss
|
| """
|
|
|
| emotion_logits = eq_outputs["emotion_logits"]
|
| emotion_loss = F.cross_entropy(emotion_logits, emotion_labels)
|
|
|
|
|
| frustration_score = eq_outputs["frustration_score"].squeeze()
|
| frustration_loss = F.binary_cross_entropy(frustration_score, frustration_labels.float())
|
|
|
| return emotion_loss + frustration_loss
|
|
|
| def _compute_music_module_loss(
|
| self,
|
| music_outputs: Dict[str, torch.Tensor],
|
| music_labels: Dict[str, torch.Tensor],
|
| ) -> torch.Tensor:
|
| """
|
| Compute music module auxiliary losses.
|
|
|
| Args:
|
| music_outputs: Dictionary with outputs from various music modules
|
| music_labels: Ground truth labels for music tasks
|
|
|
| Returns:
|
| Music module loss
|
| """
|
| total_loss = 0.0
|
| count = 0
|
|
|
|
|
| if "tab_validity" in music_outputs and "tab_valid" in music_labels:
|
| tab_loss = F.binary_cross_entropy(
|
| music_outputs["tab_validity"].squeeze(),
|
| music_labels["tab_valid"].float(),
|
| )
|
| total_loss += tab_loss
|
| count += 1
|
|
|
|
|
| if "difficulty_logits" in music_outputs and "difficulty" in music_labels:
|
| diff_loss = F.cross_entropy(
|
| music_outputs["difficulty_logits"],
|
| music_labels["difficulty"],
|
| )
|
| total_loss += diff_loss
|
| count += 1
|
|
|
|
|
| if "chord_quality_logits" in music_outputs and "chord_quality" in music_labels:
|
| chord_loss = F.cross_entropy(
|
| music_outputs["chord_quality_logits"],
|
| music_labels["chord_quality"],
|
| )
|
| total_loss += chord_loss
|
| count += 1
|
|
|
|
|
| if "scale_degree_logits" in music_outputs and "scale_degree" in music_labels:
|
| scale_loss = F.cross_entropy(
|
| music_outputs["scale_degree_logits"],
|
| music_labels["scale_degree"],
|
| )
|
| total_loss += scale_loss
|
| count += 1
|
|
|
| if count > 0:
|
| total_loss = total_loss / count
|
|
|
| return total_loss
|
|
|
|
|
| def compute_lora_gradient_norm(model: nn.Module) -> float:
|
| """
|
| Compute L2 norm of gradients for LoRA parameters.
|
| Useful for monitoring training stability.
|
| """
|
| total_norm = 0.0
|
| for p in model.parameters():
|
| if p.requires_grad and p.grad is not None:
|
| param_norm = p.grad.detach().data.norm(2)
|
| total_norm += param_norm.item() ** 2
|
| return total_norm ** 0.5
|
|
|
|
|
| def get_parameter_groups(model: nn.Module, weight_decay: float = 0.1) -> List[Dict]:
|
| """
|
| Get parameter groups for optimizer (LoRA-specific).
|
| Apply weight decay only to LoRA weights, not biases/LayerNorm.
|
| """
|
|
|
| no_decay = ["bias", "layer_norm", "layernorm", "ln"]
|
| decay_params = []
|
| no_decay_params = []
|
|
|
| for name, param in model.named_parameters():
|
| if not param.requires_grad:
|
| continue
|
|
|
| if any(nd in name.lower() for nd in no_decay):
|
| no_decay_params.append(param)
|
| else:
|
| decay_params.append(param)
|
|
|
| return [
|
| {"params": decay_params, "weight_decay": weight_decay},
|
| {"params": no_decay_params, "weight_decay": 0.0},
|
| ]
|
|
|
|
|
| def test_losses():
|
| """Test the loss functions."""
|
| import torch
|
|
|
|
|
| config = {
|
| "loss_weights": {
|
| "lm_loss": 1.0,
|
| "eq_loss": 0.1,
|
| "music_module_loss": 0.05,
|
| }
|
| }
|
| loss_fn = TouchGrassLoss(config)
|
|
|
|
|
| batch_size = 2
|
| seq_len = 10
|
| vocab_size = 32000
|
|
|
| logits = torch.randn(batch_size, seq_len - 1, vocab_size)
|
| labels = torch.randint(0, vocab_size, (batch_size, seq_len))
|
|
|
|
|
| eq_outputs = {
|
| "emotion_logits": torch.randn(batch_size, 4),
|
| "frustration_score": torch.rand(batch_size, 1),
|
| }
|
| emotion_labels = torch.randint(0, 4, (batch_size,))
|
| frustration_labels = torch.randint(0, 2, (batch_size,))
|
|
|
|
|
| losses = loss_fn.forward(
|
| logits=logits,
|
| labels=labels,
|
| eq_outputs=eq_outputs,
|
| eq_labels=(emotion_labels, frustration_labels),
|
| )
|
|
|
| print("Loss components:")
|
| for key, value in losses.items():
|
| print(f" {key}: {value.item():.4f}")
|
|
|
|
|
| model = torch.nn.Linear(10, 10)
|
| model.weight.grad = torch.randn_like(model.weight)
|
| grad_norm = compute_lora_gradient_norm(model)
|
| print(f"\nGradient norm: {grad_norm:.4f}")
|
|
|
| print("\nLoss functions test complete!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_losses() |