| """
|
| Tests for Music EQ Adapter (Emotional Intelligence).
|
| """
|
|
|
| import pytest
|
| import torch
|
|
|
| from TouchGrass.models.eq_adapter import MusicEQAdapter
|
|
|
|
|
| class TestMusicEQAdapter:
|
| """Test suite for MusicEQAdapter."""
|
|
|
| def setup_method(self):
|
| """Set up test fixtures."""
|
| self.d_model = 768
|
| self.batch_size = 4
|
| self.module = MusicEQAdapter(d_model=self.d_model)
|
|
|
| def test_module_initialization(self):
|
| """Test that module initializes correctly."""
|
| assert isinstance(self.module.frustration_detector, torch.nn.Sequential)
|
| assert isinstance(self.module.emotion_classifier, torch.nn.Linear)
|
| assert isinstance(self.module.simplify_gate, torch.nn.Linear)
|
| assert isinstance(self.module.encouragement_embed, torch.nn.Embedding)
|
| assert isinstance(self.module.simplification_strategies, torch.nn.Embedding)
|
|
|
| def test_forward_pass(self):
|
| """Test forward pass with dummy inputs."""
|
| seq_len = 10
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
|
|
| output = self.module(hidden_states)
|
|
|
| assert "frustration" in output
|
| assert "emotion" in output
|
| assert "encouragement" in output
|
| assert "simplification" in output
|
| assert output["frustration"].shape == (self.batch_size, seq_len, 1)
|
| assert output["emotion"].shape == (self.batch_size, seq_len, 4)
|
| assert output["encouragement"].shape[0] == self.batch_size
|
| assert output["encouragement"].shape[1] == seq_len
|
| assert output["simplification"].shape[0] == self.batch_size
|
| assert output["simplification"].shape[1] == seq_len
|
|
|
| def test_frustration_detector_output_range(self):
|
| """Test that frustration detector outputs are in [0, 1]."""
|
| seq_len = 5
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
|
|
| output = self.module(hidden_states)
|
| frustration = output["frustration"]
|
|
|
| assert torch.all(frustration >= 0)
|
| assert torch.all(frustration <= 1)
|
|
|
| def test_emotion_classifier_output(self):
|
| """Test emotion classifier produces logits for 4 classes."""
|
| seq_len = 5
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
|
|
| output = self.module(hidden_states)
|
| emotion_logits = output["emotion"]
|
|
|
| assert emotion_logits.shape == (self.batch_size, seq_len, 4)
|
|
|
| def test_emotion_classes(self):
|
| """Test that emotion classes match expected emotions."""
|
| expected_emotions = ["frustrated", "confused", "excited", "confident"]
|
|
|
| assert self.module.emotion_classifier.out_features == len(expected_emotions)
|
|
|
| def test_simplify_gate_transformation(self):
|
| """Test that simplify gate transforms context correctly."""
|
| seq_len = 5
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| context = torch.randn(self.batch_size, 5)
|
|
|
| output = self.module(hidden_states, context)
|
| simplification = output["simplification"]
|
|
|
|
|
| assert simplification.shape[-1] == self.d_model
|
|
|
| def test_encouragement_templates(self):
|
| """Test that encouragement templates are embedded."""
|
|
|
| assert self.module.encouragement_embed.num_embeddings > 0
|
| assert self.module.encouragement_embed.embedding_dim > 0
|
|
|
| def test_simplification_strategies(self):
|
| """Test that simplification strategies are embedded."""
|
| assert self.module.simplification_strategies.num_embeddings > 0
|
| assert self.module.simplification_strategies.embedding_dim > 0
|
|
|
| def test_high_frustration_detection(self):
|
| """Test detection of high frustration levels."""
|
| seq_len = 1
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
|
|
| output = self.module(hidden_states)
|
| frustration = output["frustration"]
|
|
|
|
|
| assert torch.all((frustration >= 0) & (frustration <= 1))
|
|
|
| def test_different_batch_sizes(self):
|
| """Test forward pass with different batch sizes."""
|
| for batch_size in [1, 2, 8]:
|
| seq_len = 10
|
| hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
|
|
| output = self.module(hidden_states)
|
| assert output["frustration"].shape[0] == batch_size
|
| assert output["emotion"].shape[0] == batch_size
|
|
|
| def test_different_seq_lengths(self):
|
| """Test forward pass with different sequence lengths."""
|
| for seq_len in [1, 5, 20, 50]:
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
|
|
| output = self.module(hidden_states)
|
| assert output["frustration"].shape[1] == seq_len
|
| assert output["emotion"].shape[1] == seq_len
|
|
|
| def test_gradient_flow(self):
|
| """Test that gradients flow through the module."""
|
| seq_len = 5
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
|
|
|
| output = self.module(hidden_states)
|
| loss = output["frustration"].sum() + output["emotion"].sum()
|
| loss.backward()
|
|
|
| assert hidden_states.grad is not None
|
| assert self.module.frustration_detector[0].weight.grad is not None
|
|
|
| def test_emotion_softmax_normalization(self):
|
| """Test that emotion outputs sum to 1 across classes (if softmax applied)."""
|
| seq_len = 1
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
|
|
| output = self.module(hidden_states)
|
| emotion_probs = torch.softmax(output["emotion"], dim=-1)
|
|
|
|
|
| sums = emotion_probs.sum(dim=-1)
|
| assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5)
|
|
|
| def test_frustration_sigmoid_normalization(self):
|
| """Test that frustration outputs are in [0, 1] (sigmoid)."""
|
| seq_len = 1
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
|
|
| output = self.module(hidden_states)
|
| frustration = output["frustration"]
|
|
|
| assert torch.all((frustration >= 0) & (frustration <= 1))
|
|
|
| def test_simplify_gate_sigmoid(self):
|
| """Test that simplify gate uses sigmoid activation."""
|
| seq_len = 1
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| context = torch.randn(self.batch_size, 5)
|
|
|
| output = self.module(hidden_states, context)
|
|
|
|
|
| assert output["simplification"].shape == hidden_states.shape
|
|
|
| def test_context_aware_simplification(self):
|
| """Test that simplification is context-aware."""
|
| seq_len = 5
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
|
|
|
|
| context1 = torch.tensor([[0.9, 0.0, 0.0, 0.0, 0.0]]).expand(self.batch_size, -1)
|
| context2 = torch.tensor([[0.1, 0.0, 0.0, 0.0, 0.0]]).expand(self.batch_size, -1)
|
|
|
| output1 = self.module(hidden_states, context1)
|
| output2 = self.module(hidden_states, context2)
|
|
|
|
|
|
|
| simplification_diff = (output1["simplification"] - output2["simplification"]).abs().mean()
|
|
|
|
|
| assert output1["simplification"].shape == output2["simplification"].shape
|
|
|
| def test_encouragement_output_range(self):
|
| """Test that encouragement outputs are valid embeddings."""
|
| seq_len = 5
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
|
|
| output = self.module(hidden_states)
|
| encouragement = output["encouragement"]
|
|
|
|
|
| assert encouragement.shape[0] == self.batch_size
|
| assert encouragement.shape[1] == seq_len
|
| assert encouragement.shape[2] > 0
|
|
|
| def test_module_without_context(self):
|
| """Test module works without explicit context (uses default)."""
|
| seq_len = 5
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
|
|
|
|
| output = self.module(hidden_states)
|
|
|
| assert "frustration" in output
|
| assert "emotion" in output
|
|
|
|
|
| if __name__ == "__main__":
|
| pytest.main([__file__, "-v"])
|
|
|