File size: 1,463 Bytes
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e518f94
 
 
 
 
2298e0a
e518f94
 
 
 
 
 
2298e0a
e518f94
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from __future__ import annotations

from types import SimpleNamespace

import pytest

pytest.importorskip("torch")

from app.core.model_support import describe_model_support


def test_model_support_accepts_transformer_h_layout() -> None:
    layer = SimpleNamespace(attn=object())
    model = SimpleNamespace(
        transformer=SimpleNamespace(h=[layer, layer]),
        config=SimpleNamespace(_attn_implementation="eager"),
    )

    support = describe_model_support(model)

    assert support.supports_attribution is True
    assert support.layer_path == "transformer.h"
    assert support.attention_attr == "attn"


def test_model_support_rejects_non_eager_attention() -> None:
    layer = SimpleNamespace(self_attn=object())
    model = SimpleNamespace(
        model=SimpleNamespace(layers=[layer]),
        config=SimpleNamespace(_attn_implementation="sdpa"),
    )

    support = describe_model_support(model)

    assert support.supports_attribution is False
    assert "eager" in (support.reason or "")


def test_model_support_accepts_hrm_layout() -> None:
    layer = SimpleNamespace(attention=object())
    model = SimpleNamespace(
        model=SimpleNamespace(h_stack=[layer, layer]),
        config=SimpleNamespace(_attn_implementation="eager"),
    )

    support = describe_model_support(model)

    assert support.supports_attribution is True
    assert support.layer_path == "model.h_stack"
    assert support.attention_attr == "attention"