File size: 6,053 Bytes
fda8fb3
 
2620860
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aca1293
 
 
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bd6f97
 
fda8fb3
3bd6f97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fda8fb3
 
 
 
3bd6f97
fda8fb3
 
 
 
 
 
3bd6f97
fda8fb3
3bd6f97
 
 
 
 
 
 
 
 
 
 
fda8fb3
3bd6f97
 
fda8fb3
 
 
 
 
3bd6f97
 
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2620860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from __future__ import annotations

import inspect
from dataclasses import dataclass
from typing import Any


@dataclass(frozen=True, slots=True)
class ModelSupport:
    supports_attribution: bool
    reason: str | None
    layer_path: str | None
    attention_attr: str | None
    layer_count: int
    attention_impl: str | None


_LAYER_PATH_CANDIDATES: tuple[tuple[str, ...], ...] = (
    ("model", "layers"),
    ("model", "model", "layers"),
    ("transformer", "h"),
    ("gpt_neox", "layers"),
    ("model", "h_stack"),
    ("model", "l_stack"),
    ("layers",),
)
_ATTENTION_ATTR_CANDIDATES: tuple[str, ...] = ("self_attn", "attn", "attention")


def _resolve_attr_chain(obj: Any, path: tuple[str, ...]) -> Any | None:
    current = obj
    for segment in path:
        current = getattr(current, segment, None)
        if current is None:
            return None
    return current


def _get_attention_impl(model: Any) -> str | None:
    config = getattr(model, "config", None)
    if config is None:
        return None
    return getattr(config, "_attn_implementation", None) or getattr(config, "attn_implementation", None)


def describe_model_support(model: Any) -> ModelSupport:
    attn_impl = _get_attention_impl(model)
    layers = None
    layer_path = None

    # Step 1: Try hardcoded candidates
    for candidate in _LAYER_PATH_CANDIDATES:
        try:
            maybe_layers = _resolve_attr_chain(model, candidate)
            if maybe_layers is not None and hasattr(maybe_layers, "__iter__"):
                l = list(maybe_layers)
                if len(l) > 0 and isinstance(l[0], torch.nn.Module):
                    layers = l
                    layer_path = ".".join(candidate)
                    break
        except Exception:
            continue

    # Step 2: Aggressive discovery if candidates failed
    if not layers:
        for attr_name in dir(model):
            if attr_name.startswith("_"):
                continue
            try:
                attr = getattr(model, attr_name, None)
                if isinstance(attr, torch.nn.ModuleList) and len(attr) > 0:
                    layers = list(attr)
                    layer_path = attr_name
                    break
                
                # Check one level deep for common container names
                if attr_name in ("model", "transformer", "gpt_neox", "h_stack", "l_stack"):
                    for sub_attr_name in dir(attr):
                        if sub_attr_name.startswith("_"):
                            continue
                        sub_attr = getattr(attr, sub_attr_name, None)
                        if isinstance(sub_attr, torch.nn.ModuleList) and len(sub_attr) > 0:
                            layers = list(sub_attr)
                            layer_path = f"{attr_name}.{sub_attr_name}"
                            break
                    if layers:
                        break
            except Exception:
                continue

    if not layers:
        return ModelSupport(
            supports_attribution=False,
            reason=f"Unsupported model structure: unable to locate decoder layers (Model type: {type(model).__name__}).",
            layer_path=layer_path,
            attention_attr=None,
            layer_count=0,
            attention_impl=attn_impl,
        )

    # Step 3: Find attention attribute
    for attention_attr in _ATTENTION_ATTR_CANDIDATES:
        try:
            if all(getattr(layer, attention_attr, None) is not None for layer in layers):
                if attn_impl != "eager":
                    return ModelSupport(
                        supports_attribution=False,
                        reason="Attention gradients require attn_implementation='eager'.",
                        layer_path=layer_path,
                        attention_attr=attention_attr,
                        layer_count=len(layers),
                        attention_impl=attn_impl,
                    )
                return ModelSupport(
                    supports_attribution=True,
                    reason=None,
                    layer_path=layer_path,
                    attention_attr=attention_attr,
                    layer_count=len(layers),
                    attention_impl=attn_impl,
                )
        except Exception:
            continue

    return ModelSupport(
        supports_attribution=False,
        reason="Unsupported attention module layout: no known attention attribute found on decoder layers.",
        layer_path=layer_path,
        attention_attr=None,
        layer_count=len(layers),
        attention_impl=attn_impl,
    )


def get_decoder_layers(model: Any) -> tuple[list[Any], str, str]:
    support = describe_model_support(model)
    if not support.supports_attribution or support.layer_path is None or support.attention_attr is None:
        reason = support.reason or "Model does not support attribution analysis."
        raise RuntimeError(reason)

    layers = _resolve_attr_chain(model, tuple(support.layer_path.split(".")))
    if layers is None:
        raise RuntimeError("Model support metadata became inconsistent while resolving layers.")
    return list(layers), support.layer_path, support.attention_attr


def should_use_prefix_token_type_ids(model: Any) -> bool:
    config = getattr(model, "config", None)
    if not getattr(config, "prefix_lm", False):
        return False

    forward = getattr(model, "forward", None)
    if forward is None:
        return False

    try:
        parameters = inspect.signature(forward).parameters
    except (TypeError, ValueError):
        return False
    return "token_type_ids" in parameters


def add_prefix_token_type_ids(model: Any, encoded: dict[str, Any]) -> dict[str, Any]:
    if not should_use_prefix_token_type_ids(model):
        return encoded

    input_ids = encoded.get("input_ids")
    if input_ids is None:
        return encoded

    encoded = dict(encoded)
    encoded["token_type_ids"] = input_ids.new_ones(input_ids.shape)
    return encoded