SeeMODEL / app.py
priyadip's picture
Fix: raise max_lines to 5000 and force textarea overflow-y to allow full scroll
c8c9cf2
"""
πŸ”¬ Neural Model Analyzer β€” Advanced Model Introspection Tool
Supports every model on Hugging Face. Runs on free CPU Spaces.
"""
import gradio as gr
import torch
import os
import gc
import numpy as np
import re
from collections import OrderedDict, defaultdict, Counter
from pathlib import Path
import traceback
try:
from safetensors import safe_open
HAS_SAFETENSORS = True
except ImportError:
HAS_SAFETENSORS = False
try:
import onnx
from onnx import numpy_helper
HAS_ONNX = True
except ImportError:
HAS_ONNX = False
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import warnings
warnings.filterwarnings("ignore")
from architecture_detector import detect_architectures, infer_model_config, format_detection_report
# ─────────────────────────── Styling ───────────────────────────
CUSTOM_CSS = """
.gradio-container { max-width: 1400px !important; }
.main-header {
text-align: center; padding: 24px 10px 6px 10px;
background: linear-gradient(135deg, #0d1117 0%, #161b22 100%);
border-radius: 16px; margin-bottom: 10px;
border: 1px solid #30363d;
}
.main-header h1 {
font-size: 2.4em; font-weight: 800; margin: 0;
background: linear-gradient(135deg, #58a6ff 0%, #bc8cff 50%, #f778ba 100%);
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
}
.main-header p { color: #8b949e; font-size: 0.95em; margin: 6px 0 0 0; }
.info-panel {
background: #0d1117; border-left: 3px solid #58a6ff;
padding: 10px 14px; border-radius: 0 8px 8px 0;
margin: 6px 0; font-size: 0.88em; color: #c9d1d9;
}
/* Ensure all textareas are fully scrollable */
textarea {
overflow-y: auto !important;
resize: vertical !important;
}
footer { display: none !important; }
"""
MAX_FILE_SIZE_MB = 4096
def _sizeof_fmt(num_bytes):
for unit in ["B", "KB", "MB", "GB"]:
if abs(num_bytes) < 1024:
return f"{num_bytes:.1f} {unit}"
num_bytes /= 1024
return f"{num_bytes:.1f} TB"
def _fmt(n):
if n >= 1e9: return f"{n / 1e9:.2f}B"
if n >= 1e6: return f"{n / 1e6:.2f}M"
if n >= 1e3: return f"{n / 1e3:.1f}K"
return str(n)
# ─────────────────────────── Loading ───────────────────────────
def load_state_dict_safe(file_path: str) -> tuple:
ext = Path(file_path).suffix.lower()
file_size = os.path.getsize(file_path)
metadata = {
"format": ext,
"file_size_bytes": file_size,
"file_size_str": _sizeof_fmt(file_size),
"file_name": Path(file_path).name,
}
state_dict = OrderedDict()
if file_size > MAX_FILE_SIZE_MB * 1024 * 1024:
raise ValueError(
f"File is {_sizeof_fmt(file_size)} β€” exceeds {MAX_FILE_SIZE_MB} MB limit for free CPU."
)
gc.collect()
if ext in (".pth", ".pt"):
try:
data = torch.load(file_path, map_location="cpu", weights_only=False)
except Exception:
data = torch.load(file_path, map_location="cpu", weights_only=True)
if isinstance(data, torch.nn.Module):
metadata["has_full_model"] = True
metadata["model_repr"] = repr(data)[:5000]
state_dict = OrderedDict(data.state_dict())
elif isinstance(data, dict):
for ck in ["state_dict", "model_state_dict", "model", "net", "params",
"network", "g_ema", "gen", "generator", "discriminator",
"ema_model", "module", "teacher", "student"]:
if ck in data:
candidate = data[ck]
if isinstance(candidate, dict):
state_dict = OrderedDict(candidate)
metadata["checkpoint_key_used"] = ck
break
elif isinstance(candidate, torch.nn.Module):
state_dict = OrderedDict(candidate.state_dict())
metadata["checkpoint_key_used"] = ck
metadata["has_full_model"] = True
break
if not state_dict:
tensors = {k: v for k, v in data.items() if isinstance(v, torch.Tensor)}
if tensors:
state_dict = OrderedDict(tensors)
else:
for top_k, top_v in data.items():
if isinstance(top_v, dict):
sub = {f"{top_k}.{k}": v for k, v in top_v.items()
if isinstance(v, torch.Tensor)}
state_dict.update(sub)
extra = [k for k in data.keys()
if not isinstance(data[k], (torch.Tensor, torch.nn.Module, dict))]
if extra:
metadata["extra_checkpoint_keys"] = extra[:20]
for ik in ["epoch", "global_step", "step", "iteration", "best_metric", "best_loss"]:
if ik in data:
metadata[ik] = data[ik]
if any(k in data for k in ["optimizer", "optimizer_state_dict", "optim"]):
metadata["has_optimizer_state"] = True
if any(k in data for k in ["scheduler", "lr_scheduler"]):
metadata["has_lr_scheduler"] = True
if "config" in data:
metadata["embedded_config"] = str(data["config"])[:3000]
if "args" in data:
metadata["training_args"] = str(data["args"])[:2000]
del data; gc.collect()
else:
metadata["note"] = f"File contains {type(data).__name__}."
elif ext == ".bin":
data = torch.load(file_path, map_location="cpu", weights_only=False)
if isinstance(data, dict):
state_dict = OrderedDict({k: v for k, v in data.items() if isinstance(v, torch.Tensor)})
metadata["format_note"] = "PyTorch .bin (HuggingFace style)"
del data; gc.collect()
elif ext in (".safetensors",):
if not HAS_SAFETENSORS:
raise ImportError("safetensors library not installed")
with safe_open(file_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
try:
sf_meta = f.metadata()
if sf_meta:
metadata["safetensors_metadata"] = dict(sf_meta)
except Exception:
pass
metadata["format_note"] = "SafeTensors"
elif ext == ".onnx":
if not HAS_ONNX:
raise ImportError("onnx library not installed")
model = onnx.load(file_path)
for init in model.graph.initializer:
arr = numpy_helper.to_array(init)
state_dict[init.name] = torch.from_numpy(arr.copy())
metadata["format_note"] = "ONNX"
metadata["onnx_opset"] = [op.version for op in model.opset_import]
metadata["onnx_ir_version"] = model.ir_version
metadata["onnx_inputs"] = [
{"name": inp.name,
"shape": [d.dim_value if d.dim_value else (d.dim_param or "?")
for d in inp.type.tensor_type.shape.dim]}
for inp in model.graph.input
]
metadata["onnx_outputs"] = [
{"name": out.name,
"shape": [d.dim_value if d.dim_value else (d.dim_param or "?")
for d in out.type.tensor_type.shape.dim]}
for out in model.graph.output
]
op_counts = Counter(node.op_type for node in model.graph.node)
metadata["onnx_ops"] = dict(op_counts.most_common(30))
del model; gc.collect()
else:
raise ValueError(f"Unsupported: '{ext}'. Use .pth .pt .bin .safetensors .onnx")
cleaned = OrderedDict()
for k, v in state_dict.items():
cleaned[re.sub(r"^(module\.|model\.module\.)", "", k)] = v
state_dict = cleaned
return state_dict, metadata
# ─────────────────────────── Analysis ───────────────────────────
def build_summary(state_dict, metadata, detections, config) -> str:
total_params = sum(t.numel() for t in state_dict.values())
total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values())
dtypes = sorted(set(str(t.dtype) for t in state_dict.values()))
s = "+" + "=" * 63 + "+\n"
s += "| \U0001f52c NEURAL MODEL ANALYSIS REPORT |\n"
s += "+" + "=" * 63 + "+\n\n"
s += "\U0001f4c1 FILE INFO\n"
s += f" File name: {metadata.get('file_name', '?')}\n"
s += f" File format: {metadata.get('format', '?')} {metadata.get('format_note', '')}\n"
s += f" File size: {metadata.get('file_size_str', '?')}\n"
if metadata.get("checkpoint_key_used"):
s += f" Checkpoint key: '{metadata['checkpoint_key_used']}'\n"
if metadata.get("has_full_model"):
s += f" Full model: Yes (nn.Module)\n"
if metadata.get("has_optimizer_state"):
s += f" Optimizer: Included\n"
if metadata.get("has_lr_scheduler"):
s += f" LR Scheduler: Included\n"
for k in ["epoch", "global_step", "step", "best_metric", "best_loss"]:
if k in metadata:
s += f" {k}: {metadata[k]}\n"
s += f"\n\U0001f4ca PARAMETER OVERVIEW\n"
s += f" Total parameters: {_fmt(total_params)} ({total_params:,})\n"
s += f" Weight memory: {_sizeof_fmt(total_bytes)}\n"
s += f" Number of tensors: {len(state_dict):,}\n"
s += f" Data types: {', '.join(dtypes)}\n"
trainable = sum(t.numel() for k, t in state_dict.items()
if "running_" not in k and "num_batches" not in k)
non_train = total_params - trainable
if non_train > 0:
s += f" Trainable (est.): {_fmt(trainable)}\n"
s += f" Non-trainable: {_fmt(non_train)}\n"
s += f"\n{format_detection_report(detections, config)}"
# ONNX
if metadata.get("onnx_inputs"):
s += "\n\U0001f4e5 ONNX INPUTS\n"
for inp in metadata["onnx_inputs"]:
s += f" {inp['name']}: {inp['shape']}\n"
if metadata.get("onnx_outputs"):
s += "\n\U0001f4e4 ONNX OUTPUTS\n"
for out in metadata["onnx_outputs"]:
s += f" {out['name']}: {out['shape']}\n"
if metadata.get("onnx_ops"):
s += "\n\U0001f527 ONNX OPS (top 15)\n"
for op, cnt in list(metadata["onnx_ops"].items())[:15]:
s += f" {op}: {cnt}\n"
# Input/Output inference
if not metadata.get("onnx_inputs"):
s += "\n\U0001f4e5 INFERRED INPUT\n"
for k, t in state_dict.items():
lk = k.lower()
if any(x in lk for x in ["embed_tokens.weight", "word_embed", "wte.weight",
"patch_embed.proj.weight", "conv1.weight",
"feature_extractor.conv"]):
if t.ndim == 2:
s += f" Token input: vocab={t.shape[0]:,}, dim={t.shape[1]}\n"
s += f" Format: [batch, seq_len] (token IDs)\n"
elif t.ndim == 4:
s += f" Image input: {t.shape[1]} channels, kernel {t.shape[2]}x{t.shape[3]}\n"
ch = t.shape[1]
s += f" Format: [batch, {ch}, H, W] ({'RGB' if ch == 3 else 'grayscale' if ch == 1 else f'{ch}-ch'})\n"
break
s += "\n\U0001f4e4 INFERRED OUTPUT\n"
for k in reversed(list(state_dict.keys())):
t = state_dict[k]
lk = k.lower()
if any(x in lk for x in ["lm_head.weight", "classifier.weight", "cls.predictions",
"qa_outputs.weight", "score.weight", "head.weight"]):
if t.ndim == 2:
s += f" Output: {k}\n"
s += f" Shape: [{t.shape[0]:,} classes/tokens, {t.shape[1]:,} hidden]\n"
elif t.ndim == 1:
s += f" Output bias: {k} -> {t.shape[0]:,} outputs\n"
break
if metadata.get("embedded_config"):
s += f"\n\U0001f4cb EMBEDDED CONFIG\n {metadata['embedded_config'][:1200]}\n"
if metadata.get("model_repr"):
s += f"\n\U0001f3d7 MODEL REPR (nn.Module)\n{metadata['model_repr'][:3000]}\n"
return s
def build_layer_tree(state_dict) -> str:
def _count(tree):
total = 0
for v in tree.values():
if isinstance(v, dict) and "numel" in v:
total += v["numel"]
elif isinstance(v, dict):
total += _count(v)
return total
tree = OrderedDict()
for key, tensor in state_dict.items():
parts = key.split(".")
node = tree
for part in parts[:-1]:
if part not in node:
node[part] = OrderedDict()
node = node[part]
node[parts[-1]] = {
"shape": list(tensor.shape), "dtype": str(tensor.dtype),
"numel": tensor.numel(),
"mb": tensor.numel() * tensor.element_size() / 1048576,
}
lines = []
def _render(subtree, prefix="", depth=0):
items = list(subtree.items())
for i, (key, val) in enumerate(items):
last = (i == len(items) - 1)
conn = "└── " if last else "β”œβ”€β”€ "
ext = " " if last else "β”‚ "
if isinstance(val, dict) and "shape" in val:
sh = "x".join(map(str, val["shape"]))
sz = f"{val['mb']:.3f} MB" if val["mb"] >= 0.001 else f"{val['numel']} el"
lines.append(f"{prefix}{conn}\U0001f539 {key} [{sh}] {val['dtype']} ({sz})")
elif isinstance(val, dict):
cnt = _count(val)
cnt_s = f" ({_fmt(cnt)} params)" if cnt > 0 else ""
lines.append(f"{prefix}{conn}\U0001f4e6 {key}{cnt_s}")
_render(val, prefix + ext, depth + 1)
_render(tree)
return "\n".join(lines) if lines else "No hierarchy found."
def infer_all_layers(state_dict) -> str:
layers = OrderedDict()
seen = set()
for key, tensor in state_dict.items():
parts = key.rsplit(".", 1)
mod = parts[0] if len(parts) == 2 else ""
param = parts[-1]
shape = tensor.shape
ndim = tensor.ndim
ml = mod.lower()
if mod in seen:
continue
seen.add(mod)
ltype = "Unknown"
details = {}
if param == "weight":
if ndim == 4:
ltype = "Conv2d"
details = {"out": shape[0], "in": shape[1], "k": f"{shape[2]}x{shape[3]}"}
if shape[1] == 1: ltype = "DepthwiseConv2d"
if any(x in ml for x in ["transpose", "deconv"]): ltype = "ConvTranspose2d"
elif ndim == 3:
ltype = "Conv1d"
details = {"out": shape[0], "in": shape[1], "k": shape[2]}
elif ndim == 5:
ltype = "Conv3d"
details = {"out": shape[0], "in": shape[1]}
elif ndim == 2:
if any(x in ml for x in ["embed", "token", "wte", "wpe"]):
ltype = "Embedding"
details = {"vocab": shape[0], "dim": shape[1]}
elif any(x in ml for x in ["norm", "ln_", "layernorm", "rmsnorm"]):
ltype = "Norm"
details = {"features": shape[0]}
elif any(x in ml for x in ["q_proj", "k_proj", "v_proj", "qkv", "query", "key", "value"]):
ltype = "Attention Projection"
details = {"out": shape[0], "in": shape[1]}
elif any(x in ml for x in ["o_proj", "out_proj"]) and "attn" in ml:
ltype = "Attention Output"
details = {"out": shape[0], "in": shape[1]}
elif any(x in ml for x in ["gate_proj", "up_proj", "wi", "fc1", "c_fc", "intermediate"]):
ltype = "FFN Up/Gate"
details = {"out": shape[0], "in": shape[1]}
elif any(x in ml for x in ["down_proj", "wo", "fc2", "c_proj"]) and "attn" not in ml:
ltype = "FFN Down"
details = {"out": shape[0], "in": shape[1]}
elif any(x in ml for x in ["lm_head", "classifier", "cls", "score", "qa_output"]):
ltype = "Head / Classifier"
details = {"classes": shape[0], "hidden": shape[1]}
else:
ltype = "Linear"
details = {"out": shape[0], "in": shape[1]}
elif ndim == 1:
if any(x in ml for x in ["norm", "bn", "ln"]):
ltype = "Norm"
details = {"features": shape[0]}
else:
ltype = "Bias/Scale"
details = {"dim": shape[0]}
elif param == "bias":
ltype = "Bias"; details = {"dim": shape[0]}
elif "running_mean" in param:
ltype = "BatchNorm stats"
elif "A_log" in param or param == "D":
ltype = "Mamba SSM"
elif "inv_freq" in param:
ltype = "Rotary inv_freq"
layers[mod] = {"type": ltype, "details": details}
text = "\U0001f9e9 INFERRED LAYER TYPES\n\n"
tc = Counter(v["type"] for v in layers.values())
text += "DISTRIBUTION:\n"
for lt, cnt in tc.most_common(25):
bar = "\u2588" * min(cnt, 40)
text += f" {lt:<30} {cnt:>5} {bar}\n"
text += f"\nDETAILED ({len(layers)} modules):\n"
text += f"{'Module':<60} {'Type':<25} {'Details'}\n"
text += "\u2500" * 115 + "\n"
for path, info in list(layers.items())[:150]:
short = path if len(path) < 58 else "\u2026" + path[-56:]
det = ", ".join(f"{k}={v}" for k, v in info["details"].items())
text += f"{short:<60} {info['type']:<25} {det}\n"
if len(layers) > 150:
text += f"\n \u2026 and {len(layers) - 150} more\n"
# Connections
text += "\n\n\U0001f4e1 LAYER CONNECTIONS\n" + "\u2500" * 60 + "\n"
io_dims = []
for key, tensor in state_dict.items():
parts = key.rsplit(".", 1)
mod = parts[0] if len(parts) == 2 else ""
if parts[-1] == "weight" and tensor.ndim >= 2:
io_dims.append((mod, tensor.shape[1] if tensor.ndim == 2 else tensor.shape[1],
tensor.shape[0]))
direct = reshape = 0
for i in range(len(io_dims) - 1):
if io_dims[i][2] == io_dims[i + 1][1]:
direct += 1
else:
reshape += 1
residual = []
seen_p = set()
for i in range(len(io_dims)):
for j in range(i + 2, min(i + 12, len(io_dims))):
if io_dims[i][2] == io_dims[j][1]:
pair = (io_dims[i][0].split(".")[0], io_dims[j][0].split(".")[0])
if pair not in seen_p:
residual.append((io_dims[i][0], io_dims[j][0], io_dims[i][2]))
seen_p.add(pair)
text += f" Direct (dim match): {direct}\n"
text += f" Reshape (dim change): {reshape}\n"
text += f" Possible skip/residual: {len(residual)}\n\n"
if residual:
text += " SKIP/RESIDUAL CANDIDATES:\n"
for src, dst, dim in residual[:30]:
text += f" {src} ...({dim})...> {dst}\n"
if len(residual) > 30:
text += f" \u2026 and {len(residual) - 30} more\n"
return text
def compute_weight_stats(state_dict) -> tuple:
stats = []
for key, tensor in state_dict.items():
if tensor.numel() == 0: continue
t = tensor.float()
entry = {
"name": key, "shape": list(tensor.shape), "dtype": str(tensor.dtype),
"numel": tensor.numel(),
"mb": tensor.numel() * tensor.element_size() / 1048576,
"mean": t.mean().item(),
"std": t.std().item() if t.numel() > 1 else 0.0,
"min": t.min().item(), "max": t.max().item(),
"abs_mean": t.abs().mean().item(),
"sparsity": (t == 0).float().mean().item() * 100,
"l2": t.norm(2).item(),
}
issues = []
if entry["std"] < 1e-7 and entry["numel"] > 10: issues.append("Near-zero var")
if entry["sparsity"] > 90: issues.append(">90% sparse")
if abs(entry["mean"]) > 10: issues.append("Large mean")
if entry["std"] > 100: issues.append("High std")
if np.isnan(entry["mean"]) or np.isinf(entry["mean"]): issues.append("NaN/Inf!")
entry["issues"] = issues
stats.append(entry)
total_iss = sum(len(s["issues"]) for s in stats)
text = "\u2696 WEIGHT STATISTICS\n\n"
if total_iss == 0:
text += "All weights healthy.\n\n"
else:
text += f"{total_iss} issues:\n"
for s in stats:
for iss in s["issues"]:
text += f" {s['name']}: {iss}\n"
text += "\n"
text += f"{'Tensor':<55} {'Shape':<18} {'Params':>10} {'Mean':>10} {'Std':>10} {'Sparse':>8}\n"
text += "\u2500" * 115 + "\n"
for s in stats[:120]:
name = s["name"] if len(s["name"]) < 53 else "\u2026" + s["name"][-52:]
sh = "x".join(map(str, s["shape"]))
if len(sh) > 16: sh = sh[:13] + "\u2026"
text += (f"{name:<55} {sh:<18} {_fmt(s['numel']):>10} "
f"{s['mean']:>10.5f} {s['std']:>10.5f} {s['sparsity']:>7.1f}%\n")
if len(stats) > 120:
text += f"\n\u2026 and {len(stats) - 120} more\n"
return stats, text
# ─────────────────────────── Plots ───────────────────────────
def _style_ax(ax):
ax.set_facecolor("#161b22")
ax.tick_params(colors="#8b949e", labelsize=7)
for sp in ax.spines.values(): sp.set_color("#30363d")
def plot_distributions(state_dict, max_n=20):
candidates = [(k, v) for k, v in state_dict.items() if v.numel() > 100 and v.ndim >= 2]
if not candidates:
candidates = [(k, v) for k, v in state_dict.items() if v.numel() > 10]
if not candidates:
fig, ax = plt.subplots(figsize=(10, 3)); fig.patch.set_facecolor("#0d1117")
ax.text(0.5, 0.5, "No layers for histogram", ha="center", va="center", color="#8b949e"); ax.axis("off")
return fig
if len(candidates) > max_n:
idx = np.linspace(0, len(candidates)-1, max_n, dtype=int)
candidates = [candidates[i] for i in idx]
n = len(candidates); cols = min(4, n); rows = (n + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(4.5*cols, 3*rows))
fig.patch.set_facecolor("#0d1117")
if n == 1: axes = np.array([[axes]])
axes = np.atleast_2d(axes)
cmap = plt.cm.plasma
for idx, (key, tensor) in enumerate(candidates):
r, c = divmod(idx, cols); ax = axes[r, c]; _style_ax(ax)
data = tensor.float().flatten().numpy()
q1, q99 = np.percentile(data, [1, 99])
clipped = data[(data >= q1) & (data <= q99)]
ax.hist(clipped, bins=70, color=cmap(idx/max(n-1,1)), alpha=0.85, edgecolor="none", density=True)
ax.axvline(0, color="#ff6b6b", ls="--", alpha=0.4, lw=0.7)
short = ".".join(key.split(".")[-2:])
if len(short) > 30: short = "\u2026" + short[-27:]
ax.set_title(short, fontsize=7.5, color="#c9d1d9", pad=3)
for idx in range(n, rows*cols):
r, c = divmod(idx, cols); axes[r, c].axis("off")
fig.suptitle("Weight Distributions", fontsize=13, color="#c9d1d9", y=1.01)
plt.tight_layout(); return fig
def plot_module_sizes(state_dict):
mod_params = defaultdict(int)
for k, t in state_dict.items(): mod_params[k.split(".")[0]] += t.numel()
sorted_m = sorted(mod_params.items(), key=lambda x: -x[1])[:30]
if not sorted_m:
fig, ax = plt.subplots(figsize=(8,3)); fig.patch.set_facecolor("#0d1117")
ax.text(0.5,0.5,"No modules",ha="center",va="center",color="#8b949e"); ax.axis("off"); return fig
names = [m[0] for m in sorted_m]; counts = [m[1] for m in sorted_m]
fig, ax = plt.subplots(figsize=(12, max(3, len(names)*0.35))); fig.patch.set_facecolor("#0d1117"); _style_ax(ax)
colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(names)))
bars = ax.barh(range(len(names)), counts, color=colors, edgecolor="none", height=0.7)
ax.set_yticks(range(len(names))); ax.set_yticklabels(names, fontsize=8.5, color="#c9d1d9"); ax.invert_yaxis()
ax.set_xlabel("Parameters", color="#c9d1d9"); ax.set_title("Params by Module", color="#c9d1d9", fontsize=12)
for bar, cnt in zip(bars, counts):
ax.text(bar.get_width()+max(counts)*0.01, bar.get_y()+bar.get_height()/2, _fmt(cnt), va="center", fontsize=7.5, color="#8b949e")
plt.tight_layout(); return fig
def plot_heatmap(stats):
ws = [s for s in stats if s["numel"] > 10 and "weight" in s["name"]]
if not ws: ws = stats
if len(ws) < 2:
fig, ax = plt.subplots(figsize=(8,3)); fig.patch.set_facecolor("#0d1117")
ax.text(0.5,0.5,"Not enough layers",ha="center",va="center",color="#8b949e"); ax.axis("off"); return fig
if len(ws) > 50:
idx = np.linspace(0, len(ws)-1, 50, dtype=int); ws = [ws[i] for i in idx]
metrics = ["mean","std","abs_mean","sparsity"]; labels = ["Mean","Std","Abs Mean","Sparsity%"]
mat = np.array([[s[m] for m in metrics] for s in ws])
for j in range(mat.shape[1]):
rng = mat[:,j].max()-mat[:,j].min()
if rng > 0: mat[:,j] = (mat[:,j]-mat[:,j].min())/rng
fig, ax = plt.subplots(figsize=(8, max(4, len(ws)*0.28))); fig.patch.set_facecolor("#0d1117"); _style_ax(ax)
cmap = LinearSegmentedColormap.from_list("c", ["#0d1117","#238636","#f0e68c","#da3633"])
im = ax.imshow(mat, aspect="auto", cmap=cmap, interpolation="nearest")
names = [".".join(s["name"].split(".")[-3:])[-35:] for s in ws]
ax.set_yticks(range(len(names))); ax.set_yticklabels(names, fontsize=6, color="#c9d1d9")
ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, fontsize=9, color="#c9d1d9")
ax.set_title("Stats Heatmap", color="#c9d1d9", fontsize=12); plt.colorbar(im, ax=ax, shrink=0.8)
plt.tight_layout(); return fig
def plot_memory(state_dict):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)); fig.patch.set_facecolor("#0d1117")
dtype_mem = defaultdict(float)
for t in state_dict.values(): dtype_mem[str(t.dtype)] += t.numel()*t.element_size()/1048576
labels = list(dtype_mem.keys()); sizes = list(dtype_mem.values())
colors = plt.cm.Set2(np.linspace(0,1,max(len(labels),1)))
ax1.set_facecolor("#0d1117")
if labels: ax1.pie(sizes, labels=labels, autopct="%1.1f%%", colors=colors, textprops={"color":"#c9d1d9","fontsize":9})
ax1.set_title("Memory by dtype", color="#c9d1d9", fontsize=12)
top = sorted(state_dict.items(), key=lambda x:-x[1].numel())[:12]
names = [".".join(k.split(".")[-2:])[-28:] for k,v in top]
mem = [v.numel()*v.element_size()/1048576 for k,v in top]
_style_ax(ax2); colors2 = plt.cm.cool(np.linspace(0.2,0.8,len(names)))
bars = ax2.barh(range(len(names)), mem, color=colors2, edgecolor="none")
ax2.set_yticks(range(len(names))); ax2.set_yticklabels(names, fontsize=7.5, color="#c9d1d9"); ax2.invert_yaxis()
ax2.set_xlabel("MB", color="#c9d1d9"); ax2.set_title("Top 12 Largest", color="#c9d1d9", fontsize=12)
for bar, m in zip(bars, mem):
ax2.text(bar.get_width()+max(mem)*0.02, bar.get_y()+bar.get_height()/2, f"{m:.2f} MB", va="center", fontsize=7, color="#8b949e")
plt.tight_layout(); return fig
def plot_depth(state_dict):
sizes = []; names = []
for k, t in state_dict.items():
if "weight" in k and t.ndim >= 2:
sizes.append(t.numel())
short = ".".join(k.split(".")[-2:])
names.append(short if len(short) < 30 else "\u2026" + short[-27:])
if len(sizes) < 2:
fig, ax = plt.subplots(figsize=(10,3)); fig.patch.set_facecolor("#0d1117")
ax.text(0.5,0.5,"Not enough layers",ha="center",va="center",color="#8b949e"); ax.axis("off"); return fig
fig, ax = plt.subplots(figsize=(14, 4.5)); fig.patch.set_facecolor("#0d1117"); _style_ax(ax)
colors = plt.cm.plasma(np.linspace(0.1, 0.9, len(sizes)))
ax.bar(range(len(sizes)), sizes, color=colors, edgecolor="none", width=0.8)
ax.set_xlabel("Layer index", color="#c9d1d9"); ax.set_ylabel("Params", color="#c9d1d9")
ax.set_title("Parameter Count Across Depth", color="#c9d1d9", fontsize=12)
step = max(1, len(names)//25)
ax.set_xticks(list(range(len(names)))[::step])
ax.set_xticklabels(names[::step], rotation=90, fontsize=5.5, color="#8b949e")
plt.tight_layout(); return fig
# ─────────────────────────── Architecture Overview ───────────────────────────
def _rbox(ax, cx, cy, w, h, label, fc, ec="#ffffff20", fs: int = 9, fw="bold", sub=None, tc="#ffffff"):
from matplotlib.patches import FancyBboxPatch
p = FancyBboxPatch((cx - w/2, cy - h/2), w, h,
boxstyle="round,pad=0.07", fc=fc, ec=ec, lw=0.9, zorder=3, clip_on=False)
ax.add_patch(p)
if sub:
ax.text(cx, cy + 0.12, label, ha="center", va="center",
color=tc, fontsize=fs, fontweight=fw, zorder=4)
ax.text(cx, cy - 0.22, sub, ha="center", va="center",
color="#ffffffaa", fontsize=max(int(fs) - 2, 6), zorder=4)
else:
ax.text(cx, cy, label, ha="center", va="center",
color=tc, fontsize=fs, fontweight=fw, zorder=4)
def _arr(ax, x1, y1, x2, y2, color="#58a6ff", lw=1.5):
ax.annotate("", xy=(x2, y2), xytext=(x1, y1),
arrowprops=dict(arrowstyle="-|>", color=color, lw=lw, mutation_scale=12), zorder=2)
def _draw_transformer_overview(arch, n_heads, n_layers, hidden, ff_dim,
vocab, head_dim, max_pos, gqa, n_kv, total_p):
show_h = min(max(n_heads, 1), 4)
fig = plt.figure(figsize=(18, 11))
fig.patch.set_facecolor("#0d1117")
ax = fig.add_axes((0.01, 0.06, 0.70, 0.92)) # main diagram
ax_h = fig.add_axes((0.73, 0.58, 0.25, 0.36)) # attention heatmap
ax_c = fig.add_axes((0.73, 0.06, 0.25, 0.48)) # config panel
for a in [ax, ax_c]:
a.set_facecolor("#0d1117"); a.axis("off")
ax.set_xlim(0, 16); ax.set_ylim(0, 11)
ax_c.set_xlim(0, 1); ax_c.set_ylim(0, 1)
ax_h.set_facecolor("#161b22")
for sp in ax_h.spines.values(): sp.set_color("#30363d")
ax_h.tick_params(colors="#8b949e", labelsize=6)
CQ = "#e67e22"; CK = "#c0392b"; CV = "#2980b9"
CH = "#16a085"; CO = "#8e44ad"; CA = "#27ae60"
from matplotlib.patches import FancyBboxPatch
# ── Left sidebar: transformer block ──
sx = 0.95
outer = FancyBboxPatch((0.05, 0.85), 1.8, 9.3,
boxstyle="round,pad=0.1", fc="#161b22", ec="#30363d", lw=1.5, zorder=1)
ax.add_patch(outer)
ax.text(sx, 10.4, f"Transformer Block Γ—{n_layers}",
ha="center", va="center", color="#8b949e", fontsize=7.5, style="italic", zorder=4)
sb = [
(9.5, "Add &\nNorm", "#2d3748"),
(8.2, "Feed\nForward", "#1a4f7a"),
(6.9, "Add &\nNorm", "#2d3748"),
(5.3, "Multi-Head\nAttention", "#1e3a5f"),
(3.5, "Pos\nEncoding", "#1a3a2a"),
(2.1, "Token\nEmbedding", "#2d3748"),
(1.0, "Input", "#1a2332"),
]
for sy, slbl, sc in sb:
_rbox(ax, sx, sy, 1.55, 0.75, slbl, sc, fs=8, fw="bold")
for i in range(len(sb) - 1):
_arr(ax, sx, sb[i+1][0]+0.38, sx, sb[i][0]-0.38, color="#58a6ff60", lw=1.0)
# Residual skip connection
ax.annotate("", xy=(0.15, 6.9+0.38), xytext=(0.15, 5.3-0.38),
arrowprops=dict(arrowstyle="-|>", color="#bc8cff", lw=1.2,
connectionstyle="arc3,rad=-0.4"), zorder=2)
ax.text(0.07, 6.1, "+", ha="center", va="center",
color="#bc8cff", fontsize=11, fontweight="bold")
# ── Q / K / V rows ──
rows = [
(8.8, "Q", CQ),
(5.8, "K", CK),
(2.8, "V", CV),
]
XIN = 2.8; XMUL = 3.65; XW = 4.5; XEQ = 5.35; XPRI = 6.15
HHSTART = 7.3; HGAP = 0.82
for row_y, lbl, color in rows:
_rbox(ax, XIN, row_y, 0.95, 0.68, lbl, color,
sub=f"(seq,{hidden})", fs=12, fw="bold")
ax.text(XMUL, row_y, "Γ—", ha="center", va="center",
color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4)
_rbox(ax, XW, row_y, 1.1, 0.68, f"W$^{{{lbl}}}$", color,
sub=f"({hidden},{hidden})", fs=11, fw="bold")
ax.text(XEQ, row_y, "=", ha="center", va="center",
color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4)
_rbox(ax, XPRI, row_y, 0.95, 0.68, f"{lbl}'", color,
sub=f"(seq,{hidden})", fs=11, fw="bold")
_arr(ax, XPRI+0.48, row_y, HHSTART-0.20, row_y, color=color, lw=1.5)
for hi in range(show_h):
hx = HHSTART + hi * HGAP
_rbox(ax, hx, row_y, 0.65, 0.62, f"{lbl}{hi+1}", color, fs=9)
if n_heads > show_h:
ax.text(HHSTART + show_h * HGAP - 0.05, row_y, "…",
ha="center", va="center", color=color, fontsize=12, zorder=4)
# d_k label
dk_x = HHSTART + (show_h - 1) * HGAP / 2
ax.text(dk_x, rows[0][0] + 0.68, f"d_k = {head_dim}",
ha="center", va="bottom", color="#f778ba", fontsize=8, zorder=4,
bbox=dict(boxstyle="round,pad=0.2", fc="#0d1117", ec="#f778ba60", lw=0.8))
# d_model bracket under heads
hend_x = HHSTART + (show_h - 1) * HGAP
ax.annotate("", xy=(hend_x+0.33, rows[2][0]-0.68),
xytext=(HHSTART-0.33, rows[2][0]-0.68),
arrowprops=dict(arrowstyle="<->", color="#8b949e", lw=0.8))
ax.text(dk_x, rows[2][0]-0.84, f"d_model = {hidden}",
ha="center", va="top", color="#8b949e", fontsize=7.5)
# Vertical attention flow arrows between head cols
for hi in range(show_h):
hx = HHSTART + hi * HGAP
ax.annotate("", xy=(hx, rows[1][0]+0.31), xytext=(hx, rows[0][0]-0.31),
arrowprops=dict(arrowstyle="-|>", color="#ffffff22", lw=0.8), zorder=2)
ax.annotate("", xy=(hx, rows[2][0]+0.31), xytext=(hx, rows[1][0]-0.31),
arrowprops=dict(arrowstyle="-|>", color="#ffffff22", lw=0.8), zorder=2)
# Attention formula
fml_x = HHSTART + (show_h - 1) * HGAP / 2
fml_y = (rows[0][0] + rows[2][0]) / 2
ax.text(fml_x, fml_y, f"softmax(QKα΅€ / √{head_dim})",
ha="center", va="center", color="#f0e68c", fontsize=9, style="italic", zorder=4,
bbox=dict(boxstyle="round,pad=0.3", fc="#14101e", ec="#f0e68c50", lw=0.8))
# ── Output section: H Γ— W^O = MH-A ──
out_y = 1.3
H_x = HHSTART + (show_h - 1) * HGAP / 2
for hi in range(show_h):
hx = HHSTART + hi * HGAP
ax.annotate("", xy=(H_x + (hi - show_h/2 + 0.5) * 0.22, out_y+0.40),
xytext=(hx, rows[2][0]-0.31),
arrowprops=dict(arrowstyle="-|>", color=CV, lw=0.9), zorder=2)
WO_x = H_x + 1.75
MHA_x = WO_x + 1.85
_rbox(ax, H_x, out_y, 1.1, 0.65, "H", CH, sub=f"(seq,{hidden})", fs=12, fw="bold")
ax.text(WO_x-0.68, out_y, "Γ—", ha="center", va="center",
color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4)
_rbox(ax, WO_x, out_y, 1.0, 0.65, "W$^O$", CO, sub=f"({hidden},{hidden})", fs=11, fw="bold")
ax.text(WO_x+0.73, out_y, "=", ha="center", va="center",
color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4)
_rbox(ax, MHA_x, out_y, 1.35, 0.65, "MH-A", CA, sub=f"(seq,{hidden})", fs=11, fw="bold")
_arr(ax, H_x+0.55, out_y, WO_x-0.50, out_y, color=CH, lw=1.5)
_arr(ax, WO_x+0.50, out_y, MHA_x-0.68, out_y, color=CO, lw=1.5)
# Formulas at bottom
ax.text(7.8, 0.60, "Attention(Q,K,V) = softmax(QKα΅€/√d_k) Β· V",
ha="center", va="center", color="#c9d1d9", fontsize=9, style="italic", zorder=4)
ax.text(7.8, 0.22, f"MultiHead = Concat(head₁ … head_{n_heads}) Β· W\u1D3C",
ha="center", va="center", color="#c9d1d9", fontsize=9, style="italic", zorder=4)
# Title
ax.text(8.0, 10.75, arch, ha="center", va="center",
color="#58a6ff", fontsize=13, fontweight="bold", zorder=4)
ax.text(8.0, 10.35, f"{_fmt(total_p)} params",
ha="center", va="center", color="#8b949e", fontsize=8.5, zorder=4)
# ── Attention heatmap ──
seq_d = 8
np.random.seed(7)
attn = np.random.dirichlet([1.5] * seq_d, seq_d)
attn = 0.45 * np.eye(seq_d) + 0.55 * attn
attn /= attn.sum(axis=1, keepdims=True)
ax_h.imshow(attn, cmap="YlOrRd", aspect="auto", interpolation="bilinear", vmin=0)
ax_h.set_title("Attention Visualization", color="#c9d1d9", fontsize=9, pad=4)
# ── Config panel ──
ax_c.text(0.5, 0.97, "βš™ Model Config", ha="center", va="top",
color="#58a6ff", fontsize=10, fontweight="bold")
ax_c.axhline(y=0.93, color="#30363d", lw=0.8)
info = [
("Architecture", arch.split("(")[0].strip()),
("Parameters", _fmt(total_p)),
("Layers", str(n_layers)),
("Hidden size", str(hidden)),
("Attn heads", f"{n_heads}" + (" (GQA)" if gqa else "")),
("Head dim", str(head_dim)),
("FF dim", _fmt(ff_dim) if ff_dim else "β€”"),
("Vocab size", f"{vocab:,}" if vocab else "β€”"),
("Max position", str(max_pos) if max_pos else "β€”"),
]
if gqa and n_kv:
info.insert(5, ("KV heads", str(n_kv)))
for i, (k, v) in enumerate(info[:10]):
y = 0.89 - i * 0.087
ax_c.text(0.03, y, k + ":", ha="left", va="center",
color="#8b949e", fontsize=8.5)
ax_c.text(0.97, y, v, ha="right", va="center",
color="#e6edf3", fontsize=8.5, fontweight="bold")
plt.tight_layout(pad=0.2)
return fig
def _draw_cnn_overview(state_dict, arch, config, total_p):
fig = plt.figure(figsize=(18, 8))
fig.patch.set_facecolor("#0d1117")
ax = fig.add_axes((0.02, 0.14, 0.96, 0.80))
ax.set_facecolor("#0d1117"); ax.axis("off")
ax.set_xlim(0, 18); ax.set_ylim(0, 8)
# Collect top modules
top_mods = {}
for k, t in state_dict.items():
top = k.split(".")[0]
top_mods[top] = top_mods.get(top, 0) + t.numel()
blocks = list(top_mods.items())[:12]
if not blocks:
blocks = [("conv1", 0), ("layer1", 0), ("fc", 0)]
colors_map = {"conv": "#2980b9", "layer": "#1a5276", "block": "#1a5276",
"stage": "#1a5276", "res": "#1a5276", "pool": "#27ae60",
"fc": "#8e44ad", "head": "#8e44ad", "classifier": "#8e44ad",
"norm": "#d68910", "bn": "#d68910"}
def _block_color(name):
nl = name.lower()
for key, col in colors_map.items():
if key in nl:
return col
return "#2d3748"
ax.text(9, 7.6, arch, ha="center", va="center",
color="#58a6ff", fontsize=14, fontweight="bold")
ax.text(9, 7.1, f"{_fmt(total_p)} parameters β€’ CNN / Vision Architecture",
ha="center", va="center", color="#8b949e", fontsize=9)
n = len(blocks)
spacing = 15.0 / max(n, 1)
_rbox(ax, 0.8, 4.5, 0.85, 0.70, "Input\nImage", "#1a3a2a", fs=8)
prev_x = 1.23
for i, (name, params) in enumerate(blocks):
bx = 1.5 + i * spacing + spacing * 0.5
color = _block_color(name)
_rbox(ax, bx, 4.5, max(spacing * 0.82, 0.9), 0.72, name, color,
sub=_fmt(params) if params else "", fs=9, fw="bold")
_arr(ax, prev_x, 4.5, bx - spacing * 0.41 - 0.02, 4.5, color="#58a6ff", lw=1.5)
ax.text(bx, 3.9, _block_color(name) and
next((v for k, v in [
("conv","Conv"), ("layer","Block"), ("pool","Pool"),
("fc","Linear"), ("head","Head"), ("norm","Norm"),
("res","ResBlock"), ("stage","Stage")] if k in name.lower()), "Module"),
ha="center", va="center", color="#8b949e", fontsize=7)
prev_x = bx + spacing * 0.41
_rbox(ax, prev_x + 0.6, 4.5, 1.0, 0.70, "Output", "#27ae60", fs=9)
_arr(ax, prev_x, 4.5, prev_x + 0.1, 4.5, color="#58a6ff", lw=1.5)
# Legend
legend = [("Conv/Feature", "#2980b9"), ("Layer/Block", "#1a5276"),
("Pool/Down", "#27ae60"), ("FC/Head", "#8e44ad"), ("Norm", "#d68910")]
for i, (lbl, lc) in enumerate(legend):
_rbox(ax, 2.0 + i * 3.3, 2.2, 2.4, 0.50, lbl, lc, fs=8)
ax.text(9, 1.5, "Layer type legend", ha="center", va="center",
color="#8b949e", fontsize=8)
plt.tight_layout(pad=0.2)
return fig
def _draw_generic_flow(state_dict, arch, config, total_p):
fig = plt.figure(figsize=(18, 9))
fig.patch.set_facecolor("#0d1117")
ax = fig.add_axes((0.02, 0.08, 0.96, 0.85))
ax.set_facecolor("#0d1117"); ax.axis("off")
ax.set_xlim(0, 18); ax.set_ylim(0, 9)
top_mods = {}
for k, t in state_dict.items():
top = k.split(".")[0]
top_mods[top] = top_mods.get(top, 0) + t.numel()
sorted_mods = sorted(top_mods.items(), key=lambda x: -x[1])[:10]
ax.text(9, 8.5, arch, ha="center", va="center",
color="#58a6ff", fontsize=14, fontweight="bold")
ax.text(9, 8.0, f"{_fmt(total_p)} parameters",
ha="center", va="center", color="#8b949e", fontsize=9)
pal = ["#2980b9","#e67e22","#27ae60","#8e44ad","#c0392b",
"#1abc9c","#d68910","#2ecc71","#3498db","#9b59b6"]
total_n = max(sum(p for _, p in sorted_mods), 1)
x = 0.5
for i, (name, params) in enumerate(sorted_mods):
w = max(0.9, (params / total_n) * 15)
cx = x + w / 2
_rbox(ax, cx, 5.5, w - 0.12, 1.1, name, pal[i % len(pal)],
sub=_fmt(params), fs=9, fw="bold")
if i > 0:
_arr(ax, x, 5.5, x + 0.06, 5.5, color="#58a6ff", lw=1.5)
x += w
top_t = sorted(state_dict.items(), key=lambda x: -x[1].numel())[:5]
ax.text(9, 3.8, "Largest tensors:", ha="center", va="center",
color="#58a6ff", fontsize=9, fontweight="bold")
for i, (k, t) in enumerate(top_t):
sh = "Γ—".join(map(str, t.shape))
short = k if len(k) < 48 else "…" + k[-46:]
ax.text(9, 3.3 - i * 0.48, f"{short} [{sh}] {_fmt(t.numel())}",
ha="center", va="center", color="#c9d1d9", fontsize=8)
plt.tight_layout(pad=0.2)
return fig
def plot_architecture_overview(state_dict, detections, config):
arch = detections[0]["family"] if detections else "Unknown Model"
category = detections[0]["category"] if detections else ""
hidden = config.get("likely_hidden_size", 768)
n_layers = config.get("num_layers", 12)
n_heads = config.get("num_attention_heads", 8)
head_dim = config.get("head_dim", hidden // max(n_heads, 1))
ff_dim = config.get("intermediate_size", hidden * 4)
vocab = config.get("vocab_size", 0)
max_pos = config.get("max_position_embeddings", 0)
gqa = config.get("grouped_query_attention", False)
n_kv = config.get("num_kv_heads", n_heads)
total_p = sum(t.numel() for t in state_dict.values())
al = (arch + " " + category).lower()
is_cnn = any(x in al for x in ["resnet","convnext","efficientnet","mobilenet","vgg",
"densenet","yolo","convnet","regnet","detr","convolution"])
is_diff = any(x in al for x in ["stable diffusion","unet","sdxl","flux","dit",
"pixart","vqgan","stylegan","esrgan"])
if is_cnn:
return _draw_cnn_overview(state_dict, arch, config, total_p)
elif is_diff:
return _draw_generic_flow(state_dict, arch, config, total_p)
else:
return _draw_transformer_overview(arch, n_heads, n_layers, hidden, ff_dim,
vocab, head_dim, max_pos, gqa, n_kv, total_p)
# ─────────────────────────── Architecture Graph (HTML) ───────────────────────────
def build_architecture_html(state_dict, detections, config):
"""
Generate an interactive SVG/HTML architecture graph.
Connections are inferred purely from weight dimension matching β€”
no hardcoded architecture checks. Works for any model.
"""
import html as _hl
arch = detections[0]["family"] if detections else "Unknown"
total_p = sum(t.numel() for t in state_dict.values())
# ── Colour palette (fill / border) ──
TC = {"Embedding":"#1b4332","Q Proj":"#7c3400","K Proj":"#7c1010",
"V Proj":"#0d3b6e","Attn Out":"#0a2a5e","FFN Up":"#2d0f57",
"FFN Down":"#57103a","Norm":"#2a3441","Conv2d":"#004d5e",
"Conv1d":"#003d4a","Linear":"#1b3047","Output":"#14451f","Other":"#1c2328"}
BC = {"Embedding":"#40c070","Q Proj":"#fb8c00","K Proj":"#ef5350",
"V Proj":"#42a5f5","Attn Out":"#1976d2","FFN Up":"#ab47bc",
"FFN Down":"#ec407a","Norm":"#78909c","Conv2d":"#00bcd4",
"Conv1d":"#009688","Linear":"#546e7a","Output":"#43a047","Other":"#546e7a"}
def _infer(path, tensor):
lk, s, nd = path.lower(), tensor.shape, tensor.ndim
if nd == 4: return "Conv2d", int(s[1]), int(s[0]), f"k{s[2]}Γ—{s[3]}"
if nd == 3: return "Conv1d", int(s[1]), int(s[0]), f"k{s[2]}"
if nd == 2:
i, o = int(s[1]), int(s[0])
end = path.rsplit(".", 1)[-1].lower()
if any(x in lk for x in ["embed","token","wte","wpe","word_embed","tok_embed"]): return "Embedding", i, o, ""
if "q_proj" in lk or ("query" in lk and end == "weight"): return "Q Proj", i, o, ""
if "k_proj" in lk or ("key" in lk and end == "weight"): return "K Proj", i, o, ""
if "v_proj" in lk or ("value" in lk and end == "weight"): return "V Proj", i, o, ""
if ("o_proj" in lk or "out_proj" in lk) and any(x in lk for x in ["attn","self","mha"]): return "Attn Out", i, o, ""
if any(x in lk for x in ["gate_proj","up_proj","fc1","wi_0","c_fc","intermediate"]): return "FFN Up", i, o, ""
if any(x in lk for x in ["down_proj","fc2","c_proj"]) and "attn" not in lk: return "FFN Down", i, o, ""
if any(x in lk for x in ["lm_head","classifier","cls","score","head.weight"]): return "Output", i, o, ""
if any(x in lk for x in ["norm","ln_","rms","layer_norm","ln_f"]): return "Norm", i, o, ""
return "Linear", i, o, ""
if nd == 1 and any(x in path.lower() for x in ["norm","bn","ln","rms"]):
return "Norm", int(s[0]), int(s[0]), ""
return None, None, None, None
# ── Extract weight modules ──
mods, seen = [], set()
for key, tensor in state_dict.items():
parts = key.rsplit(".", 1)
path = parts[0] if len(parts) == 2 else key
param = parts[-1] if len(parts) == 2 else key
if param not in ("weight", "W", "kernel") or path in seen:
continue
seen.add(path)
ltype, ind, outd, det = _infer(path, tensor)
if ltype is None:
continue
pparts = path.split(".")
grp = ".".join(pparts[:-1]) if len(pparts) > 1 else "__root__"
mods.append({"path": path, "name": pparts[-1], "group": grp,
"type": ltype, "in_dim": ind, "out_dim": outd,
"params": int(tensor.numel()), "det": det})
if not mods:
return "<div style='color:#8b949e;padding:20px;font-family:monospace'>No weight modules found.</div>"
# ── Group by parent path ──
groups = OrderedDict()
for m in mods:
groups.setdefault(m["group"], []).append(m)
# Sample if too many groups (keep first 3, last 3, evenly spaced middle)
MAX_G = 22
if len(groups) > MAX_G:
gl = list(groups.items())
keep = list(range(min(3, len(gl)))) + list(range(max(len(gl)-3, 3), len(gl)))
step = max(1, (len(gl) - 6) // (MAX_G - 6))
for i in range(3, len(gl) - 3, step):
keep.append(i)
keep = sorted(set(keep))[:MAX_G]
groups = OrderedDict(gl[i] for i in keep)
# Flatten and assign IDs
flat = []
for gmods in groups.values():
for m in gmods:
m["id"] = len(flat)
flat.append(m)
n_total = len(mods)
# ── Find connections purely from dimension matching ──
edges, seen_e = [], set()
for i in range(len(flat)):
a = flat[i]
found_direct = False
for j in range(i + 1, min(i + 30, len(flat))):
b = flat[j]
if (i, j) in seen_e or a["out_dim"] <= 0:
continue
if a["out_dim"] == b["in_dim"]:
is_direct = (j == i + 1 and not found_direct)
etype = "direct" if is_direct else "skip"
edges.append({"from": i, "to": j, "type": etype, "dim": a["out_dim"]})
seen_e.add((i, j))
if is_direct:
found_direct = True
if j - i > 10 and etype == "skip":
break # Don't look too far for skip connections
# ── Layout: vertical swimlanes (one horizontal band per group) ──
NW, NH = 192, 52
HGAP, GPX, GPY, GHH, GVGAP = 10, 12, 10, 24, 14
cy = 8
grects = {}
for g, gmods in groups.items():
nm = len(gmods)
gw = nm * (NW + HGAP) - HGAP + GPX * 2
gh = GHH + GPY + NH + GPY
for j, m in enumerate(gmods):
m["ax"] = GPX + j * (NW + HGAP) + NW // 2 # absolute x in SVG
m["ay"] = cy + GHH + GPY + NH // 2 # absolute y in SVG
gl = g.rsplit(".", 1)[-1] if "." in g else g
grects[g] = {"y": cy, "w": gw, "h": gh,
"label": (gl if len(gl) <= 38 else "…" + gl[-36:]),
"full": g}
cy += gh + GVGAP
svg_w = max((r["w"] for r in grects.values()), default=600) + 60
svg_h = cy + 10
# ── Build SVG ──
def e(s): return _hl.escape(str(s))
parts = ['''<defs>
<marker id="ma" markerWidth="7" markerHeight="7" refX="6" refY="3.5" orient="auto">
<path d="M0,0 L7,3.5 L0,7 z" fill="#58a6ff" opacity="0.9"/>
</marker>
<marker id="ms" markerWidth="7" markerHeight="7" refX="6" refY="3.5" orient="auto">
<path d="M0,0 L7,3.5 L0,7 z" fill="#bc8cff" opacity="0.8"/>
</marker>
</defs>''']
# Group background bands
for g, r in grects.items():
parts.append(
f'<rect x="2" y="{r["y"]+2}" width="{r["w"]}" height="{r["h"]}" rx="10" fill="#000" opacity="0.35"/>'
f'<rect x="0" y="{r["y"]}" width="{r["w"]}" height="{r["h"]}" rx="10" fill="#161b22" stroke="#21262d" stroke-width="1.2"/>'
f'<text x="10" y="{r["y"]+16}" font-family="monospace" font-size="11" fill="#8b949e">{e(r["label"])}</text>')
# Edges (drawn before nodes)
for ed in edges:
a, b = flat[ed["from"]], flat[ed["to"]]
col = "#58a6ff" if ed["type"] == "direct" else "#bc8cff"
mk = "ma" if ed["type"] == "direct" else "ms"
op = "0.75" if ed["type"] == "direct" else "0.45"
dash = "" if ed["type"] == "direct" else 'stroke-dasharray="5,3"'
tip = e(f"dim={ed['dim']}")
same_g = a["group"] == b["group"]
if same_g:
x1, y1 = a["ax"] + NW // 2, a["ay"]
x2, y2 = b["ax"] - NW // 2, b["ay"]
mx = (x1 + x2) // 2
d = f"M{x1},{y1} C{mx},{y1} {mx},{y2} {x2},{y2}"
else:
x1, y1 = a["ax"], a["ay"] + NH // 2
x2, y2 = b["ax"], b["ay"] - NH // 2
dy = y2 - y1
d = f"M{x1},{y1} C{x1},{int(y1+dy*0.4)} {x2},{int(y1+dy*0.6)} {x2},{y2}"
parts.append(f'<path d="{d}" fill="none" stroke="{col}" stroke-width="1.6" stroke-opacity="{op}" {dash} marker-end="url(#{mk})"><title>{tip}</title></path>')
# Module nodes
for m in flat:
x0, y0 = m["ax"] - NW // 2, m["ay"] - NH // 2
fc = TC.get(m["type"], TC["Other"])
bc = BC.get(m["type"], BC["Other"])
nm_label = e(m["name"])
type_lbl = e(m["type"])
dim_str = f"{m['in_dim']}β†’{m['out_dim']}" if m["in_dim"] != m["out_dim"] else str(m["in_dim"])
tip = e(f"{m['path']}\nType: {m['type']}\nIn: {m['in_dim']} Out: {m['out_dim']}\nParams: {m['params']:,}")
parts.append(
f'<g transform="translate({x0},{y0})" style="cursor:default">'
f'<rect width="{NW}" height="{NH}" rx="7" fill="{fc}" stroke="{bc}" stroke-width="1.5"/>'
f'<rect width="{NW}" height="17" rx="7" fill="{bc}" fill-opacity="0.30"/>'
f'<rect y="10" width="{NW}" height="7" fill="{bc}" fill-opacity="0.30"/>'
f'<text x="7" y="13" font-family="monospace" font-size="9" fill="#ffffffbb" font-weight="bold">{type_lbl}</text>'
f'<text x="{NW//2}" y="33" font-family="monospace" font-size="12" fill="#e6edf3" font-weight="bold" text-anchor="middle">{nm_label}</text>'
f'<text x="{NW//2}" y="47" font-family="monospace" font-size="10" fill="{bc}" text-anchor="middle">{dim_str}</text>'
f'<title>{tip}</title></g>')
svg_body = "\n".join(parts)
# Legend
seen_types = list(dict.fromkeys(m["type"] for m in flat))
legend = "".join(
f'<span style="display:inline-flex;align-items:center;gap:4px;font-size:10px;'
f'color:#8b949e;font-family:monospace;margin-right:6px;">'
f'<span style="width:11px;height:11px;border-radius:3px;background:{BC.get(t,"#546e7a")};'
f'display:inline-block;border:1px solid {BC.get(t,"#546e7a")}40"></span>{e(t)}</span>'
for t in seen_types)
sampled_note = f" (showing {len(flat)} of {n_total})" if n_total > len(flat) else ""
return f'''<div style="background:#0d1117;border-radius:12px;padding:16px;border:1px solid #30363d;">
<div style="display:flex;justify-content:space-between;align-items:flex-start;margin-bottom:10px;">
<div>
<div style="color:#58a6ff;font-size:14px;font-weight:bold;font-family:monospace;">⬑ {e(arch)} β€” Architecture Graph</div>
<div style="color:#8b949e;font-size:11px;margin-top:4px;">{_fmt(total_p)} params Β· {len(flat)} modules Β· {len(edges)} connections{sampled_note}</div>
</div>
<div style="color:#8b949e;font-size:10px;text-align:right;line-height:1.6;">
Edges drawn where out_dim = in_dim<br>
─── direct &nbsp; - - - skip/residual
</div>
</div>
<div style="overflow:auto;border:1px solid #21262d;border-radius:8px;max-height:660px;background:#0a0e14;">
<svg width="{svg_w}" height="{svg_h}" style="display:block;">{svg_body}</svg>
</div>
<div style="margin-top:10px;display:flex;gap:4px;flex-wrap:wrap;">{legend}</div>
</div>'''
# ─────────────────────────── Main ───────────────────────────
def analyze_model(file):
empty = ("", None, "", None, None, None, None, None, "", "", "")
if file is None:
return ("", None, "Please upload a model file.") + empty[3:]
try:
file_path = file.name if hasattr(file, "name") else file
state_dict, metadata = load_state_dict_safe(file_path)
if not state_dict:
return ("", None, "No tensors found.") + empty[3:]
detections = detect_architectures(state_dict)
config = infer_model_config(state_dict, detections)
summary = build_summary(state_dict, metadata, detections, config)
tree_text = build_layer_tree(state_dict)
types_text = infer_all_layers(state_dict)
stats, stats_text = compute_weight_stats(state_dict)
arch_html = build_architecture_html(state_dict, detections, config); gc.collect()
fig_overview = plot_architecture_overview(state_dict, detections, config); gc.collect()
fig_dist = plot_distributions(state_dict); gc.collect()
fig_sizes = plot_module_sizes(state_dict); gc.collect()
fig_heat = plot_heatmap(stats); gc.collect()
fig_mem = plot_memory(state_dict); gc.collect()
fig_depth = plot_depth(state_dict); gc.collect()
return (arch_html, fig_overview, summary, fig_dist, fig_sizes, fig_heat, fig_mem,
fig_depth, tree_text, types_text, stats_text)
except Exception as e:
err = f"Error:\n{str(e)}\n\n{traceback.format_exc()}"
return ("", None, err) + empty[3:]
def build_app():
theme = gr.themes.Base(
primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.purple,
neutral_hue=gr.themes.colors.gray,
font=gr.themes.GoogleFont("IBM Plex Sans"),
font_mono=gr.themes.GoogleFont("IBM Plex Mono"),
).set(
body_background_fill="#0d1117", body_background_fill_dark="#0d1117",
block_background_fill="#161b22", block_background_fill_dark="#161b22",
block_border_color="#30363d", block_label_text_color="#c9d1d9",
input_background_fill="#0d1117", input_background_fill_dark="#0d1117",
button_primary_background_fill="linear-gradient(135deg, #58a6ff 0%, #bc8cff 100%)",
button_primary_text_color="#ffffff",
)
with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="Neural Model Analyzer") as app:
gr.HTML("""
<div class="main-header">
<h1>\U0001f52c Neural Model Analyzer</h1>
<p>Deep introspection for <b>every</b> model on HuggingFace &mdash; 150+ architectures &middot; Free CPU</p>
</div>""")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(label="Upload Model", file_types=[".pth",".pt",".bin",".safetensors",".onnx"], type="filepath")
analyze_btn = gr.Button("Analyze Model", variant="primary", size="lg")
gr.HTML("""<div class="info-panel">
<b>All HF model families supported:</b> BERT, RoBERTa, DeBERTa, ALBERT, ELECTRA,
GPT-2, GPT-Neo, LLaMA, Mistral, Mixtral, Phi, Qwen, Falcon, BLOOM, OPT, Mamba, RWKV,
T5, BART, Pegasus, MarianMT, NLLB, ViT, DeiT, BEiT, Swin, ConvNeXt, ResNet,
EfficientNet, DINOv2, DETR, YOLO, SAM, SegFormer, Mask2Former, CLIP, BLIP, LLaVA,
Whisper, Wav2Vec2, HuBERT, MusicGen, Bark, Stable Diffusion, SDXL, FLUX, DiT,
ControlNet, LoRA, ESM, PatchTST, GNN, <b>and 100+ more</b>
</div>""")
with gr.Tabs():
with gr.Tab("πŸ•Έ Architecture Graph"):
arch_html_out = gr.HTML(label="Architecture Graph")
with gr.Tab("πŸ”­ Model Overview"):
overview_plot = gr.Plot(label="Architecture Overview")
with gr.Tab("Summary & Architecture"):
summary_out = gr.Textbox(label="Full Analysis", lines=40, max_lines=5000, interactive=False)
with gr.Tab("Layer Tree"):
tree_out = gr.Textbox(label="Hierarchy", lines=40, max_lines=5000, interactive=False)
with gr.Tab("Types & Connections"):
types_out = gr.Textbox(label="Types + Connections", lines=40, max_lines=5000, interactive=False)
with gr.Tab("Weight Stats"):
stats_out = gr.Textbox(label="Statistics", lines=40, max_lines=5000, interactive=False)
with gr.Tab("Distributions"):
dist_plot = gr.Plot(label="Weight Distributions")
with gr.Tab("Module Sizes"):
sizes_plot = gr.Plot(label="Module Parameters")
with gr.Tab("Heatmap"):
heat_plot = gr.Plot(label="Stats Heatmap")
with gr.Tab("Memory"):
mem_plot = gr.Plot(label="Memory")
with gr.Tab("Depth Profile"):
depth_plot = gr.Plot(label="Depth")
analyze_btn.click(fn=analyze_model, inputs=[file_input],
outputs=[arch_html_out, overview_plot, summary_out, dist_plot, sizes_plot,
heat_plot, mem_plot, depth_plot, tree_out, types_out, stats_out])
return app
if __name__ == "__main__":
app = build_app()
app.launch()