transformer-xray / scripts /scan_kinds.py
AlexWortega's picture
(1) MoE/Loop prototype gets a '✱ 1 of N' / '🔁 1 of N' badge — makes the single example expert/layer obvious inside its container; (2) T5LayerFF / DenseActDense / DenseGatedActDense classified as MLP not Block (matched 'Layer' before); (3) scan_kinds.py utility to audit class→kind mappings across architectures
43cf1fd verified
"""Scan all class→kind mappings across the test architectures and flag
suspicious classifications (Head not at top-level, same-kind nested chains
> 2 deep, MLP wrappers misclassified as Block, etc.)"""
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from collections import Counter
from backend.model_loader import _load_config_robust # noqa
from backend.static_graph import build_static_graph, classify # noqa
MODELS = [
"prajjwal1/bert-tiny",
"bert-base-uncased",
"distilbert/distilbert-base-uncased",
"FacebookAI/roberta-base",
"microsoft/deberta-v3-base",
"openai-community/gpt2",
"meta-llama/Llama-3.2-1B",
"Qwen/Qwen2.5-0.5B",
"microsoft/phi-2",
"google/gemma-2-2b",
"google-t5/t5-small",
"facebook/bart-base",
"openai/whisper-tiny",
"google/vit-base-patch16-224",
"openai/clip-vit-base-patch32",
"answerdotai/ModernBERT-base",
]
def scan(model_id: str):
try:
cfg = _load_config_robust(model_id)
g = build_static_graph(cfg)
except Exception as e:
print(f"\n=== {model_id} === FAIL: {e}")
return
print(f"\n=== {model_id} ({g['arch']}) — {len(g['nodes'])} nodes ===")
# Group by kind
cls_by_kind: dict[str, set[str]] = {}
for n in g["nodes"]:
cls_by_kind.setdefault(n["kind"], set()).add(n["module_class"])
for kind in sorted(cls_by_kind):
print(f" {kind:11s} {sorted(cls_by_kind[kind])}")
# Heuristic warnings
head_classes = cls_by_kind.get("Head", set())
if len(head_classes) > 2:
print(f" ⚠ MULTIPLE HEAD CLASSES: {sorted(head_classes)}")
# Block classes — should be transformer-layer-y
blocks = cls_by_kind.get("Block", set())
suspicious_blocks = [c for c in blocks if "MLP" in c or "FFN" in c or "FF" in c.removesuffix("FF")[-2:] or "FeedForward" in c]
if suspicious_blocks:
print(f" ⚠ Block-classified looks MLP-y: {suspicious_blocks}")
# Other classes — should be top-level model wrappers (depth 0/1) only
other_classes = []
for n in g["nodes"]:
if n["kind"] == "Other" and n["depth"] >= 2 and n["module_class"] not in ("Other", ""):
other_classes.append(f"d={n['depth']} {n['module_class']}")
if other_classes:
print(f" ℹ Other-kind beyond depth-1: {other_classes[:5]}{'…' if len(other_classes) > 5 else ''}")
for m in MODELS:
scan(m)