File size: 4,033 Bytes
9635a89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Tests for the core FmriEncoder model."""

from unittest.mock import MagicMock

import pytest
import torch


def _make_model(hidden=256, n_outputs=100, n_timesteps=10, modalities=None):
    """Build a small FmriEncoderModel for testing."""
    from neuraltrain.models.transformer import TransformerEncoder

    from cortexlab.core.model import FmriEncoder

    if modalities is None:
        modalities = {"text": (2, 32), "audio": (2, 32), "video": (2, 32)}

    config = FmriEncoder(
        hidden=hidden,
        max_seq_len=128,
        dropout=0.0,
        modality_dropout=0.0,
        temporal_dropout=0.0,
        linear_baseline=False,
        encoder=TransformerEncoder(depth=2, heads=4),
    )
    model = config.build(
        feature_dims=modalities,
        n_outputs=n_outputs,
        n_output_timesteps=n_timesteps,
    )
    return model


def _make_segments(n):
    """Create dummy segments for SegmentData."""
    import neuralset.segments as seg
    return [seg.Segment(start=float(i), duration=1.0, timeline="test") for i in range(n)]


def _make_batch(modalities, batch_size=2, seq_len=20):
    """Create a synthetic SegmentData-like batch."""
    from neuralset.dataloader import SegmentData

    data = {}
    for name, (n_layers, feat_dim) in modalities.items():
        data[name] = torch.randn(batch_size, n_layers, feat_dim, seq_len)
    data["subject_id"] = torch.zeros(batch_size, dtype=torch.long)
    return SegmentData(data=data, segments=_make_segments(batch_size))


class TestFmriEncoderModel:
    def test_forward_shape(self):
        modalities = {"text": (2, 32), "audio": (2, 32)}
        model = _make_model(modalities=modalities)
        batch = _make_batch(modalities)
        out = model(batch)
        assert out.shape == (2, 100, 10), f"Expected (2, 100, 10), got {out.shape}"

    def test_forward_no_pool(self):
        modalities = {"text": (2, 32)}
        model = _make_model(modalities=modalities)
        batch = _make_batch(modalities)
        out = model(batch, pool_outputs=False)
        assert out.shape[0] == 2
        assert out.shape[1] == 100

    def test_return_attn(self):
        modalities = {"text": (2, 32)}
        model = _make_model(modalities=modalities)
        batch = _make_batch(modalities)
        result = model(batch, return_attn=True)
        assert isinstance(result, tuple)
        out, attn_maps = result
        assert out.shape == (2, 100, 10)
        # attn_maps may be empty if the transformer doesn't expose weights
        assert isinstance(attn_maps, list)

    def test_missing_modality_zeros(self):
        modalities = {"text": (2, 32), "audio": (2, 32)}
        model = _make_model(modalities=modalities)
        # Only provide text, not audio
        from neuralset.dataloader import SegmentData
        data = {"text": torch.randn(2, 2, 32, 20), "subject_id": torch.zeros(2, dtype=torch.long)}
        batch = SegmentData(data=data, segments=_make_segments(2))
        out = model(batch)
        assert out.shape == (2, 100, 10)

    def test_modality_dropout_training(self):
        modalities = {"text": (2, 32), "audio": (2, 32)}
        from neuraltrain.models.transformer import TransformerEncoder

        from cortexlab.core.model import FmriEncoder
        config = FmriEncoder(
            hidden=256, max_seq_len=128, modality_dropout=0.5,
            encoder=TransformerEncoder(depth=2, heads=4),
        )
        model = config.build(feature_dims=modalities, n_outputs=100, n_output_timesteps=10)
        model.train()
        batch = _make_batch(modalities)
        out = model(batch)
        assert out.shape == (2, 100, 10)

    def test_linear_baseline(self):
        modalities = {"text": (2, 32)}
        from cortexlab.core.model import FmriEncoder
        config = FmriEncoder(hidden=256, linear_baseline=True)
        model = config.build(feature_dims=modalities, n_outputs=100, n_output_timesteps=10)
        batch = _make_batch(modalities)
        out = model(batch)
        assert out.shape == (2, 100, 10)