Spaces:
Sleeping
Sleeping
| """Structural-health backend for AI Model X-Ray. | |
| Extracts attention-head connectivity graphs from any HuggingFace transformer and | |
| computes the spectral simplicial hierarchy per layer: | |
| lambda2(T(G)) <= lambda2(G) (and the chain lambda2(T3) <= lambda2(T(G)) <= lambda2(G)) | |
| The graph / hierarchy / fragility math is ported verbatim from the validated | |
| Modal experiments (`modal_jepa.py`, `modal_cross_arch.py`): lambda2 (algebraic | |
| connectivity), triangle_graph T(G), the fragility index FI (fraction of edges in | |
| zero triangles) and the per-layer `hierarchy` record. Only the surrounding | |
| plumbing — model-type detection, per-layer regime classification, the precomputed | |
| reference cache and the `scan_model` entry point — is new here. | |
| Each layer is classified into a pruning regime from its coherence ratio | |
| rho = lambda2(T(G)) / lambda2(G) and its FI: | |
| immune rho > 0.8 and FI == 0 triangle-redundant -> safe to prune | |
| buffer 0.5 <= rho <= 0.8 partial redundancy -> prune with caution | |
| critical rho < 0.5 low redundancy -> do not prune | |
| (layers where T(G) is disconnected have rho undefined -> treated as buffer) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import statistics | |
| # --------------------------------------------------------------------------- # | |
| # Constants | |
| # --------------------------------------------------------------------------- # | |
| TOL = 1e-9 | |
| THRESHOLD = 0.3 # Pearson r edge threshold (matches the reference runs) | |
| SCALING_FACTOR = 500.0 # structural-integrity gauge: score = 100 - FI * factor | |
| HERE = os.path.dirname(os.path.abspath(__file__)) | |
| CACHE_PATH = os.path.join(HERE, "data", "cross_architecture_spectral.json") | |
| # 16 diverse English sentences for text models (verbatim from modal_cross_arch). | |
| SENTENCES = [ | |
| "The quick brown fox jumps over the lazy dog near the river.", | |
| "Scientists discovered a new species of butterfly in the rainforest.", | |
| "She carefully placed the fragile vase on the wooden shelf.", | |
| "Global markets rallied after the central bank cut interest rates.", | |
| "He learned to play the violin when he was only six years old.", | |
| "The ancient castle stood silently on the hill above the village.", | |
| "Machine learning models require large amounts of training data.", | |
| "A gentle breeze carried the scent of blossoms across the garden.", | |
| "The committee voted unanimously to approve the new budget proposal.", | |
| "Children laughed and played in the park on a sunny afternoon.", | |
| "The spacecraft entered orbit after a seven month journey to Mars.", | |
| "Fresh bread and strong coffee filled the small bakery with warmth.", | |
| "Engineers tested the bridge under heavy load before opening it.", | |
| "The novel explores themes of memory, loss, and the passage of time.", | |
| "Volunteers planted hundreds of trees along the eroding coastline.", | |
| "Quantum computers may one day solve problems classical machines cannot.", | |
| ] | |
| # Pre-loaded models surfaced in the dropdown: id -> display metadata. | |
| PRELOADED = { | |
| "bert-base-uncased": {"label": "BERT", "type": "Encoder", "params": "110M", "cache_key": "BERT"}, | |
| "gpt2": {"label": "GPT-2", "type": "Decoder", "params": "124M", "cache_key": "GPT-2"}, | |
| "google/vit-base-patch16-224": {"label": "ViT", "type": "Vision", "params": "86M", "cache_key": "ViT"}, | |
| "distilbert-base-uncased": {"label": "DistilBERT","type": "Encoder", "params": "66M", "cache_key": None}, | |
| } | |
| # ========================================================================= # | |
| # Graph / hierarchy helpers (ported verbatim from modal_jepa.py) | |
| # ========================================================================= # | |
| def lambda2(G): | |
| """Second-smallest Laplacian eigenvalue (algebraic connectivity) of G. | |
| Returns (val, connected, n). 0 exactly when G is disconnected.""" | |
| import networkx as nx | |
| import numpy as np | |
| n = G.number_of_nodes() | |
| if n < 2: | |
| return None, False, n | |
| connected = nx.is_connected(G) | |
| if n <= 1800: | |
| L = nx.laplacian_matrix(G).toarray().astype(float) | |
| ev = np.linalg.eigvalsh(L) | |
| return float(ev[1]), connected, n | |
| from scipy.sparse.linalg import eigsh | |
| Ls = nx.laplacian_matrix(G).astype(float) | |
| try: | |
| ev = eigsh(Ls, k=2, sigma=1e-8, which="LM", return_eigenvectors=False) | |
| except Exception: | |
| ev = eigsh(Ls, k=2, which="SM", return_eigenvectors=False) | |
| return float(np.sort(ev)[1]), connected, n | |
| def _triangles(G, adj): | |
| seen = set() | |
| for v in G.nodes(): | |
| nbrs = sorted(adj[v]) | |
| for x in range(len(nbrs)): | |
| for y in range(x + 1, len(nbrs)): | |
| a, b = nbrs[x], nbrs[y] | |
| if b in adj[a]: | |
| tri = tuple(sorted((v, a, b))) | |
| if tri not in seen: | |
| seen.add(tri) | |
| yield tri | |
| def triangle_graph(G): | |
| """T(G): nodes = edges of G; two edges adjacent iff they lie in a common | |
| triangle. Built by enumerating triangles.""" | |
| import networkx as nx | |
| edges = [tuple(sorted(e)) for e in G.edges()] | |
| idx = {e: i for i, e in enumerate(edges)} | |
| T = nx.Graph() | |
| T.add_nodes_from(range(len(edges))) | |
| adj = {v: set(G.neighbors(v)) for v in G.nodes()} | |
| for a, b, c in _triangles(G, adj): | |
| e_ab, e_ac, e_bc = tuple(sorted((a, b))), tuple(sorted((a, c))), tuple(sorted((b, c))) | |
| i, j, k = idx[e_ab], idx[e_ac], idx[e_bc] | |
| T.add_edge(i, j); T.add_edge(i, k); T.add_edge(j, k) | |
| return T | |
| def compute_fi_from_adj(A): | |
| """SAL fragility index: fraction of edges participating in zero triangles. | |
| A: n x n binary symmetric adjacency (no self-loops). High FI = fragile.""" | |
| import numpy as np | |
| n = A.shape[0] | |
| if n < 2: | |
| return 0.0 | |
| total = fragile = 0 | |
| for i in range(n): | |
| row_i = A[i] | |
| for j in range(i + 1, n): | |
| if A[i, j]: | |
| total += 1 | |
| if int(np.dot(row_i, A[j])) == 0: | |
| fragile += 1 | |
| return 1.0 if total == 0 else fragile / total | |
| def graph_from_corr(C, threshold): | |
| """Build a simple graph from a correlation matrix: edge iff C_ij > threshold.""" | |
| import networkx as nx | |
| import numpy as np | |
| A = (C > threshold).astype(np.int8) | |
| np.fill_diagonal(A, 0) | |
| G = nx.from_numpy_array(A) | |
| return G, A | |
| def _density(G): | |
| n = G.number_of_nodes() | |
| if n < 2: | |
| return 0.0 | |
| return 2.0 * G.number_of_edges() / (n * (n - 1)) | |
| def _count_edges_triangles_capped(G, cap): | |
| adj = {v: set(G.neighbors(v)) for v in G.nodes()} | |
| c = 0 | |
| for _ in _triangles(G, adj): | |
| c += 1 | |
| if c > cap: | |
| return c | |
| return c | |
| def hierarchy(C, threshold, edge_cap=400, giant=False): | |
| """Per-layer spectral-hierarchy record for a correlation matrix C. | |
| rho / violation are defined ONLY where both G and T(G) are connected | |
| (`eligible`). If the graph is too dense the threshold is raised in small | |
| steps so T(G) stays tractable; the threshold actually used is reported. | |
| `giant=True` evaluates the hierarchy on G's largest connected component.""" | |
| import networkx as nx | |
| t = threshold | |
| bumped = False | |
| tri_cap = 4000 | |
| while True: | |
| G, A = graph_from_corr(C, t) | |
| ntri = _count_edges_triangles_capped(G, tri_cap) | |
| if (G.number_of_edges() > edge_cap or ntri > tri_cap) and t < 0.95: | |
| t = round(t + 0.02, 4); bumped = True | |
| continue | |
| break | |
| fi = compute_fi_from_adj(A) | |
| n_comp = nx.number_connected_components(G) if G.number_of_nodes() else 0 | |
| orig_nodes, orig_edges = G.number_of_nodes(), G.number_of_edges() | |
| if giant and orig_nodes >= 1 and not nx.is_connected(G): | |
| G = G.subgraph(max(nx.connected_components(G), key=len)).copy() | |
| l2G, gconn, gn = lambda2(G) | |
| TG = triangle_graph(G) | |
| tconn = TG.number_of_nodes() >= 2 and nx.is_connected(TG) | |
| l2TG, _, _ = lambda2(TG) | |
| eligible = bool(gconn and tconn) | |
| return { | |
| "threshold": t, "threshold_bumped": bumped, "giant": giant, | |
| "orig_nodes": int(orig_nodes), "orig_edges": int(orig_edges), | |
| "n_components": int(n_comp), | |
| "nodes": int(G.number_of_nodes()), "edges": int(G.number_of_edges()), | |
| "density": _density(G), | |
| "G_connected": bool(gconn), | |
| "TG_nodes": int(TG.number_of_nodes()), "TG_edges": int(TG.number_of_edges()), | |
| "TG_connected": bool(tconn), | |
| "eligible": eligible, | |
| "l2G": l2G, "l2TG": l2TG, | |
| "rho": (l2TG / l2G) if (eligible and l2G and l2G > TOL) else None, | |
| "fi": fi, | |
| "violation": bool(eligible and l2TG is not None and l2G is not None | |
| and l2TG - l2G > TOL), | |
| } | |
| # ========================================================================= # | |
| # Regime classification | |
| # ========================================================================= # | |
| def classify_regime(rec): | |
| """Map a per-layer hierarchy record to a pruning regime. | |
| Returns one of 'immune' | 'buffer' | 'critical'. Layers whose triangle graph | |
| is disconnected (rho undefined) are treated as 'buffer' — partial structure, | |
| prune with caution.""" | |
| rho = rec.get("rho") | |
| fi = rec.get("fi") or 0.0 | |
| if rho is None: | |
| return "buffer" | |
| if rho > 0.8 and fi <= TOL: | |
| return "immune" | |
| if rho < 0.5: | |
| return "critical" | |
| return "buffer" | |
| REGIME_META = { | |
| "immune": {"label": "Immune", "color": "#10B981", "advice": "safe to prune"}, | |
| "buffer": {"label": "Buffer", "color": "#EF9F27", "advice": "prune with caution"}, | |
| "critical": {"label": "Critical", "color": "#E24B4A", "advice": "do not prune"}, | |
| } | |
| def summarize(result): | |
| """Collapse a scan result (per_layer + meta) into dashboard-ready aggregates.""" | |
| pl = result["per_layer"] | |
| eligible = [r for r in pl if r["eligible"]] | |
| rhos = [r["rho"] for r in eligible if r["rho"] is not None] | |
| fis = [r["fi"] for r in pl if r["fi"] is not None] | |
| regimes = [classify_regime(r) for r in pl] | |
| rho_mean = statistics.mean(rhos) if rhos else None | |
| # Headline fragility = full-graph (global) base FI when available, else mean. | |
| fi_head = result.get("fi_base") | |
| if fi_head is None: | |
| fi_head = statistics.mean(fis) if fis else 0.0 | |
| score = max(0.0, min(100.0, 100.0 - fi_head * SCALING_FACTOR)) | |
| return { | |
| "n_layers": result["n_layers"], | |
| "n_heads": result["n_heads"], | |
| "n_immune": regimes.count("immune"), | |
| "n_buffer": regimes.count("buffer"), | |
| "n_critical": regimes.count("critical"), | |
| "n_eligible": len(eligible), | |
| "violations": sum(1 for r in eligible if r["violation"]), | |
| "rho_mean": rho_mean, | |
| "rho_min": min(rhos) if rhos else None, | |
| "rho_max": max(rhos) if rhos else None, | |
| "fi_head": fi_head, | |
| "score": score, | |
| "regimes": regimes, | |
| } | |
| def score_band(score): | |
| """(label, color) for the structural-integrity gauge value.""" | |
| if score >= 92: | |
| return "Over-redundant", "#10B981" | |
| if score >= 80: | |
| return "Healthy", "#10B981" | |
| if score >= 50: | |
| return "Buffer zone", "#EF9F27" | |
| return "Critical", "#E24B4A" | |
| # ========================================================================= # | |
| # Precomputed reference cache | |
| # ========================================================================= # | |
| _CACHE = None | |
| def _load_cache(): | |
| global _CACHE | |
| if _CACHE is None: | |
| with open(CACHE_PATH, encoding="utf-8") as f: | |
| _CACHE = json.load(f) | |
| return _CACHE | |
| def reference_rows(): | |
| """The architecture-comparison table rows (I-JEPA, BERT, GPT-2, ViT).""" | |
| cache = _load_cache() | |
| return list(cache["comparison"]) | |
| def cached_result(model_id): | |
| """Return a scan-shaped result for a pre-loaded model from the cache, or None | |
| if the model has no precomputed block (e.g. DistilBERT -> needs a live scan).""" | |
| meta = PRELOADED.get(model_id) | |
| if not meta or not meta.get("cache_key"): | |
| return None | |
| cache = _load_cache() | |
| block = cache["models"].get(meta["cache_key"]) | |
| if block is None: | |
| return None | |
| return { | |
| "model_id": model_id, | |
| "label": meta["label"], | |
| "arch": meta["type"], | |
| "modality": block["modality"], | |
| "n_layers": block["n_layers"], | |
| "n_heads": block["n_heads"], | |
| "hidden": block["hidden"], | |
| "params": meta["params"], | |
| "per_layer": block["per_layer"], | |
| "fi_base": block["masking"]["base_fi"], | |
| "source": "cached", | |
| } | |
| # ========================================================================= # | |
| # Live attention extraction + scan | |
| # ========================================================================= # | |
| def _detect_modality(config): | |
| """Classify a loaded model as 'vision' | 'decoder' | 'encoder'.""" | |
| arch = " ".join(getattr(config, "architectures", None) or []).lower() | |
| mt = (getattr(config, "model_type", "") or "").lower() | |
| if "vit" in mt or "vit" in arch or "image" in mt or getattr(config, "image_size", None): | |
| return "vision" | |
| if getattr(config, "is_decoder", False) or "gpt" in mt or "llama" in mt \ | |
| or "causal" in arch or "lmhead" in arch: | |
| return "decoder" | |
| return "encoder" | |
| def _arch_label(modality): | |
| return {"vision": "Vision", "decoder": "Decoder", "encoder": "Encoder"}[modality] | |
| def _corr(M): | |
| import numpy as np | |
| C = np.corrcoef(M) | |
| return np.nan_to_num(C, nan=0.0) | |
| def _extract_text(model_id, dev, max_length=32): | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| tok = AutoTokenizer.from_pretrained(model_id) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| model = AutoModel.from_pretrained( | |
| model_id, attn_implementation="eager", torch_dtype=torch.float32) | |
| model.to(dev).eval() | |
| cfg = model.config | |
| n_layers, n_heads = cfg.num_hidden_layers, cfg.num_attention_heads | |
| acc = None; seq_len = None; count = 0 | |
| for s in range(0, len(SENTENCES), 4): | |
| batch = SENTENCES[s:s + 4] | |
| inp = tok(batch, return_tensors="pt", padding="max_length", | |
| truncation=True, max_length=max_length).to(dev) | |
| with torch.no_grad(): | |
| out = model(**inp, output_attentions=True) | |
| atts = out.attentions | |
| if atts is None or atts[0] is None: | |
| raise RuntimeError(f"{model_id} returned no attentions") | |
| if acc is None: | |
| seq_len = atts[0].shape[-1] | |
| acc = np.zeros((n_layers, n_heads, seq_len, seq_len), dtype=np.float64) | |
| for L in range(n_layers): | |
| acc[L] += atts[L].sum(dim=0).float().cpu().numpy() | |
| count += len(batch) | |
| acc /= count | |
| return acc, n_layers, n_heads, int(cfg.hidden_size), int(seq_len) | |
| def _extract_vision(model_id, dev, num_samples=16): | |
| import numpy as np | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModel | |
| proc = AutoImageProcessor.from_pretrained(model_id) | |
| model = AutoModel.from_pretrained( | |
| model_id, attn_implementation="eager", torch_dtype=torch.float32) | |
| model.to(dev).eval() | |
| cfg = model.config | |
| n_layers, n_heads = cfg.num_hidden_layers, cfg.num_attention_heads | |
| images = _load_images(num_samples) | |
| acc = None; seq_len = None; count = 0 | |
| for s in range(0, len(images), 4): | |
| batch = images[s:s + 4] | |
| inp = proc(images=batch, return_tensors="pt").to(dev) | |
| with torch.no_grad(): | |
| out = model(**inp, output_attentions=True) | |
| atts = out.attentions | |
| if atts is None or atts[0] is None: | |
| raise RuntimeError(f"{model_id} returned no attentions") | |
| if acc is None: | |
| seq_len = atts[0].shape[-1] | |
| acc = np.zeros((n_layers, n_heads, seq_len, seq_len), dtype=np.float64) | |
| for L in range(n_layers): | |
| acc[L] += atts[L].sum(dim=0).float().cpu().numpy() | |
| count += len(batch) | |
| acc /= count | |
| return acc, n_layers, n_heads, int(cfg.hidden_size), int(seq_len) | |
| def _load_images(n): | |
| """n PIL RGB images: HF cifar10 test split, with a COCO-URL fallback.""" | |
| from PIL import Image | |
| imgs = [] | |
| try: | |
| from datasets import load_dataset | |
| ds = load_dataset("uoft-cs/cifar10", split=f"test[:{n}]") | |
| col = "img" if "img" in ds.column_names else ds.column_names[0] | |
| imgs = [ds[i][col].convert("RGB") for i in range(len(ds))] | |
| except Exception: | |
| pass | |
| if len(imgs) < 4: | |
| import io | |
| import urllib.request | |
| urls = [ | |
| "http://images.cocodataset.org/val2017/000000039769.jpg", | |
| "http://images.cocodataset.org/val2017/000000000139.jpg", | |
| "http://images.cocodataset.org/val2017/000000000285.jpg", | |
| "http://images.cocodataset.org/val2017/000000000632.jpg", | |
| "http://images.cocodataset.org/val2017/000000000724.jpg", | |
| "http://images.cocodataset.org/val2017/000000000776.jpg", | |
| "http://images.cocodataset.org/val2017/000000000785.jpg", | |
| "http://images.cocodataset.org/val2017/000000000802.jpg", | |
| ] | |
| for u in urls[:max(n, 8)]: | |
| try: | |
| with urllib.request.urlopen(u, timeout=20) as r: | |
| imgs.append(Image.open(io.BytesIO(r.read())).convert("RGB")) | |
| except Exception: | |
| pass | |
| if not imgs: | |
| raise RuntimeError("could not load any probe images") | |
| return imgs[:n] | |
| def scan_model(model_id, num_samples=16): | |
| """Full structural scan of a HuggingFace transformer. | |
| Loads the model with output_attentions, runs it on probe data appropriate to | |
| its modality (16 sentences for text, 16 CIFAR-10 images for vision), builds | |
| the per-layer attention-correlation graph (Pearson r > 0.3) and computes the | |
| spectral hierarchy + fragility index per layer. Returns a scan-shaped dict | |
| consumable by `summarize` and the UI renderers.""" | |
| import torch | |
| from transformers import AutoConfig | |
| dev = "cuda" if torch.cuda.is_available() else "cpu" | |
| cfg = AutoConfig.from_pretrained(model_id) | |
| modality = _detect_modality(cfg) | |
| if modality == "vision": | |
| acc, n_layers, n_heads, hidden, seq_len = _extract_vision(model_id, dev, num_samples) | |
| else: | |
| acc, n_layers, n_heads, hidden, seq_len = _extract_text(model_id, dev) | |
| per_layer = [] | |
| for L in range(n_layers): | |
| C = _corr(acc[L].reshape(n_heads, -1)) | |
| rec = hierarchy(C, THRESHOLD, edge_cap=400) | |
| rec["layer"] = L | |
| per_layer.append(rec) | |
| # Global graph base FI over all heads (for the headline fragility number). | |
| global_sigs = acc.reshape(n_layers * n_heads, -1) | |
| grec = hierarchy(_corr(global_sigs), THRESHOLD, edge_cap=10 ** 9, giant=True) | |
| label = PRELOADED.get(model_id, {}).get("label") or model_id.split("/")[-1] | |
| params = PRELOADED.get(model_id, {}).get("params") | |
| return { | |
| "model_id": model_id, | |
| "label": label, | |
| "arch": _arch_label(modality), | |
| "modality": modality, | |
| "n_layers": n_layers, | |
| "n_heads": n_heads, | |
| "hidden": hidden, | |
| "params": params, | |
| "per_layer": per_layer, | |
| "fi_base": grec["fi"], | |
| "source": "live", | |
| } | |
| def get_result(model_id, force_rescan=False): | |
| """Cache-first scan: return the precomputed block for a pre-loaded model, or | |
| run a live scan (custom models, DistilBERT, or an explicit re-scan).""" | |
| if not force_rescan: | |
| cached = cached_result(model_id) | |
| if cached is not None: | |
| return cached | |
| return scan_model(model_id) | |