"""Scan all class→kind mappings across the test architectures and flag suspicious classifications (Head not at top-level, same-kind nested chains > 2 deep, MLP wrappers misclassified as Block, etc.)""" import os, sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from collections import Counter from backend.model_loader import _load_config_robust # noqa from backend.static_graph import build_static_graph, classify # noqa MODELS = [ "prajjwal1/bert-tiny", "bert-base-uncased", "distilbert/distilbert-base-uncased", "FacebookAI/roberta-base", "microsoft/deberta-v3-base", "openai-community/gpt2", "meta-llama/Llama-3.2-1B", "Qwen/Qwen2.5-0.5B", "microsoft/phi-2", "google/gemma-2-2b", "google-t5/t5-small", "facebook/bart-base", "openai/whisper-tiny", "google/vit-base-patch16-224", "openai/clip-vit-base-patch32", "answerdotai/ModernBERT-base", ] def scan(model_id: str): try: cfg = _load_config_robust(model_id) g = build_static_graph(cfg) except Exception as e: print(f"\n=== {model_id} === FAIL: {e}") return print(f"\n=== {model_id} ({g['arch']}) — {len(g['nodes'])} nodes ===") # Group by kind cls_by_kind: dict[str, set[str]] = {} for n in g["nodes"]: cls_by_kind.setdefault(n["kind"], set()).add(n["module_class"]) for kind in sorted(cls_by_kind): print(f" {kind:11s} {sorted(cls_by_kind[kind])}") # Heuristic warnings head_classes = cls_by_kind.get("Head", set()) if len(head_classes) > 2: print(f" ⚠ MULTIPLE HEAD CLASSES: {sorted(head_classes)}") # Block classes — should be transformer-layer-y blocks = cls_by_kind.get("Block", set()) suspicious_blocks = [c for c in blocks if "MLP" in c or "FFN" in c or "FF" in c.removesuffix("FF")[-2:] or "FeedForward" in c] if suspicious_blocks: print(f" ⚠ Block-classified looks MLP-y: {suspicious_blocks}") # Other classes — should be top-level model wrappers (depth 0/1) only other_classes = [] for n in g["nodes"]: if n["kind"] == "Other" and n["depth"] >= 2 and n["module_class"] not in ("Other", ""): other_classes.append(f"d={n['depth']} {n['module_class']}") if other_classes: print(f" ℹ Other-kind beyond depth-1: {other_classes[:5]}{'…' if len(other_classes) > 5 else ''}") for m in MODELS: scan(m)