from __future__ import annotations import inspect from dataclasses import dataclass from typing import Any @dataclass(frozen=True, slots=True) class ModelSupport: supports_attribution: bool reason: str | None layer_path: str | None attention_attr: str | None layer_count: int attention_impl: str | None _LAYER_PATH_CANDIDATES: tuple[tuple[str, ...], ...] = ( ("model", "layers"), ("model", "model", "layers"), ("transformer", "h"), ("gpt_neox", "layers"), ("model", "h_stack"), ("model", "l_stack"), ("layers",), ) _ATTENTION_ATTR_CANDIDATES: tuple[str, ...] = ("self_attn", "attn", "attention") def _resolve_attr_chain(obj: Any, path: tuple[str, ...]) -> Any | None: current = obj for segment in path: current = getattr(current, segment, None) if current is None: return None return current def _get_attention_impl(model: Any) -> str | None: config = getattr(model, "config", None) if config is None: return None return getattr(config, "_attn_implementation", None) or getattr(config, "attn_implementation", None) def describe_model_support(model: Any) -> ModelSupport: attn_impl = _get_attention_impl(model) layers = None layer_path = None # Step 1: Try hardcoded candidates for candidate in _LAYER_PATH_CANDIDATES: try: maybe_layers = _resolve_attr_chain(model, candidate) if maybe_layers is not None and hasattr(maybe_layers, "__iter__"): l = list(maybe_layers) if len(l) > 0 and isinstance(l[0], torch.nn.Module): layers = l layer_path = ".".join(candidate) break except Exception: continue # Step 2: Aggressive discovery if candidates failed if not layers: for attr_name in dir(model): if attr_name.startswith("_"): continue try: attr = getattr(model, attr_name, None) if isinstance(attr, torch.nn.ModuleList) and len(attr) > 0: layers = list(attr) layer_path = attr_name break # Check one level deep for common container names if attr_name in ("model", "transformer", "gpt_neox", "h_stack", "l_stack"): for sub_attr_name in dir(attr): if sub_attr_name.startswith("_"): continue sub_attr = getattr(attr, sub_attr_name, None) if isinstance(sub_attr, torch.nn.ModuleList) and len(sub_attr) > 0: layers = list(sub_attr) layer_path = f"{attr_name}.{sub_attr_name}" break if layers: break except Exception: continue if not layers: return ModelSupport( supports_attribution=False, reason=f"Unsupported model structure: unable to locate decoder layers (Model type: {type(model).__name__}).", layer_path=layer_path, attention_attr=None, layer_count=0, attention_impl=attn_impl, ) # Step 3: Find attention attribute for attention_attr in _ATTENTION_ATTR_CANDIDATES: try: if all(getattr(layer, attention_attr, None) is not None for layer in layers): if attn_impl != "eager": return ModelSupport( supports_attribution=False, reason="Attention gradients require attn_implementation='eager'.", layer_path=layer_path, attention_attr=attention_attr, layer_count=len(layers), attention_impl=attn_impl, ) return ModelSupport( supports_attribution=True, reason=None, layer_path=layer_path, attention_attr=attention_attr, layer_count=len(layers), attention_impl=attn_impl, ) except Exception: continue return ModelSupport( supports_attribution=False, reason="Unsupported attention module layout: no known attention attribute found on decoder layers.", layer_path=layer_path, attention_attr=None, layer_count=len(layers), attention_impl=attn_impl, ) def get_decoder_layers(model: Any) -> tuple[list[Any], str, str]: support = describe_model_support(model) if not support.supports_attribution or support.layer_path is None or support.attention_attr is None: reason = support.reason or "Model does not support attribution analysis." raise RuntimeError(reason) layers = _resolve_attr_chain(model, tuple(support.layer_path.split("."))) if layers is None: raise RuntimeError("Model support metadata became inconsistent while resolving layers.") return list(layers), support.layer_path, support.attention_attr def should_use_prefix_token_type_ids(model: Any) -> bool: config = getattr(model, "config", None) if not getattr(config, "prefix_lm", False): return False forward = getattr(model, "forward", None) if forward is None: return False try: parameters = inspect.signature(forward).parameters except (TypeError, ValueError): return False return "token_type_ids" in parameters def add_prefix_token_type_ids(model: Any, encoded: dict[str, Any]) -> dict[str, Any]: if not should_use_prefix_token_type_ids(model): return encoded input_ids = encoded.get("input_ids") if input_ids is None: return encoded encoded = dict(encoded) encoded["token_type_ids"] = input_ids.new_ones(input_ids.shape) return encoded