| | """
|
| | Tests for Music EQ Adapter (Emotional Intelligence).
|
| | """
|
| |
|
| | import pytest
|
| | import torch
|
| |
|
| | from TouchGrass.models.eq_adapter import MusicEQAdapter
|
| |
|
| |
|
| | class TestMusicEQAdapter:
|
| | """Test suite for MusicEQAdapter."""
|
| |
|
| | def setup_method(self):
|
| | """Set up test fixtures."""
|
| | self.d_model = 768
|
| | self.batch_size = 4
|
| | self.module = MusicEQAdapter(d_model=self.d_model)
|
| |
|
| | def test_module_initialization(self):
|
| | """Test that module initializes correctly."""
|
| | assert isinstance(self.module.frustration_detector, torch.nn.Sequential)
|
| | assert isinstance(self.module.emotion_classifier, torch.nn.Linear)
|
| | assert isinstance(self.module.simplify_gate, torch.nn.Linear)
|
| | assert isinstance(self.module.encouragement_embed, torch.nn.Embedding)
|
| | assert isinstance(self.module.simplification_strategies, torch.nn.Embedding)
|
| |
|
| | def test_forward_pass(self):
|
| | """Test forward pass with dummy inputs."""
|
| | seq_len = 10
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| |
|
| | output = self.module(hidden_states)
|
| |
|
| | assert "frustration" in output
|
| | assert "emotion" in output
|
| | assert "encouragement" in output
|
| | assert "simplification" in output
|
| | assert output["frustration"].shape == (self.batch_size, seq_len, 1)
|
| | assert output["emotion"].shape == (self.batch_size, seq_len, 4)
|
| | assert output["encouragement"].shape[0] == self.batch_size
|
| | assert output["encouragement"].shape[1] == seq_len
|
| | assert output["simplification"].shape[0] == self.batch_size
|
| | assert output["simplification"].shape[1] == seq_len
|
| |
|
| | def test_frustration_detector_output_range(self):
|
| | """Test that frustration detector outputs are in [0, 1]."""
|
| | seq_len = 5
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| |
|
| | output = self.module(hidden_states)
|
| | frustration = output["frustration"]
|
| |
|
| | assert torch.all(frustration >= 0)
|
| | assert torch.all(frustration <= 1)
|
| |
|
| | def test_emotion_classifier_output(self):
|
| | """Test emotion classifier produces logits for 4 classes."""
|
| | seq_len = 5
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| |
|
| | output = self.module(hidden_states)
|
| | emotion_logits = output["emotion"]
|
| |
|
| | assert emotion_logits.shape == (self.batch_size, seq_len, 4)
|
| |
|
| | def test_emotion_classes(self):
|
| | """Test that emotion classes match expected emotions."""
|
| | expected_emotions = ["frustrated", "confused", "excited", "confident"]
|
| |
|
| | assert self.module.emotion_classifier.out_features == len(expected_emotions)
|
| |
|
| | def test_simplify_gate_transformation(self):
|
| | """Test that simplify gate transforms context correctly."""
|
| | seq_len = 5
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| | context = torch.randn(self.batch_size, 5)
|
| |
|
| | output = self.module(hidden_states, context)
|
| | simplification = output["simplification"]
|
| |
|
| |
|
| | assert simplification.shape[-1] == self.d_model
|
| |
|
| | def test_encouragement_templates(self):
|
| | """Test that encouragement templates are embedded."""
|
| |
|
| | assert self.module.encouragement_embed.num_embeddings > 0
|
| | assert self.module.encouragement_embed.embedding_dim > 0
|
| |
|
| | def test_simplification_strategies(self):
|
| | """Test that simplification strategies are embedded."""
|
| | assert self.module.simplification_strategies.num_embeddings > 0
|
| | assert self.module.simplification_strategies.embedding_dim > 0
|
| |
|
| | def test_high_frustration_detection(self):
|
| | """Test detection of high frustration levels."""
|
| | seq_len = 1
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| |
|
| | output = self.module(hidden_states)
|
| | frustration = output["frustration"]
|
| |
|
| |
|
| | assert torch.all((frustration >= 0) & (frustration <= 1))
|
| |
|
| | def test_different_batch_sizes(self):
|
| | """Test forward pass with different batch sizes."""
|
| | for batch_size in [1, 2, 8]:
|
| | seq_len = 10
|
| | hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
| |
|
| | output = self.module(hidden_states)
|
| | assert output["frustration"].shape[0] == batch_size
|
| | assert output["emotion"].shape[0] == batch_size
|
| |
|
| | def test_different_seq_lengths(self):
|
| | """Test forward pass with different sequence lengths."""
|
| | for seq_len in [1, 5, 20, 50]:
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| |
|
| | output = self.module(hidden_states)
|
| | assert output["frustration"].shape[1] == seq_len
|
| | assert output["emotion"].shape[1] == seq_len
|
| |
|
| | def test_gradient_flow(self):
|
| | """Test that gradients flow through the module."""
|
| | seq_len = 5
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
|
| |
|
| | output = self.module(hidden_states)
|
| | loss = output["frustration"].sum() + output["emotion"].sum()
|
| | loss.backward()
|
| |
|
| | assert hidden_states.grad is not None
|
| | assert self.module.frustration_detector[0].weight.grad is not None
|
| |
|
| | def test_emotion_softmax_normalization(self):
|
| | """Test that emotion outputs sum to 1 across classes (if softmax applied)."""
|
| | seq_len = 1
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| |
|
| | output = self.module(hidden_states)
|
| | emotion_probs = torch.softmax(output["emotion"], dim=-1)
|
| |
|
| |
|
| | sums = emotion_probs.sum(dim=-1)
|
| | assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5)
|
| |
|
| | def test_frustration_sigmoid_normalization(self):
|
| | """Test that frustration outputs are in [0, 1] (sigmoid)."""
|
| | seq_len = 1
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| |
|
| | output = self.module(hidden_states)
|
| | frustration = output["frustration"]
|
| |
|
| | assert torch.all((frustration >= 0) & (frustration <= 1))
|
| |
|
| | def test_simplify_gate_sigmoid(self):
|
| | """Test that simplify gate uses sigmoid activation."""
|
| | seq_len = 1
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| | context = torch.randn(self.batch_size, 5)
|
| |
|
| | output = self.module(hidden_states, context)
|
| |
|
| |
|
| | assert output["simplification"].shape == hidden_states.shape
|
| |
|
| | def test_context_aware_simplification(self):
|
| | """Test that simplification is context-aware."""
|
| | seq_len = 5
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| |
|
| |
|
| | context1 = torch.tensor([[0.9, 0.0, 0.0, 0.0, 0.0]]).expand(self.batch_size, -1)
|
| | context2 = torch.tensor([[0.1, 0.0, 0.0, 0.0, 0.0]]).expand(self.batch_size, -1)
|
| |
|
| | output1 = self.module(hidden_states, context1)
|
| | output2 = self.module(hidden_states, context2)
|
| |
|
| |
|
| |
|
| | simplification_diff = (output1["simplification"] - output2["simplification"]).abs().mean()
|
| |
|
| |
|
| | assert output1["simplification"].shape == output2["simplification"].shape
|
| |
|
| | def test_encouragement_output_range(self):
|
| | """Test that encouragement outputs are valid embeddings."""
|
| | seq_len = 5
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| |
|
| | output = self.module(hidden_states)
|
| | encouragement = output["encouragement"]
|
| |
|
| |
|
| | assert encouragement.shape[0] == self.batch_size
|
| | assert encouragement.shape[1] == seq_len
|
| | assert encouragement.shape[2] > 0
|
| |
|
| | def test_module_without_context(self):
|
| | """Test module works without explicit context (uses default)."""
|
| | seq_len = 5
|
| | hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| |
|
| |
|
| | output = self.module(hidden_states)
|
| |
|
| | assert "frustration" in output
|
| | assert "emotion" in output
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | pytest.main([__file__, "-v"])
|
| |
|