""" πŸ”¬ Neural Model Analyzer β€” Advanced Model Introspection Tool Supports every model on Hugging Face. Runs on free CPU Spaces. """ import gradio as gr import torch import os import gc import numpy as np import re from collections import OrderedDict, defaultdict, Counter from pathlib import Path import traceback try: from safetensors import safe_open HAS_SAFETENSORS = True except ImportError: HAS_SAFETENSORS = False try: import onnx from onnx import numpy_helper HAS_ONNX = True except ImportError: HAS_ONNX = False import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap import warnings warnings.filterwarnings("ignore") from architecture_detector import detect_architectures, infer_model_config, format_detection_report # ─────────────────────────── Styling ─────────────────────────── CUSTOM_CSS = """ .gradio-container { max-width: 1400px !important; } .main-header { text-align: center; padding: 24px 10px 6px 10px; background: linear-gradient(135deg, #0d1117 0%, #161b22 100%); border-radius: 16px; margin-bottom: 10px; border: 1px solid #30363d; } .main-header h1 { font-size: 2.4em; font-weight: 800; margin: 0; background: linear-gradient(135deg, #58a6ff 0%, #bc8cff 50%, #f778ba 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .main-header p { color: #8b949e; font-size: 0.95em; margin: 6px 0 0 0; } .info-panel { background: #0d1117; border-left: 3px solid #58a6ff; padding: 10px 14px; border-radius: 0 8px 8px 0; margin: 6px 0; font-size: 0.88em; color: #c9d1d9; } /* Ensure all textareas are fully scrollable */ textarea { overflow-y: auto !important; resize: vertical !important; } footer { display: none !important; } """ MAX_FILE_SIZE_MB = 4096 def _sizeof_fmt(num_bytes): for unit in ["B", "KB", "MB", "GB"]: if abs(num_bytes) < 1024: return f"{num_bytes:.1f} {unit}" num_bytes /= 1024 return f"{num_bytes:.1f} TB" def _fmt(n): if n >= 1e9: return f"{n / 1e9:.2f}B" if n >= 1e6: return f"{n / 1e6:.2f}M" if n >= 1e3: return f"{n / 1e3:.1f}K" return str(n) # ─────────────────────────── Loading ─────────────────────────── def load_state_dict_safe(file_path: str) -> tuple: ext = Path(file_path).suffix.lower() file_size = os.path.getsize(file_path) metadata = { "format": ext, "file_size_bytes": file_size, "file_size_str": _sizeof_fmt(file_size), "file_name": Path(file_path).name, } state_dict = OrderedDict() if file_size > MAX_FILE_SIZE_MB * 1024 * 1024: raise ValueError( f"File is {_sizeof_fmt(file_size)} β€” exceeds {MAX_FILE_SIZE_MB} MB limit for free CPU." ) gc.collect() if ext in (".pth", ".pt"): try: data = torch.load(file_path, map_location="cpu", weights_only=False) except Exception: data = torch.load(file_path, map_location="cpu", weights_only=True) if isinstance(data, torch.nn.Module): metadata["has_full_model"] = True metadata["model_repr"] = repr(data)[:5000] state_dict = OrderedDict(data.state_dict()) elif isinstance(data, dict): for ck in ["state_dict", "model_state_dict", "model", "net", "params", "network", "g_ema", "gen", "generator", "discriminator", "ema_model", "module", "teacher", "student"]: if ck in data: candidate = data[ck] if isinstance(candidate, dict): state_dict = OrderedDict(candidate) metadata["checkpoint_key_used"] = ck break elif isinstance(candidate, torch.nn.Module): state_dict = OrderedDict(candidate.state_dict()) metadata["checkpoint_key_used"] = ck metadata["has_full_model"] = True break if not state_dict: tensors = {k: v for k, v in data.items() if isinstance(v, torch.Tensor)} if tensors: state_dict = OrderedDict(tensors) else: for top_k, top_v in data.items(): if isinstance(top_v, dict): sub = {f"{top_k}.{k}": v for k, v in top_v.items() if isinstance(v, torch.Tensor)} state_dict.update(sub) extra = [k for k in data.keys() if not isinstance(data[k], (torch.Tensor, torch.nn.Module, dict))] if extra: metadata["extra_checkpoint_keys"] = extra[:20] for ik in ["epoch", "global_step", "step", "iteration", "best_metric", "best_loss"]: if ik in data: metadata[ik] = data[ik] if any(k in data for k in ["optimizer", "optimizer_state_dict", "optim"]): metadata["has_optimizer_state"] = True if any(k in data for k in ["scheduler", "lr_scheduler"]): metadata["has_lr_scheduler"] = True if "config" in data: metadata["embedded_config"] = str(data["config"])[:3000] if "args" in data: metadata["training_args"] = str(data["args"])[:2000] del data; gc.collect() else: metadata["note"] = f"File contains {type(data).__name__}." elif ext == ".bin": data = torch.load(file_path, map_location="cpu", weights_only=False) if isinstance(data, dict): state_dict = OrderedDict({k: v for k, v in data.items() if isinstance(v, torch.Tensor)}) metadata["format_note"] = "PyTorch .bin (HuggingFace style)" del data; gc.collect() elif ext in (".safetensors",): if not HAS_SAFETENSORS: raise ImportError("safetensors library not installed") with safe_open(file_path, framework="pt", device="cpu") as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) try: sf_meta = f.metadata() if sf_meta: metadata["safetensors_metadata"] = dict(sf_meta) except Exception: pass metadata["format_note"] = "SafeTensors" elif ext == ".onnx": if not HAS_ONNX: raise ImportError("onnx library not installed") model = onnx.load(file_path) for init in model.graph.initializer: arr = numpy_helper.to_array(init) state_dict[init.name] = torch.from_numpy(arr.copy()) metadata["format_note"] = "ONNX" metadata["onnx_opset"] = [op.version for op in model.opset_import] metadata["onnx_ir_version"] = model.ir_version metadata["onnx_inputs"] = [ {"name": inp.name, "shape": [d.dim_value if d.dim_value else (d.dim_param or "?") for d in inp.type.tensor_type.shape.dim]} for inp in model.graph.input ] metadata["onnx_outputs"] = [ {"name": out.name, "shape": [d.dim_value if d.dim_value else (d.dim_param or "?") for d in out.type.tensor_type.shape.dim]} for out in model.graph.output ] op_counts = Counter(node.op_type for node in model.graph.node) metadata["onnx_ops"] = dict(op_counts.most_common(30)) del model; gc.collect() else: raise ValueError(f"Unsupported: '{ext}'. Use .pth .pt .bin .safetensors .onnx") cleaned = OrderedDict() for k, v in state_dict.items(): cleaned[re.sub(r"^(module\.|model\.module\.)", "", k)] = v state_dict = cleaned return state_dict, metadata # ─────────────────────────── Analysis ─────────────────────────── def build_summary(state_dict, metadata, detections, config) -> str: total_params = sum(t.numel() for t in state_dict.values()) total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values()) dtypes = sorted(set(str(t.dtype) for t in state_dict.values())) s = "+" + "=" * 63 + "+\n" s += "| \U0001f52c NEURAL MODEL ANALYSIS REPORT |\n" s += "+" + "=" * 63 + "+\n\n" s += "\U0001f4c1 FILE INFO\n" s += f" File name: {metadata.get('file_name', '?')}\n" s += f" File format: {metadata.get('format', '?')} {metadata.get('format_note', '')}\n" s += f" File size: {metadata.get('file_size_str', '?')}\n" if metadata.get("checkpoint_key_used"): s += f" Checkpoint key: '{metadata['checkpoint_key_used']}'\n" if metadata.get("has_full_model"): s += f" Full model: Yes (nn.Module)\n" if metadata.get("has_optimizer_state"): s += f" Optimizer: Included\n" if metadata.get("has_lr_scheduler"): s += f" LR Scheduler: Included\n" for k in ["epoch", "global_step", "step", "best_metric", "best_loss"]: if k in metadata: s += f" {k}: {metadata[k]}\n" s += f"\n\U0001f4ca PARAMETER OVERVIEW\n" s += f" Total parameters: {_fmt(total_params)} ({total_params:,})\n" s += f" Weight memory: {_sizeof_fmt(total_bytes)}\n" s += f" Number of tensors: {len(state_dict):,}\n" s += f" Data types: {', '.join(dtypes)}\n" trainable = sum(t.numel() for k, t in state_dict.items() if "running_" not in k and "num_batches" not in k) non_train = total_params - trainable if non_train > 0: s += f" Trainable (est.): {_fmt(trainable)}\n" s += f" Non-trainable: {_fmt(non_train)}\n" s += f"\n{format_detection_report(detections, config)}" # ONNX if metadata.get("onnx_inputs"): s += "\n\U0001f4e5 ONNX INPUTS\n" for inp in metadata["onnx_inputs"]: s += f" {inp['name']}: {inp['shape']}\n" if metadata.get("onnx_outputs"): s += "\n\U0001f4e4 ONNX OUTPUTS\n" for out in metadata["onnx_outputs"]: s += f" {out['name']}: {out['shape']}\n" if metadata.get("onnx_ops"): s += "\n\U0001f527 ONNX OPS (top 15)\n" for op, cnt in list(metadata["onnx_ops"].items())[:15]: s += f" {op}: {cnt}\n" # Input/Output inference if not metadata.get("onnx_inputs"): s += "\n\U0001f4e5 INFERRED INPUT\n" for k, t in state_dict.items(): lk = k.lower() if any(x in lk for x in ["embed_tokens.weight", "word_embed", "wte.weight", "patch_embed.proj.weight", "conv1.weight", "feature_extractor.conv"]): if t.ndim == 2: s += f" Token input: vocab={t.shape[0]:,}, dim={t.shape[1]}\n" s += f" Format: [batch, seq_len] (token IDs)\n" elif t.ndim == 4: s += f" Image input: {t.shape[1]} channels, kernel {t.shape[2]}x{t.shape[3]}\n" ch = t.shape[1] s += f" Format: [batch, {ch}, H, W] ({'RGB' if ch == 3 else 'grayscale' if ch == 1 else f'{ch}-ch'})\n" break s += "\n\U0001f4e4 INFERRED OUTPUT\n" for k in reversed(list(state_dict.keys())): t = state_dict[k] lk = k.lower() if any(x in lk for x in ["lm_head.weight", "classifier.weight", "cls.predictions", "qa_outputs.weight", "score.weight", "head.weight"]): if t.ndim == 2: s += f" Output: {k}\n" s += f" Shape: [{t.shape[0]:,} classes/tokens, {t.shape[1]:,} hidden]\n" elif t.ndim == 1: s += f" Output bias: {k} -> {t.shape[0]:,} outputs\n" break if metadata.get("embedded_config"): s += f"\n\U0001f4cb EMBEDDED CONFIG\n {metadata['embedded_config'][:1200]}\n" if metadata.get("model_repr"): s += f"\n\U0001f3d7 MODEL REPR (nn.Module)\n{metadata['model_repr'][:3000]}\n" return s def build_layer_tree(state_dict) -> str: def _count(tree): total = 0 for v in tree.values(): if isinstance(v, dict) and "numel" in v: total += v["numel"] elif isinstance(v, dict): total += _count(v) return total tree = OrderedDict() for key, tensor in state_dict.items(): parts = key.split(".") node = tree for part in parts[:-1]: if part not in node: node[part] = OrderedDict() node = node[part] node[parts[-1]] = { "shape": list(tensor.shape), "dtype": str(tensor.dtype), "numel": tensor.numel(), "mb": tensor.numel() * tensor.element_size() / 1048576, } lines = [] def _render(subtree, prefix="", depth=0): items = list(subtree.items()) for i, (key, val) in enumerate(items): last = (i == len(items) - 1) conn = "└── " if last else "β”œβ”€β”€ " ext = " " if last else "β”‚ " if isinstance(val, dict) and "shape" in val: sh = "x".join(map(str, val["shape"])) sz = f"{val['mb']:.3f} MB" if val["mb"] >= 0.001 else f"{val['numel']} el" lines.append(f"{prefix}{conn}\U0001f539 {key} [{sh}] {val['dtype']} ({sz})") elif isinstance(val, dict): cnt = _count(val) cnt_s = f" ({_fmt(cnt)} params)" if cnt > 0 else "" lines.append(f"{prefix}{conn}\U0001f4e6 {key}{cnt_s}") _render(val, prefix + ext, depth + 1) _render(tree) return "\n".join(lines) if lines else "No hierarchy found." def infer_all_layers(state_dict) -> str: layers = OrderedDict() seen = set() for key, tensor in state_dict.items(): parts = key.rsplit(".", 1) mod = parts[0] if len(parts) == 2 else "" param = parts[-1] shape = tensor.shape ndim = tensor.ndim ml = mod.lower() if mod in seen: continue seen.add(mod) ltype = "Unknown" details = {} if param == "weight": if ndim == 4: ltype = "Conv2d" details = {"out": shape[0], "in": shape[1], "k": f"{shape[2]}x{shape[3]}"} if shape[1] == 1: ltype = "DepthwiseConv2d" if any(x in ml for x in ["transpose", "deconv"]): ltype = "ConvTranspose2d" elif ndim == 3: ltype = "Conv1d" details = {"out": shape[0], "in": shape[1], "k": shape[2]} elif ndim == 5: ltype = "Conv3d" details = {"out": shape[0], "in": shape[1]} elif ndim == 2: if any(x in ml for x in ["embed", "token", "wte", "wpe"]): ltype = "Embedding" details = {"vocab": shape[0], "dim": shape[1]} elif any(x in ml for x in ["norm", "ln_", "layernorm", "rmsnorm"]): ltype = "Norm" details = {"features": shape[0]} elif any(x in ml for x in ["q_proj", "k_proj", "v_proj", "qkv", "query", "key", "value"]): ltype = "Attention Projection" details = {"out": shape[0], "in": shape[1]} elif any(x in ml for x in ["o_proj", "out_proj"]) and "attn" in ml: ltype = "Attention Output" details = {"out": shape[0], "in": shape[1]} elif any(x in ml for x in ["gate_proj", "up_proj", "wi", "fc1", "c_fc", "intermediate"]): ltype = "FFN Up/Gate" details = {"out": shape[0], "in": shape[1]} elif any(x in ml for x in ["down_proj", "wo", "fc2", "c_proj"]) and "attn" not in ml: ltype = "FFN Down" details = {"out": shape[0], "in": shape[1]} elif any(x in ml for x in ["lm_head", "classifier", "cls", "score", "qa_output"]): ltype = "Head / Classifier" details = {"classes": shape[0], "hidden": shape[1]} else: ltype = "Linear" details = {"out": shape[0], "in": shape[1]} elif ndim == 1: if any(x in ml for x in ["norm", "bn", "ln"]): ltype = "Norm" details = {"features": shape[0]} else: ltype = "Bias/Scale" details = {"dim": shape[0]} elif param == "bias": ltype = "Bias"; details = {"dim": shape[0]} elif "running_mean" in param: ltype = "BatchNorm stats" elif "A_log" in param or param == "D": ltype = "Mamba SSM" elif "inv_freq" in param: ltype = "Rotary inv_freq" layers[mod] = {"type": ltype, "details": details} text = "\U0001f9e9 INFERRED LAYER TYPES\n\n" tc = Counter(v["type"] for v in layers.values()) text += "DISTRIBUTION:\n" for lt, cnt in tc.most_common(25): bar = "\u2588" * min(cnt, 40) text += f" {lt:<30} {cnt:>5} {bar}\n" text += f"\nDETAILED ({len(layers)} modules):\n" text += f"{'Module':<60} {'Type':<25} {'Details'}\n" text += "\u2500" * 115 + "\n" for path, info in list(layers.items())[:150]: short = path if len(path) < 58 else "\u2026" + path[-56:] det = ", ".join(f"{k}={v}" for k, v in info["details"].items()) text += f"{short:<60} {info['type']:<25} {det}\n" if len(layers) > 150: text += f"\n \u2026 and {len(layers) - 150} more\n" # Connections text += "\n\n\U0001f4e1 LAYER CONNECTIONS\n" + "\u2500" * 60 + "\n" io_dims = [] for key, tensor in state_dict.items(): parts = key.rsplit(".", 1) mod = parts[0] if len(parts) == 2 else "" if parts[-1] == "weight" and tensor.ndim >= 2: io_dims.append((mod, tensor.shape[1] if tensor.ndim == 2 else tensor.shape[1], tensor.shape[0])) direct = reshape = 0 for i in range(len(io_dims) - 1): if io_dims[i][2] == io_dims[i + 1][1]: direct += 1 else: reshape += 1 residual = [] seen_p = set() for i in range(len(io_dims)): for j in range(i + 2, min(i + 12, len(io_dims))): if io_dims[i][2] == io_dims[j][1]: pair = (io_dims[i][0].split(".")[0], io_dims[j][0].split(".")[0]) if pair not in seen_p: residual.append((io_dims[i][0], io_dims[j][0], io_dims[i][2])) seen_p.add(pair) text += f" Direct (dim match): {direct}\n" text += f" Reshape (dim change): {reshape}\n" text += f" Possible skip/residual: {len(residual)}\n\n" if residual: text += " SKIP/RESIDUAL CANDIDATES:\n" for src, dst, dim in residual[:30]: text += f" {src} ...({dim})...> {dst}\n" if len(residual) > 30: text += f" \u2026 and {len(residual) - 30} more\n" return text def compute_weight_stats(state_dict) -> tuple: stats = [] for key, tensor in state_dict.items(): if tensor.numel() == 0: continue t = tensor.float() entry = { "name": key, "shape": list(tensor.shape), "dtype": str(tensor.dtype), "numel": tensor.numel(), "mb": tensor.numel() * tensor.element_size() / 1048576, "mean": t.mean().item(), "std": t.std().item() if t.numel() > 1 else 0.0, "min": t.min().item(), "max": t.max().item(), "abs_mean": t.abs().mean().item(), "sparsity": (t == 0).float().mean().item() * 100, "l2": t.norm(2).item(), } issues = [] if entry["std"] < 1e-7 and entry["numel"] > 10: issues.append("Near-zero var") if entry["sparsity"] > 90: issues.append(">90% sparse") if abs(entry["mean"]) > 10: issues.append("Large mean") if entry["std"] > 100: issues.append("High std") if np.isnan(entry["mean"]) or np.isinf(entry["mean"]): issues.append("NaN/Inf!") entry["issues"] = issues stats.append(entry) total_iss = sum(len(s["issues"]) for s in stats) text = "\u2696 WEIGHT STATISTICS\n\n" if total_iss == 0: text += "All weights healthy.\n\n" else: text += f"{total_iss} issues:\n" for s in stats: for iss in s["issues"]: text += f" {s['name']}: {iss}\n" text += "\n" text += f"{'Tensor':<55} {'Shape':<18} {'Params':>10} {'Mean':>10} {'Std':>10} {'Sparse':>8}\n" text += "\u2500" * 115 + "\n" for s in stats[:120]: name = s["name"] if len(s["name"]) < 53 else "\u2026" + s["name"][-52:] sh = "x".join(map(str, s["shape"])) if len(sh) > 16: sh = sh[:13] + "\u2026" text += (f"{name:<55} {sh:<18} {_fmt(s['numel']):>10} " f"{s['mean']:>10.5f} {s['std']:>10.5f} {s['sparsity']:>7.1f}%\n") if len(stats) > 120: text += f"\n\u2026 and {len(stats) - 120} more\n" return stats, text # ─────────────────────────── Plots ─────────────────────────── def _style_ax(ax): ax.set_facecolor("#161b22") ax.tick_params(colors="#8b949e", labelsize=7) for sp in ax.spines.values(): sp.set_color("#30363d") def plot_distributions(state_dict, max_n=20): candidates = [(k, v) for k, v in state_dict.items() if v.numel() > 100 and v.ndim >= 2] if not candidates: candidates = [(k, v) for k, v in state_dict.items() if v.numel() > 10] if not candidates: fig, ax = plt.subplots(figsize=(10, 3)); fig.patch.set_facecolor("#0d1117") ax.text(0.5, 0.5, "No layers for histogram", ha="center", va="center", color="#8b949e"); ax.axis("off") return fig if len(candidates) > max_n: idx = np.linspace(0, len(candidates)-1, max_n, dtype=int) candidates = [candidates[i] for i in idx] n = len(candidates); cols = min(4, n); rows = (n + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(4.5*cols, 3*rows)) fig.patch.set_facecolor("#0d1117") if n == 1: axes = np.array([[axes]]) axes = np.atleast_2d(axes) cmap = plt.cm.plasma for idx, (key, tensor) in enumerate(candidates): r, c = divmod(idx, cols); ax = axes[r, c]; _style_ax(ax) data = tensor.float().flatten().numpy() q1, q99 = np.percentile(data, [1, 99]) clipped = data[(data >= q1) & (data <= q99)] ax.hist(clipped, bins=70, color=cmap(idx/max(n-1,1)), alpha=0.85, edgecolor="none", density=True) ax.axvline(0, color="#ff6b6b", ls="--", alpha=0.4, lw=0.7) short = ".".join(key.split(".")[-2:]) if len(short) > 30: short = "\u2026" + short[-27:] ax.set_title(short, fontsize=7.5, color="#c9d1d9", pad=3) for idx in range(n, rows*cols): r, c = divmod(idx, cols); axes[r, c].axis("off") fig.suptitle("Weight Distributions", fontsize=13, color="#c9d1d9", y=1.01) plt.tight_layout(); return fig def plot_module_sizes(state_dict): mod_params = defaultdict(int) for k, t in state_dict.items(): mod_params[k.split(".")[0]] += t.numel() sorted_m = sorted(mod_params.items(), key=lambda x: -x[1])[:30] if not sorted_m: fig, ax = plt.subplots(figsize=(8,3)); fig.patch.set_facecolor("#0d1117") ax.text(0.5,0.5,"No modules",ha="center",va="center",color="#8b949e"); ax.axis("off"); return fig names = [m[0] for m in sorted_m]; counts = [m[1] for m in sorted_m] fig, ax = plt.subplots(figsize=(12, max(3, len(names)*0.35))); fig.patch.set_facecolor("#0d1117"); _style_ax(ax) colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(names))) bars = ax.barh(range(len(names)), counts, color=colors, edgecolor="none", height=0.7) ax.set_yticks(range(len(names))); ax.set_yticklabels(names, fontsize=8.5, color="#c9d1d9"); ax.invert_yaxis() ax.set_xlabel("Parameters", color="#c9d1d9"); ax.set_title("Params by Module", color="#c9d1d9", fontsize=12) for bar, cnt in zip(bars, counts): ax.text(bar.get_width()+max(counts)*0.01, bar.get_y()+bar.get_height()/2, _fmt(cnt), va="center", fontsize=7.5, color="#8b949e") plt.tight_layout(); return fig def plot_heatmap(stats): ws = [s for s in stats if s["numel"] > 10 and "weight" in s["name"]] if not ws: ws = stats if len(ws) < 2: fig, ax = plt.subplots(figsize=(8,3)); fig.patch.set_facecolor("#0d1117") ax.text(0.5,0.5,"Not enough layers",ha="center",va="center",color="#8b949e"); ax.axis("off"); return fig if len(ws) > 50: idx = np.linspace(0, len(ws)-1, 50, dtype=int); ws = [ws[i] for i in idx] metrics = ["mean","std","abs_mean","sparsity"]; labels = ["Mean","Std","Abs Mean","Sparsity%"] mat = np.array([[s[m] for m in metrics] for s in ws]) for j in range(mat.shape[1]): rng = mat[:,j].max()-mat[:,j].min() if rng > 0: mat[:,j] = (mat[:,j]-mat[:,j].min())/rng fig, ax = plt.subplots(figsize=(8, max(4, len(ws)*0.28))); fig.patch.set_facecolor("#0d1117"); _style_ax(ax) cmap = LinearSegmentedColormap.from_list("c", ["#0d1117","#238636","#f0e68c","#da3633"]) im = ax.imshow(mat, aspect="auto", cmap=cmap, interpolation="nearest") names = [".".join(s["name"].split(".")[-3:])[-35:] for s in ws] ax.set_yticks(range(len(names))); ax.set_yticklabels(names, fontsize=6, color="#c9d1d9") ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, fontsize=9, color="#c9d1d9") ax.set_title("Stats Heatmap", color="#c9d1d9", fontsize=12); plt.colorbar(im, ax=ax, shrink=0.8) plt.tight_layout(); return fig def plot_memory(state_dict): fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)); fig.patch.set_facecolor("#0d1117") dtype_mem = defaultdict(float) for t in state_dict.values(): dtype_mem[str(t.dtype)] += t.numel()*t.element_size()/1048576 labels = list(dtype_mem.keys()); sizes = list(dtype_mem.values()) colors = plt.cm.Set2(np.linspace(0,1,max(len(labels),1))) ax1.set_facecolor("#0d1117") if labels: ax1.pie(sizes, labels=labels, autopct="%1.1f%%", colors=colors, textprops={"color":"#c9d1d9","fontsize":9}) ax1.set_title("Memory by dtype", color="#c9d1d9", fontsize=12) top = sorted(state_dict.items(), key=lambda x:-x[1].numel())[:12] names = [".".join(k.split(".")[-2:])[-28:] for k,v in top] mem = [v.numel()*v.element_size()/1048576 for k,v in top] _style_ax(ax2); colors2 = plt.cm.cool(np.linspace(0.2,0.8,len(names))) bars = ax2.barh(range(len(names)), mem, color=colors2, edgecolor="none") ax2.set_yticks(range(len(names))); ax2.set_yticklabels(names, fontsize=7.5, color="#c9d1d9"); ax2.invert_yaxis() ax2.set_xlabel("MB", color="#c9d1d9"); ax2.set_title("Top 12 Largest", color="#c9d1d9", fontsize=12) for bar, m in zip(bars, mem): ax2.text(bar.get_width()+max(mem)*0.02, bar.get_y()+bar.get_height()/2, f"{m:.2f} MB", va="center", fontsize=7, color="#8b949e") plt.tight_layout(); return fig def plot_depth(state_dict): sizes = []; names = [] for k, t in state_dict.items(): if "weight" in k and t.ndim >= 2: sizes.append(t.numel()) short = ".".join(k.split(".")[-2:]) names.append(short if len(short) < 30 else "\u2026" + short[-27:]) if len(sizes) < 2: fig, ax = plt.subplots(figsize=(10,3)); fig.patch.set_facecolor("#0d1117") ax.text(0.5,0.5,"Not enough layers",ha="center",va="center",color="#8b949e"); ax.axis("off"); return fig fig, ax = plt.subplots(figsize=(14, 4.5)); fig.patch.set_facecolor("#0d1117"); _style_ax(ax) colors = plt.cm.plasma(np.linspace(0.1, 0.9, len(sizes))) ax.bar(range(len(sizes)), sizes, color=colors, edgecolor="none", width=0.8) ax.set_xlabel("Layer index", color="#c9d1d9"); ax.set_ylabel("Params", color="#c9d1d9") ax.set_title("Parameter Count Across Depth", color="#c9d1d9", fontsize=12) step = max(1, len(names)//25) ax.set_xticks(list(range(len(names)))[::step]) ax.set_xticklabels(names[::step], rotation=90, fontsize=5.5, color="#8b949e") plt.tight_layout(); return fig # ─────────────────────────── Architecture Overview ─────────────────────────── def _rbox(ax, cx, cy, w, h, label, fc, ec="#ffffff20", fs: int = 9, fw="bold", sub=None, tc="#ffffff"): from matplotlib.patches import FancyBboxPatch p = FancyBboxPatch((cx - w/2, cy - h/2), w, h, boxstyle="round,pad=0.07", fc=fc, ec=ec, lw=0.9, zorder=3, clip_on=False) ax.add_patch(p) if sub: ax.text(cx, cy + 0.12, label, ha="center", va="center", color=tc, fontsize=fs, fontweight=fw, zorder=4) ax.text(cx, cy - 0.22, sub, ha="center", va="center", color="#ffffffaa", fontsize=max(int(fs) - 2, 6), zorder=4) else: ax.text(cx, cy, label, ha="center", va="center", color=tc, fontsize=fs, fontweight=fw, zorder=4) def _arr(ax, x1, y1, x2, y2, color="#58a6ff", lw=1.5): ax.annotate("", xy=(x2, y2), xytext=(x1, y1), arrowprops=dict(arrowstyle="-|>", color=color, lw=lw, mutation_scale=12), zorder=2) def _draw_transformer_overview(arch, n_heads, n_layers, hidden, ff_dim, vocab, head_dim, max_pos, gqa, n_kv, total_p): show_h = min(max(n_heads, 1), 4) fig = plt.figure(figsize=(18, 11)) fig.patch.set_facecolor("#0d1117") ax = fig.add_axes((0.01, 0.06, 0.70, 0.92)) # main diagram ax_h = fig.add_axes((0.73, 0.58, 0.25, 0.36)) # attention heatmap ax_c = fig.add_axes((0.73, 0.06, 0.25, 0.48)) # config panel for a in [ax, ax_c]: a.set_facecolor("#0d1117"); a.axis("off") ax.set_xlim(0, 16); ax.set_ylim(0, 11) ax_c.set_xlim(0, 1); ax_c.set_ylim(0, 1) ax_h.set_facecolor("#161b22") for sp in ax_h.spines.values(): sp.set_color("#30363d") ax_h.tick_params(colors="#8b949e", labelsize=6) CQ = "#e67e22"; CK = "#c0392b"; CV = "#2980b9" CH = "#16a085"; CO = "#8e44ad"; CA = "#27ae60" from matplotlib.patches import FancyBboxPatch # ── Left sidebar: transformer block ── sx = 0.95 outer = FancyBboxPatch((0.05, 0.85), 1.8, 9.3, boxstyle="round,pad=0.1", fc="#161b22", ec="#30363d", lw=1.5, zorder=1) ax.add_patch(outer) ax.text(sx, 10.4, f"Transformer Block Γ—{n_layers}", ha="center", va="center", color="#8b949e", fontsize=7.5, style="italic", zorder=4) sb = [ (9.5, "Add &\nNorm", "#2d3748"), (8.2, "Feed\nForward", "#1a4f7a"), (6.9, "Add &\nNorm", "#2d3748"), (5.3, "Multi-Head\nAttention", "#1e3a5f"), (3.5, "Pos\nEncoding", "#1a3a2a"), (2.1, "Token\nEmbedding", "#2d3748"), (1.0, "Input", "#1a2332"), ] for sy, slbl, sc in sb: _rbox(ax, sx, sy, 1.55, 0.75, slbl, sc, fs=8, fw="bold") for i in range(len(sb) - 1): _arr(ax, sx, sb[i+1][0]+0.38, sx, sb[i][0]-0.38, color="#58a6ff60", lw=1.0) # Residual skip connection ax.annotate("", xy=(0.15, 6.9+0.38), xytext=(0.15, 5.3-0.38), arrowprops=dict(arrowstyle="-|>", color="#bc8cff", lw=1.2, connectionstyle="arc3,rad=-0.4"), zorder=2) ax.text(0.07, 6.1, "+", ha="center", va="center", color="#bc8cff", fontsize=11, fontweight="bold") # ── Q / K / V rows ── rows = [ (8.8, "Q", CQ), (5.8, "K", CK), (2.8, "V", CV), ] XIN = 2.8; XMUL = 3.65; XW = 4.5; XEQ = 5.35; XPRI = 6.15 HHSTART = 7.3; HGAP = 0.82 for row_y, lbl, color in rows: _rbox(ax, XIN, row_y, 0.95, 0.68, lbl, color, sub=f"(seq,{hidden})", fs=12, fw="bold") ax.text(XMUL, row_y, "Γ—", ha="center", va="center", color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4) _rbox(ax, XW, row_y, 1.1, 0.68, f"W$^{{{lbl}}}$", color, sub=f"({hidden},{hidden})", fs=11, fw="bold") ax.text(XEQ, row_y, "=", ha="center", va="center", color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4) _rbox(ax, XPRI, row_y, 0.95, 0.68, f"{lbl}'", color, sub=f"(seq,{hidden})", fs=11, fw="bold") _arr(ax, XPRI+0.48, row_y, HHSTART-0.20, row_y, color=color, lw=1.5) for hi in range(show_h): hx = HHSTART + hi * HGAP _rbox(ax, hx, row_y, 0.65, 0.62, f"{lbl}{hi+1}", color, fs=9) if n_heads > show_h: ax.text(HHSTART + show_h * HGAP - 0.05, row_y, "…", ha="center", va="center", color=color, fontsize=12, zorder=4) # d_k label dk_x = HHSTART + (show_h - 1) * HGAP / 2 ax.text(dk_x, rows[0][0] + 0.68, f"d_k = {head_dim}", ha="center", va="bottom", color="#f778ba", fontsize=8, zorder=4, bbox=dict(boxstyle="round,pad=0.2", fc="#0d1117", ec="#f778ba60", lw=0.8)) # d_model bracket under heads hend_x = HHSTART + (show_h - 1) * HGAP ax.annotate("", xy=(hend_x+0.33, rows[2][0]-0.68), xytext=(HHSTART-0.33, rows[2][0]-0.68), arrowprops=dict(arrowstyle="<->", color="#8b949e", lw=0.8)) ax.text(dk_x, rows[2][0]-0.84, f"d_model = {hidden}", ha="center", va="top", color="#8b949e", fontsize=7.5) # Vertical attention flow arrows between head cols for hi in range(show_h): hx = HHSTART + hi * HGAP ax.annotate("", xy=(hx, rows[1][0]+0.31), xytext=(hx, rows[0][0]-0.31), arrowprops=dict(arrowstyle="-|>", color="#ffffff22", lw=0.8), zorder=2) ax.annotate("", xy=(hx, rows[2][0]+0.31), xytext=(hx, rows[1][0]-0.31), arrowprops=dict(arrowstyle="-|>", color="#ffffff22", lw=0.8), zorder=2) # Attention formula fml_x = HHSTART + (show_h - 1) * HGAP / 2 fml_y = (rows[0][0] + rows[2][0]) / 2 ax.text(fml_x, fml_y, f"softmax(QKα΅€ / √{head_dim})", ha="center", va="center", color="#f0e68c", fontsize=9, style="italic", zorder=4, bbox=dict(boxstyle="round,pad=0.3", fc="#14101e", ec="#f0e68c50", lw=0.8)) # ── Output section: H Γ— W^O = MH-A ── out_y = 1.3 H_x = HHSTART + (show_h - 1) * HGAP / 2 for hi in range(show_h): hx = HHSTART + hi * HGAP ax.annotate("", xy=(H_x + (hi - show_h/2 + 0.5) * 0.22, out_y+0.40), xytext=(hx, rows[2][0]-0.31), arrowprops=dict(arrowstyle="-|>", color=CV, lw=0.9), zorder=2) WO_x = H_x + 1.75 MHA_x = WO_x + 1.85 _rbox(ax, H_x, out_y, 1.1, 0.65, "H", CH, sub=f"(seq,{hidden})", fs=12, fw="bold") ax.text(WO_x-0.68, out_y, "Γ—", ha="center", va="center", color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4) _rbox(ax, WO_x, out_y, 1.0, 0.65, "W$^O$", CO, sub=f"({hidden},{hidden})", fs=11, fw="bold") ax.text(WO_x+0.73, out_y, "=", ha="center", va="center", color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4) _rbox(ax, MHA_x, out_y, 1.35, 0.65, "MH-A", CA, sub=f"(seq,{hidden})", fs=11, fw="bold") _arr(ax, H_x+0.55, out_y, WO_x-0.50, out_y, color=CH, lw=1.5) _arr(ax, WO_x+0.50, out_y, MHA_x-0.68, out_y, color=CO, lw=1.5) # Formulas at bottom ax.text(7.8, 0.60, "Attention(Q,K,V) = softmax(QKα΅€/√d_k) Β· V", ha="center", va="center", color="#c9d1d9", fontsize=9, style="italic", zorder=4) ax.text(7.8, 0.22, f"MultiHead = Concat(head₁ … head_{n_heads}) Β· W\u1D3C", ha="center", va="center", color="#c9d1d9", fontsize=9, style="italic", zorder=4) # Title ax.text(8.0, 10.75, arch, ha="center", va="center", color="#58a6ff", fontsize=13, fontweight="bold", zorder=4) ax.text(8.0, 10.35, f"{_fmt(total_p)} params", ha="center", va="center", color="#8b949e", fontsize=8.5, zorder=4) # ── Attention heatmap ── seq_d = 8 np.random.seed(7) attn = np.random.dirichlet([1.5] * seq_d, seq_d) attn = 0.45 * np.eye(seq_d) + 0.55 * attn attn /= attn.sum(axis=1, keepdims=True) ax_h.imshow(attn, cmap="YlOrRd", aspect="auto", interpolation="bilinear", vmin=0) ax_h.set_title("Attention Visualization", color="#c9d1d9", fontsize=9, pad=4) # ── Config panel ── ax_c.text(0.5, 0.97, "βš™ Model Config", ha="center", va="top", color="#58a6ff", fontsize=10, fontweight="bold") ax_c.axhline(y=0.93, color="#30363d", lw=0.8) info = [ ("Architecture", arch.split("(")[0].strip()), ("Parameters", _fmt(total_p)), ("Layers", str(n_layers)), ("Hidden size", str(hidden)), ("Attn heads", f"{n_heads}" + (" (GQA)" if gqa else "")), ("Head dim", str(head_dim)), ("FF dim", _fmt(ff_dim) if ff_dim else "β€”"), ("Vocab size", f"{vocab:,}" if vocab else "β€”"), ("Max position", str(max_pos) if max_pos else "β€”"), ] if gqa and n_kv: info.insert(5, ("KV heads", str(n_kv))) for i, (k, v) in enumerate(info[:10]): y = 0.89 - i * 0.087 ax_c.text(0.03, y, k + ":", ha="left", va="center", color="#8b949e", fontsize=8.5) ax_c.text(0.97, y, v, ha="right", va="center", color="#e6edf3", fontsize=8.5, fontweight="bold") plt.tight_layout(pad=0.2) return fig def _draw_cnn_overview(state_dict, arch, config, total_p): fig = plt.figure(figsize=(18, 8)) fig.patch.set_facecolor("#0d1117") ax = fig.add_axes((0.02, 0.14, 0.96, 0.80)) ax.set_facecolor("#0d1117"); ax.axis("off") ax.set_xlim(0, 18); ax.set_ylim(0, 8) # Collect top modules top_mods = {} for k, t in state_dict.items(): top = k.split(".")[0] top_mods[top] = top_mods.get(top, 0) + t.numel() blocks = list(top_mods.items())[:12] if not blocks: blocks = [("conv1", 0), ("layer1", 0), ("fc", 0)] colors_map = {"conv": "#2980b9", "layer": "#1a5276", "block": "#1a5276", "stage": "#1a5276", "res": "#1a5276", "pool": "#27ae60", "fc": "#8e44ad", "head": "#8e44ad", "classifier": "#8e44ad", "norm": "#d68910", "bn": "#d68910"} def _block_color(name): nl = name.lower() for key, col in colors_map.items(): if key in nl: return col return "#2d3748" ax.text(9, 7.6, arch, ha="center", va="center", color="#58a6ff", fontsize=14, fontweight="bold") ax.text(9, 7.1, f"{_fmt(total_p)} parameters β€’ CNN / Vision Architecture", ha="center", va="center", color="#8b949e", fontsize=9) n = len(blocks) spacing = 15.0 / max(n, 1) _rbox(ax, 0.8, 4.5, 0.85, 0.70, "Input\nImage", "#1a3a2a", fs=8) prev_x = 1.23 for i, (name, params) in enumerate(blocks): bx = 1.5 + i * spacing + spacing * 0.5 color = _block_color(name) _rbox(ax, bx, 4.5, max(spacing * 0.82, 0.9), 0.72, name, color, sub=_fmt(params) if params else "", fs=9, fw="bold") _arr(ax, prev_x, 4.5, bx - spacing * 0.41 - 0.02, 4.5, color="#58a6ff", lw=1.5) ax.text(bx, 3.9, _block_color(name) and next((v for k, v in [ ("conv","Conv"), ("layer","Block"), ("pool","Pool"), ("fc","Linear"), ("head","Head"), ("norm","Norm"), ("res","ResBlock"), ("stage","Stage")] if k in name.lower()), "Module"), ha="center", va="center", color="#8b949e", fontsize=7) prev_x = bx + spacing * 0.41 _rbox(ax, prev_x + 0.6, 4.5, 1.0, 0.70, "Output", "#27ae60", fs=9) _arr(ax, prev_x, 4.5, prev_x + 0.1, 4.5, color="#58a6ff", lw=1.5) # Legend legend = [("Conv/Feature", "#2980b9"), ("Layer/Block", "#1a5276"), ("Pool/Down", "#27ae60"), ("FC/Head", "#8e44ad"), ("Norm", "#d68910")] for i, (lbl, lc) in enumerate(legend): _rbox(ax, 2.0 + i * 3.3, 2.2, 2.4, 0.50, lbl, lc, fs=8) ax.text(9, 1.5, "Layer type legend", ha="center", va="center", color="#8b949e", fontsize=8) plt.tight_layout(pad=0.2) return fig def _draw_generic_flow(state_dict, arch, config, total_p): fig = plt.figure(figsize=(18, 9)) fig.patch.set_facecolor("#0d1117") ax = fig.add_axes((0.02, 0.08, 0.96, 0.85)) ax.set_facecolor("#0d1117"); ax.axis("off") ax.set_xlim(0, 18); ax.set_ylim(0, 9) top_mods = {} for k, t in state_dict.items(): top = k.split(".")[0] top_mods[top] = top_mods.get(top, 0) + t.numel() sorted_mods = sorted(top_mods.items(), key=lambda x: -x[1])[:10] ax.text(9, 8.5, arch, ha="center", va="center", color="#58a6ff", fontsize=14, fontweight="bold") ax.text(9, 8.0, f"{_fmt(total_p)} parameters", ha="center", va="center", color="#8b949e", fontsize=9) pal = ["#2980b9","#e67e22","#27ae60","#8e44ad","#c0392b", "#1abc9c","#d68910","#2ecc71","#3498db","#9b59b6"] total_n = max(sum(p for _, p in sorted_mods), 1) x = 0.5 for i, (name, params) in enumerate(sorted_mods): w = max(0.9, (params / total_n) * 15) cx = x + w / 2 _rbox(ax, cx, 5.5, w - 0.12, 1.1, name, pal[i % len(pal)], sub=_fmt(params), fs=9, fw="bold") if i > 0: _arr(ax, x, 5.5, x + 0.06, 5.5, color="#58a6ff", lw=1.5) x += w top_t = sorted(state_dict.items(), key=lambda x: -x[1].numel())[:5] ax.text(9, 3.8, "Largest tensors:", ha="center", va="center", color="#58a6ff", fontsize=9, fontweight="bold") for i, (k, t) in enumerate(top_t): sh = "Γ—".join(map(str, t.shape)) short = k if len(k) < 48 else "…" + k[-46:] ax.text(9, 3.3 - i * 0.48, f"{short} [{sh}] {_fmt(t.numel())}", ha="center", va="center", color="#c9d1d9", fontsize=8) plt.tight_layout(pad=0.2) return fig def plot_architecture_overview(state_dict, detections, config): arch = detections[0]["family"] if detections else "Unknown Model" category = detections[0]["category"] if detections else "" hidden = config.get("likely_hidden_size", 768) n_layers = config.get("num_layers", 12) n_heads = config.get("num_attention_heads", 8) head_dim = config.get("head_dim", hidden // max(n_heads, 1)) ff_dim = config.get("intermediate_size", hidden * 4) vocab = config.get("vocab_size", 0) max_pos = config.get("max_position_embeddings", 0) gqa = config.get("grouped_query_attention", False) n_kv = config.get("num_kv_heads", n_heads) total_p = sum(t.numel() for t in state_dict.values()) al = (arch + " " + category).lower() is_cnn = any(x in al for x in ["resnet","convnext","efficientnet","mobilenet","vgg", "densenet","yolo","convnet","regnet","detr","convolution"]) is_diff = any(x in al for x in ["stable diffusion","unet","sdxl","flux","dit", "pixart","vqgan","stylegan","esrgan"]) if is_cnn: return _draw_cnn_overview(state_dict, arch, config, total_p) elif is_diff: return _draw_generic_flow(state_dict, arch, config, total_p) else: return _draw_transformer_overview(arch, n_heads, n_layers, hidden, ff_dim, vocab, head_dim, max_pos, gqa, n_kv, total_p) # ─────────────────────────── Architecture Graph (HTML) ─────────────────────────── def build_architecture_html(state_dict, detections, config): """ Generate an interactive SVG/HTML architecture graph. Connections are inferred purely from weight dimension matching β€” no hardcoded architecture checks. Works for any model. """ import html as _hl arch = detections[0]["family"] if detections else "Unknown" total_p = sum(t.numel() for t in state_dict.values()) # ── Colour palette (fill / border) ── TC = {"Embedding":"#1b4332","Q Proj":"#7c3400","K Proj":"#7c1010", "V Proj":"#0d3b6e","Attn Out":"#0a2a5e","FFN Up":"#2d0f57", "FFN Down":"#57103a","Norm":"#2a3441","Conv2d":"#004d5e", "Conv1d":"#003d4a","Linear":"#1b3047","Output":"#14451f","Other":"#1c2328"} BC = {"Embedding":"#40c070","Q Proj":"#fb8c00","K Proj":"#ef5350", "V Proj":"#42a5f5","Attn Out":"#1976d2","FFN Up":"#ab47bc", "FFN Down":"#ec407a","Norm":"#78909c","Conv2d":"#00bcd4", "Conv1d":"#009688","Linear":"#546e7a","Output":"#43a047","Other":"#546e7a"} def _infer(path, tensor): lk, s, nd = path.lower(), tensor.shape, tensor.ndim if nd == 4: return "Conv2d", int(s[1]), int(s[0]), f"k{s[2]}Γ—{s[3]}" if nd == 3: return "Conv1d", int(s[1]), int(s[0]), f"k{s[2]}" if nd == 2: i, o = int(s[1]), int(s[0]) end = path.rsplit(".", 1)[-1].lower() if any(x in lk for x in ["embed","token","wte","wpe","word_embed","tok_embed"]): return "Embedding", i, o, "" if "q_proj" in lk or ("query" in lk and end == "weight"): return "Q Proj", i, o, "" if "k_proj" in lk or ("key" in lk and end == "weight"): return "K Proj", i, o, "" if "v_proj" in lk or ("value" in lk and end == "weight"): return "V Proj", i, o, "" if ("o_proj" in lk or "out_proj" in lk) and any(x in lk for x in ["attn","self","mha"]): return "Attn Out", i, o, "" if any(x in lk for x in ["gate_proj","up_proj","fc1","wi_0","c_fc","intermediate"]): return "FFN Up", i, o, "" if any(x in lk for x in ["down_proj","fc2","c_proj"]) and "attn" not in lk: return "FFN Down", i, o, "" if any(x in lk for x in ["lm_head","classifier","cls","score","head.weight"]): return "Output", i, o, "" if any(x in lk for x in ["norm","ln_","rms","layer_norm","ln_f"]): return "Norm", i, o, "" return "Linear", i, o, "" if nd == 1 and any(x in path.lower() for x in ["norm","bn","ln","rms"]): return "Norm", int(s[0]), int(s[0]), "" return None, None, None, None # ── Extract weight modules ── mods, seen = [], set() for key, tensor in state_dict.items(): parts = key.rsplit(".", 1) path = parts[0] if len(parts) == 2 else key param = parts[-1] if len(parts) == 2 else key if param not in ("weight", "W", "kernel") or path in seen: continue seen.add(path) ltype, ind, outd, det = _infer(path, tensor) if ltype is None: continue pparts = path.split(".") grp = ".".join(pparts[:-1]) if len(pparts) > 1 else "__root__" mods.append({"path": path, "name": pparts[-1], "group": grp, "type": ltype, "in_dim": ind, "out_dim": outd, "params": int(tensor.numel()), "det": det}) if not mods: return "
No weight modules found.
" # ── Group by parent path ── groups = OrderedDict() for m in mods: groups.setdefault(m["group"], []).append(m) # Sample if too many groups (keep first 3, last 3, evenly spaced middle) MAX_G = 22 if len(groups) > MAX_G: gl = list(groups.items()) keep = list(range(min(3, len(gl)))) + list(range(max(len(gl)-3, 3), len(gl))) step = max(1, (len(gl) - 6) // (MAX_G - 6)) for i in range(3, len(gl) - 3, step): keep.append(i) keep = sorted(set(keep))[:MAX_G] groups = OrderedDict(gl[i] for i in keep) # Flatten and assign IDs flat = [] for gmods in groups.values(): for m in gmods: m["id"] = len(flat) flat.append(m) n_total = len(mods) # ── Find connections purely from dimension matching ── edges, seen_e = [], set() for i in range(len(flat)): a = flat[i] found_direct = False for j in range(i + 1, min(i + 30, len(flat))): b = flat[j] if (i, j) in seen_e or a["out_dim"] <= 0: continue if a["out_dim"] == b["in_dim"]: is_direct = (j == i + 1 and not found_direct) etype = "direct" if is_direct else "skip" edges.append({"from": i, "to": j, "type": etype, "dim": a["out_dim"]}) seen_e.add((i, j)) if is_direct: found_direct = True if j - i > 10 and etype == "skip": break # Don't look too far for skip connections # ── Layout: vertical swimlanes (one horizontal band per group) ── NW, NH = 192, 52 HGAP, GPX, GPY, GHH, GVGAP = 10, 12, 10, 24, 14 cy = 8 grects = {} for g, gmods in groups.items(): nm = len(gmods) gw = nm * (NW + HGAP) - HGAP + GPX * 2 gh = GHH + GPY + NH + GPY for j, m in enumerate(gmods): m["ax"] = GPX + j * (NW + HGAP) + NW // 2 # absolute x in SVG m["ay"] = cy + GHH + GPY + NH // 2 # absolute y in SVG gl = g.rsplit(".", 1)[-1] if "." in g else g grects[g] = {"y": cy, "w": gw, "h": gh, "label": (gl if len(gl) <= 38 else "…" + gl[-36:]), "full": g} cy += gh + GVGAP svg_w = max((r["w"] for r in grects.values()), default=600) + 60 svg_h = cy + 10 # ── Build SVG ── def e(s): return _hl.escape(str(s)) parts = [''' '''] # Group background bands for g, r in grects.items(): parts.append( f'' f'' f'{e(r["label"])}') # Edges (drawn before nodes) for ed in edges: a, b = flat[ed["from"]], flat[ed["to"]] col = "#58a6ff" if ed["type"] == "direct" else "#bc8cff" mk = "ma" if ed["type"] == "direct" else "ms" op = "0.75" if ed["type"] == "direct" else "0.45" dash = "" if ed["type"] == "direct" else 'stroke-dasharray="5,3"' tip = e(f"dim={ed['dim']}") same_g = a["group"] == b["group"] if same_g: x1, y1 = a["ax"] + NW // 2, a["ay"] x2, y2 = b["ax"] - NW // 2, b["ay"] mx = (x1 + x2) // 2 d = f"M{x1},{y1} C{mx},{y1} {mx},{y2} {x2},{y2}" else: x1, y1 = a["ax"], a["ay"] + NH // 2 x2, y2 = b["ax"], b["ay"] - NH // 2 dy = y2 - y1 d = f"M{x1},{y1} C{x1},{int(y1+dy*0.4)} {x2},{int(y1+dy*0.6)} {x2},{y2}" parts.append(f'{tip}') # Module nodes for m in flat: x0, y0 = m["ax"] - NW // 2, m["ay"] - NH // 2 fc = TC.get(m["type"], TC["Other"]) bc = BC.get(m["type"], BC["Other"]) nm_label = e(m["name"]) type_lbl = e(m["type"]) dim_str = f"{m['in_dim']}β†’{m['out_dim']}" if m["in_dim"] != m["out_dim"] else str(m["in_dim"]) tip = e(f"{m['path']}\nType: {m['type']}\nIn: {m['in_dim']} Out: {m['out_dim']}\nParams: {m['params']:,}") parts.append( f'' f'' f'' f'' f'{type_lbl}' f'{nm_label}' f'{dim_str}' f'{tip}') svg_body = "\n".join(parts) # Legend seen_types = list(dict.fromkeys(m["type"] for m in flat)) legend = "".join( f'' f'{e(t)}' for t in seen_types) sampled_note = f" (showing {len(flat)} of {n_total})" if n_total > len(flat) else "" return f'''
⬑ {e(arch)} β€” Architecture Graph
{_fmt(total_p)} params Β· {len(flat)} modules Β· {len(edges)} connections{sampled_note}
Edges drawn where out_dim = in_dim
─── direct   - - - skip/residual
{svg_body}
{legend}
''' # ─────────────────────────── Main ─────────────────────────── def analyze_model(file): empty = ("", None, "", None, None, None, None, None, "", "", "") if file is None: return ("", None, "Please upload a model file.") + empty[3:] try: file_path = file.name if hasattr(file, "name") else file state_dict, metadata = load_state_dict_safe(file_path) if not state_dict: return ("", None, "No tensors found.") + empty[3:] detections = detect_architectures(state_dict) config = infer_model_config(state_dict, detections) summary = build_summary(state_dict, metadata, detections, config) tree_text = build_layer_tree(state_dict) types_text = infer_all_layers(state_dict) stats, stats_text = compute_weight_stats(state_dict) arch_html = build_architecture_html(state_dict, detections, config); gc.collect() fig_overview = plot_architecture_overview(state_dict, detections, config); gc.collect() fig_dist = plot_distributions(state_dict); gc.collect() fig_sizes = plot_module_sizes(state_dict); gc.collect() fig_heat = plot_heatmap(stats); gc.collect() fig_mem = plot_memory(state_dict); gc.collect() fig_depth = plot_depth(state_dict); gc.collect() return (arch_html, fig_overview, summary, fig_dist, fig_sizes, fig_heat, fig_mem, fig_depth, tree_text, types_text, stats_text) except Exception as e: err = f"Error:\n{str(e)}\n\n{traceback.format_exc()}" return ("", None, err) + empty[3:] def build_app(): theme = gr.themes.Base( primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.purple, neutral_hue=gr.themes.colors.gray, font=gr.themes.GoogleFont("IBM Plex Sans"), font_mono=gr.themes.GoogleFont("IBM Plex Mono"), ).set( body_background_fill="#0d1117", body_background_fill_dark="#0d1117", block_background_fill="#161b22", block_background_fill_dark="#161b22", block_border_color="#30363d", block_label_text_color="#c9d1d9", input_background_fill="#0d1117", input_background_fill_dark="#0d1117", button_primary_background_fill="linear-gradient(135deg, #58a6ff 0%, #bc8cff 100%)", button_primary_text_color="#ffffff", ) with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="Neural Model Analyzer") as app: gr.HTML("""

\U0001f52c Neural Model Analyzer

Deep introspection for every model on HuggingFace — 150+ architectures · Free CPU

""") with gr.Row(): with gr.Column(scale=1): file_input = gr.File(label="Upload Model", file_types=[".pth",".pt",".bin",".safetensors",".onnx"], type="filepath") analyze_btn = gr.Button("Analyze Model", variant="primary", size="lg") gr.HTML("""
All HF model families supported: BERT, RoBERTa, DeBERTa, ALBERT, ELECTRA, GPT-2, GPT-Neo, LLaMA, Mistral, Mixtral, Phi, Qwen, Falcon, BLOOM, OPT, Mamba, RWKV, T5, BART, Pegasus, MarianMT, NLLB, ViT, DeiT, BEiT, Swin, ConvNeXt, ResNet, EfficientNet, DINOv2, DETR, YOLO, SAM, SegFormer, Mask2Former, CLIP, BLIP, LLaVA, Whisper, Wav2Vec2, HuBERT, MusicGen, Bark, Stable Diffusion, SDXL, FLUX, DiT, ControlNet, LoRA, ESM, PatchTST, GNN, and 100+ more
""") with gr.Tabs(): with gr.Tab("πŸ•Έ Architecture Graph"): arch_html_out = gr.HTML(label="Architecture Graph") with gr.Tab("πŸ”­ Model Overview"): overview_plot = gr.Plot(label="Architecture Overview") with gr.Tab("Summary & Architecture"): summary_out = gr.Textbox(label="Full Analysis", lines=40, max_lines=5000, interactive=False) with gr.Tab("Layer Tree"): tree_out = gr.Textbox(label="Hierarchy", lines=40, max_lines=5000, interactive=False) with gr.Tab("Types & Connections"): types_out = gr.Textbox(label="Types + Connections", lines=40, max_lines=5000, interactive=False) with gr.Tab("Weight Stats"): stats_out = gr.Textbox(label="Statistics", lines=40, max_lines=5000, interactive=False) with gr.Tab("Distributions"): dist_plot = gr.Plot(label="Weight Distributions") with gr.Tab("Module Sizes"): sizes_plot = gr.Plot(label="Module Parameters") with gr.Tab("Heatmap"): heat_plot = gr.Plot(label="Stats Heatmap") with gr.Tab("Memory"): mem_plot = gr.Plot(label="Memory") with gr.Tab("Depth Profile"): depth_plot = gr.Plot(label="Depth") analyze_btn.click(fn=analyze_model, inputs=[file_input], outputs=[arch_html_out, overview_plot, summary_out, dist_plot, sizes_plot, heat_plot, mem_plot, depth_plot, tree_out, types_out, stats_out]) return app if __name__ == "__main__": app = build_app() app.launch()