| """Tests for ROI attention extraction.""" |
|
|
| from unittest.mock import MagicMock |
|
|
| import numpy as np |
| import pytest |
| import torch |
|
|
|
|
| class TestAttentionExtractor: |
| def test_context_manager_returns_list(self): |
| from cortexlab.core.attention import AttentionExtractor |
|
|
| encoder = torch.nn.TransformerEncoder( |
| torch.nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True), |
| num_layers=2, |
| ) |
| with AttentionExtractor(encoder) as maps: |
| x = torch.randn(2, 10, 32) |
| _ = encoder(x) |
|
|
| assert isinstance(maps, list) |
|
|
| def test_hooks_cleaned_up(self): |
| from cortexlab.core.attention import AttentionExtractor |
|
|
| encoder = torch.nn.TransformerEncoder( |
| torch.nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True), |
| num_layers=2, |
| ) |
| with AttentionExtractor(encoder) as _maps: |
| x = torch.randn(1, 5, 32) |
| _ = encoder(x) |
|
|
| |
| assert isinstance(_maps, list) |
|
|
|
|
| class TestAttentionToRoiScores: |
| def test_basic_roi_scores(self): |
| from cortexlab.core.attention import attention_to_roi_scores |
|
|
| |
| attn_maps = [torch.randn(1, 4, 10, 10) for _ in range(2)] |
| roi_indices = { |
| "V1": np.array([0, 1, 2]), |
| "MT": np.array([5, 6]), |
| } |
| scores = attention_to_roi_scores(attn_maps, roi_indices) |
|
|
| assert "V1" in scores |
| assert "MT" in scores |
| assert scores["V1"].shape == (10,) |
| assert scores["MT"].shape == (10,) |
|
|
| def test_with_predictor_weights(self): |
| from cortexlab.core.attention import attention_to_roi_scores |
|
|
| attn_maps = [torch.randn(1, 4, 10, 10)] |
| roi_indices = { |
| "V1": np.array([0, 1]), |
| "A1": np.array([3, 4]), |
| } |
| |
| weights = torch.randn(32, 10) |
| scores = attention_to_roi_scores(attn_maps, roi_indices, predictor_weights=weights) |
|
|
| assert "V1" in scores |
| assert "A1" in scores |
|
|
| def test_empty_attn_maps(self): |
| from cortexlab.core.attention import attention_to_roi_scores |
|
|
| roi_indices = {"V1": np.array([0, 1])} |
| scores = attention_to_roi_scores([], roi_indices) |
|
|
| assert "V1" in scores |
| assert len(scores["V1"]) == 0 |
|
|