File size: 4,799 Bytes
9635a89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Tests for modality attribution."""

from unittest.mock import MagicMock

import numpy as np
import pytest
import torch

from tests.conftest import make_segments


def _make_mock_model(n_vertices=100):
    """Create a mock model for attribution testing."""
    model = MagicMock()
    model.feature_dims = {"text": (2, 32), "audio": (2, 32), "video": (2, 32)}
    model.eval = MagicMock()

    call_count = [0]

    def fake_forward(batch, **kwargs):
        call_count[0] += 1
        B = 2
        # Make text-ablated predictions differ more to simulate importance
        if torch.all(batch.data.get("text", torch.ones(1)) == 0):
            return torch.ones(B, n_vertices, 10) * 0.5
        elif torch.all(batch.data.get("audio", torch.ones(1)) == 0):
            return torch.ones(B, n_vertices, 10) * 0.8
        elif torch.all(batch.data.get("video", torch.ones(1)) == 0):
            return torch.ones(B, n_vertices, 10) * 0.9
        return torch.ones(B, n_vertices, 10) * 1.0

    model.side_effect = fake_forward
    model.return_value = torch.ones(2, n_vertices, 10)
    # Override __call__ to use our function
    model.__class__ = type("MockModel", (), {
        "__call__": staticmethod(fake_forward),
        "feature_dims": {"text": (2, 32), "audio": (2, 32), "video": (2, 32)},
        "eval": lambda self: None,
    })
    # Simpler approach: just use a plain class
    class FakeModel:
        feature_dims = {"text": (2, 32), "audio": (2, 32), "video": (2, 32)}
        def eval(self): pass
        def __call__(self, batch, **kwargs):
            return fake_forward(batch, **kwargs)
    return FakeModel()


class TestModalityAttributor:
    def test_ablation_basic(self):
        from neuralset.dataloader import SegmentData

        from cortexlab.inference.attribution import ModalityAttributor

        model = _make_mock_model()
        attributor = ModalityAttributor(model)

        data = {
            "text": torch.randn(2, 2, 32, 20),
            "audio": torch.randn(2, 2, 32, 20),
            "video": torch.randn(2, 2, 32, 20),
            "subject_id": torch.zeros(2, dtype=torch.long),
        }
        batch = SegmentData(data=data, segments=make_segments(2))
        scores = attributor.attribute(batch)

        assert "text" in scores
        assert "audio" in scores
        assert "video" in scores
        assert scores["text"].shape == (100,)

    def test_text_most_important(self):
        from neuralset.dataloader import SegmentData

        from cortexlab.inference.attribution import ModalityAttributor

        model = _make_mock_model()
        attributor = ModalityAttributor(model)

        data = {
            "text": torch.randn(2, 2, 32, 20),
            "audio": torch.randn(2, 2, 32, 20),
            "video": torch.randn(2, 2, 32, 20),
            "subject_id": torch.zeros(2, dtype=torch.long),
        }
        batch = SegmentData(data=data, segments=make_segments(2))
        scores = attributor.attribute(batch)

        # Text ablation causes the biggest change (1.0 -> 0.5 = 0.5 diff)
        assert scores["text"].mean() > scores["audio"].mean()
        assert scores["audio"].mean() > scores["video"].mean()

    def test_normalised_scores_sum_to_one(self):
        from neuralset.dataloader import SegmentData

        from cortexlab.inference.attribution import ModalityAttributor

        model = _make_mock_model()
        attributor = ModalityAttributor(model)

        data = {
            "text": torch.randn(2, 2, 32, 20),
            "audio": torch.randn(2, 2, 32, 20),
            "video": torch.randn(2, 2, 32, 20),
            "subject_id": torch.zeros(2, dtype=torch.long),
        }
        batch = SegmentData(data=data, segments=make_segments(2))
        scores = attributor.attribute(batch)

        total = scores["text_normalised"] + scores["audio_normalised"] + scores["video_normalised"]
        np.testing.assert_allclose(total, 1.0, atol=1e-6)

    def test_with_roi_indices(self):
        from neuralset.dataloader import SegmentData

        from cortexlab.inference.attribution import ModalityAttributor

        roi_indices = {
            "V1": np.array([0, 1, 2, 3, 4]),
            "MT": np.array([10, 11, 12]),
        }
        model = _make_mock_model()
        attributor = ModalityAttributor(model, roi_indices=roi_indices)

        data = {
            "text": torch.randn(2, 2, 32, 20),
            "audio": torch.randn(2, 2, 32, 20),
            "video": torch.randn(2, 2, 32, 20),
            "subject_id": torch.zeros(2, dtype=torch.long),
        }
        batch = SegmentData(data=data, segments=make_segments(2))
        scores = attributor.attribute(batch)

        assert "text_roi" in scores
        assert "V1" in scores["text_roi"]
        assert "MT" in scores["text_roi"]