File size: 4,799 Bytes
9635a89 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """Tests for modality attribution."""
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
from tests.conftest import make_segments
def _make_mock_model(n_vertices=100):
"""Create a mock model for attribution testing."""
model = MagicMock()
model.feature_dims = {"text": (2, 32), "audio": (2, 32), "video": (2, 32)}
model.eval = MagicMock()
call_count = [0]
def fake_forward(batch, **kwargs):
call_count[0] += 1
B = 2
# Make text-ablated predictions differ more to simulate importance
if torch.all(batch.data.get("text", torch.ones(1)) == 0):
return torch.ones(B, n_vertices, 10) * 0.5
elif torch.all(batch.data.get("audio", torch.ones(1)) == 0):
return torch.ones(B, n_vertices, 10) * 0.8
elif torch.all(batch.data.get("video", torch.ones(1)) == 0):
return torch.ones(B, n_vertices, 10) * 0.9
return torch.ones(B, n_vertices, 10) * 1.0
model.side_effect = fake_forward
model.return_value = torch.ones(2, n_vertices, 10)
# Override __call__ to use our function
model.__class__ = type("MockModel", (), {
"__call__": staticmethod(fake_forward),
"feature_dims": {"text": (2, 32), "audio": (2, 32), "video": (2, 32)},
"eval": lambda self: None,
})
# Simpler approach: just use a plain class
class FakeModel:
feature_dims = {"text": (2, 32), "audio": (2, 32), "video": (2, 32)}
def eval(self): pass
def __call__(self, batch, **kwargs):
return fake_forward(batch, **kwargs)
return FakeModel()
class TestModalityAttributor:
def test_ablation_basic(self):
from neuralset.dataloader import SegmentData
from cortexlab.inference.attribution import ModalityAttributor
model = _make_mock_model()
attributor = ModalityAttributor(model)
data = {
"text": torch.randn(2, 2, 32, 20),
"audio": torch.randn(2, 2, 32, 20),
"video": torch.randn(2, 2, 32, 20),
"subject_id": torch.zeros(2, dtype=torch.long),
}
batch = SegmentData(data=data, segments=make_segments(2))
scores = attributor.attribute(batch)
assert "text" in scores
assert "audio" in scores
assert "video" in scores
assert scores["text"].shape == (100,)
def test_text_most_important(self):
from neuralset.dataloader import SegmentData
from cortexlab.inference.attribution import ModalityAttributor
model = _make_mock_model()
attributor = ModalityAttributor(model)
data = {
"text": torch.randn(2, 2, 32, 20),
"audio": torch.randn(2, 2, 32, 20),
"video": torch.randn(2, 2, 32, 20),
"subject_id": torch.zeros(2, dtype=torch.long),
}
batch = SegmentData(data=data, segments=make_segments(2))
scores = attributor.attribute(batch)
# Text ablation causes the biggest change (1.0 -> 0.5 = 0.5 diff)
assert scores["text"].mean() > scores["audio"].mean()
assert scores["audio"].mean() > scores["video"].mean()
def test_normalised_scores_sum_to_one(self):
from neuralset.dataloader import SegmentData
from cortexlab.inference.attribution import ModalityAttributor
model = _make_mock_model()
attributor = ModalityAttributor(model)
data = {
"text": torch.randn(2, 2, 32, 20),
"audio": torch.randn(2, 2, 32, 20),
"video": torch.randn(2, 2, 32, 20),
"subject_id": torch.zeros(2, dtype=torch.long),
}
batch = SegmentData(data=data, segments=make_segments(2))
scores = attributor.attribute(batch)
total = scores["text_normalised"] + scores["audio_normalised"] + scores["video_normalised"]
np.testing.assert_allclose(total, 1.0, atol=1e-6)
def test_with_roi_indices(self):
from neuralset.dataloader import SegmentData
from cortexlab.inference.attribution import ModalityAttributor
roi_indices = {
"V1": np.array([0, 1, 2, 3, 4]),
"MT": np.array([10, 11, 12]),
}
model = _make_mock_model()
attributor = ModalityAttributor(model, roi_indices=roi_indices)
data = {
"text": torch.randn(2, 2, 32, 20),
"audio": torch.randn(2, 2, 32, 20),
"video": torch.randn(2, 2, 32, 20),
"subject_id": torch.zeros(2, dtype=torch.long),
}
batch = SegmentData(data=data, segments=make_segments(2))
scores = attributor.attribute(batch)
assert "text_roi" in scores
assert "V1" in scores["text_roi"]
assert "MT" in scores["text_roi"]
|