File size: 862 Bytes
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from __future__ import annotations

import numpy as np
import pytest

torch = pytest.importorskip("torch")

from app.analysis.hooks import get_stored_attentions
from app.analysis.suppression import compute_attribution_matrix


def test_attribution_matrix_shape_and_cleanup() -> None:
    from tests.conftest import FakeCausalLM

    model = FakeCausalLM()
    input_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long)
    token_ranges = [(0, 2), (2, 4), (4, 5)]

    result = compute_attribution_matrix(
        input_ids=input_ids,
        token_ranges=token_ranges,
        model=model,
        take_log=False,
    )

    assert result.matrix.shape == (3, 3)
    assert np.allclose(np.diag(result.raw_matrix), 0.0)
    assert np.allclose(np.triu(result.raw_matrix), 0.0)
    assert np.isfinite(result.matrix).all()
    assert get_stored_attentions() == {}