cortexlab / tests /test_attention.py
SID2000's picture
Upload folder using huggingface_hub
9635a89 verified
"""Tests for ROI attention extraction."""
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
class TestAttentionExtractor:
def test_context_manager_returns_list(self):
from cortexlab.core.attention import AttentionExtractor
encoder = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True),
num_layers=2,
)
with AttentionExtractor(encoder) as maps:
x = torch.randn(2, 10, 32)
_ = encoder(x)
assert isinstance(maps, list)
def test_hooks_cleaned_up(self):
from cortexlab.core.attention import AttentionExtractor
encoder = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True),
num_layers=2,
)
with AttentionExtractor(encoder) as _maps:
x = torch.randn(1, 5, 32)
_ = encoder(x)
# Hooks should be cleaned up after exiting context
assert isinstance(_maps, list)
class TestAttentionToRoiScores:
def test_basic_roi_scores(self):
from cortexlab.core.attention import attention_to_roi_scores
# Simulate 2 layers of attention maps: (B=1, heads=4, T=10, T=10)
attn_maps = [torch.randn(1, 4, 10, 10) for _ in range(2)]
roi_indices = {
"V1": np.array([0, 1, 2]),
"MT": np.array([5, 6]),
}
scores = attention_to_roi_scores(attn_maps, roi_indices)
assert "V1" in scores
assert "MT" in scores
assert scores["V1"].shape == (10,)
assert scores["MT"].shape == (10,)
def test_with_predictor_weights(self):
from cortexlab.core.attention import attention_to_roi_scores
attn_maps = [torch.randn(1, 4, 10, 10)]
roi_indices = {
"V1": np.array([0, 1]),
"A1": np.array([3, 4]),
}
# Predictor weights: (hidden=32, n_vertices=10)
weights = torch.randn(32, 10)
scores = attention_to_roi_scores(attn_maps, roi_indices, predictor_weights=weights)
assert "V1" in scores
assert "A1" in scores
def test_empty_attn_maps(self):
from cortexlab.core.attention import attention_to_roi_scores
roi_indices = {"V1": np.array([0, 1])}
scores = attention_to_roi_scores([], roi_indices)
assert "V1" in scores
assert len(scores["V1"]) == 0