"""Structural-health backend for AI Model X-Ray. Extracts attention-head connectivity graphs from any HuggingFace transformer and computes the spectral simplicial hierarchy per layer: lambda2(T(G)) <= lambda2(G) (and the chain lambda2(T3) <= lambda2(T(G)) <= lambda2(G)) The graph / hierarchy / fragility math is ported verbatim from the validated Modal experiments (`modal_jepa.py`, `modal_cross_arch.py`): lambda2 (algebraic connectivity), triangle_graph T(G), the fragility index FI (fraction of edges in zero triangles) and the per-layer `hierarchy` record. Only the surrounding plumbing — model-type detection, per-layer regime classification, the precomputed reference cache and the `scan_model` entry point — is new here. Each layer is classified into a pruning regime from its coherence ratio rho = lambda2(T(G)) / lambda2(G) and its FI: immune rho > 0.8 and FI == 0 triangle-redundant -> safe to prune buffer 0.5 <= rho <= 0.8 partial redundancy -> prune with caution critical rho < 0.5 low redundancy -> do not prune (layers where T(G) is disconnected have rho undefined -> treated as buffer) """ from __future__ import annotations import json import os import statistics # --------------------------------------------------------------------------- # # Constants # --------------------------------------------------------------------------- # TOL = 1e-9 THRESHOLD = 0.3 # Pearson r edge threshold (matches the reference runs) SCALING_FACTOR = 500.0 # structural-integrity gauge: score = 100 - FI * factor HERE = os.path.dirname(os.path.abspath(__file__)) CACHE_PATH = os.path.join(HERE, "data", "cross_architecture_spectral.json") # 16 diverse English sentences for text models (verbatim from modal_cross_arch). SENTENCES = [ "The quick brown fox jumps over the lazy dog near the river.", "Scientists discovered a new species of butterfly in the rainforest.", "She carefully placed the fragile vase on the wooden shelf.", "Global markets rallied after the central bank cut interest rates.", "He learned to play the violin when he was only six years old.", "The ancient castle stood silently on the hill above the village.", "Machine learning models require large amounts of training data.", "A gentle breeze carried the scent of blossoms across the garden.", "The committee voted unanimously to approve the new budget proposal.", "Children laughed and played in the park on a sunny afternoon.", "The spacecraft entered orbit after a seven month journey to Mars.", "Fresh bread and strong coffee filled the small bakery with warmth.", "Engineers tested the bridge under heavy load before opening it.", "The novel explores themes of memory, loss, and the passage of time.", "Volunteers planted hundreds of trees along the eroding coastline.", "Quantum computers may one day solve problems classical machines cannot.", ] # Pre-loaded models surfaced in the dropdown: id -> display metadata. PRELOADED = { "bert-base-uncased": {"label": "BERT", "type": "Encoder", "params": "110M", "cache_key": "BERT"}, "gpt2": {"label": "GPT-2", "type": "Decoder", "params": "124M", "cache_key": "GPT-2"}, "google/vit-base-patch16-224": {"label": "ViT", "type": "Vision", "params": "86M", "cache_key": "ViT"}, "distilbert-base-uncased": {"label": "DistilBERT","type": "Encoder", "params": "66M", "cache_key": None}, } # ========================================================================= # # Graph / hierarchy helpers (ported verbatim from modal_jepa.py) # ========================================================================= # def lambda2(G): """Second-smallest Laplacian eigenvalue (algebraic connectivity) of G. Returns (val, connected, n). 0 exactly when G is disconnected.""" import networkx as nx import numpy as np n = G.number_of_nodes() if n < 2: return None, False, n connected = nx.is_connected(G) if n <= 1800: L = nx.laplacian_matrix(G).toarray().astype(float) ev = np.linalg.eigvalsh(L) return float(ev[1]), connected, n from scipy.sparse.linalg import eigsh Ls = nx.laplacian_matrix(G).astype(float) try: ev = eigsh(Ls, k=2, sigma=1e-8, which="LM", return_eigenvectors=False) except Exception: ev = eigsh(Ls, k=2, which="SM", return_eigenvectors=False) return float(np.sort(ev)[1]), connected, n def _triangles(G, adj): seen = set() for v in G.nodes(): nbrs = sorted(adj[v]) for x in range(len(nbrs)): for y in range(x + 1, len(nbrs)): a, b = nbrs[x], nbrs[y] if b in adj[a]: tri = tuple(sorted((v, a, b))) if tri not in seen: seen.add(tri) yield tri def triangle_graph(G): """T(G): nodes = edges of G; two edges adjacent iff they lie in a common triangle. Built by enumerating triangles.""" import networkx as nx edges = [tuple(sorted(e)) for e in G.edges()] idx = {e: i for i, e in enumerate(edges)} T = nx.Graph() T.add_nodes_from(range(len(edges))) adj = {v: set(G.neighbors(v)) for v in G.nodes()} for a, b, c in _triangles(G, adj): e_ab, e_ac, e_bc = tuple(sorted((a, b))), tuple(sorted((a, c))), tuple(sorted((b, c))) i, j, k = idx[e_ab], idx[e_ac], idx[e_bc] T.add_edge(i, j); T.add_edge(i, k); T.add_edge(j, k) return T def compute_fi_from_adj(A): """SAL fragility index: fraction of edges participating in zero triangles. A: n x n binary symmetric adjacency (no self-loops). High FI = fragile.""" import numpy as np n = A.shape[0] if n < 2: return 0.0 total = fragile = 0 for i in range(n): row_i = A[i] for j in range(i + 1, n): if A[i, j]: total += 1 if int(np.dot(row_i, A[j])) == 0: fragile += 1 return 1.0 if total == 0 else fragile / total def graph_from_corr(C, threshold): """Build a simple graph from a correlation matrix: edge iff C_ij > threshold.""" import networkx as nx import numpy as np A = (C > threshold).astype(np.int8) np.fill_diagonal(A, 0) G = nx.from_numpy_array(A) return G, A def _density(G): n = G.number_of_nodes() if n < 2: return 0.0 return 2.0 * G.number_of_edges() / (n * (n - 1)) def _count_edges_triangles_capped(G, cap): adj = {v: set(G.neighbors(v)) for v in G.nodes()} c = 0 for _ in _triangles(G, adj): c += 1 if c > cap: return c return c def hierarchy(C, threshold, edge_cap=400, giant=False): """Per-layer spectral-hierarchy record for a correlation matrix C. rho / violation are defined ONLY where both G and T(G) are connected (`eligible`). If the graph is too dense the threshold is raised in small steps so T(G) stays tractable; the threshold actually used is reported. `giant=True` evaluates the hierarchy on G's largest connected component.""" import networkx as nx t = threshold bumped = False tri_cap = 4000 while True: G, A = graph_from_corr(C, t) ntri = _count_edges_triangles_capped(G, tri_cap) if (G.number_of_edges() > edge_cap or ntri > tri_cap) and t < 0.95: t = round(t + 0.02, 4); bumped = True continue break fi = compute_fi_from_adj(A) n_comp = nx.number_connected_components(G) if G.number_of_nodes() else 0 orig_nodes, orig_edges = G.number_of_nodes(), G.number_of_edges() if giant and orig_nodes >= 1 and not nx.is_connected(G): G = G.subgraph(max(nx.connected_components(G), key=len)).copy() l2G, gconn, gn = lambda2(G) TG = triangle_graph(G) tconn = TG.number_of_nodes() >= 2 and nx.is_connected(TG) l2TG, _, _ = lambda2(TG) eligible = bool(gconn and tconn) return { "threshold": t, "threshold_bumped": bumped, "giant": giant, "orig_nodes": int(orig_nodes), "orig_edges": int(orig_edges), "n_components": int(n_comp), "nodes": int(G.number_of_nodes()), "edges": int(G.number_of_edges()), "density": _density(G), "G_connected": bool(gconn), "TG_nodes": int(TG.number_of_nodes()), "TG_edges": int(TG.number_of_edges()), "TG_connected": bool(tconn), "eligible": eligible, "l2G": l2G, "l2TG": l2TG, "rho": (l2TG / l2G) if (eligible and l2G and l2G > TOL) else None, "fi": fi, "violation": bool(eligible and l2TG is not None and l2G is not None and l2TG - l2G > TOL), } # ========================================================================= # # Regime classification # ========================================================================= # def classify_regime(rec): """Map a per-layer hierarchy record to a pruning regime. Returns one of 'immune' | 'buffer' | 'critical'. Layers whose triangle graph is disconnected (rho undefined) are treated as 'buffer' — partial structure, prune with caution.""" rho = rec.get("rho") fi = rec.get("fi") or 0.0 if rho is None: return "buffer" if rho > 0.8 and fi <= TOL: return "immune" if rho < 0.5: return "critical" return "buffer" REGIME_META = { "immune": {"label": "Immune", "color": "#10B981", "advice": "safe to prune"}, "buffer": {"label": "Buffer", "color": "#EF9F27", "advice": "prune with caution"}, "critical": {"label": "Critical", "color": "#E24B4A", "advice": "do not prune"}, } def summarize(result): """Collapse a scan result (per_layer + meta) into dashboard-ready aggregates.""" pl = result["per_layer"] eligible = [r for r in pl if r["eligible"]] rhos = [r["rho"] for r in eligible if r["rho"] is not None] fis = [r["fi"] for r in pl if r["fi"] is not None] regimes = [classify_regime(r) for r in pl] rho_mean = statistics.mean(rhos) if rhos else None # Headline fragility = full-graph (global) base FI when available, else mean. fi_head = result.get("fi_base") if fi_head is None: fi_head = statistics.mean(fis) if fis else 0.0 score = max(0.0, min(100.0, 100.0 - fi_head * SCALING_FACTOR)) return { "n_layers": result["n_layers"], "n_heads": result["n_heads"], "n_immune": regimes.count("immune"), "n_buffer": regimes.count("buffer"), "n_critical": regimes.count("critical"), "n_eligible": len(eligible), "violations": sum(1 for r in eligible if r["violation"]), "rho_mean": rho_mean, "rho_min": min(rhos) if rhos else None, "rho_max": max(rhos) if rhos else None, "fi_head": fi_head, "score": score, "regimes": regimes, } def score_band(score): """(label, color) for the structural-integrity gauge value.""" if score >= 92: return "Over-redundant", "#10B981" if score >= 80: return "Healthy", "#10B981" if score >= 50: return "Buffer zone", "#EF9F27" return "Critical", "#E24B4A" # ========================================================================= # # Precomputed reference cache # ========================================================================= # _CACHE = None def _load_cache(): global _CACHE if _CACHE is None: with open(CACHE_PATH, encoding="utf-8") as f: _CACHE = json.load(f) return _CACHE def reference_rows(): """The architecture-comparison table rows (I-JEPA, BERT, GPT-2, ViT).""" cache = _load_cache() return list(cache["comparison"]) def cached_result(model_id): """Return a scan-shaped result for a pre-loaded model from the cache, or None if the model has no precomputed block (e.g. DistilBERT -> needs a live scan).""" meta = PRELOADED.get(model_id) if not meta or not meta.get("cache_key"): return None cache = _load_cache() block = cache["models"].get(meta["cache_key"]) if block is None: return None return { "model_id": model_id, "label": meta["label"], "arch": meta["type"], "modality": block["modality"], "n_layers": block["n_layers"], "n_heads": block["n_heads"], "hidden": block["hidden"], "params": meta["params"], "per_layer": block["per_layer"], "fi_base": block["masking"]["base_fi"], "source": "cached", } # ========================================================================= # # Live attention extraction + scan # ========================================================================= # def _detect_modality(config): """Classify a loaded model as 'vision' | 'decoder' | 'encoder'.""" arch = " ".join(getattr(config, "architectures", None) or []).lower() mt = (getattr(config, "model_type", "") or "").lower() if "vit" in mt or "vit" in arch or "image" in mt or getattr(config, "image_size", None): return "vision" if getattr(config, "is_decoder", False) or "gpt" in mt or "llama" in mt \ or "causal" in arch or "lmhead" in arch: return "decoder" return "encoder" def _arch_label(modality): return {"vision": "Vision", "decoder": "Decoder", "encoder": "Encoder"}[modality] def _corr(M): import numpy as np C = np.corrcoef(M) return np.nan_to_num(C, nan=0.0) def _extract_text(model_id, dev, max_length=32): import numpy as np import torch from transformers import AutoModel, AutoTokenizer tok = AutoTokenizer.from_pretrained(model_id) if tok.pad_token is None: tok.pad_token = tok.eos_token model = AutoModel.from_pretrained( model_id, attn_implementation="eager", torch_dtype=torch.float32) model.to(dev).eval() cfg = model.config n_layers, n_heads = cfg.num_hidden_layers, cfg.num_attention_heads acc = None; seq_len = None; count = 0 for s in range(0, len(SENTENCES), 4): batch = SENTENCES[s:s + 4] inp = tok(batch, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length).to(dev) with torch.no_grad(): out = model(**inp, output_attentions=True) atts = out.attentions if atts is None or atts[0] is None: raise RuntimeError(f"{model_id} returned no attentions") if acc is None: seq_len = atts[0].shape[-1] acc = np.zeros((n_layers, n_heads, seq_len, seq_len), dtype=np.float64) for L in range(n_layers): acc[L] += atts[L].sum(dim=0).float().cpu().numpy() count += len(batch) acc /= count return acc, n_layers, n_heads, int(cfg.hidden_size), int(seq_len) def _extract_vision(model_id, dev, num_samples=16): import numpy as np import torch from transformers import AutoImageProcessor, AutoModel proc = AutoImageProcessor.from_pretrained(model_id) model = AutoModel.from_pretrained( model_id, attn_implementation="eager", torch_dtype=torch.float32) model.to(dev).eval() cfg = model.config n_layers, n_heads = cfg.num_hidden_layers, cfg.num_attention_heads images = _load_images(num_samples) acc = None; seq_len = None; count = 0 for s in range(0, len(images), 4): batch = images[s:s + 4] inp = proc(images=batch, return_tensors="pt").to(dev) with torch.no_grad(): out = model(**inp, output_attentions=True) atts = out.attentions if atts is None or atts[0] is None: raise RuntimeError(f"{model_id} returned no attentions") if acc is None: seq_len = atts[0].shape[-1] acc = np.zeros((n_layers, n_heads, seq_len, seq_len), dtype=np.float64) for L in range(n_layers): acc[L] += atts[L].sum(dim=0).float().cpu().numpy() count += len(batch) acc /= count return acc, n_layers, n_heads, int(cfg.hidden_size), int(seq_len) def _load_images(n): """n PIL RGB images: HF cifar10 test split, with a COCO-URL fallback.""" from PIL import Image imgs = [] try: from datasets import load_dataset ds = load_dataset("uoft-cs/cifar10", split=f"test[:{n}]") col = "img" if "img" in ds.column_names else ds.column_names[0] imgs = [ds[i][col].convert("RGB") for i in range(len(ds))] except Exception: pass if len(imgs) < 4: import io import urllib.request urls = [ "http://images.cocodataset.org/val2017/000000039769.jpg", "http://images.cocodataset.org/val2017/000000000139.jpg", "http://images.cocodataset.org/val2017/000000000285.jpg", "http://images.cocodataset.org/val2017/000000000632.jpg", "http://images.cocodataset.org/val2017/000000000724.jpg", "http://images.cocodataset.org/val2017/000000000776.jpg", "http://images.cocodataset.org/val2017/000000000785.jpg", "http://images.cocodataset.org/val2017/000000000802.jpg", ] for u in urls[:max(n, 8)]: try: with urllib.request.urlopen(u, timeout=20) as r: imgs.append(Image.open(io.BytesIO(r.read())).convert("RGB")) except Exception: pass if not imgs: raise RuntimeError("could not load any probe images") return imgs[:n] def scan_model(model_id, num_samples=16): """Full structural scan of a HuggingFace transformer. Loads the model with output_attentions, runs it on probe data appropriate to its modality (16 sentences for text, 16 CIFAR-10 images for vision), builds the per-layer attention-correlation graph (Pearson r > 0.3) and computes the spectral hierarchy + fragility index per layer. Returns a scan-shaped dict consumable by `summarize` and the UI renderers.""" import torch from transformers import AutoConfig dev = "cuda" if torch.cuda.is_available() else "cpu" cfg = AutoConfig.from_pretrained(model_id) modality = _detect_modality(cfg) if modality == "vision": acc, n_layers, n_heads, hidden, seq_len = _extract_vision(model_id, dev, num_samples) else: acc, n_layers, n_heads, hidden, seq_len = _extract_text(model_id, dev) per_layer = [] for L in range(n_layers): C = _corr(acc[L].reshape(n_heads, -1)) rec = hierarchy(C, THRESHOLD, edge_cap=400) rec["layer"] = L per_layer.append(rec) # Global graph base FI over all heads (for the headline fragility number). global_sigs = acc.reshape(n_layers * n_heads, -1) grec = hierarchy(_corr(global_sigs), THRESHOLD, edge_cap=10 ** 9, giant=True) label = PRELOADED.get(model_id, {}).get("label") or model_id.split("/")[-1] params = PRELOADED.get(model_id, {}).get("params") return { "model_id": model_id, "label": label, "arch": _arch_label(modality), "modality": modality, "n_layers": n_layers, "n_heads": n_heads, "hidden": hidden, "params": params, "per_layer": per_layer, "fi_base": grec["fi"], "source": "live", } def get_result(model_id, force_rescan=False): """Cache-first scan: return the precomputed block for a pre-loaded model, or run a live scan (custom models, DistilBERT, or an explicit re-scan).""" if not force_rescan: cached = cached_result(model_id) if cached is not None: return cached return scan_model(model_id)