File size: 4,033 Bytes
9635a89 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | """Tests for the core FmriEncoder model."""
from unittest.mock import MagicMock
import pytest
import torch
def _make_model(hidden=256, n_outputs=100, n_timesteps=10, modalities=None):
"""Build a small FmriEncoderModel for testing."""
from neuraltrain.models.transformer import TransformerEncoder
from cortexlab.core.model import FmriEncoder
if modalities is None:
modalities = {"text": (2, 32), "audio": (2, 32), "video": (2, 32)}
config = FmriEncoder(
hidden=hidden,
max_seq_len=128,
dropout=0.0,
modality_dropout=0.0,
temporal_dropout=0.0,
linear_baseline=False,
encoder=TransformerEncoder(depth=2, heads=4),
)
model = config.build(
feature_dims=modalities,
n_outputs=n_outputs,
n_output_timesteps=n_timesteps,
)
return model
def _make_segments(n):
"""Create dummy segments for SegmentData."""
import neuralset.segments as seg
return [seg.Segment(start=float(i), duration=1.0, timeline="test") for i in range(n)]
def _make_batch(modalities, batch_size=2, seq_len=20):
"""Create a synthetic SegmentData-like batch."""
from neuralset.dataloader import SegmentData
data = {}
for name, (n_layers, feat_dim) in modalities.items():
data[name] = torch.randn(batch_size, n_layers, feat_dim, seq_len)
data["subject_id"] = torch.zeros(batch_size, dtype=torch.long)
return SegmentData(data=data, segments=_make_segments(batch_size))
class TestFmriEncoderModel:
def test_forward_shape(self):
modalities = {"text": (2, 32), "audio": (2, 32)}
model = _make_model(modalities=modalities)
batch = _make_batch(modalities)
out = model(batch)
assert out.shape == (2, 100, 10), f"Expected (2, 100, 10), got {out.shape}"
def test_forward_no_pool(self):
modalities = {"text": (2, 32)}
model = _make_model(modalities=modalities)
batch = _make_batch(modalities)
out = model(batch, pool_outputs=False)
assert out.shape[0] == 2
assert out.shape[1] == 100
def test_return_attn(self):
modalities = {"text": (2, 32)}
model = _make_model(modalities=modalities)
batch = _make_batch(modalities)
result = model(batch, return_attn=True)
assert isinstance(result, tuple)
out, attn_maps = result
assert out.shape == (2, 100, 10)
# attn_maps may be empty if the transformer doesn't expose weights
assert isinstance(attn_maps, list)
def test_missing_modality_zeros(self):
modalities = {"text": (2, 32), "audio": (2, 32)}
model = _make_model(modalities=modalities)
# Only provide text, not audio
from neuralset.dataloader import SegmentData
data = {"text": torch.randn(2, 2, 32, 20), "subject_id": torch.zeros(2, dtype=torch.long)}
batch = SegmentData(data=data, segments=_make_segments(2))
out = model(batch)
assert out.shape == (2, 100, 10)
def test_modality_dropout_training(self):
modalities = {"text": (2, 32), "audio": (2, 32)}
from neuraltrain.models.transformer import TransformerEncoder
from cortexlab.core.model import FmriEncoder
config = FmriEncoder(
hidden=256, max_seq_len=128, modality_dropout=0.5,
encoder=TransformerEncoder(depth=2, heads=4),
)
model = config.build(feature_dims=modalities, n_outputs=100, n_output_timesteps=10)
model.train()
batch = _make_batch(modalities)
out = model(batch)
assert out.shape == (2, 100, 10)
def test_linear_baseline(self):
modalities = {"text": (2, 32)}
from cortexlab.core.model import FmriEncoder
config = FmriEncoder(hidden=256, linear_baseline=True)
model = config.build(feature_dims=modalities, n_outputs=100, n_output_timesteps=10)
batch = _make_batch(modalities)
out = model(batch)
assert out.shape == (2, 100, 10)
|