\U0001f52c Neural Model Analyzer
Deep introspection for every model on HuggingFace — 150+ architectures · Free CPU
""" π¬ Neural Model Analyzer β Advanced Model Introspection Tool Supports every model on Hugging Face. Runs on free CPU Spaces. """ import gradio as gr import torch import os import gc import numpy as np import re from collections import OrderedDict, defaultdict, Counter from pathlib import Path import traceback try: from safetensors import safe_open HAS_SAFETENSORS = True except ImportError: HAS_SAFETENSORS = False try: import onnx from onnx import numpy_helper HAS_ONNX = True except ImportError: HAS_ONNX = False import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap import warnings warnings.filterwarnings("ignore") from architecture_detector import detect_architectures, infer_model_config, format_detection_report # βββββββββββββββββββββββββββ Styling βββββββββββββββββββββββββββ CUSTOM_CSS = """ .gradio-container { max-width: 1400px !important; } .main-header { text-align: center; padding: 24px 10px 6px 10px; background: linear-gradient(135deg, #0d1117 0%, #161b22 100%); border-radius: 16px; margin-bottom: 10px; border: 1px solid #30363d; } .main-header h1 { font-size: 2.4em; font-weight: 800; margin: 0; background: linear-gradient(135deg, #58a6ff 0%, #bc8cff 50%, #f778ba 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .main-header p { color: #8b949e; font-size: 0.95em; margin: 6px 0 0 0; } .info-panel { background: #0d1117; border-left: 3px solid #58a6ff; padding: 10px 14px; border-radius: 0 8px 8px 0; margin: 6px 0; font-size: 0.88em; color: #c9d1d9; } /* Ensure all textareas are fully scrollable */ textarea { overflow-y: auto !important; resize: vertical !important; } footer { display: none !important; } """ MAX_FILE_SIZE_MB = 4096 def _sizeof_fmt(num_bytes): for unit in ["B", "KB", "MB", "GB"]: if abs(num_bytes) < 1024: return f"{num_bytes:.1f} {unit}" num_bytes /= 1024 return f"{num_bytes:.1f} TB" def _fmt(n): if n >= 1e9: return f"{n / 1e9:.2f}B" if n >= 1e6: return f"{n / 1e6:.2f}M" if n >= 1e3: return f"{n / 1e3:.1f}K" return str(n) # βββββββββββββββββββββββββββ Loading βββββββββββββββββββββββββββ def load_state_dict_safe(file_path: str) -> tuple: ext = Path(file_path).suffix.lower() file_size = os.path.getsize(file_path) metadata = { "format": ext, "file_size_bytes": file_size, "file_size_str": _sizeof_fmt(file_size), "file_name": Path(file_path).name, } state_dict = OrderedDict() if file_size > MAX_FILE_SIZE_MB * 1024 * 1024: raise ValueError( f"File is {_sizeof_fmt(file_size)} β exceeds {MAX_FILE_SIZE_MB} MB limit for free CPU." ) gc.collect() if ext in (".pth", ".pt"): try: data = torch.load(file_path, map_location="cpu", weights_only=False) except Exception: data = torch.load(file_path, map_location="cpu", weights_only=True) if isinstance(data, torch.nn.Module): metadata["has_full_model"] = True metadata["model_repr"] = repr(data)[:5000] state_dict = OrderedDict(data.state_dict()) elif isinstance(data, dict): for ck in ["state_dict", "model_state_dict", "model", "net", "params", "network", "g_ema", "gen", "generator", "discriminator", "ema_model", "module", "teacher", "student"]: if ck in data: candidate = data[ck] if isinstance(candidate, dict): state_dict = OrderedDict(candidate) metadata["checkpoint_key_used"] = ck break elif isinstance(candidate, torch.nn.Module): state_dict = OrderedDict(candidate.state_dict()) metadata["checkpoint_key_used"] = ck metadata["has_full_model"] = True break if not state_dict: tensors = {k: v for k, v in data.items() if isinstance(v, torch.Tensor)} if tensors: state_dict = OrderedDict(tensors) else: for top_k, top_v in data.items(): if isinstance(top_v, dict): sub = {f"{top_k}.{k}": v for k, v in top_v.items() if isinstance(v, torch.Tensor)} state_dict.update(sub) extra = [k for k in data.keys() if not isinstance(data[k], (torch.Tensor, torch.nn.Module, dict))] if extra: metadata["extra_checkpoint_keys"] = extra[:20] for ik in ["epoch", "global_step", "step", "iteration", "best_metric", "best_loss"]: if ik in data: metadata[ik] = data[ik] if any(k in data for k in ["optimizer", "optimizer_state_dict", "optim"]): metadata["has_optimizer_state"] = True if any(k in data for k in ["scheduler", "lr_scheduler"]): metadata["has_lr_scheduler"] = True if "config" in data: metadata["embedded_config"] = str(data["config"])[:3000] if "args" in data: metadata["training_args"] = str(data["args"])[:2000] del data; gc.collect() else: metadata["note"] = f"File contains {type(data).__name__}." elif ext == ".bin": data = torch.load(file_path, map_location="cpu", weights_only=False) if isinstance(data, dict): state_dict = OrderedDict({k: v for k, v in data.items() if isinstance(v, torch.Tensor)}) metadata["format_note"] = "PyTorch .bin (HuggingFace style)" del data; gc.collect() elif ext in (".safetensors",): if not HAS_SAFETENSORS: raise ImportError("safetensors library not installed") with safe_open(file_path, framework="pt", device="cpu") as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) try: sf_meta = f.metadata() if sf_meta: metadata["safetensors_metadata"] = dict(sf_meta) except Exception: pass metadata["format_note"] = "SafeTensors" elif ext == ".onnx": if not HAS_ONNX: raise ImportError("onnx library not installed") model = onnx.load(file_path) for init in model.graph.initializer: arr = numpy_helper.to_array(init) state_dict[init.name] = torch.from_numpy(arr.copy()) metadata["format_note"] = "ONNX" metadata["onnx_opset"] = [op.version for op in model.opset_import] metadata["onnx_ir_version"] = model.ir_version metadata["onnx_inputs"] = [ {"name": inp.name, "shape": [d.dim_value if d.dim_value else (d.dim_param or "?") for d in inp.type.tensor_type.shape.dim]} for inp in model.graph.input ] metadata["onnx_outputs"] = [ {"name": out.name, "shape": [d.dim_value if d.dim_value else (d.dim_param or "?") for d in out.type.tensor_type.shape.dim]} for out in model.graph.output ] op_counts = Counter(node.op_type for node in model.graph.node) metadata["onnx_ops"] = dict(op_counts.most_common(30)) del model; gc.collect() else: raise ValueError(f"Unsupported: '{ext}'. Use .pth .pt .bin .safetensors .onnx") cleaned = OrderedDict() for k, v in state_dict.items(): cleaned[re.sub(r"^(module\.|model\.module\.)", "", k)] = v state_dict = cleaned return state_dict, metadata # βββββββββββββββββββββββββββ Analysis βββββββββββββββββββββββββββ def build_summary(state_dict, metadata, detections, config) -> str: total_params = sum(t.numel() for t in state_dict.values()) total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values()) dtypes = sorted(set(str(t.dtype) for t in state_dict.values())) s = "+" + "=" * 63 + "+\n" s += "| \U0001f52c NEURAL MODEL ANALYSIS REPORT |\n" s += "+" + "=" * 63 + "+\n\n" s += "\U0001f4c1 FILE INFO\n" s += f" File name: {metadata.get('file_name', '?')}\n" s += f" File format: {metadata.get('format', '?')} {metadata.get('format_note', '')}\n" s += f" File size: {metadata.get('file_size_str', '?')}\n" if metadata.get("checkpoint_key_used"): s += f" Checkpoint key: '{metadata['checkpoint_key_used']}'\n" if metadata.get("has_full_model"): s += f" Full model: Yes (nn.Module)\n" if metadata.get("has_optimizer_state"): s += f" Optimizer: Included\n" if metadata.get("has_lr_scheduler"): s += f" LR Scheduler: Included\n" for k in ["epoch", "global_step", "step", "best_metric", "best_loss"]: if k in metadata: s += f" {k}: {metadata[k]}\n" s += f"\n\U0001f4ca PARAMETER OVERVIEW\n" s += f" Total parameters: {_fmt(total_params)} ({total_params:,})\n" s += f" Weight memory: {_sizeof_fmt(total_bytes)}\n" s += f" Number of tensors: {len(state_dict):,}\n" s += f" Data types: {', '.join(dtypes)}\n" trainable = sum(t.numel() for k, t in state_dict.items() if "running_" not in k and "num_batches" not in k) non_train = total_params - trainable if non_train > 0: s += f" Trainable (est.): {_fmt(trainable)}\n" s += f" Non-trainable: {_fmt(non_train)}\n" s += f"\n{format_detection_report(detections, config)}" # ONNX if metadata.get("onnx_inputs"): s += "\n\U0001f4e5 ONNX INPUTS\n" for inp in metadata["onnx_inputs"]: s += f" {inp['name']}: {inp['shape']}\n" if metadata.get("onnx_outputs"): s += "\n\U0001f4e4 ONNX OUTPUTS\n" for out in metadata["onnx_outputs"]: s += f" {out['name']}: {out['shape']}\n" if metadata.get("onnx_ops"): s += "\n\U0001f527 ONNX OPS (top 15)\n" for op, cnt in list(metadata["onnx_ops"].items())[:15]: s += f" {op}: {cnt}\n" # Input/Output inference if not metadata.get("onnx_inputs"): s += "\n\U0001f4e5 INFERRED INPUT\n" for k, t in state_dict.items(): lk = k.lower() if any(x in lk for x in ["embed_tokens.weight", "word_embed", "wte.weight", "patch_embed.proj.weight", "conv1.weight", "feature_extractor.conv"]): if t.ndim == 2: s += f" Token input: vocab={t.shape[0]:,}, dim={t.shape[1]}\n" s += f" Format: [batch, seq_len] (token IDs)\n" elif t.ndim == 4: s += f" Image input: {t.shape[1]} channels, kernel {t.shape[2]}x{t.shape[3]}\n" ch = t.shape[1] s += f" Format: [batch, {ch}, H, W] ({'RGB' if ch == 3 else 'grayscale' if ch == 1 else f'{ch}-ch'})\n" break s += "\n\U0001f4e4 INFERRED OUTPUT\n" for k in reversed(list(state_dict.keys())): t = state_dict[k] lk = k.lower() if any(x in lk for x in ["lm_head.weight", "classifier.weight", "cls.predictions", "qa_outputs.weight", "score.weight", "head.weight"]): if t.ndim == 2: s += f" Output: {k}\n" s += f" Shape: [{t.shape[0]:,} classes/tokens, {t.shape[1]:,} hidden]\n" elif t.ndim == 1: s += f" Output bias: {k} -> {t.shape[0]:,} outputs\n" break if metadata.get("embedded_config"): s += f"\n\U0001f4cb EMBEDDED CONFIG\n {metadata['embedded_config'][:1200]}\n" if metadata.get("model_repr"): s += f"\n\U0001f3d7 MODEL REPR (nn.Module)\n{metadata['model_repr'][:3000]}\n" return s def build_layer_tree(state_dict) -> str: def _count(tree): total = 0 for v in tree.values(): if isinstance(v, dict) and "numel" in v: total += v["numel"] elif isinstance(v, dict): total += _count(v) return total tree = OrderedDict() for key, tensor in state_dict.items(): parts = key.split(".") node = tree for part in parts[:-1]: if part not in node: node[part] = OrderedDict() node = node[part] node[parts[-1]] = { "shape": list(tensor.shape), "dtype": str(tensor.dtype), "numel": tensor.numel(), "mb": tensor.numel() * tensor.element_size() / 1048576, } lines = [] def _render(subtree, prefix="", depth=0): items = list(subtree.items()) for i, (key, val) in enumerate(items): last = (i == len(items) - 1) conn = "βββ " if last else "βββ " ext = " " if last else "β " if isinstance(val, dict) and "shape" in val: sh = "x".join(map(str, val["shape"])) sz = f"{val['mb']:.3f} MB" if val["mb"] >= 0.001 else f"{val['numel']} el" lines.append(f"{prefix}{conn}\U0001f539 {key} [{sh}] {val['dtype']} ({sz})") elif isinstance(val, dict): cnt = _count(val) cnt_s = f" ({_fmt(cnt)} params)" if cnt > 0 else "" lines.append(f"{prefix}{conn}\U0001f4e6 {key}{cnt_s}") _render(val, prefix + ext, depth + 1) _render(tree) return "\n".join(lines) if lines else "No hierarchy found." def infer_all_layers(state_dict) -> str: layers = OrderedDict() seen = set() for key, tensor in state_dict.items(): parts = key.rsplit(".", 1) mod = parts[0] if len(parts) == 2 else "" param = parts[-1] shape = tensor.shape ndim = tensor.ndim ml = mod.lower() if mod in seen: continue seen.add(mod) ltype = "Unknown" details = {} if param == "weight": if ndim == 4: ltype = "Conv2d" details = {"out": shape[0], "in": shape[1], "k": f"{shape[2]}x{shape[3]}"} if shape[1] == 1: ltype = "DepthwiseConv2d" if any(x in ml for x in ["transpose", "deconv"]): ltype = "ConvTranspose2d" elif ndim == 3: ltype = "Conv1d" details = {"out": shape[0], "in": shape[1], "k": shape[2]} elif ndim == 5: ltype = "Conv3d" details = {"out": shape[0], "in": shape[1]} elif ndim == 2: if any(x in ml for x in ["embed", "token", "wte", "wpe"]): ltype = "Embedding" details = {"vocab": shape[0], "dim": shape[1]} elif any(x in ml for x in ["norm", "ln_", "layernorm", "rmsnorm"]): ltype = "Norm" details = {"features": shape[0]} elif any(x in ml for x in ["q_proj", "k_proj", "v_proj", "qkv", "query", "key", "value"]): ltype = "Attention Projection" details = {"out": shape[0], "in": shape[1]} elif any(x in ml for x in ["o_proj", "out_proj"]) and "attn" in ml: ltype = "Attention Output" details = {"out": shape[0], "in": shape[1]} elif any(x in ml for x in ["gate_proj", "up_proj", "wi", "fc1", "c_fc", "intermediate"]): ltype = "FFN Up/Gate" details = {"out": shape[0], "in": shape[1]} elif any(x in ml for x in ["down_proj", "wo", "fc2", "c_proj"]) and "attn" not in ml: ltype = "FFN Down" details = {"out": shape[0], "in": shape[1]} elif any(x in ml for x in ["lm_head", "classifier", "cls", "score", "qa_output"]): ltype = "Head / Classifier" details = {"classes": shape[0], "hidden": shape[1]} else: ltype = "Linear" details = {"out": shape[0], "in": shape[1]} elif ndim == 1: if any(x in ml for x in ["norm", "bn", "ln"]): ltype = "Norm" details = {"features": shape[0]} else: ltype = "Bias/Scale" details = {"dim": shape[0]} elif param == "bias": ltype = "Bias"; details = {"dim": shape[0]} elif "running_mean" in param: ltype = "BatchNorm stats" elif "A_log" in param or param == "D": ltype = "Mamba SSM" elif "inv_freq" in param: ltype = "Rotary inv_freq" layers[mod] = {"type": ltype, "details": details} text = "\U0001f9e9 INFERRED LAYER TYPES\n\n" tc = Counter(v["type"] for v in layers.values()) text += "DISTRIBUTION:\n" for lt, cnt in tc.most_common(25): bar = "\u2588" * min(cnt, 40) text += f" {lt:<30} {cnt:>5} {bar}\n" text += f"\nDETAILED ({len(layers)} modules):\n" text += f"{'Module':<60} {'Type':<25} {'Details'}\n" text += "\u2500" * 115 + "\n" for path, info in list(layers.items())[:150]: short = path if len(path) < 58 else "\u2026" + path[-56:] det = ", ".join(f"{k}={v}" for k, v in info["details"].items()) text += f"{short:<60} {info['type']:<25} {det}\n" if len(layers) > 150: text += f"\n \u2026 and {len(layers) - 150} more\n" # Connections text += "\n\n\U0001f4e1 LAYER CONNECTIONS\n" + "\u2500" * 60 + "\n" io_dims = [] for key, tensor in state_dict.items(): parts = key.rsplit(".", 1) mod = parts[0] if len(parts) == 2 else "" if parts[-1] == "weight" and tensor.ndim >= 2: io_dims.append((mod, tensor.shape[1] if tensor.ndim == 2 else tensor.shape[1], tensor.shape[0])) direct = reshape = 0 for i in range(len(io_dims) - 1): if io_dims[i][2] == io_dims[i + 1][1]: direct += 1 else: reshape += 1 residual = [] seen_p = set() for i in range(len(io_dims)): for j in range(i + 2, min(i + 12, len(io_dims))): if io_dims[i][2] == io_dims[j][1]: pair = (io_dims[i][0].split(".")[0], io_dims[j][0].split(".")[0]) if pair not in seen_p: residual.append((io_dims[i][0], io_dims[j][0], io_dims[i][2])) seen_p.add(pair) text += f" Direct (dim match): {direct}\n" text += f" Reshape (dim change): {reshape}\n" text += f" Possible skip/residual: {len(residual)}\n\n" if residual: text += " SKIP/RESIDUAL CANDIDATES:\n" for src, dst, dim in residual[:30]: text += f" {src} ...({dim})...> {dst}\n" if len(residual) > 30: text += f" \u2026 and {len(residual) - 30} more\n" return text def compute_weight_stats(state_dict) -> tuple: stats = [] for key, tensor in state_dict.items(): if tensor.numel() == 0: continue t = tensor.float() entry = { "name": key, "shape": list(tensor.shape), "dtype": str(tensor.dtype), "numel": tensor.numel(), "mb": tensor.numel() * tensor.element_size() / 1048576, "mean": t.mean().item(), "std": t.std().item() if t.numel() > 1 else 0.0, "min": t.min().item(), "max": t.max().item(), "abs_mean": t.abs().mean().item(), "sparsity": (t == 0).float().mean().item() * 100, "l2": t.norm(2).item(), } issues = [] if entry["std"] < 1e-7 and entry["numel"] > 10: issues.append("Near-zero var") if entry["sparsity"] > 90: issues.append(">90% sparse") if abs(entry["mean"]) > 10: issues.append("Large mean") if entry["std"] > 100: issues.append("High std") if np.isnan(entry["mean"]) or np.isinf(entry["mean"]): issues.append("NaN/Inf!") entry["issues"] = issues stats.append(entry) total_iss = sum(len(s["issues"]) for s in stats) text = "\u2696 WEIGHT STATISTICS\n\n" if total_iss == 0: text += "All weights healthy.\n\n" else: text += f"{total_iss} issues:\n" for s in stats: for iss in s["issues"]: text += f" {s['name']}: {iss}\n" text += "\n" text += f"{'Tensor':<55} {'Shape':<18} {'Params':>10} {'Mean':>10} {'Std':>10} {'Sparse':>8}\n" text += "\u2500" * 115 + "\n" for s in stats[:120]: name = s["name"] if len(s["name"]) < 53 else "\u2026" + s["name"][-52:] sh = "x".join(map(str, s["shape"])) if len(sh) > 16: sh = sh[:13] + "\u2026" text += (f"{name:<55} {sh:<18} {_fmt(s['numel']):>10} " f"{s['mean']:>10.5f} {s['std']:>10.5f} {s['sparsity']:>7.1f}%\n") if len(stats) > 120: text += f"\n\u2026 and {len(stats) - 120} more\n" return stats, text # βββββββββββββββββββββββββββ Plots βββββββββββββββββββββββββββ def _style_ax(ax): ax.set_facecolor("#161b22") ax.tick_params(colors="#8b949e", labelsize=7) for sp in ax.spines.values(): sp.set_color("#30363d") def plot_distributions(state_dict, max_n=20): candidates = [(k, v) for k, v in state_dict.items() if v.numel() > 100 and v.ndim >= 2] if not candidates: candidates = [(k, v) for k, v in state_dict.items() if v.numel() > 10] if not candidates: fig, ax = plt.subplots(figsize=(10, 3)); fig.patch.set_facecolor("#0d1117") ax.text(0.5, 0.5, "No layers for histogram", ha="center", va="center", color="#8b949e"); ax.axis("off") return fig if len(candidates) > max_n: idx = np.linspace(0, len(candidates)-1, max_n, dtype=int) candidates = [candidates[i] for i in idx] n = len(candidates); cols = min(4, n); rows = (n + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(4.5*cols, 3*rows)) fig.patch.set_facecolor("#0d1117") if n == 1: axes = np.array([[axes]]) axes = np.atleast_2d(axes) cmap = plt.cm.plasma for idx, (key, tensor) in enumerate(candidates): r, c = divmod(idx, cols); ax = axes[r, c]; _style_ax(ax) data = tensor.float().flatten().numpy() q1, q99 = np.percentile(data, [1, 99]) clipped = data[(data >= q1) & (data <= q99)] ax.hist(clipped, bins=70, color=cmap(idx/max(n-1,1)), alpha=0.85, edgecolor="none", density=True) ax.axvline(0, color="#ff6b6b", ls="--", alpha=0.4, lw=0.7) short = ".".join(key.split(".")[-2:]) if len(short) > 30: short = "\u2026" + short[-27:] ax.set_title(short, fontsize=7.5, color="#c9d1d9", pad=3) for idx in range(n, rows*cols): r, c = divmod(idx, cols); axes[r, c].axis("off") fig.suptitle("Weight Distributions", fontsize=13, color="#c9d1d9", y=1.01) plt.tight_layout(); return fig def plot_module_sizes(state_dict): mod_params = defaultdict(int) for k, t in state_dict.items(): mod_params[k.split(".")[0]] += t.numel() sorted_m = sorted(mod_params.items(), key=lambda x: -x[1])[:30] if not sorted_m: fig, ax = plt.subplots(figsize=(8,3)); fig.patch.set_facecolor("#0d1117") ax.text(0.5,0.5,"No modules",ha="center",va="center",color="#8b949e"); ax.axis("off"); return fig names = [m[0] for m in sorted_m]; counts = [m[1] for m in sorted_m] fig, ax = plt.subplots(figsize=(12, max(3, len(names)*0.35))); fig.patch.set_facecolor("#0d1117"); _style_ax(ax) colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(names))) bars = ax.barh(range(len(names)), counts, color=colors, edgecolor="none", height=0.7) ax.set_yticks(range(len(names))); ax.set_yticklabels(names, fontsize=8.5, color="#c9d1d9"); ax.invert_yaxis() ax.set_xlabel("Parameters", color="#c9d1d9"); ax.set_title("Params by Module", color="#c9d1d9", fontsize=12) for bar, cnt in zip(bars, counts): ax.text(bar.get_width()+max(counts)*0.01, bar.get_y()+bar.get_height()/2, _fmt(cnt), va="center", fontsize=7.5, color="#8b949e") plt.tight_layout(); return fig def plot_heatmap(stats): ws = [s for s in stats if s["numel"] > 10 and "weight" in s["name"]] if not ws: ws = stats if len(ws) < 2: fig, ax = plt.subplots(figsize=(8,3)); fig.patch.set_facecolor("#0d1117") ax.text(0.5,0.5,"Not enough layers",ha="center",va="center",color="#8b949e"); ax.axis("off"); return fig if len(ws) > 50: idx = np.linspace(0, len(ws)-1, 50, dtype=int); ws = [ws[i] for i in idx] metrics = ["mean","std","abs_mean","sparsity"]; labels = ["Mean","Std","Abs Mean","Sparsity%"] mat = np.array([[s[m] for m in metrics] for s in ws]) for j in range(mat.shape[1]): rng = mat[:,j].max()-mat[:,j].min() if rng > 0: mat[:,j] = (mat[:,j]-mat[:,j].min())/rng fig, ax = plt.subplots(figsize=(8, max(4, len(ws)*0.28))); fig.patch.set_facecolor("#0d1117"); _style_ax(ax) cmap = LinearSegmentedColormap.from_list("c", ["#0d1117","#238636","#f0e68c","#da3633"]) im = ax.imshow(mat, aspect="auto", cmap=cmap, interpolation="nearest") names = [".".join(s["name"].split(".")[-3:])[-35:] for s in ws] ax.set_yticks(range(len(names))); ax.set_yticklabels(names, fontsize=6, color="#c9d1d9") ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, fontsize=9, color="#c9d1d9") ax.set_title("Stats Heatmap", color="#c9d1d9", fontsize=12); plt.colorbar(im, ax=ax, shrink=0.8) plt.tight_layout(); return fig def plot_memory(state_dict): fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)); fig.patch.set_facecolor("#0d1117") dtype_mem = defaultdict(float) for t in state_dict.values(): dtype_mem[str(t.dtype)] += t.numel()*t.element_size()/1048576 labels = list(dtype_mem.keys()); sizes = list(dtype_mem.values()) colors = plt.cm.Set2(np.linspace(0,1,max(len(labels),1))) ax1.set_facecolor("#0d1117") if labels: ax1.pie(sizes, labels=labels, autopct="%1.1f%%", colors=colors, textprops={"color":"#c9d1d9","fontsize":9}) ax1.set_title("Memory by dtype", color="#c9d1d9", fontsize=12) top = sorted(state_dict.items(), key=lambda x:-x[1].numel())[:12] names = [".".join(k.split(".")[-2:])[-28:] for k,v in top] mem = [v.numel()*v.element_size()/1048576 for k,v in top] _style_ax(ax2); colors2 = plt.cm.cool(np.linspace(0.2,0.8,len(names))) bars = ax2.barh(range(len(names)), mem, color=colors2, edgecolor="none") ax2.set_yticks(range(len(names))); ax2.set_yticklabels(names, fontsize=7.5, color="#c9d1d9"); ax2.invert_yaxis() ax2.set_xlabel("MB", color="#c9d1d9"); ax2.set_title("Top 12 Largest", color="#c9d1d9", fontsize=12) for bar, m in zip(bars, mem): ax2.text(bar.get_width()+max(mem)*0.02, bar.get_y()+bar.get_height()/2, f"{m:.2f} MB", va="center", fontsize=7, color="#8b949e") plt.tight_layout(); return fig def plot_depth(state_dict): sizes = []; names = [] for k, t in state_dict.items(): if "weight" in k and t.ndim >= 2: sizes.append(t.numel()) short = ".".join(k.split(".")[-2:]) names.append(short if len(short) < 30 else "\u2026" + short[-27:]) if len(sizes) < 2: fig, ax = plt.subplots(figsize=(10,3)); fig.patch.set_facecolor("#0d1117") ax.text(0.5,0.5,"Not enough layers",ha="center",va="center",color="#8b949e"); ax.axis("off"); return fig fig, ax = plt.subplots(figsize=(14, 4.5)); fig.patch.set_facecolor("#0d1117"); _style_ax(ax) colors = plt.cm.plasma(np.linspace(0.1, 0.9, len(sizes))) ax.bar(range(len(sizes)), sizes, color=colors, edgecolor="none", width=0.8) ax.set_xlabel("Layer index", color="#c9d1d9"); ax.set_ylabel("Params", color="#c9d1d9") ax.set_title("Parameter Count Across Depth", color="#c9d1d9", fontsize=12) step = max(1, len(names)//25) ax.set_xticks(list(range(len(names)))[::step]) ax.set_xticklabels(names[::step], rotation=90, fontsize=5.5, color="#8b949e") plt.tight_layout(); return fig # βββββββββββββββββββββββββββ Architecture Overview βββββββββββββββββββββββββββ def _rbox(ax, cx, cy, w, h, label, fc, ec="#ffffff20", fs: int = 9, fw="bold", sub=None, tc="#ffffff"): from matplotlib.patches import FancyBboxPatch p = FancyBboxPatch((cx - w/2, cy - h/2), w, h, boxstyle="round,pad=0.07", fc=fc, ec=ec, lw=0.9, zorder=3, clip_on=False) ax.add_patch(p) if sub: ax.text(cx, cy + 0.12, label, ha="center", va="center", color=tc, fontsize=fs, fontweight=fw, zorder=4) ax.text(cx, cy - 0.22, sub, ha="center", va="center", color="#ffffffaa", fontsize=max(int(fs) - 2, 6), zorder=4) else: ax.text(cx, cy, label, ha="center", va="center", color=tc, fontsize=fs, fontweight=fw, zorder=4) def _arr(ax, x1, y1, x2, y2, color="#58a6ff", lw=1.5): ax.annotate("", xy=(x2, y2), xytext=(x1, y1), arrowprops=dict(arrowstyle="-|>", color=color, lw=lw, mutation_scale=12), zorder=2) def _draw_transformer_overview(arch, n_heads, n_layers, hidden, ff_dim, vocab, head_dim, max_pos, gqa, n_kv, total_p): show_h = min(max(n_heads, 1), 4) fig = plt.figure(figsize=(18, 11)) fig.patch.set_facecolor("#0d1117") ax = fig.add_axes((0.01, 0.06, 0.70, 0.92)) # main diagram ax_h = fig.add_axes((0.73, 0.58, 0.25, 0.36)) # attention heatmap ax_c = fig.add_axes((0.73, 0.06, 0.25, 0.48)) # config panel for a in [ax, ax_c]: a.set_facecolor("#0d1117"); a.axis("off") ax.set_xlim(0, 16); ax.set_ylim(0, 11) ax_c.set_xlim(0, 1); ax_c.set_ylim(0, 1) ax_h.set_facecolor("#161b22") for sp in ax_h.spines.values(): sp.set_color("#30363d") ax_h.tick_params(colors="#8b949e", labelsize=6) CQ = "#e67e22"; CK = "#c0392b"; CV = "#2980b9" CH = "#16a085"; CO = "#8e44ad"; CA = "#27ae60" from matplotlib.patches import FancyBboxPatch # ββ Left sidebar: transformer block ββ sx = 0.95 outer = FancyBboxPatch((0.05, 0.85), 1.8, 9.3, boxstyle="round,pad=0.1", fc="#161b22", ec="#30363d", lw=1.5, zorder=1) ax.add_patch(outer) ax.text(sx, 10.4, f"Transformer Block Γ{n_layers}", ha="center", va="center", color="#8b949e", fontsize=7.5, style="italic", zorder=4) sb = [ (9.5, "Add &\nNorm", "#2d3748"), (8.2, "Feed\nForward", "#1a4f7a"), (6.9, "Add &\nNorm", "#2d3748"), (5.3, "Multi-Head\nAttention", "#1e3a5f"), (3.5, "Pos\nEncoding", "#1a3a2a"), (2.1, "Token\nEmbedding", "#2d3748"), (1.0, "Input", "#1a2332"), ] for sy, slbl, sc in sb: _rbox(ax, sx, sy, 1.55, 0.75, slbl, sc, fs=8, fw="bold") for i in range(len(sb) - 1): _arr(ax, sx, sb[i+1][0]+0.38, sx, sb[i][0]-0.38, color="#58a6ff60", lw=1.0) # Residual skip connection ax.annotate("", xy=(0.15, 6.9+0.38), xytext=(0.15, 5.3-0.38), arrowprops=dict(arrowstyle="-|>", color="#bc8cff", lw=1.2, connectionstyle="arc3,rad=-0.4"), zorder=2) ax.text(0.07, 6.1, "+", ha="center", va="center", color="#bc8cff", fontsize=11, fontweight="bold") # ββ Q / K / V rows ββ rows = [ (8.8, "Q", CQ), (5.8, "K", CK), (2.8, "V", CV), ] XIN = 2.8; XMUL = 3.65; XW = 4.5; XEQ = 5.35; XPRI = 6.15 HHSTART = 7.3; HGAP = 0.82 for row_y, lbl, color in rows: _rbox(ax, XIN, row_y, 0.95, 0.68, lbl, color, sub=f"(seq,{hidden})", fs=12, fw="bold") ax.text(XMUL, row_y, "Γ", ha="center", va="center", color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4) _rbox(ax, XW, row_y, 1.1, 0.68, f"W$^{{{lbl}}}$", color, sub=f"({hidden},{hidden})", fs=11, fw="bold") ax.text(XEQ, row_y, "=", ha="center", va="center", color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4) _rbox(ax, XPRI, row_y, 0.95, 0.68, f"{lbl}'", color, sub=f"(seq,{hidden})", fs=11, fw="bold") _arr(ax, XPRI+0.48, row_y, HHSTART-0.20, row_y, color=color, lw=1.5) for hi in range(show_h): hx = HHSTART + hi * HGAP _rbox(ax, hx, row_y, 0.65, 0.62, f"{lbl}{hi+1}", color, fs=9) if n_heads > show_h: ax.text(HHSTART + show_h * HGAP - 0.05, row_y, "β¦", ha="center", va="center", color=color, fontsize=12, zorder=4) # d_k label dk_x = HHSTART + (show_h - 1) * HGAP / 2 ax.text(dk_x, rows[0][0] + 0.68, f"d_k = {head_dim}", ha="center", va="bottom", color="#f778ba", fontsize=8, zorder=4, bbox=dict(boxstyle="round,pad=0.2", fc="#0d1117", ec="#f778ba60", lw=0.8)) # d_model bracket under heads hend_x = HHSTART + (show_h - 1) * HGAP ax.annotate("", xy=(hend_x+0.33, rows[2][0]-0.68), xytext=(HHSTART-0.33, rows[2][0]-0.68), arrowprops=dict(arrowstyle="<->", color="#8b949e", lw=0.8)) ax.text(dk_x, rows[2][0]-0.84, f"d_model = {hidden}", ha="center", va="top", color="#8b949e", fontsize=7.5) # Vertical attention flow arrows between head cols for hi in range(show_h): hx = HHSTART + hi * HGAP ax.annotate("", xy=(hx, rows[1][0]+0.31), xytext=(hx, rows[0][0]-0.31), arrowprops=dict(arrowstyle="-|>", color="#ffffff22", lw=0.8), zorder=2) ax.annotate("", xy=(hx, rows[2][0]+0.31), xytext=(hx, rows[1][0]-0.31), arrowprops=dict(arrowstyle="-|>", color="#ffffff22", lw=0.8), zorder=2) # Attention formula fml_x = HHSTART + (show_h - 1) * HGAP / 2 fml_y = (rows[0][0] + rows[2][0]) / 2 ax.text(fml_x, fml_y, f"softmax(QKα΅ / β{head_dim})", ha="center", va="center", color="#f0e68c", fontsize=9, style="italic", zorder=4, bbox=dict(boxstyle="round,pad=0.3", fc="#14101e", ec="#f0e68c50", lw=0.8)) # ββ Output section: H Γ W^O = MH-A ββ out_y = 1.3 H_x = HHSTART + (show_h - 1) * HGAP / 2 for hi in range(show_h): hx = HHSTART + hi * HGAP ax.annotate("", xy=(H_x + (hi - show_h/2 + 0.5) * 0.22, out_y+0.40), xytext=(hx, rows[2][0]-0.31), arrowprops=dict(arrowstyle="-|>", color=CV, lw=0.9), zorder=2) WO_x = H_x + 1.75 MHA_x = WO_x + 1.85 _rbox(ax, H_x, out_y, 1.1, 0.65, "H", CH, sub=f"(seq,{hidden})", fs=12, fw="bold") ax.text(WO_x-0.68, out_y, "Γ", ha="center", va="center", color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4) _rbox(ax, WO_x, out_y, 1.0, 0.65, "W$^O$", CO, sub=f"({hidden},{hidden})", fs=11, fw="bold") ax.text(WO_x+0.73, out_y, "=", ha="center", va="center", color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4) _rbox(ax, MHA_x, out_y, 1.35, 0.65, "MH-A", CA, sub=f"(seq,{hidden})", fs=11, fw="bold") _arr(ax, H_x+0.55, out_y, WO_x-0.50, out_y, color=CH, lw=1.5) _arr(ax, WO_x+0.50, out_y, MHA_x-0.68, out_y, color=CO, lw=1.5) # Formulas at bottom ax.text(7.8, 0.60, "Attention(Q,K,V) = softmax(QKα΅/βd_k) Β· V", ha="center", va="center", color="#c9d1d9", fontsize=9, style="italic", zorder=4) ax.text(7.8, 0.22, f"MultiHead = Concat(headβ β¦ head_{n_heads}) Β· W\u1D3C", ha="center", va="center", color="#c9d1d9", fontsize=9, style="italic", zorder=4) # Title ax.text(8.0, 10.75, arch, ha="center", va="center", color="#58a6ff", fontsize=13, fontweight="bold", zorder=4) ax.text(8.0, 10.35, f"{_fmt(total_p)} params", ha="center", va="center", color="#8b949e", fontsize=8.5, zorder=4) # ββ Attention heatmap ββ seq_d = 8 np.random.seed(7) attn = np.random.dirichlet([1.5] * seq_d, seq_d) attn = 0.45 * np.eye(seq_d) + 0.55 * attn attn /= attn.sum(axis=1, keepdims=True) ax_h.imshow(attn, cmap="YlOrRd", aspect="auto", interpolation="bilinear", vmin=0) ax_h.set_title("Attention Visualization", color="#c9d1d9", fontsize=9, pad=4) # ββ Config panel ββ ax_c.text(0.5, 0.97, "β Model Config", ha="center", va="top", color="#58a6ff", fontsize=10, fontweight="bold") ax_c.axhline(y=0.93, color="#30363d", lw=0.8) info = [ ("Architecture", arch.split("(")[0].strip()), ("Parameters", _fmt(total_p)), ("Layers", str(n_layers)), ("Hidden size", str(hidden)), ("Attn heads", f"{n_heads}" + (" (GQA)" if gqa else "")), ("Head dim", str(head_dim)), ("FF dim", _fmt(ff_dim) if ff_dim else "β"), ("Vocab size", f"{vocab:,}" if vocab else "β"), ("Max position", str(max_pos) if max_pos else "β"), ] if gqa and n_kv: info.insert(5, ("KV heads", str(n_kv))) for i, (k, v) in enumerate(info[:10]): y = 0.89 - i * 0.087 ax_c.text(0.03, y, k + ":", ha="left", va="center", color="#8b949e", fontsize=8.5) ax_c.text(0.97, y, v, ha="right", va="center", color="#e6edf3", fontsize=8.5, fontweight="bold") plt.tight_layout(pad=0.2) return fig def _draw_cnn_overview(state_dict, arch, config, total_p): fig = plt.figure(figsize=(18, 8)) fig.patch.set_facecolor("#0d1117") ax = fig.add_axes((0.02, 0.14, 0.96, 0.80)) ax.set_facecolor("#0d1117"); ax.axis("off") ax.set_xlim(0, 18); ax.set_ylim(0, 8) # Collect top modules top_mods = {} for k, t in state_dict.items(): top = k.split(".")[0] top_mods[top] = top_mods.get(top, 0) + t.numel() blocks = list(top_mods.items())[:12] if not blocks: blocks = [("conv1", 0), ("layer1", 0), ("fc", 0)] colors_map = {"conv": "#2980b9", "layer": "#1a5276", "block": "#1a5276", "stage": "#1a5276", "res": "#1a5276", "pool": "#27ae60", "fc": "#8e44ad", "head": "#8e44ad", "classifier": "#8e44ad", "norm": "#d68910", "bn": "#d68910"} def _block_color(name): nl = name.lower() for key, col in colors_map.items(): if key in nl: return col return "#2d3748" ax.text(9, 7.6, arch, ha="center", va="center", color="#58a6ff", fontsize=14, fontweight="bold") ax.text(9, 7.1, f"{_fmt(total_p)} parameters β’ CNN / Vision Architecture", ha="center", va="center", color="#8b949e", fontsize=9) n = len(blocks) spacing = 15.0 / max(n, 1) _rbox(ax, 0.8, 4.5, 0.85, 0.70, "Input\nImage", "#1a3a2a", fs=8) prev_x = 1.23 for i, (name, params) in enumerate(blocks): bx = 1.5 + i * spacing + spacing * 0.5 color = _block_color(name) _rbox(ax, bx, 4.5, max(spacing * 0.82, 0.9), 0.72, name, color, sub=_fmt(params) if params else "", fs=9, fw="bold") _arr(ax, prev_x, 4.5, bx - spacing * 0.41 - 0.02, 4.5, color="#58a6ff", lw=1.5) ax.text(bx, 3.9, _block_color(name) and next((v for k, v in [ ("conv","Conv"), ("layer","Block"), ("pool","Pool"), ("fc","Linear"), ("head","Head"), ("norm","Norm"), ("res","ResBlock"), ("stage","Stage")] if k in name.lower()), "Module"), ha="center", va="center", color="#8b949e", fontsize=7) prev_x = bx + spacing * 0.41 _rbox(ax, prev_x + 0.6, 4.5, 1.0, 0.70, "Output", "#27ae60", fs=9) _arr(ax, prev_x, 4.5, prev_x + 0.1, 4.5, color="#58a6ff", lw=1.5) # Legend legend = [("Conv/Feature", "#2980b9"), ("Layer/Block", "#1a5276"), ("Pool/Down", "#27ae60"), ("FC/Head", "#8e44ad"), ("Norm", "#d68910")] for i, (lbl, lc) in enumerate(legend): _rbox(ax, 2.0 + i * 3.3, 2.2, 2.4, 0.50, lbl, lc, fs=8) ax.text(9, 1.5, "Layer type legend", ha="center", va="center", color="#8b949e", fontsize=8) plt.tight_layout(pad=0.2) return fig def _draw_generic_flow(state_dict, arch, config, total_p): fig = plt.figure(figsize=(18, 9)) fig.patch.set_facecolor("#0d1117") ax = fig.add_axes((0.02, 0.08, 0.96, 0.85)) ax.set_facecolor("#0d1117"); ax.axis("off") ax.set_xlim(0, 18); ax.set_ylim(0, 9) top_mods = {} for k, t in state_dict.items(): top = k.split(".")[0] top_mods[top] = top_mods.get(top, 0) + t.numel() sorted_mods = sorted(top_mods.items(), key=lambda x: -x[1])[:10] ax.text(9, 8.5, arch, ha="center", va="center", color="#58a6ff", fontsize=14, fontweight="bold") ax.text(9, 8.0, f"{_fmt(total_p)} parameters", ha="center", va="center", color="#8b949e", fontsize=9) pal = ["#2980b9","#e67e22","#27ae60","#8e44ad","#c0392b", "#1abc9c","#d68910","#2ecc71","#3498db","#9b59b6"] total_n = max(sum(p for _, p in sorted_mods), 1) x = 0.5 for i, (name, params) in enumerate(sorted_mods): w = max(0.9, (params / total_n) * 15) cx = x + w / 2 _rbox(ax, cx, 5.5, w - 0.12, 1.1, name, pal[i % len(pal)], sub=_fmt(params), fs=9, fw="bold") if i > 0: _arr(ax, x, 5.5, x + 0.06, 5.5, color="#58a6ff", lw=1.5) x += w top_t = sorted(state_dict.items(), key=lambda x: -x[1].numel())[:5] ax.text(9, 3.8, "Largest tensors:", ha="center", va="center", color="#58a6ff", fontsize=9, fontweight="bold") for i, (k, t) in enumerate(top_t): sh = "Γ".join(map(str, t.shape)) short = k if len(k) < 48 else "β¦" + k[-46:] ax.text(9, 3.3 - i * 0.48, f"{short} [{sh}] {_fmt(t.numel())}", ha="center", va="center", color="#c9d1d9", fontsize=8) plt.tight_layout(pad=0.2) return fig def plot_architecture_overview(state_dict, detections, config): arch = detections[0]["family"] if detections else "Unknown Model" category = detections[0]["category"] if detections else "" hidden = config.get("likely_hidden_size", 768) n_layers = config.get("num_layers", 12) n_heads = config.get("num_attention_heads", 8) head_dim = config.get("head_dim", hidden // max(n_heads, 1)) ff_dim = config.get("intermediate_size", hidden * 4) vocab = config.get("vocab_size", 0) max_pos = config.get("max_position_embeddings", 0) gqa = config.get("grouped_query_attention", False) n_kv = config.get("num_kv_heads", n_heads) total_p = sum(t.numel() for t in state_dict.values()) al = (arch + " " + category).lower() is_cnn = any(x in al for x in ["resnet","convnext","efficientnet","mobilenet","vgg", "densenet","yolo","convnet","regnet","detr","convolution"]) is_diff = any(x in al for x in ["stable diffusion","unet","sdxl","flux","dit", "pixart","vqgan","stylegan","esrgan"]) if is_cnn: return _draw_cnn_overview(state_dict, arch, config, total_p) elif is_diff: return _draw_generic_flow(state_dict, arch, config, total_p) else: return _draw_transformer_overview(arch, n_heads, n_layers, hidden, ff_dim, vocab, head_dim, max_pos, gqa, n_kv, total_p) # βββββββββββββββββββββββββββ Architecture Graph (HTML) βββββββββββββββββββββββββββ def build_architecture_html(state_dict, detections, config): """ Generate an interactive SVG/HTML architecture graph. Connections are inferred purely from weight dimension matching β no hardcoded architecture checks. Works for any model. """ import html as _hl arch = detections[0]["family"] if detections else "Unknown" total_p = sum(t.numel() for t in state_dict.values()) # ββ Colour palette (fill / border) ββ TC = {"Embedding":"#1b4332","Q Proj":"#7c3400","K Proj":"#7c1010", "V Proj":"#0d3b6e","Attn Out":"#0a2a5e","FFN Up":"#2d0f57", "FFN Down":"#57103a","Norm":"#2a3441","Conv2d":"#004d5e", "Conv1d":"#003d4a","Linear":"#1b3047","Output":"#14451f","Other":"#1c2328"} BC = {"Embedding":"#40c070","Q Proj":"#fb8c00","K Proj":"#ef5350", "V Proj":"#42a5f5","Attn Out":"#1976d2","FFN Up":"#ab47bc", "FFN Down":"#ec407a","Norm":"#78909c","Conv2d":"#00bcd4", "Conv1d":"#009688","Linear":"#546e7a","Output":"#43a047","Other":"#546e7a"} def _infer(path, tensor): lk, s, nd = path.lower(), tensor.shape, tensor.ndim if nd == 4: return "Conv2d", int(s[1]), int(s[0]), f"k{s[2]}Γ{s[3]}" if nd == 3: return "Conv1d", int(s[1]), int(s[0]), f"k{s[2]}" if nd == 2: i, o = int(s[1]), int(s[0]) end = path.rsplit(".", 1)[-1].lower() if any(x in lk for x in ["embed","token","wte","wpe","word_embed","tok_embed"]): return "Embedding", i, o, "" if "q_proj" in lk or ("query" in lk and end == "weight"): return "Q Proj", i, o, "" if "k_proj" in lk or ("key" in lk and end == "weight"): return "K Proj", i, o, "" if "v_proj" in lk or ("value" in lk and end == "weight"): return "V Proj", i, o, "" if ("o_proj" in lk or "out_proj" in lk) and any(x in lk for x in ["attn","self","mha"]): return "Attn Out", i, o, "" if any(x in lk for x in ["gate_proj","up_proj","fc1","wi_0","c_fc","intermediate"]): return "FFN Up", i, o, "" if any(x in lk for x in ["down_proj","fc2","c_proj"]) and "attn" not in lk: return "FFN Down", i, o, "" if any(x in lk for x in ["lm_head","classifier","cls","score","head.weight"]): return "Output", i, o, "" if any(x in lk for x in ["norm","ln_","rms","layer_norm","ln_f"]): return "Norm", i, o, "" return "Linear", i, o, "" if nd == 1 and any(x in path.lower() for x in ["norm","bn","ln","rms"]): return "Norm", int(s[0]), int(s[0]), "" return None, None, None, None # ββ Extract weight modules ββ mods, seen = [], set() for key, tensor in state_dict.items(): parts = key.rsplit(".", 1) path = parts[0] if len(parts) == 2 else key param = parts[-1] if len(parts) == 2 else key if param not in ("weight", "W", "kernel") or path in seen: continue seen.add(path) ltype, ind, outd, det = _infer(path, tensor) if ltype is None: continue pparts = path.split(".") grp = ".".join(pparts[:-1]) if len(pparts) > 1 else "__root__" mods.append({"path": path, "name": pparts[-1], "group": grp, "type": ltype, "in_dim": ind, "out_dim": outd, "params": int(tensor.numel()), "det": det}) if not mods: return "
Deep introspection for every model on HuggingFace — 150+ architectures · Free CPU