File size: 2,394 Bytes
9635a89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Tests for cross-subject adaptation."""

from unittest.mock import MagicMock

import numpy as np
import pytest
import torch


def _make_mock_model(n_subjects=3, hidden=32, n_vertices=50):
    """Create a mock model with predictor weights."""
    model = MagicMock()
    model.config = MagicMock()
    model.config.linear_baseline = False
    model.config.low_rank_head = None
    model.config.hidden = hidden

    model.aggregate_features = MagicMock(
        return_value=torch.randn(2, 10, hidden)
    )
    model.transformer_forward = MagicMock(
        return_value=torch.randn(2, 10, hidden)
    )
    model.pooler = MagicMock(
        return_value=torch.randn(2, hidden, 10)
    )
    model.eval = MagicMock()

    predictor = MagicMock()
    predictor.weights = torch.nn.Parameter(
        torch.randn(n_subjects, hidden, n_vertices)
    )
    predictor.bias = None
    model.predictor = predictor

    return model


def _make_calibration_loader(n_batches=2, batch_size=2, n_vertices=50):
    """Create a mock calibration data loader."""
    from neuralset.dataloader import SegmentData

    batches = []
    for _ in range(n_batches):
        data = {
            "text": torch.randn(batch_size, 2, 32, 10),
            "fmri": torch.randn(batch_size, n_vertices, 10),
            "subject_id": torch.zeros(batch_size, dtype=torch.long),
        }
        import neuralset.segments as seg
        segments = [seg.Segment(start=float(i), duration=1.0, timeline="test") for i in range(batch_size)]
        batches.append(SegmentData(data=data, segments=segments))
    return batches


class TestSubjectAdapter:
    def test_nearest_neighbor(self):
        from cortexlab.core.subject import SubjectAdapter

        model = _make_mock_model()
        loader = _make_calibration_loader()
        adapter = SubjectAdapter.from_nearest_neighbor(model, loader)

        assert adapter._weights.shape[0] == 1  # one new subject
        assert adapter._weights.shape[1] == 32  # hidden dim
        assert adapter._weights.shape[2] == 50  # n_vertices

    def test_inject_into_model(self):
        from cortexlab.core.subject import SubjectAdapter

        model = _make_mock_model(n_subjects=3)
        adapter = SubjectAdapter(weights=torch.randn(1, 32, 50))
        new_id = adapter.inject_into_model(model)

        assert new_id == 3
        assert model.predictor.weights.shape[0] == 4