Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import inspect | |
| from dataclasses import dataclass | |
| from typing import Any | |
| class ModelSupport: | |
| supports_attribution: bool | |
| reason: str | None | |
| layer_path: str | None | |
| attention_attr: str | None | |
| layer_count: int | |
| attention_impl: str | None | |
| _LAYER_PATH_CANDIDATES: tuple[tuple[str, ...], ...] = ( | |
| ("model", "layers"), | |
| ("model", "model", "layers"), | |
| ("transformer", "h"), | |
| ("gpt_neox", "layers"), | |
| ("model", "h_stack"), | |
| ("model", "l_stack"), | |
| ("layers",), | |
| ) | |
| _ATTENTION_ATTR_CANDIDATES: tuple[str, ...] = ("self_attn", "attn", "attention") | |
| def _resolve_attr_chain(obj: Any, path: tuple[str, ...]) -> Any | None: | |
| current = obj | |
| for segment in path: | |
| current = getattr(current, segment, None) | |
| if current is None: | |
| return None | |
| return current | |
| def _get_attention_impl(model: Any) -> str | None: | |
| config = getattr(model, "config", None) | |
| if config is None: | |
| return None | |
| return getattr(config, "_attn_implementation", None) or getattr(config, "attn_implementation", None) | |
| def describe_model_support(model: Any) -> ModelSupport: | |
| attn_impl = _get_attention_impl(model) | |
| layers = None | |
| layer_path = None | |
| # Step 1: Try hardcoded candidates | |
| for candidate in _LAYER_PATH_CANDIDATES: | |
| try: | |
| maybe_layers = _resolve_attr_chain(model, candidate) | |
| if maybe_layers is not None and hasattr(maybe_layers, "__iter__"): | |
| l = list(maybe_layers) | |
| if len(l) > 0 and isinstance(l[0], torch.nn.Module): | |
| layers = l | |
| layer_path = ".".join(candidate) | |
| break | |
| except Exception: | |
| continue | |
| # Step 2: Aggressive discovery if candidates failed | |
| if not layers: | |
| for attr_name in dir(model): | |
| if attr_name.startswith("_"): | |
| continue | |
| try: | |
| attr = getattr(model, attr_name, None) | |
| if isinstance(attr, torch.nn.ModuleList) and len(attr) > 0: | |
| layers = list(attr) | |
| layer_path = attr_name | |
| break | |
| # Check one level deep for common container names | |
| if attr_name in ("model", "transformer", "gpt_neox", "h_stack", "l_stack"): | |
| for sub_attr_name in dir(attr): | |
| if sub_attr_name.startswith("_"): | |
| continue | |
| sub_attr = getattr(attr, sub_attr_name, None) | |
| if isinstance(sub_attr, torch.nn.ModuleList) and len(sub_attr) > 0: | |
| layers = list(sub_attr) | |
| layer_path = f"{attr_name}.{sub_attr_name}" | |
| break | |
| if layers: | |
| break | |
| except Exception: | |
| continue | |
| if not layers: | |
| return ModelSupport( | |
| supports_attribution=False, | |
| reason=f"Unsupported model structure: unable to locate decoder layers (Model type: {type(model).__name__}).", | |
| layer_path=layer_path, | |
| attention_attr=None, | |
| layer_count=0, | |
| attention_impl=attn_impl, | |
| ) | |
| # Step 3: Find attention attribute | |
| for attention_attr in _ATTENTION_ATTR_CANDIDATES: | |
| try: | |
| if all(getattr(layer, attention_attr, None) is not None for layer in layers): | |
| if attn_impl != "eager": | |
| return ModelSupport( | |
| supports_attribution=False, | |
| reason="Attention gradients require attn_implementation='eager'.", | |
| layer_path=layer_path, | |
| attention_attr=attention_attr, | |
| layer_count=len(layers), | |
| attention_impl=attn_impl, | |
| ) | |
| return ModelSupport( | |
| supports_attribution=True, | |
| reason=None, | |
| layer_path=layer_path, | |
| attention_attr=attention_attr, | |
| layer_count=len(layers), | |
| attention_impl=attn_impl, | |
| ) | |
| except Exception: | |
| continue | |
| return ModelSupport( | |
| supports_attribution=False, | |
| reason="Unsupported attention module layout: no known attention attribute found on decoder layers.", | |
| layer_path=layer_path, | |
| attention_attr=None, | |
| layer_count=len(layers), | |
| attention_impl=attn_impl, | |
| ) | |
| def get_decoder_layers(model: Any) -> tuple[list[Any], str, str]: | |
| support = describe_model_support(model) | |
| if not support.supports_attribution or support.layer_path is None or support.attention_attr is None: | |
| reason = support.reason or "Model does not support attribution analysis." | |
| raise RuntimeError(reason) | |
| layers = _resolve_attr_chain(model, tuple(support.layer_path.split("."))) | |
| if layers is None: | |
| raise RuntimeError("Model support metadata became inconsistent while resolving layers.") | |
| return list(layers), support.layer_path, support.attention_attr | |
| def should_use_prefix_token_type_ids(model: Any) -> bool: | |
| config = getattr(model, "config", None) | |
| if not getattr(config, "prefix_lm", False): | |
| return False | |
| forward = getattr(model, "forward", None) | |
| if forward is None: | |
| return False | |
| try: | |
| parameters = inspect.signature(forward).parameters | |
| except (TypeError, ValueError): | |
| return False | |
| return "token_type_ids" in parameters | |
| def add_prefix_token_type_ids(model: Any, encoded: dict[str, Any]) -> dict[str, Any]: | |
| if not should_use_prefix_token_type_ids(model): | |
| return encoded | |
| input_ids = encoded.get("input_ids") | |
| if input_ids is None: | |
| return encoded | |
| encoded = dict(encoded) | |
| encoded["token_type_ids"] = input_ids.new_ones(input_ids.shape) | |
| return encoded | |