cot-anc / tests /test_suppression_shape.py
BART-ender's picture
Deploy Thought Anchors
fda8fb3 verified
raw
history blame contribute delete
862 Bytes
from __future__ import annotations
import numpy as np
import pytest
torch = pytest.importorskip("torch")
from app.analysis.hooks import get_stored_attentions
from app.analysis.suppression import compute_attribution_matrix
def test_attribution_matrix_shape_and_cleanup() -> None:
from tests.conftest import FakeCausalLM
model = FakeCausalLM()
input_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long)
token_ranges = [(0, 2), (2, 4), (4, 5)]
result = compute_attribution_matrix(
input_ids=input_ids,
token_ranges=token_ranges,
model=model,
take_log=False,
)
assert result.matrix.shape == (3, 3)
assert np.allclose(np.diag(result.raw_matrix), 0.0)
assert np.allclose(np.triu(result.raw_matrix), 0.0)
assert np.isfinite(result.matrix).all()
assert get_stored_attentions() == {}