Spaces:
Sleeping
Sleeping
File size: 6,053 Bytes
fda8fb3 2620860 fda8fb3 aca1293 fda8fb3 3bd6f97 fda8fb3 3bd6f97 fda8fb3 3bd6f97 fda8fb3 3bd6f97 fda8fb3 3bd6f97 fda8fb3 3bd6f97 fda8fb3 3bd6f97 fda8fb3 2620860 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | from __future__ import annotations
import inspect
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True, slots=True)
class ModelSupport:
supports_attribution: bool
reason: str | None
layer_path: str | None
attention_attr: str | None
layer_count: int
attention_impl: str | None
_LAYER_PATH_CANDIDATES: tuple[tuple[str, ...], ...] = (
("model", "layers"),
("model", "model", "layers"),
("transformer", "h"),
("gpt_neox", "layers"),
("model", "h_stack"),
("model", "l_stack"),
("layers",),
)
_ATTENTION_ATTR_CANDIDATES: tuple[str, ...] = ("self_attn", "attn", "attention")
def _resolve_attr_chain(obj: Any, path: tuple[str, ...]) -> Any | None:
current = obj
for segment in path:
current = getattr(current, segment, None)
if current is None:
return None
return current
def _get_attention_impl(model: Any) -> str | None:
config = getattr(model, "config", None)
if config is None:
return None
return getattr(config, "_attn_implementation", None) or getattr(config, "attn_implementation", None)
def describe_model_support(model: Any) -> ModelSupport:
attn_impl = _get_attention_impl(model)
layers = None
layer_path = None
# Step 1: Try hardcoded candidates
for candidate in _LAYER_PATH_CANDIDATES:
try:
maybe_layers = _resolve_attr_chain(model, candidate)
if maybe_layers is not None and hasattr(maybe_layers, "__iter__"):
l = list(maybe_layers)
if len(l) > 0 and isinstance(l[0], torch.nn.Module):
layers = l
layer_path = ".".join(candidate)
break
except Exception:
continue
# Step 2: Aggressive discovery if candidates failed
if not layers:
for attr_name in dir(model):
if attr_name.startswith("_"):
continue
try:
attr = getattr(model, attr_name, None)
if isinstance(attr, torch.nn.ModuleList) and len(attr) > 0:
layers = list(attr)
layer_path = attr_name
break
# Check one level deep for common container names
if attr_name in ("model", "transformer", "gpt_neox", "h_stack", "l_stack"):
for sub_attr_name in dir(attr):
if sub_attr_name.startswith("_"):
continue
sub_attr = getattr(attr, sub_attr_name, None)
if isinstance(sub_attr, torch.nn.ModuleList) and len(sub_attr) > 0:
layers = list(sub_attr)
layer_path = f"{attr_name}.{sub_attr_name}"
break
if layers:
break
except Exception:
continue
if not layers:
return ModelSupport(
supports_attribution=False,
reason=f"Unsupported model structure: unable to locate decoder layers (Model type: {type(model).__name__}).",
layer_path=layer_path,
attention_attr=None,
layer_count=0,
attention_impl=attn_impl,
)
# Step 3: Find attention attribute
for attention_attr in _ATTENTION_ATTR_CANDIDATES:
try:
if all(getattr(layer, attention_attr, None) is not None for layer in layers):
if attn_impl != "eager":
return ModelSupport(
supports_attribution=False,
reason="Attention gradients require attn_implementation='eager'.",
layer_path=layer_path,
attention_attr=attention_attr,
layer_count=len(layers),
attention_impl=attn_impl,
)
return ModelSupport(
supports_attribution=True,
reason=None,
layer_path=layer_path,
attention_attr=attention_attr,
layer_count=len(layers),
attention_impl=attn_impl,
)
except Exception:
continue
return ModelSupport(
supports_attribution=False,
reason="Unsupported attention module layout: no known attention attribute found on decoder layers.",
layer_path=layer_path,
attention_attr=None,
layer_count=len(layers),
attention_impl=attn_impl,
)
def get_decoder_layers(model: Any) -> tuple[list[Any], str, str]:
support = describe_model_support(model)
if not support.supports_attribution or support.layer_path is None or support.attention_attr is None:
reason = support.reason or "Model does not support attribution analysis."
raise RuntimeError(reason)
layers = _resolve_attr_chain(model, tuple(support.layer_path.split(".")))
if layers is None:
raise RuntimeError("Model support metadata became inconsistent while resolving layers.")
return list(layers), support.layer_path, support.attention_attr
def should_use_prefix_token_type_ids(model: Any) -> bool:
config = getattr(model, "config", None)
if not getattr(config, "prefix_lm", False):
return False
forward = getattr(model, "forward", None)
if forward is None:
return False
try:
parameters = inspect.signature(forward).parameters
except (TypeError, ValueError):
return False
return "token_type_ids" in parameters
def add_prefix_token_type_ids(model: Any, encoded: dict[str, Any]) -> dict[str, Any]:
if not should_use_prefix_token_type_ids(model):
return encoded
input_ids = encoded.get("input_ids")
if input_ids is None:
return encoded
encoded = dict(encoded)
encoded["token_type_ids"] = input_ids.new_ones(input_ids.shape)
return encoded
|