transformer-xray / scripts /test_architectures.py
AlexWortega's picture
(1) test_architectures.py: 17 archs covered with invariants (head detected, loop present, FLOPs > 0, min depth, source extraction). All pass. (2) FLOPs aggregation fixes: read tree dict 'args'+'repeat' fields (not 'config' which only flat nodes have); Linear+Head both compute 2·in·out; Loop multiplies children by repeat. (3) _eval handles getattr(config, 'X', default) — fixes Llama q_proj resolving num_attention_heads * head_dim. (4) HEAD_ATTR_NAMES adds cls, vocab_projector, embed_out for BERT/DeBERTa/DistilBert variants
a5da241 verified
"""End-to-end test: load configs for a representative set of architectures,
build the static graph, and validate invariants. Prints a compact report."""
from __future__ import annotations
import os
import sys
import time
import traceback
from dataclasses import dataclass
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from backend.model_loader import _load_config_robust, _hf_token # noqa: E402
from backend.static_graph import build_static_graph, class_source # noqa: E402
from backend.modality import detect_modality, input_info, output_info # noqa: E402
@dataclass
class TestCase:
model_id: str
expect_head: bool = True # has at least one Head-kind node
expect_loop: bool = True # has at least one Loop container
expect_moe: bool = False # has at least one MoE container
head_kind: str | None = None # expected head_kind
min_nodes: int = 5
max_depth: int = 4 # min(max_depth) — must reach this depth at least
note: str = "" # what makes this test interesting
CASES = [
# encoder-only with MLM heads
TestCase("prajjwal1/bert-tiny", head_kind="encoder", expect_head=False, min_nodes=10, note="config has no architectures field — falls back to BertModel (no head)"),
TestCase("bert-base-uncased", head_kind="masked_lm", min_nodes=20, note="canonical encoder-only with MLM head (self.cls)"),
TestCase("distilbert/distilbert-base-uncased", head_kind="masked_lm", min_nodes=20, note="DistilBert head = vocab_projector"),
TestCase("FacebookAI/roberta-base", head_kind="masked_lm", min_nodes=20),
TestCase("microsoft/deberta-v3-base", head_kind="encoder", expect_head=False, min_nodes=15, note="model_type 'deberta-v2' (hyphen); DebertaV2Model has no head"),
# decoder-only causal LM (dense)
TestCase("openai-community/gpt2", head_kind="causal_lm", min_nodes=15, note="n_layer alias for num_hidden_layers"),
TestCase("meta-llama/Llama-3.2-1B", head_kind="causal_lm", min_nodes=15, note="getattr(config, 'head_dim', default) idiom"),
TestCase("Qwen/Qwen2.5-0.5B", head_kind="causal_lm", min_nodes=15),
TestCase("microsoft/phi-2", head_kind="causal_lm", min_nodes=15),
TestCase("google/gemma-2-2b", head_kind="causal_lm", min_nodes=15),
TestCase("mistralai/Mistral-7B-v0.1", head_kind="causal_lm", min_nodes=15),
TestCase("stabilityai/stablelm-2-1_6b", head_kind="causal_lm", min_nodes=15),
# encoder-decoder
TestCase("google-t5/t5-small", head_kind="encoder", min_nodes=15, note="encoder+decoder, ModuleList().append idiom"),
TestCase("facebook/bart-base", head_kind="encoder", expect_head=False, min_nodes=20, note="BartModel (encoder-decoder, no head)"),
TestCase("openai/whisper-tiny", head_kind="causal_lm", min_nodes=15, note="audio encoder-decoder, proj_out as head"),
# vision
TestCase("google/vit-base-patch16-224", head_kind="image_classification", min_nodes=15, note="ViT classifier via IfExp"),
TestCase("openai/clip-vit-base-patch32", head_kind="encoder", expect_head=False, min_nodes=20, note="CLIPModel — no classifier head, just text+vision towers"),
]
def has_kind(graph, kind: str) -> bool:
return any(n.get("kind") == kind for n in graph["nodes"])
def has_head(graph) -> bool:
return has_kind(graph, "Head")
def has_loop(graph) -> bool:
return has_kind(graph, "Loop")
def has_moe(graph) -> bool:
return has_kind(graph, "MoE")
def max_depth(graph) -> int:
return max((n.get("depth", 0) for n in graph["nodes"]), default=0)
def run_one(case: TestCase) -> tuple[bool, str, dict]:
failures: list[str] = []
info: dict = {"id": case.model_id}
t0 = time.time()
try:
cfg = _load_config_robust(case.model_id)
except Exception as e:
return False, f"config load: {type(e).__name__}: {e}", info
info["config_ms"] = int((time.time() - t0) * 1000)
info["model_type"] = getattr(cfg, "model_type", None)
info["modality"] = detect_modality(getattr(cfg, "model_type", None))
t0 = time.time()
try:
graph = build_static_graph(cfg, head_kind=case.head_kind)
except Exception as e:
traceback.print_exc()
return False, f"graph build: {type(e).__name__}: {e}", info
info["graph_ms"] = int((time.time() - t0) * 1000)
info["nodes"] = len(graph["nodes"])
info["edges"] = len(graph["edges"])
info["max_depth"] = max_depth(graph)
info["has_head"] = has_head(graph)
info["has_loop"] = has_loop(graph)
info["has_moe"] = has_moe(graph)
info["arch"] = graph.get("arch")
# Invariants
if info["nodes"] < case.min_nodes:
failures.append(f"too few nodes: {info['nodes']} < {case.min_nodes}")
if info["max_depth"] < case.max_depth:
failures.append(f"shallow graph: max_depth={info['max_depth']} < {case.max_depth}")
if case.expect_head and not info["has_head"]:
failures.append("no Head-kind node")
if case.expect_loop and not info["has_loop"]:
failures.append("no Loop container")
if case.expect_moe and not info["has_moe"]:
failures.append("no MoE container")
# FLOPs sanity — root should have non-zero estimate
root = next((n for n in graph["nodes"] if n["id"] == "<root>"), None)
if root is not None:
info["root_flops_per_token"] = root.get("flops_per_token", 0)
if info["root_flops_per_token"] == 0:
failures.append("root flops_per_token == 0")
# Source extraction — pick a deepest module and try fetching its class source
arch = graph.get("arch")
if arch:
try:
src = class_source(getattr(cfg, "model_type", None), arch)
info["source_chars"] = len(src) if src else 0
if not src:
failures.append(f"could not fetch source for {arch}")
except Exception as e:
failures.append(f"source fetch: {e}")
# Input/Output info sanity
try:
ii = input_info(info["modality"], cfg)
oi = output_info(info["modality"], case.head_kind or "encoder", cfg)
info["input_kind"] = ii.get("kind")
info["output_kind"] = oi.get("kind")
except Exception as e:
failures.append(f"io info: {e}")
if failures:
return False, "; ".join(failures), info
return True, "ok", info
def main():
print(f"running {len(CASES)} architecture tests\n")
results = []
pass_count = 0
for case in CASES:
ok, msg, info = run_one(case)
results.append((ok, case, msg, info))
if ok:
pass_count += 1
marker = "✓" if ok else "✗"
print(
f"{marker} {case.model_id:50s} "
f"nodes={info.get('nodes', '-'):>4} d={info.get('max_depth', '-')} "
f"head={info.get('has_head', '-')!s:5} loop={info.get('has_loop', '-')!s:5} "
f"moe={info.get('has_moe', '-')!s:5} "
f"flops={info.get('root_flops_per_token', 0)/1e9:.2f}G/tok "
f"{'' if ok else '— ' + msg}"
)
if case.note and not ok:
print(f" ↳ note: {case.note}")
print(f"\n{pass_count}/{len(CASES)} passed")
return 0 if pass_count == len(CASES) else 1
if __name__ == "__main__":
sys.exit(main())