ai-model-xray / analysis.py
CognitiveEngineering's picture
Deploy AI Model X-Ray - structural health scanner
451cf40 verified
Raw
History Blame Contribute Delete
19.8 kB
"""Structural-health backend for AI Model X-Ray.
Extracts attention-head connectivity graphs from any HuggingFace transformer and
computes the spectral simplicial hierarchy per layer:
lambda2(T(G)) <= lambda2(G) (and the chain lambda2(T3) <= lambda2(T(G)) <= lambda2(G))
The graph / hierarchy / fragility math is ported verbatim from the validated
Modal experiments (`modal_jepa.py`, `modal_cross_arch.py`): lambda2 (algebraic
connectivity), triangle_graph T(G), the fragility index FI (fraction of edges in
zero triangles) and the per-layer `hierarchy` record. Only the surrounding
plumbing — model-type detection, per-layer regime classification, the precomputed
reference cache and the `scan_model` entry point — is new here.
Each layer is classified into a pruning regime from its coherence ratio
rho = lambda2(T(G)) / lambda2(G) and its FI:
immune rho > 0.8 and FI == 0 triangle-redundant -> safe to prune
buffer 0.5 <= rho <= 0.8 partial redundancy -> prune with caution
critical rho < 0.5 low redundancy -> do not prune
(layers where T(G) is disconnected have rho undefined -> treated as buffer)
"""
from __future__ import annotations
import json
import os
import statistics
# --------------------------------------------------------------------------- #
# Constants
# --------------------------------------------------------------------------- #
TOL = 1e-9
THRESHOLD = 0.3 # Pearson r edge threshold (matches the reference runs)
SCALING_FACTOR = 500.0 # structural-integrity gauge: score = 100 - FI * factor
HERE = os.path.dirname(os.path.abspath(__file__))
CACHE_PATH = os.path.join(HERE, "data", "cross_architecture_spectral.json")
# 16 diverse English sentences for text models (verbatim from modal_cross_arch).
SENTENCES = [
"The quick brown fox jumps over the lazy dog near the river.",
"Scientists discovered a new species of butterfly in the rainforest.",
"She carefully placed the fragile vase on the wooden shelf.",
"Global markets rallied after the central bank cut interest rates.",
"He learned to play the violin when he was only six years old.",
"The ancient castle stood silently on the hill above the village.",
"Machine learning models require large amounts of training data.",
"A gentle breeze carried the scent of blossoms across the garden.",
"The committee voted unanimously to approve the new budget proposal.",
"Children laughed and played in the park on a sunny afternoon.",
"The spacecraft entered orbit after a seven month journey to Mars.",
"Fresh bread and strong coffee filled the small bakery with warmth.",
"Engineers tested the bridge under heavy load before opening it.",
"The novel explores themes of memory, loss, and the passage of time.",
"Volunteers planted hundreds of trees along the eroding coastline.",
"Quantum computers may one day solve problems classical machines cannot.",
]
# Pre-loaded models surfaced in the dropdown: id -> display metadata.
PRELOADED = {
"bert-base-uncased": {"label": "BERT", "type": "Encoder", "params": "110M", "cache_key": "BERT"},
"gpt2": {"label": "GPT-2", "type": "Decoder", "params": "124M", "cache_key": "GPT-2"},
"google/vit-base-patch16-224": {"label": "ViT", "type": "Vision", "params": "86M", "cache_key": "ViT"},
"distilbert-base-uncased": {"label": "DistilBERT","type": "Encoder", "params": "66M", "cache_key": None},
}
# ========================================================================= #
# Graph / hierarchy helpers (ported verbatim from modal_jepa.py)
# ========================================================================= #
def lambda2(G):
"""Second-smallest Laplacian eigenvalue (algebraic connectivity) of G.
Returns (val, connected, n). 0 exactly when G is disconnected."""
import networkx as nx
import numpy as np
n = G.number_of_nodes()
if n < 2:
return None, False, n
connected = nx.is_connected(G)
if n <= 1800:
L = nx.laplacian_matrix(G).toarray().astype(float)
ev = np.linalg.eigvalsh(L)
return float(ev[1]), connected, n
from scipy.sparse.linalg import eigsh
Ls = nx.laplacian_matrix(G).astype(float)
try:
ev = eigsh(Ls, k=2, sigma=1e-8, which="LM", return_eigenvectors=False)
except Exception:
ev = eigsh(Ls, k=2, which="SM", return_eigenvectors=False)
return float(np.sort(ev)[1]), connected, n
def _triangles(G, adj):
seen = set()
for v in G.nodes():
nbrs = sorted(adj[v])
for x in range(len(nbrs)):
for y in range(x + 1, len(nbrs)):
a, b = nbrs[x], nbrs[y]
if b in adj[a]:
tri = tuple(sorted((v, a, b)))
if tri not in seen:
seen.add(tri)
yield tri
def triangle_graph(G):
"""T(G): nodes = edges of G; two edges adjacent iff they lie in a common
triangle. Built by enumerating triangles."""
import networkx as nx
edges = [tuple(sorted(e)) for e in G.edges()]
idx = {e: i for i, e in enumerate(edges)}
T = nx.Graph()
T.add_nodes_from(range(len(edges)))
adj = {v: set(G.neighbors(v)) for v in G.nodes()}
for a, b, c in _triangles(G, adj):
e_ab, e_ac, e_bc = tuple(sorted((a, b))), tuple(sorted((a, c))), tuple(sorted((b, c)))
i, j, k = idx[e_ab], idx[e_ac], idx[e_bc]
T.add_edge(i, j); T.add_edge(i, k); T.add_edge(j, k)
return T
def compute_fi_from_adj(A):
"""SAL fragility index: fraction of edges participating in zero triangles.
A: n x n binary symmetric adjacency (no self-loops). High FI = fragile."""
import numpy as np
n = A.shape[0]
if n < 2:
return 0.0
total = fragile = 0
for i in range(n):
row_i = A[i]
for j in range(i + 1, n):
if A[i, j]:
total += 1
if int(np.dot(row_i, A[j])) == 0:
fragile += 1
return 1.0 if total == 0 else fragile / total
def graph_from_corr(C, threshold):
"""Build a simple graph from a correlation matrix: edge iff C_ij > threshold."""
import networkx as nx
import numpy as np
A = (C > threshold).astype(np.int8)
np.fill_diagonal(A, 0)
G = nx.from_numpy_array(A)
return G, A
def _density(G):
n = G.number_of_nodes()
if n < 2:
return 0.0
return 2.0 * G.number_of_edges() / (n * (n - 1))
def _count_edges_triangles_capped(G, cap):
adj = {v: set(G.neighbors(v)) for v in G.nodes()}
c = 0
for _ in _triangles(G, adj):
c += 1
if c > cap:
return c
return c
def hierarchy(C, threshold, edge_cap=400, giant=False):
"""Per-layer spectral-hierarchy record for a correlation matrix C.
rho / violation are defined ONLY where both G and T(G) are connected
(`eligible`). If the graph is too dense the threshold is raised in small
steps so T(G) stays tractable; the threshold actually used is reported.
`giant=True` evaluates the hierarchy on G's largest connected component."""
import networkx as nx
t = threshold
bumped = False
tri_cap = 4000
while True:
G, A = graph_from_corr(C, t)
ntri = _count_edges_triangles_capped(G, tri_cap)
if (G.number_of_edges() > edge_cap or ntri > tri_cap) and t < 0.95:
t = round(t + 0.02, 4); bumped = True
continue
break
fi = compute_fi_from_adj(A)
n_comp = nx.number_connected_components(G) if G.number_of_nodes() else 0
orig_nodes, orig_edges = G.number_of_nodes(), G.number_of_edges()
if giant and orig_nodes >= 1 and not nx.is_connected(G):
G = G.subgraph(max(nx.connected_components(G), key=len)).copy()
l2G, gconn, gn = lambda2(G)
TG = triangle_graph(G)
tconn = TG.number_of_nodes() >= 2 and nx.is_connected(TG)
l2TG, _, _ = lambda2(TG)
eligible = bool(gconn and tconn)
return {
"threshold": t, "threshold_bumped": bumped, "giant": giant,
"orig_nodes": int(orig_nodes), "orig_edges": int(orig_edges),
"n_components": int(n_comp),
"nodes": int(G.number_of_nodes()), "edges": int(G.number_of_edges()),
"density": _density(G),
"G_connected": bool(gconn),
"TG_nodes": int(TG.number_of_nodes()), "TG_edges": int(TG.number_of_edges()),
"TG_connected": bool(tconn),
"eligible": eligible,
"l2G": l2G, "l2TG": l2TG,
"rho": (l2TG / l2G) if (eligible and l2G and l2G > TOL) else None,
"fi": fi,
"violation": bool(eligible and l2TG is not None and l2G is not None
and l2TG - l2G > TOL),
}
# ========================================================================= #
# Regime classification
# ========================================================================= #
def classify_regime(rec):
"""Map a per-layer hierarchy record to a pruning regime.
Returns one of 'immune' | 'buffer' | 'critical'. Layers whose triangle graph
is disconnected (rho undefined) are treated as 'buffer' — partial structure,
prune with caution."""
rho = rec.get("rho")
fi = rec.get("fi") or 0.0
if rho is None:
return "buffer"
if rho > 0.8 and fi <= TOL:
return "immune"
if rho < 0.5:
return "critical"
return "buffer"
REGIME_META = {
"immune": {"label": "Immune", "color": "#10B981", "advice": "safe to prune"},
"buffer": {"label": "Buffer", "color": "#EF9F27", "advice": "prune with caution"},
"critical": {"label": "Critical", "color": "#E24B4A", "advice": "do not prune"},
}
def summarize(result):
"""Collapse a scan result (per_layer + meta) into dashboard-ready aggregates."""
pl = result["per_layer"]
eligible = [r for r in pl if r["eligible"]]
rhos = [r["rho"] for r in eligible if r["rho"] is not None]
fis = [r["fi"] for r in pl if r["fi"] is not None]
regimes = [classify_regime(r) for r in pl]
rho_mean = statistics.mean(rhos) if rhos else None
# Headline fragility = full-graph (global) base FI when available, else mean.
fi_head = result.get("fi_base")
if fi_head is None:
fi_head = statistics.mean(fis) if fis else 0.0
score = max(0.0, min(100.0, 100.0 - fi_head * SCALING_FACTOR))
return {
"n_layers": result["n_layers"],
"n_heads": result["n_heads"],
"n_immune": regimes.count("immune"),
"n_buffer": regimes.count("buffer"),
"n_critical": regimes.count("critical"),
"n_eligible": len(eligible),
"violations": sum(1 for r in eligible if r["violation"]),
"rho_mean": rho_mean,
"rho_min": min(rhos) if rhos else None,
"rho_max": max(rhos) if rhos else None,
"fi_head": fi_head,
"score": score,
"regimes": regimes,
}
def score_band(score):
"""(label, color) for the structural-integrity gauge value."""
if score >= 92:
return "Over-redundant", "#10B981"
if score >= 80:
return "Healthy", "#10B981"
if score >= 50:
return "Buffer zone", "#EF9F27"
return "Critical", "#E24B4A"
# ========================================================================= #
# Precomputed reference cache
# ========================================================================= #
_CACHE = None
def _load_cache():
global _CACHE
if _CACHE is None:
with open(CACHE_PATH, encoding="utf-8") as f:
_CACHE = json.load(f)
return _CACHE
def reference_rows():
"""The architecture-comparison table rows (I-JEPA, BERT, GPT-2, ViT)."""
cache = _load_cache()
return list(cache["comparison"])
def cached_result(model_id):
"""Return a scan-shaped result for a pre-loaded model from the cache, or None
if the model has no precomputed block (e.g. DistilBERT -> needs a live scan)."""
meta = PRELOADED.get(model_id)
if not meta or not meta.get("cache_key"):
return None
cache = _load_cache()
block = cache["models"].get(meta["cache_key"])
if block is None:
return None
return {
"model_id": model_id,
"label": meta["label"],
"arch": meta["type"],
"modality": block["modality"],
"n_layers": block["n_layers"],
"n_heads": block["n_heads"],
"hidden": block["hidden"],
"params": meta["params"],
"per_layer": block["per_layer"],
"fi_base": block["masking"]["base_fi"],
"source": "cached",
}
# ========================================================================= #
# Live attention extraction + scan
# ========================================================================= #
def _detect_modality(config):
"""Classify a loaded model as 'vision' | 'decoder' | 'encoder'."""
arch = " ".join(getattr(config, "architectures", None) or []).lower()
mt = (getattr(config, "model_type", "") or "").lower()
if "vit" in mt or "vit" in arch or "image" in mt or getattr(config, "image_size", None):
return "vision"
if getattr(config, "is_decoder", False) or "gpt" in mt or "llama" in mt \
or "causal" in arch or "lmhead" in arch:
return "decoder"
return "encoder"
def _arch_label(modality):
return {"vision": "Vision", "decoder": "Decoder", "encoder": "Encoder"}[modality]
def _corr(M):
import numpy as np
C = np.corrcoef(M)
return np.nan_to_num(C, nan=0.0)
def _extract_text(model_id, dev, max_length=32):
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer
tok = AutoTokenizer.from_pretrained(model_id)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModel.from_pretrained(
model_id, attn_implementation="eager", torch_dtype=torch.float32)
model.to(dev).eval()
cfg = model.config
n_layers, n_heads = cfg.num_hidden_layers, cfg.num_attention_heads
acc = None; seq_len = None; count = 0
for s in range(0, len(SENTENCES), 4):
batch = SENTENCES[s:s + 4]
inp = tok(batch, return_tensors="pt", padding="max_length",
truncation=True, max_length=max_length).to(dev)
with torch.no_grad():
out = model(**inp, output_attentions=True)
atts = out.attentions
if atts is None or atts[0] is None:
raise RuntimeError(f"{model_id} returned no attentions")
if acc is None:
seq_len = atts[0].shape[-1]
acc = np.zeros((n_layers, n_heads, seq_len, seq_len), dtype=np.float64)
for L in range(n_layers):
acc[L] += atts[L].sum(dim=0).float().cpu().numpy()
count += len(batch)
acc /= count
return acc, n_layers, n_heads, int(cfg.hidden_size), int(seq_len)
def _extract_vision(model_id, dev, num_samples=16):
import numpy as np
import torch
from transformers import AutoImageProcessor, AutoModel
proc = AutoImageProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(
model_id, attn_implementation="eager", torch_dtype=torch.float32)
model.to(dev).eval()
cfg = model.config
n_layers, n_heads = cfg.num_hidden_layers, cfg.num_attention_heads
images = _load_images(num_samples)
acc = None; seq_len = None; count = 0
for s in range(0, len(images), 4):
batch = images[s:s + 4]
inp = proc(images=batch, return_tensors="pt").to(dev)
with torch.no_grad():
out = model(**inp, output_attentions=True)
atts = out.attentions
if atts is None or atts[0] is None:
raise RuntimeError(f"{model_id} returned no attentions")
if acc is None:
seq_len = atts[0].shape[-1]
acc = np.zeros((n_layers, n_heads, seq_len, seq_len), dtype=np.float64)
for L in range(n_layers):
acc[L] += atts[L].sum(dim=0).float().cpu().numpy()
count += len(batch)
acc /= count
return acc, n_layers, n_heads, int(cfg.hidden_size), int(seq_len)
def _load_images(n):
"""n PIL RGB images: HF cifar10 test split, with a COCO-URL fallback."""
from PIL import Image
imgs = []
try:
from datasets import load_dataset
ds = load_dataset("uoft-cs/cifar10", split=f"test[:{n}]")
col = "img" if "img" in ds.column_names else ds.column_names[0]
imgs = [ds[i][col].convert("RGB") for i in range(len(ds))]
except Exception:
pass
if len(imgs) < 4:
import io
import urllib.request
urls = [
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/val2017/000000000139.jpg",
"http://images.cocodataset.org/val2017/000000000285.jpg",
"http://images.cocodataset.org/val2017/000000000632.jpg",
"http://images.cocodataset.org/val2017/000000000724.jpg",
"http://images.cocodataset.org/val2017/000000000776.jpg",
"http://images.cocodataset.org/val2017/000000000785.jpg",
"http://images.cocodataset.org/val2017/000000000802.jpg",
]
for u in urls[:max(n, 8)]:
try:
with urllib.request.urlopen(u, timeout=20) as r:
imgs.append(Image.open(io.BytesIO(r.read())).convert("RGB"))
except Exception:
pass
if not imgs:
raise RuntimeError("could not load any probe images")
return imgs[:n]
def scan_model(model_id, num_samples=16):
"""Full structural scan of a HuggingFace transformer.
Loads the model with output_attentions, runs it on probe data appropriate to
its modality (16 sentences for text, 16 CIFAR-10 images for vision), builds
the per-layer attention-correlation graph (Pearson r > 0.3) and computes the
spectral hierarchy + fragility index per layer. Returns a scan-shaped dict
consumable by `summarize` and the UI renderers."""
import torch
from transformers import AutoConfig
dev = "cuda" if torch.cuda.is_available() else "cpu"
cfg = AutoConfig.from_pretrained(model_id)
modality = _detect_modality(cfg)
if modality == "vision":
acc, n_layers, n_heads, hidden, seq_len = _extract_vision(model_id, dev, num_samples)
else:
acc, n_layers, n_heads, hidden, seq_len = _extract_text(model_id, dev)
per_layer = []
for L in range(n_layers):
C = _corr(acc[L].reshape(n_heads, -1))
rec = hierarchy(C, THRESHOLD, edge_cap=400)
rec["layer"] = L
per_layer.append(rec)
# Global graph base FI over all heads (for the headline fragility number).
global_sigs = acc.reshape(n_layers * n_heads, -1)
grec = hierarchy(_corr(global_sigs), THRESHOLD, edge_cap=10 ** 9, giant=True)
label = PRELOADED.get(model_id, {}).get("label") or model_id.split("/")[-1]
params = PRELOADED.get(model_id, {}).get("params")
return {
"model_id": model_id,
"label": label,
"arch": _arch_label(modality),
"modality": modality,
"n_layers": n_layers,
"n_heads": n_heads,
"hidden": hidden,
"params": params,
"per_layer": per_layer,
"fi_base": grec["fi"],
"source": "live",
}
def get_result(model_id, force_rescan=False):
"""Cache-first scan: return the precomputed block for a pre-loaded model, or
run a live scan (custom models, DistilBERT, or an explicit re-scan)."""
if not force_rescan:
cached = cached_result(model_id)
if cached is not None:
return cached
return scan_model(model_id)