"""End-to-end test: load configs for a representative set of architectures, build the static graph, and validate invariants. Prints a compact report.""" from __future__ import annotations import os import sys import time import traceback from dataclasses import dataclass sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from backend.model_loader import _load_config_robust, _hf_token # noqa: E402 from backend.static_graph import build_static_graph, class_source # noqa: E402 from backend.modality import detect_modality, input_info, output_info # noqa: E402 @dataclass class TestCase: model_id: str expect_head: bool = True # has at least one Head-kind node expect_loop: bool = True # has at least one Loop container expect_moe: bool = False # has at least one MoE container head_kind: str | None = None # expected head_kind min_nodes: int = 5 max_depth: int = 4 # min(max_depth) — must reach this depth at least note: str = "" # what makes this test interesting CASES = [ # encoder-only with MLM heads TestCase("prajjwal1/bert-tiny", head_kind="encoder", expect_head=False, min_nodes=10, note="config has no architectures field — falls back to BertModel (no head)"), TestCase("bert-base-uncased", head_kind="masked_lm", min_nodes=20, note="canonical encoder-only with MLM head (self.cls)"), TestCase("distilbert/distilbert-base-uncased", head_kind="masked_lm", min_nodes=20, note="DistilBert head = vocab_projector"), TestCase("FacebookAI/roberta-base", head_kind="masked_lm", min_nodes=20), TestCase("microsoft/deberta-v3-base", head_kind="encoder", expect_head=False, min_nodes=15, note="model_type 'deberta-v2' (hyphen); DebertaV2Model has no head"), # decoder-only causal LM (dense) TestCase("openai-community/gpt2", head_kind="causal_lm", min_nodes=15, note="n_layer alias for num_hidden_layers"), TestCase("meta-llama/Llama-3.2-1B", head_kind="causal_lm", min_nodes=15, note="getattr(config, 'head_dim', default) idiom"), TestCase("Qwen/Qwen2.5-0.5B", head_kind="causal_lm", min_nodes=15), TestCase("microsoft/phi-2", head_kind="causal_lm", min_nodes=15), TestCase("google/gemma-2-2b", head_kind="causal_lm", min_nodes=15), TestCase("mistralai/Mistral-7B-v0.1", head_kind="causal_lm", min_nodes=15), TestCase("stabilityai/stablelm-2-1_6b", head_kind="causal_lm", min_nodes=15), # encoder-decoder TestCase("google-t5/t5-small", head_kind="encoder", min_nodes=15, note="encoder+decoder, ModuleList().append idiom"), TestCase("facebook/bart-base", head_kind="encoder", expect_head=False, min_nodes=20, note="BartModel (encoder-decoder, no head)"), TestCase("openai/whisper-tiny", head_kind="causal_lm", min_nodes=15, note="audio encoder-decoder, proj_out as head"), # vision TestCase("google/vit-base-patch16-224", head_kind="image_classification", min_nodes=15, note="ViT classifier via IfExp"), TestCase("openai/clip-vit-base-patch32", head_kind="encoder", expect_head=False, min_nodes=20, note="CLIPModel — no classifier head, just text+vision towers"), ] def has_kind(graph, kind: str) -> bool: return any(n.get("kind") == kind for n in graph["nodes"]) def has_head(graph) -> bool: return has_kind(graph, "Head") def has_loop(graph) -> bool: return has_kind(graph, "Loop") def has_moe(graph) -> bool: return has_kind(graph, "MoE") def max_depth(graph) -> int: return max((n.get("depth", 0) for n in graph["nodes"]), default=0) def run_one(case: TestCase) -> tuple[bool, str, dict]: failures: list[str] = [] info: dict = {"id": case.model_id} t0 = time.time() try: cfg = _load_config_robust(case.model_id) except Exception as e: return False, f"config load: {type(e).__name__}: {e}", info info["config_ms"] = int((time.time() - t0) * 1000) info["model_type"] = getattr(cfg, "model_type", None) info["modality"] = detect_modality(getattr(cfg, "model_type", None)) t0 = time.time() try: graph = build_static_graph(cfg, head_kind=case.head_kind) except Exception as e: traceback.print_exc() return False, f"graph build: {type(e).__name__}: {e}", info info["graph_ms"] = int((time.time() - t0) * 1000) info["nodes"] = len(graph["nodes"]) info["edges"] = len(graph["edges"]) info["max_depth"] = max_depth(graph) info["has_head"] = has_head(graph) info["has_loop"] = has_loop(graph) info["has_moe"] = has_moe(graph) info["arch"] = graph.get("arch") # Invariants if info["nodes"] < case.min_nodes: failures.append(f"too few nodes: {info['nodes']} < {case.min_nodes}") if info["max_depth"] < case.max_depth: failures.append(f"shallow graph: max_depth={info['max_depth']} < {case.max_depth}") if case.expect_head and not info["has_head"]: failures.append("no Head-kind node") if case.expect_loop and not info["has_loop"]: failures.append("no Loop container") if case.expect_moe and not info["has_moe"]: failures.append("no MoE container") # FLOPs sanity — root should have non-zero estimate root = next((n for n in graph["nodes"] if n["id"] == ""), None) if root is not None: info["root_flops_per_token"] = root.get("flops_per_token", 0) if info["root_flops_per_token"] == 0: failures.append("root flops_per_token == 0") # Source extraction — pick a deepest module and try fetching its class source arch = graph.get("arch") if arch: try: src = class_source(getattr(cfg, "model_type", None), arch) info["source_chars"] = len(src) if src else 0 if not src: failures.append(f"could not fetch source for {arch}") except Exception as e: failures.append(f"source fetch: {e}") # Input/Output info sanity try: ii = input_info(info["modality"], cfg) oi = output_info(info["modality"], case.head_kind or "encoder", cfg) info["input_kind"] = ii.get("kind") info["output_kind"] = oi.get("kind") except Exception as e: failures.append(f"io info: {e}") if failures: return False, "; ".join(failures), info return True, "ok", info def main(): print(f"running {len(CASES)} architecture tests\n") results = [] pass_count = 0 for case in CASES: ok, msg, info = run_one(case) results.append((ok, case, msg, info)) if ok: pass_count += 1 marker = "✓" if ok else "✗" print( f"{marker} {case.model_id:50s} " f"nodes={info.get('nodes', '-'):>4} d={info.get('max_depth', '-')} " f"head={info.get('has_head', '-')!s:5} loop={info.get('has_loop', '-')!s:5} " f"moe={info.get('has_moe', '-')!s:5} " f"flops={info.get('root_flops_per_token', 0)/1e9:.2f}G/tok " f"{'' if ok else '— ' + msg}" ) if case.note and not ok: print(f" ↳ note: {case.note}") print(f"\n{pass_count}/{len(CASES)} passed") return 0 if pass_count == len(CASES) else 1 if __name__ == "__main__": sys.exit(main())