""" Tests for Music EQ Adapter (Emotional Intelligence). """ import pytest import torch from TouchGrass.models.eq_adapter import MusicEQAdapter class TestMusicEQAdapter: """Test suite for MusicEQAdapter.""" def setup_method(self): """Set up test fixtures.""" self.d_model = 768 self.batch_size = 4 self.module = MusicEQAdapter(d_model=self.d_model) def test_module_initialization(self): """Test that module initializes correctly.""" assert isinstance(self.module.frustration_detector, torch.nn.Sequential) assert isinstance(self.module.emotion_classifier, torch.nn.Linear) assert isinstance(self.module.simplify_gate, torch.nn.Linear) assert isinstance(self.module.encouragement_embed, torch.nn.Embedding) assert isinstance(self.module.simplification_strategies, torch.nn.Embedding) def test_forward_pass(self): """Test forward pass with dummy inputs.""" seq_len = 10 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) output = self.module(hidden_states) assert "frustration" in output assert "emotion" in output assert "encouragement" in output assert "simplification" in output assert output["frustration"].shape == (self.batch_size, seq_len, 1) assert output["emotion"].shape == (self.batch_size, seq_len, 4) # 4 emotion classes assert output["encouragement"].shape[0] == self.batch_size assert output["encouragement"].shape[1] == seq_len assert output["simplification"].shape[0] == self.batch_size assert output["simplification"].shape[1] == seq_len def test_frustration_detector_output_range(self): """Test that frustration detector outputs are in [0, 1].""" seq_len = 5 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) output = self.module(hidden_states) frustration = output["frustration"] assert torch.all(frustration >= 0) assert torch.all(frustration <= 1) def test_emotion_classifier_output(self): """Test emotion classifier produces logits for 4 classes.""" seq_len = 5 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) output = self.module(hidden_states) emotion_logits = output["emotion"] assert emotion_logits.shape == (self.batch_size, seq_len, 4) def test_emotion_classes(self): """Test that emotion classes match expected emotions.""" expected_emotions = ["frustrated", "confused", "excited", "confident"] # Check that the linear layer has correct output size assert self.module.emotion_classifier.out_features == len(expected_emotions) def test_simplify_gate_transformation(self): """Test that simplify gate transforms context correctly.""" seq_len = 5 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) context = torch.randn(self.batch_size, 5) # [frustration, difficulty, ...] output = self.module(hidden_states, context) simplification = output["simplification"] # Simplified output should have same d_model assert simplification.shape[-1] == self.d_model def test_encouragement_templates(self): """Test that encouragement templates are embedded.""" # The module should have embedding for encouragement tokens assert self.module.encouragement_embed.num_embeddings > 0 assert self.module.encouragement_embed.embedding_dim > 0 def test_simplification_strategies(self): """Test that simplification strategies are embedded.""" assert self.module.simplification_strategies.num_embeddings > 0 assert self.module.simplification_strategies.embedding_dim > 0 def test_high_frustration_detection(self): """Test detection of high frustration levels.""" seq_len = 1 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) output = self.module(hidden_states) frustration = output["frustration"] # Frustration should be some value between 0 and 1 assert torch.all((frustration >= 0) & (frustration <= 1)) def test_different_batch_sizes(self): """Test forward pass with different batch sizes.""" for batch_size in [1, 2, 8]: seq_len = 10 hidden_states = torch.randn(batch_size, seq_len, self.d_model) output = self.module(hidden_states) assert output["frustration"].shape[0] == batch_size assert output["emotion"].shape[0] == batch_size def test_different_seq_lengths(self): """Test forward pass with different sequence lengths.""" for seq_len in [1, 5, 20, 50]: hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) output = self.module(hidden_states) assert output["frustration"].shape[1] == seq_len assert output["emotion"].shape[1] == seq_len def test_gradient_flow(self): """Test that gradients flow through the module.""" seq_len = 5 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True) output = self.module(hidden_states) loss = output["frustration"].sum() + output["emotion"].sum() loss.backward() assert hidden_states.grad is not None assert self.module.frustration_detector[0].weight.grad is not None def test_emotion_softmax_normalization(self): """Test that emotion outputs sum to 1 across classes (if softmax applied).""" seq_len = 1 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) output = self.module(hidden_states) emotion_probs = torch.softmax(output["emotion"], dim=-1) # Sum across emotion dimension should be close to 1 sums = emotion_probs.sum(dim=-1) assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5) def test_frustration_sigmoid_normalization(self): """Test that frustration outputs are in [0, 1] (sigmoid).""" seq_len = 1 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) output = self.module(hidden_states) frustration = output["frustration"] assert torch.all((frustration >= 0) & (frustration <= 1)) def test_simplify_gate_sigmoid(self): """Test that simplify gate uses sigmoid activation.""" seq_len = 1 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) context = torch.randn(self.batch_size, 5) output = self.module(hidden_states, context) # The simplification output should be transformed hidden states # We just verify the shape is correct assert output["simplification"].shape == hidden_states.shape def test_context_aware_simplification(self): """Test that simplification is context-aware.""" seq_len = 5 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) # Two different contexts context1 = torch.tensor([[0.9, 0.0, 0.0, 0.0, 0.0]]).expand(self.batch_size, -1) # High frustration context2 = torch.tensor([[0.1, 0.0, 0.0, 0.0, 0.0]]).expand(self.batch_size, -1) # Low frustration output1 = self.module(hidden_states, context1) output2 = self.module(hidden_states, context2) # Simplifications should differ based on frustration level # (not necessarily in all components, but the outputs should be different) simplification_diff = (output1["simplification"] - output2["simplification"]).abs().mean() # There should be some difference (we can't guarantee large difference without training) # but at least the computation should be different assert output1["simplification"].shape == output2["simplification"].shape def test_encouragement_output_range(self): """Test that encouragement outputs are valid embeddings.""" seq_len = 5 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) output = self.module(hidden_states) encouragement = output["encouragement"] # Should be some embedding vectors (we can't check exact values) assert encouragement.shape[0] == self.batch_size assert encouragement.shape[1] == seq_len assert encouragement.shape[2] > 0 def test_module_without_context(self): """Test module works without explicit context (uses default).""" seq_len = 5 hidden_states = torch.randn(self.batch_size, seq_len, self.d_model) # Should work with context=None (default) output = self.module(hidden_states) assert "frustration" in output assert "emotion" in output if __name__ == "__main__": pytest.main([__file__, "-v"])