cot-anc / app /core /model_support.py
BART-ender's picture
fix(core): implement robust auto-discovery for model layers
3bd6f97 verified
from __future__ import annotations
import inspect
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True, slots=True)
class ModelSupport:
supports_attribution: bool
reason: str | None
layer_path: str | None
attention_attr: str | None
layer_count: int
attention_impl: str | None
_LAYER_PATH_CANDIDATES: tuple[tuple[str, ...], ...] = (
("model", "layers"),
("model", "model", "layers"),
("transformer", "h"),
("gpt_neox", "layers"),
("model", "h_stack"),
("model", "l_stack"),
("layers",),
)
_ATTENTION_ATTR_CANDIDATES: tuple[str, ...] = ("self_attn", "attn", "attention")
def _resolve_attr_chain(obj: Any, path: tuple[str, ...]) -> Any | None:
current = obj
for segment in path:
current = getattr(current, segment, None)
if current is None:
return None
return current
def _get_attention_impl(model: Any) -> str | None:
config = getattr(model, "config", None)
if config is None:
return None
return getattr(config, "_attn_implementation", None) or getattr(config, "attn_implementation", None)
def describe_model_support(model: Any) -> ModelSupport:
attn_impl = _get_attention_impl(model)
layers = None
layer_path = None
# Step 1: Try hardcoded candidates
for candidate in _LAYER_PATH_CANDIDATES:
try:
maybe_layers = _resolve_attr_chain(model, candidate)
if maybe_layers is not None and hasattr(maybe_layers, "__iter__"):
l = list(maybe_layers)
if len(l) > 0 and isinstance(l[0], torch.nn.Module):
layers = l
layer_path = ".".join(candidate)
break
except Exception:
continue
# Step 2: Aggressive discovery if candidates failed
if not layers:
for attr_name in dir(model):
if attr_name.startswith("_"):
continue
try:
attr = getattr(model, attr_name, None)
if isinstance(attr, torch.nn.ModuleList) and len(attr) > 0:
layers = list(attr)
layer_path = attr_name
break
# Check one level deep for common container names
if attr_name in ("model", "transformer", "gpt_neox", "h_stack", "l_stack"):
for sub_attr_name in dir(attr):
if sub_attr_name.startswith("_"):
continue
sub_attr = getattr(attr, sub_attr_name, None)
if isinstance(sub_attr, torch.nn.ModuleList) and len(sub_attr) > 0:
layers = list(sub_attr)
layer_path = f"{attr_name}.{sub_attr_name}"
break
if layers:
break
except Exception:
continue
if not layers:
return ModelSupport(
supports_attribution=False,
reason=f"Unsupported model structure: unable to locate decoder layers (Model type: {type(model).__name__}).",
layer_path=layer_path,
attention_attr=None,
layer_count=0,
attention_impl=attn_impl,
)
# Step 3: Find attention attribute
for attention_attr in _ATTENTION_ATTR_CANDIDATES:
try:
if all(getattr(layer, attention_attr, None) is not None for layer in layers):
if attn_impl != "eager":
return ModelSupport(
supports_attribution=False,
reason="Attention gradients require attn_implementation='eager'.",
layer_path=layer_path,
attention_attr=attention_attr,
layer_count=len(layers),
attention_impl=attn_impl,
)
return ModelSupport(
supports_attribution=True,
reason=None,
layer_path=layer_path,
attention_attr=attention_attr,
layer_count=len(layers),
attention_impl=attn_impl,
)
except Exception:
continue
return ModelSupport(
supports_attribution=False,
reason="Unsupported attention module layout: no known attention attribute found on decoder layers.",
layer_path=layer_path,
attention_attr=None,
layer_count=len(layers),
attention_impl=attn_impl,
)
def get_decoder_layers(model: Any) -> tuple[list[Any], str, str]:
support = describe_model_support(model)
if not support.supports_attribution or support.layer_path is None or support.attention_attr is None:
reason = support.reason or "Model does not support attribution analysis."
raise RuntimeError(reason)
layers = _resolve_attr_chain(model, tuple(support.layer_path.split(".")))
if layers is None:
raise RuntimeError("Model support metadata became inconsistent while resolving layers.")
return list(layers), support.layer_path, support.attention_attr
def should_use_prefix_token_type_ids(model: Any) -> bool:
config = getattr(model, "config", None)
if not getattr(config, "prefix_lm", False):
return False
forward = getattr(model, "forward", None)
if forward is None:
return False
try:
parameters = inspect.signature(forward).parameters
except (TypeError, ValueError):
return False
return "token_type_ids" in parameters
def add_prefix_token_type_ids(model: Any, encoded: dict[str, Any]) -> dict[str, Any]:
if not should_use_prefix_token_type_ids(model):
return encoded
input_ids = encoded.get("input_ids")
if input_ids is None:
return encoded
encoded = dict(encoded)
encoded["token_type_ids"] = input_ids.new_ones(input_ids.shape)
return encoded