TouchGrass-7b / tests /test_eq_adapter.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Tests for Music EQ Adapter (Emotional Intelligence).
"""
import pytest
import torch
from TouchGrass.models.eq_adapter import MusicEQAdapter
class TestMusicEQAdapter:
"""Test suite for MusicEQAdapter."""
def setup_method(self):
"""Set up test fixtures."""
self.d_model = 768
self.batch_size = 4
self.module = MusicEQAdapter(d_model=self.d_model)
def test_module_initialization(self):
"""Test that module initializes correctly."""
assert isinstance(self.module.frustration_detector, torch.nn.Sequential)
assert isinstance(self.module.emotion_classifier, torch.nn.Linear)
assert isinstance(self.module.simplify_gate, torch.nn.Linear)
assert isinstance(self.module.encouragement_embed, torch.nn.Embedding)
assert isinstance(self.module.simplification_strategies, torch.nn.Embedding)
def test_forward_pass(self):
"""Test forward pass with dummy inputs."""
seq_len = 10
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
output = self.module(hidden_states)
assert "frustration" in output
assert "emotion" in output
assert "encouragement" in output
assert "simplification" in output
assert output["frustration"].shape == (self.batch_size, seq_len, 1)
assert output["emotion"].shape == (self.batch_size, seq_len, 4) # 4 emotion classes
assert output["encouragement"].shape[0] == self.batch_size
assert output["encouragement"].shape[1] == seq_len
assert output["simplification"].shape[0] == self.batch_size
assert output["simplification"].shape[1] == seq_len
def test_frustration_detector_output_range(self):
"""Test that frustration detector outputs are in [0, 1]."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
output = self.module(hidden_states)
frustration = output["frustration"]
assert torch.all(frustration >= 0)
assert torch.all(frustration <= 1)
def test_emotion_classifier_output(self):
"""Test emotion classifier produces logits for 4 classes."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
output = self.module(hidden_states)
emotion_logits = output["emotion"]
assert emotion_logits.shape == (self.batch_size, seq_len, 4)
def test_emotion_classes(self):
"""Test that emotion classes match expected emotions."""
expected_emotions = ["frustrated", "confused", "excited", "confident"]
# Check that the linear layer has correct output size
assert self.module.emotion_classifier.out_features == len(expected_emotions)
def test_simplify_gate_transformation(self):
"""Test that simplify gate transforms context correctly."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
context = torch.randn(self.batch_size, 5) # [frustration, difficulty, ...]
output = self.module(hidden_states, context)
simplification = output["simplification"]
# Simplified output should have same d_model
assert simplification.shape[-1] == self.d_model
def test_encouragement_templates(self):
"""Test that encouragement templates are embedded."""
# The module should have embedding for encouragement tokens
assert self.module.encouragement_embed.num_embeddings > 0
assert self.module.encouragement_embed.embedding_dim > 0
def test_simplification_strategies(self):
"""Test that simplification strategies are embedded."""
assert self.module.simplification_strategies.num_embeddings > 0
assert self.module.simplification_strategies.embedding_dim > 0
def test_high_frustration_detection(self):
"""Test detection of high frustration levels."""
seq_len = 1
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
output = self.module(hidden_states)
frustration = output["frustration"]
# Frustration should be some value between 0 and 1
assert torch.all((frustration >= 0) & (frustration <= 1))
def test_different_batch_sizes(self):
"""Test forward pass with different batch sizes."""
for batch_size in [1, 2, 8]:
seq_len = 10
hidden_states = torch.randn(batch_size, seq_len, self.d_model)
output = self.module(hidden_states)
assert output["frustration"].shape[0] == batch_size
assert output["emotion"].shape[0] == batch_size
def test_different_seq_lengths(self):
"""Test forward pass with different sequence lengths."""
for seq_len in [1, 5, 20, 50]:
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
output = self.module(hidden_states)
assert output["frustration"].shape[1] == seq_len
assert output["emotion"].shape[1] == seq_len
def test_gradient_flow(self):
"""Test that gradients flow through the module."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
output = self.module(hidden_states)
loss = output["frustration"].sum() + output["emotion"].sum()
loss.backward()
assert hidden_states.grad is not None
assert self.module.frustration_detector[0].weight.grad is not None
def test_emotion_softmax_normalization(self):
"""Test that emotion outputs sum to 1 across classes (if softmax applied)."""
seq_len = 1
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
output = self.module(hidden_states)
emotion_probs = torch.softmax(output["emotion"], dim=-1)
# Sum across emotion dimension should be close to 1
sums = emotion_probs.sum(dim=-1)
assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5)
def test_frustration_sigmoid_normalization(self):
"""Test that frustration outputs are in [0, 1] (sigmoid)."""
seq_len = 1
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
output = self.module(hidden_states)
frustration = output["frustration"]
assert torch.all((frustration >= 0) & (frustration <= 1))
def test_simplify_gate_sigmoid(self):
"""Test that simplify gate uses sigmoid activation."""
seq_len = 1
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
context = torch.randn(self.batch_size, 5)
output = self.module(hidden_states, context)
# The simplification output should be transformed hidden states
# We just verify the shape is correct
assert output["simplification"].shape == hidden_states.shape
def test_context_aware_simplification(self):
"""Test that simplification is context-aware."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
# Two different contexts
context1 = torch.tensor([[0.9, 0.0, 0.0, 0.0, 0.0]]).expand(self.batch_size, -1) # High frustration
context2 = torch.tensor([[0.1, 0.0, 0.0, 0.0, 0.0]]).expand(self.batch_size, -1) # Low frustration
output1 = self.module(hidden_states, context1)
output2 = self.module(hidden_states, context2)
# Simplifications should differ based on frustration level
# (not necessarily in all components, but the outputs should be different)
simplification_diff = (output1["simplification"] - output2["simplification"]).abs().mean()
# There should be some difference (we can't guarantee large difference without training)
# but at least the computation should be different
assert output1["simplification"].shape == output2["simplification"].shape
def test_encouragement_output_range(self):
"""Test that encouragement outputs are valid embeddings."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
output = self.module(hidden_states)
encouragement = output["encouragement"]
# Should be some embedding vectors (we can't check exact values)
assert encouragement.shape[0] == self.batch_size
assert encouragement.shape[1] == seq_len
assert encouragement.shape[2] > 0
def test_module_without_context(self):
"""Test module works without explicit context (uses default)."""
seq_len = 5
hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
# Should work with context=None (default)
output = self.module(hidden_states)
assert "frustration" in output
assert "emotion" in output
if __name__ == "__main__":
pytest.main([__file__, "-v"])