| """ |
| π¬ Neural Model Analyzer β Advanced Model Introspection Tool |
| Supports every model on Hugging Face. Runs on free CPU Spaces. |
| """ |
|
|
| import gradio as gr |
| import torch |
| import os |
| import gc |
| import numpy as np |
| import re |
| from collections import OrderedDict, defaultdict, Counter |
| from pathlib import Path |
| import traceback |
|
|
| try: |
| from safetensors import safe_open |
| HAS_SAFETENSORS = True |
| except ImportError: |
| HAS_SAFETENSORS = False |
|
|
| try: |
| import onnx |
| from onnx import numpy_helper |
| HAS_ONNX = True |
| except ImportError: |
| HAS_ONNX = False |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from matplotlib.colors import LinearSegmentedColormap |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| from architecture_detector import detect_architectures, infer_model_config, format_detection_report |
|
|
| |
|
|
| CUSTOM_CSS = """ |
| .gradio-container { max-width: 1400px !important; } |
| .main-header { |
| text-align: center; padding: 24px 10px 6px 10px; |
| background: linear-gradient(135deg, #0d1117 0%, #161b22 100%); |
| border-radius: 16px; margin-bottom: 10px; |
| border: 1px solid #30363d; |
| } |
| .main-header h1 { |
| font-size: 2.4em; font-weight: 800; margin: 0; |
| background: linear-gradient(135deg, #58a6ff 0%, #bc8cff 50%, #f778ba 100%); |
| -webkit-background-clip: text; -webkit-text-fill-color: transparent; |
| } |
| .main-header p { color: #8b949e; font-size: 0.95em; margin: 6px 0 0 0; } |
| .info-panel { |
| background: #0d1117; border-left: 3px solid #58a6ff; |
| padding: 10px 14px; border-radius: 0 8px 8px 0; |
| margin: 6px 0; font-size: 0.88em; color: #c9d1d9; |
| } |
| /* Ensure all textareas are fully scrollable */ |
| textarea { |
| overflow-y: auto !important; |
| resize: vertical !important; |
| } |
| footer { display: none !important; } |
| """ |
|
|
| MAX_FILE_SIZE_MB = 4096 |
|
|
|
|
| def _sizeof_fmt(num_bytes): |
| for unit in ["B", "KB", "MB", "GB"]: |
| if abs(num_bytes) < 1024: |
| return f"{num_bytes:.1f} {unit}" |
| num_bytes /= 1024 |
| return f"{num_bytes:.1f} TB" |
|
|
|
|
| def _fmt(n): |
| if n >= 1e9: return f"{n / 1e9:.2f}B" |
| if n >= 1e6: return f"{n / 1e6:.2f}M" |
| if n >= 1e3: return f"{n / 1e3:.1f}K" |
| return str(n) |
|
|
|
|
| |
|
|
| def load_state_dict_safe(file_path: str) -> tuple: |
| ext = Path(file_path).suffix.lower() |
| file_size = os.path.getsize(file_path) |
| metadata = { |
| "format": ext, |
| "file_size_bytes": file_size, |
| "file_size_str": _sizeof_fmt(file_size), |
| "file_name": Path(file_path).name, |
| } |
| state_dict = OrderedDict() |
|
|
| if file_size > MAX_FILE_SIZE_MB * 1024 * 1024: |
| raise ValueError( |
| f"File is {_sizeof_fmt(file_size)} β exceeds {MAX_FILE_SIZE_MB} MB limit for free CPU." |
| ) |
|
|
| gc.collect() |
|
|
| if ext in (".pth", ".pt"): |
| try: |
| data = torch.load(file_path, map_location="cpu", weights_only=False) |
| except Exception: |
| data = torch.load(file_path, map_location="cpu", weights_only=True) |
|
|
| if isinstance(data, torch.nn.Module): |
| metadata["has_full_model"] = True |
| metadata["model_repr"] = repr(data)[:5000] |
| state_dict = OrderedDict(data.state_dict()) |
| elif isinstance(data, dict): |
| for ck in ["state_dict", "model_state_dict", "model", "net", "params", |
| "network", "g_ema", "gen", "generator", "discriminator", |
| "ema_model", "module", "teacher", "student"]: |
| if ck in data: |
| candidate = data[ck] |
| if isinstance(candidate, dict): |
| state_dict = OrderedDict(candidate) |
| metadata["checkpoint_key_used"] = ck |
| break |
| elif isinstance(candidate, torch.nn.Module): |
| state_dict = OrderedDict(candidate.state_dict()) |
| metadata["checkpoint_key_used"] = ck |
| metadata["has_full_model"] = True |
| break |
|
|
| if not state_dict: |
| tensors = {k: v for k, v in data.items() if isinstance(v, torch.Tensor)} |
| if tensors: |
| state_dict = OrderedDict(tensors) |
| else: |
| for top_k, top_v in data.items(): |
| if isinstance(top_v, dict): |
| sub = {f"{top_k}.{k}": v for k, v in top_v.items() |
| if isinstance(v, torch.Tensor)} |
| state_dict.update(sub) |
|
|
| extra = [k for k in data.keys() |
| if not isinstance(data[k], (torch.Tensor, torch.nn.Module, dict))] |
| if extra: |
| metadata["extra_checkpoint_keys"] = extra[:20] |
| for ik in ["epoch", "global_step", "step", "iteration", "best_metric", "best_loss"]: |
| if ik in data: |
| metadata[ik] = data[ik] |
| if any(k in data for k in ["optimizer", "optimizer_state_dict", "optim"]): |
| metadata["has_optimizer_state"] = True |
| if any(k in data for k in ["scheduler", "lr_scheduler"]): |
| metadata["has_lr_scheduler"] = True |
| if "config" in data: |
| metadata["embedded_config"] = str(data["config"])[:3000] |
| if "args" in data: |
| metadata["training_args"] = str(data["args"])[:2000] |
|
|
| del data; gc.collect() |
| else: |
| metadata["note"] = f"File contains {type(data).__name__}." |
|
|
| elif ext == ".bin": |
| data = torch.load(file_path, map_location="cpu", weights_only=False) |
| if isinstance(data, dict): |
| state_dict = OrderedDict({k: v for k, v in data.items() if isinstance(v, torch.Tensor)}) |
| metadata["format_note"] = "PyTorch .bin (HuggingFace style)" |
| del data; gc.collect() |
|
|
| elif ext in (".safetensors",): |
| if not HAS_SAFETENSORS: |
| raise ImportError("safetensors library not installed") |
| with safe_open(file_path, framework="pt", device="cpu") as f: |
| for key in f.keys(): |
| state_dict[key] = f.get_tensor(key) |
| try: |
| sf_meta = f.metadata() |
| if sf_meta: |
| metadata["safetensors_metadata"] = dict(sf_meta) |
| except Exception: |
| pass |
| metadata["format_note"] = "SafeTensors" |
|
|
| elif ext == ".onnx": |
| if not HAS_ONNX: |
| raise ImportError("onnx library not installed") |
| model = onnx.load(file_path) |
| for init in model.graph.initializer: |
| arr = numpy_helper.to_array(init) |
| state_dict[init.name] = torch.from_numpy(arr.copy()) |
| metadata["format_note"] = "ONNX" |
| metadata["onnx_opset"] = [op.version for op in model.opset_import] |
| metadata["onnx_ir_version"] = model.ir_version |
| metadata["onnx_inputs"] = [ |
| {"name": inp.name, |
| "shape": [d.dim_value if d.dim_value else (d.dim_param or "?") |
| for d in inp.type.tensor_type.shape.dim]} |
| for inp in model.graph.input |
| ] |
| metadata["onnx_outputs"] = [ |
| {"name": out.name, |
| "shape": [d.dim_value if d.dim_value else (d.dim_param or "?") |
| for d in out.type.tensor_type.shape.dim]} |
| for out in model.graph.output |
| ] |
| op_counts = Counter(node.op_type for node in model.graph.node) |
| metadata["onnx_ops"] = dict(op_counts.most_common(30)) |
| del model; gc.collect() |
|
|
| else: |
| raise ValueError(f"Unsupported: '{ext}'. Use .pth .pt .bin .safetensors .onnx") |
|
|
| cleaned = OrderedDict() |
| for k, v in state_dict.items(): |
| cleaned[re.sub(r"^(module\.|model\.module\.)", "", k)] = v |
| state_dict = cleaned |
|
|
| return state_dict, metadata |
|
|
|
|
| |
|
|
| def build_summary(state_dict, metadata, detections, config) -> str: |
| total_params = sum(t.numel() for t in state_dict.values()) |
| total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values()) |
| dtypes = sorted(set(str(t.dtype) for t in state_dict.values())) |
|
|
| s = "+" + "=" * 63 + "+\n" |
| s += "| \U0001f52c NEURAL MODEL ANALYSIS REPORT |\n" |
| s += "+" + "=" * 63 + "+\n\n" |
|
|
| s += "\U0001f4c1 FILE INFO\n" |
| s += f" File name: {metadata.get('file_name', '?')}\n" |
| s += f" File format: {metadata.get('format', '?')} {metadata.get('format_note', '')}\n" |
| s += f" File size: {metadata.get('file_size_str', '?')}\n" |
| if metadata.get("checkpoint_key_used"): |
| s += f" Checkpoint key: '{metadata['checkpoint_key_used']}'\n" |
| if metadata.get("has_full_model"): |
| s += f" Full model: Yes (nn.Module)\n" |
| if metadata.get("has_optimizer_state"): |
| s += f" Optimizer: Included\n" |
| if metadata.get("has_lr_scheduler"): |
| s += f" LR Scheduler: Included\n" |
| for k in ["epoch", "global_step", "step", "best_metric", "best_loss"]: |
| if k in metadata: |
| s += f" {k}: {metadata[k]}\n" |
|
|
| s += f"\n\U0001f4ca PARAMETER OVERVIEW\n" |
| s += f" Total parameters: {_fmt(total_params)} ({total_params:,})\n" |
| s += f" Weight memory: {_sizeof_fmt(total_bytes)}\n" |
| s += f" Number of tensors: {len(state_dict):,}\n" |
| s += f" Data types: {', '.join(dtypes)}\n" |
|
|
| trainable = sum(t.numel() for k, t in state_dict.items() |
| if "running_" not in k and "num_batches" not in k) |
| non_train = total_params - trainable |
| if non_train > 0: |
| s += f" Trainable (est.): {_fmt(trainable)}\n" |
| s += f" Non-trainable: {_fmt(non_train)}\n" |
|
|
| s += f"\n{format_detection_report(detections, config)}" |
|
|
| |
| if metadata.get("onnx_inputs"): |
| s += "\n\U0001f4e5 ONNX INPUTS\n" |
| for inp in metadata["onnx_inputs"]: |
| s += f" {inp['name']}: {inp['shape']}\n" |
| if metadata.get("onnx_outputs"): |
| s += "\n\U0001f4e4 ONNX OUTPUTS\n" |
| for out in metadata["onnx_outputs"]: |
| s += f" {out['name']}: {out['shape']}\n" |
| if metadata.get("onnx_ops"): |
| s += "\n\U0001f527 ONNX OPS (top 15)\n" |
| for op, cnt in list(metadata["onnx_ops"].items())[:15]: |
| s += f" {op}: {cnt}\n" |
|
|
| |
| if not metadata.get("onnx_inputs"): |
| s += "\n\U0001f4e5 INFERRED INPUT\n" |
| for k, t in state_dict.items(): |
| lk = k.lower() |
| if any(x in lk for x in ["embed_tokens.weight", "word_embed", "wte.weight", |
| "patch_embed.proj.weight", "conv1.weight", |
| "feature_extractor.conv"]): |
| if t.ndim == 2: |
| s += f" Token input: vocab={t.shape[0]:,}, dim={t.shape[1]}\n" |
| s += f" Format: [batch, seq_len] (token IDs)\n" |
| elif t.ndim == 4: |
| s += f" Image input: {t.shape[1]} channels, kernel {t.shape[2]}x{t.shape[3]}\n" |
| ch = t.shape[1] |
| s += f" Format: [batch, {ch}, H, W] ({'RGB' if ch == 3 else 'grayscale' if ch == 1 else f'{ch}-ch'})\n" |
| break |
|
|
| s += "\n\U0001f4e4 INFERRED OUTPUT\n" |
| for k in reversed(list(state_dict.keys())): |
| t = state_dict[k] |
| lk = k.lower() |
| if any(x in lk for x in ["lm_head.weight", "classifier.weight", "cls.predictions", |
| "qa_outputs.weight", "score.weight", "head.weight"]): |
| if t.ndim == 2: |
| s += f" Output: {k}\n" |
| s += f" Shape: [{t.shape[0]:,} classes/tokens, {t.shape[1]:,} hidden]\n" |
| elif t.ndim == 1: |
| s += f" Output bias: {k} -> {t.shape[0]:,} outputs\n" |
| break |
|
|
| if metadata.get("embedded_config"): |
| s += f"\n\U0001f4cb EMBEDDED CONFIG\n {metadata['embedded_config'][:1200]}\n" |
| if metadata.get("model_repr"): |
| s += f"\n\U0001f3d7 MODEL REPR (nn.Module)\n{metadata['model_repr'][:3000]}\n" |
|
|
| return s |
|
|
|
|
| def build_layer_tree(state_dict) -> str: |
| def _count(tree): |
| total = 0 |
| for v in tree.values(): |
| if isinstance(v, dict) and "numel" in v: |
| total += v["numel"] |
| elif isinstance(v, dict): |
| total += _count(v) |
| return total |
|
|
| tree = OrderedDict() |
| for key, tensor in state_dict.items(): |
| parts = key.split(".") |
| node = tree |
| for part in parts[:-1]: |
| if part not in node: |
| node[part] = OrderedDict() |
| node = node[part] |
| node[parts[-1]] = { |
| "shape": list(tensor.shape), "dtype": str(tensor.dtype), |
| "numel": tensor.numel(), |
| "mb": tensor.numel() * tensor.element_size() / 1048576, |
| } |
|
|
| lines = [] |
| def _render(subtree, prefix="", depth=0): |
| items = list(subtree.items()) |
| for i, (key, val) in enumerate(items): |
| last = (i == len(items) - 1) |
| conn = "βββ " if last else "βββ " |
| ext = " " if last else "β " |
| if isinstance(val, dict) and "shape" in val: |
| sh = "x".join(map(str, val["shape"])) |
| sz = f"{val['mb']:.3f} MB" if val["mb"] >= 0.001 else f"{val['numel']} el" |
| lines.append(f"{prefix}{conn}\U0001f539 {key} [{sh}] {val['dtype']} ({sz})") |
| elif isinstance(val, dict): |
| cnt = _count(val) |
| cnt_s = f" ({_fmt(cnt)} params)" if cnt > 0 else "" |
| lines.append(f"{prefix}{conn}\U0001f4e6 {key}{cnt_s}") |
| _render(val, prefix + ext, depth + 1) |
|
|
| _render(tree) |
| return "\n".join(lines) if lines else "No hierarchy found." |
|
|
|
|
| def infer_all_layers(state_dict) -> str: |
| layers = OrderedDict() |
| seen = set() |
|
|
| for key, tensor in state_dict.items(): |
| parts = key.rsplit(".", 1) |
| mod = parts[0] if len(parts) == 2 else "" |
| param = parts[-1] |
| shape = tensor.shape |
| ndim = tensor.ndim |
| ml = mod.lower() |
|
|
| if mod in seen: |
| continue |
| seen.add(mod) |
|
|
| ltype = "Unknown" |
| details = {} |
|
|
| if param == "weight": |
| if ndim == 4: |
| ltype = "Conv2d" |
| details = {"out": shape[0], "in": shape[1], "k": f"{shape[2]}x{shape[3]}"} |
| if shape[1] == 1: ltype = "DepthwiseConv2d" |
| if any(x in ml for x in ["transpose", "deconv"]): ltype = "ConvTranspose2d" |
| elif ndim == 3: |
| ltype = "Conv1d" |
| details = {"out": shape[0], "in": shape[1], "k": shape[2]} |
| elif ndim == 5: |
| ltype = "Conv3d" |
| details = {"out": shape[0], "in": shape[1]} |
| elif ndim == 2: |
| if any(x in ml for x in ["embed", "token", "wte", "wpe"]): |
| ltype = "Embedding" |
| details = {"vocab": shape[0], "dim": shape[1]} |
| elif any(x in ml for x in ["norm", "ln_", "layernorm", "rmsnorm"]): |
| ltype = "Norm" |
| details = {"features": shape[0]} |
| elif any(x in ml for x in ["q_proj", "k_proj", "v_proj", "qkv", "query", "key", "value"]): |
| ltype = "Attention Projection" |
| details = {"out": shape[0], "in": shape[1]} |
| elif any(x in ml for x in ["o_proj", "out_proj"]) and "attn" in ml: |
| ltype = "Attention Output" |
| details = {"out": shape[0], "in": shape[1]} |
| elif any(x in ml for x in ["gate_proj", "up_proj", "wi", "fc1", "c_fc", "intermediate"]): |
| ltype = "FFN Up/Gate" |
| details = {"out": shape[0], "in": shape[1]} |
| elif any(x in ml for x in ["down_proj", "wo", "fc2", "c_proj"]) and "attn" not in ml: |
| ltype = "FFN Down" |
| details = {"out": shape[0], "in": shape[1]} |
| elif any(x in ml for x in ["lm_head", "classifier", "cls", "score", "qa_output"]): |
| ltype = "Head / Classifier" |
| details = {"classes": shape[0], "hidden": shape[1]} |
| else: |
| ltype = "Linear" |
| details = {"out": shape[0], "in": shape[1]} |
| elif ndim == 1: |
| if any(x in ml for x in ["norm", "bn", "ln"]): |
| ltype = "Norm" |
| details = {"features": shape[0]} |
| else: |
| ltype = "Bias/Scale" |
| details = {"dim": shape[0]} |
| elif param == "bias": |
| ltype = "Bias"; details = {"dim": shape[0]} |
| elif "running_mean" in param: |
| ltype = "BatchNorm stats" |
| elif "A_log" in param or param == "D": |
| ltype = "Mamba SSM" |
| elif "inv_freq" in param: |
| ltype = "Rotary inv_freq" |
|
|
| layers[mod] = {"type": ltype, "details": details} |
|
|
| text = "\U0001f9e9 INFERRED LAYER TYPES\n\n" |
| tc = Counter(v["type"] for v in layers.values()) |
| text += "DISTRIBUTION:\n" |
| for lt, cnt in tc.most_common(25): |
| bar = "\u2588" * min(cnt, 40) |
| text += f" {lt:<30} {cnt:>5} {bar}\n" |
|
|
| text += f"\nDETAILED ({len(layers)} modules):\n" |
| text += f"{'Module':<60} {'Type':<25} {'Details'}\n" |
| text += "\u2500" * 115 + "\n" |
| for path, info in list(layers.items())[:150]: |
| short = path if len(path) < 58 else "\u2026" + path[-56:] |
| det = ", ".join(f"{k}={v}" for k, v in info["details"].items()) |
| text += f"{short:<60} {info['type']:<25} {det}\n" |
| if len(layers) > 150: |
| text += f"\n \u2026 and {len(layers) - 150} more\n" |
|
|
| |
| text += "\n\n\U0001f4e1 LAYER CONNECTIONS\n" + "\u2500" * 60 + "\n" |
| io_dims = [] |
| for key, tensor in state_dict.items(): |
| parts = key.rsplit(".", 1) |
| mod = parts[0] if len(parts) == 2 else "" |
| if parts[-1] == "weight" and tensor.ndim >= 2: |
| io_dims.append((mod, tensor.shape[1] if tensor.ndim == 2 else tensor.shape[1], |
| tensor.shape[0])) |
|
|
| direct = reshape = 0 |
| for i in range(len(io_dims) - 1): |
| if io_dims[i][2] == io_dims[i + 1][1]: |
| direct += 1 |
| else: |
| reshape += 1 |
|
|
| residual = [] |
| seen_p = set() |
| for i in range(len(io_dims)): |
| for j in range(i + 2, min(i + 12, len(io_dims))): |
| if io_dims[i][2] == io_dims[j][1]: |
| pair = (io_dims[i][0].split(".")[0], io_dims[j][0].split(".")[0]) |
| if pair not in seen_p: |
| residual.append((io_dims[i][0], io_dims[j][0], io_dims[i][2])) |
| seen_p.add(pair) |
|
|
| text += f" Direct (dim match): {direct}\n" |
| text += f" Reshape (dim change): {reshape}\n" |
| text += f" Possible skip/residual: {len(residual)}\n\n" |
| if residual: |
| text += " SKIP/RESIDUAL CANDIDATES:\n" |
| for src, dst, dim in residual[:30]: |
| text += f" {src} ...({dim})...> {dst}\n" |
| if len(residual) > 30: |
| text += f" \u2026 and {len(residual) - 30} more\n" |
|
|
| return text |
|
|
|
|
| def compute_weight_stats(state_dict) -> tuple: |
| stats = [] |
| for key, tensor in state_dict.items(): |
| if tensor.numel() == 0: continue |
| t = tensor.float() |
| entry = { |
| "name": key, "shape": list(tensor.shape), "dtype": str(tensor.dtype), |
| "numel": tensor.numel(), |
| "mb": tensor.numel() * tensor.element_size() / 1048576, |
| "mean": t.mean().item(), |
| "std": t.std().item() if t.numel() > 1 else 0.0, |
| "min": t.min().item(), "max": t.max().item(), |
| "abs_mean": t.abs().mean().item(), |
| "sparsity": (t == 0).float().mean().item() * 100, |
| "l2": t.norm(2).item(), |
| } |
| issues = [] |
| if entry["std"] < 1e-7 and entry["numel"] > 10: issues.append("Near-zero var") |
| if entry["sparsity"] > 90: issues.append(">90% sparse") |
| if abs(entry["mean"]) > 10: issues.append("Large mean") |
| if entry["std"] > 100: issues.append("High std") |
| if np.isnan(entry["mean"]) or np.isinf(entry["mean"]): issues.append("NaN/Inf!") |
| entry["issues"] = issues |
| stats.append(entry) |
|
|
| total_iss = sum(len(s["issues"]) for s in stats) |
| text = "\u2696 WEIGHT STATISTICS\n\n" |
| if total_iss == 0: |
| text += "All weights healthy.\n\n" |
| else: |
| text += f"{total_iss} issues:\n" |
| for s in stats: |
| for iss in s["issues"]: |
| text += f" {s['name']}: {iss}\n" |
| text += "\n" |
|
|
| text += f"{'Tensor':<55} {'Shape':<18} {'Params':>10} {'Mean':>10} {'Std':>10} {'Sparse':>8}\n" |
| text += "\u2500" * 115 + "\n" |
| for s in stats[:120]: |
| name = s["name"] if len(s["name"]) < 53 else "\u2026" + s["name"][-52:] |
| sh = "x".join(map(str, s["shape"])) |
| if len(sh) > 16: sh = sh[:13] + "\u2026" |
| text += (f"{name:<55} {sh:<18} {_fmt(s['numel']):>10} " |
| f"{s['mean']:>10.5f} {s['std']:>10.5f} {s['sparsity']:>7.1f}%\n") |
| if len(stats) > 120: |
| text += f"\n\u2026 and {len(stats) - 120} more\n" |
|
|
| return stats, text |
|
|
|
|
| |
|
|
| def _style_ax(ax): |
| ax.set_facecolor("#161b22") |
| ax.tick_params(colors="#8b949e", labelsize=7) |
| for sp in ax.spines.values(): sp.set_color("#30363d") |
|
|
|
|
| def plot_distributions(state_dict, max_n=20): |
| candidates = [(k, v) for k, v in state_dict.items() if v.numel() > 100 and v.ndim >= 2] |
| if not candidates: |
| candidates = [(k, v) for k, v in state_dict.items() if v.numel() > 10] |
| if not candidates: |
| fig, ax = plt.subplots(figsize=(10, 3)); fig.patch.set_facecolor("#0d1117") |
| ax.text(0.5, 0.5, "No layers for histogram", ha="center", va="center", color="#8b949e"); ax.axis("off") |
| return fig |
| if len(candidates) > max_n: |
| idx = np.linspace(0, len(candidates)-1, max_n, dtype=int) |
| candidates = [candidates[i] for i in idx] |
| n = len(candidates); cols = min(4, n); rows = (n + cols - 1) // cols |
| fig, axes = plt.subplots(rows, cols, figsize=(4.5*cols, 3*rows)) |
| fig.patch.set_facecolor("#0d1117") |
| if n == 1: axes = np.array([[axes]]) |
| axes = np.atleast_2d(axes) |
| cmap = plt.cm.plasma |
| for idx, (key, tensor) in enumerate(candidates): |
| r, c = divmod(idx, cols); ax = axes[r, c]; _style_ax(ax) |
| data = tensor.float().flatten().numpy() |
| q1, q99 = np.percentile(data, [1, 99]) |
| clipped = data[(data >= q1) & (data <= q99)] |
| ax.hist(clipped, bins=70, color=cmap(idx/max(n-1,1)), alpha=0.85, edgecolor="none", density=True) |
| ax.axvline(0, color="#ff6b6b", ls="--", alpha=0.4, lw=0.7) |
| short = ".".join(key.split(".")[-2:]) |
| if len(short) > 30: short = "\u2026" + short[-27:] |
| ax.set_title(short, fontsize=7.5, color="#c9d1d9", pad=3) |
| for idx in range(n, rows*cols): |
| r, c = divmod(idx, cols); axes[r, c].axis("off") |
| fig.suptitle("Weight Distributions", fontsize=13, color="#c9d1d9", y=1.01) |
| plt.tight_layout(); return fig |
|
|
|
|
| def plot_module_sizes(state_dict): |
| mod_params = defaultdict(int) |
| for k, t in state_dict.items(): mod_params[k.split(".")[0]] += t.numel() |
| sorted_m = sorted(mod_params.items(), key=lambda x: -x[1])[:30] |
| if not sorted_m: |
| fig, ax = plt.subplots(figsize=(8,3)); fig.patch.set_facecolor("#0d1117") |
| ax.text(0.5,0.5,"No modules",ha="center",va="center",color="#8b949e"); ax.axis("off"); return fig |
| names = [m[0] for m in sorted_m]; counts = [m[1] for m in sorted_m] |
| fig, ax = plt.subplots(figsize=(12, max(3, len(names)*0.35))); fig.patch.set_facecolor("#0d1117"); _style_ax(ax) |
| colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(names))) |
| bars = ax.barh(range(len(names)), counts, color=colors, edgecolor="none", height=0.7) |
| ax.set_yticks(range(len(names))); ax.set_yticklabels(names, fontsize=8.5, color="#c9d1d9"); ax.invert_yaxis() |
| ax.set_xlabel("Parameters", color="#c9d1d9"); ax.set_title("Params by Module", color="#c9d1d9", fontsize=12) |
| for bar, cnt in zip(bars, counts): |
| ax.text(bar.get_width()+max(counts)*0.01, bar.get_y()+bar.get_height()/2, _fmt(cnt), va="center", fontsize=7.5, color="#8b949e") |
| plt.tight_layout(); return fig |
|
|
|
|
| def plot_heatmap(stats): |
| ws = [s for s in stats if s["numel"] > 10 and "weight" in s["name"]] |
| if not ws: ws = stats |
| if len(ws) < 2: |
| fig, ax = plt.subplots(figsize=(8,3)); fig.patch.set_facecolor("#0d1117") |
| ax.text(0.5,0.5,"Not enough layers",ha="center",va="center",color="#8b949e"); ax.axis("off"); return fig |
| if len(ws) > 50: |
| idx = np.linspace(0, len(ws)-1, 50, dtype=int); ws = [ws[i] for i in idx] |
| metrics = ["mean","std","abs_mean","sparsity"]; labels = ["Mean","Std","Abs Mean","Sparsity%"] |
| mat = np.array([[s[m] for m in metrics] for s in ws]) |
| for j in range(mat.shape[1]): |
| rng = mat[:,j].max()-mat[:,j].min() |
| if rng > 0: mat[:,j] = (mat[:,j]-mat[:,j].min())/rng |
| fig, ax = plt.subplots(figsize=(8, max(4, len(ws)*0.28))); fig.patch.set_facecolor("#0d1117"); _style_ax(ax) |
| cmap = LinearSegmentedColormap.from_list("c", ["#0d1117","#238636","#f0e68c","#da3633"]) |
| im = ax.imshow(mat, aspect="auto", cmap=cmap, interpolation="nearest") |
| names = [".".join(s["name"].split(".")[-3:])[-35:] for s in ws] |
| ax.set_yticks(range(len(names))); ax.set_yticklabels(names, fontsize=6, color="#c9d1d9") |
| ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, fontsize=9, color="#c9d1d9") |
| ax.set_title("Stats Heatmap", color="#c9d1d9", fontsize=12); plt.colorbar(im, ax=ax, shrink=0.8) |
| plt.tight_layout(); return fig |
|
|
|
|
| def plot_memory(state_dict): |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)); fig.patch.set_facecolor("#0d1117") |
| dtype_mem = defaultdict(float) |
| for t in state_dict.values(): dtype_mem[str(t.dtype)] += t.numel()*t.element_size()/1048576 |
| labels = list(dtype_mem.keys()); sizes = list(dtype_mem.values()) |
| colors = plt.cm.Set2(np.linspace(0,1,max(len(labels),1))) |
| ax1.set_facecolor("#0d1117") |
| if labels: ax1.pie(sizes, labels=labels, autopct="%1.1f%%", colors=colors, textprops={"color":"#c9d1d9","fontsize":9}) |
| ax1.set_title("Memory by dtype", color="#c9d1d9", fontsize=12) |
| top = sorted(state_dict.items(), key=lambda x:-x[1].numel())[:12] |
| names = [".".join(k.split(".")[-2:])[-28:] for k,v in top] |
| mem = [v.numel()*v.element_size()/1048576 for k,v in top] |
| _style_ax(ax2); colors2 = plt.cm.cool(np.linspace(0.2,0.8,len(names))) |
| bars = ax2.barh(range(len(names)), mem, color=colors2, edgecolor="none") |
| ax2.set_yticks(range(len(names))); ax2.set_yticklabels(names, fontsize=7.5, color="#c9d1d9"); ax2.invert_yaxis() |
| ax2.set_xlabel("MB", color="#c9d1d9"); ax2.set_title("Top 12 Largest", color="#c9d1d9", fontsize=12) |
| for bar, m in zip(bars, mem): |
| ax2.text(bar.get_width()+max(mem)*0.02, bar.get_y()+bar.get_height()/2, f"{m:.2f} MB", va="center", fontsize=7, color="#8b949e") |
| plt.tight_layout(); return fig |
|
|
|
|
| def plot_depth(state_dict): |
| sizes = []; names = [] |
| for k, t in state_dict.items(): |
| if "weight" in k and t.ndim >= 2: |
| sizes.append(t.numel()) |
| short = ".".join(k.split(".")[-2:]) |
| names.append(short if len(short) < 30 else "\u2026" + short[-27:]) |
| if len(sizes) < 2: |
| fig, ax = plt.subplots(figsize=(10,3)); fig.patch.set_facecolor("#0d1117") |
| ax.text(0.5,0.5,"Not enough layers",ha="center",va="center",color="#8b949e"); ax.axis("off"); return fig |
| fig, ax = plt.subplots(figsize=(14, 4.5)); fig.patch.set_facecolor("#0d1117"); _style_ax(ax) |
| colors = plt.cm.plasma(np.linspace(0.1, 0.9, len(sizes))) |
| ax.bar(range(len(sizes)), sizes, color=colors, edgecolor="none", width=0.8) |
| ax.set_xlabel("Layer index", color="#c9d1d9"); ax.set_ylabel("Params", color="#c9d1d9") |
| ax.set_title("Parameter Count Across Depth", color="#c9d1d9", fontsize=12) |
| step = max(1, len(names)//25) |
| ax.set_xticks(list(range(len(names)))[::step]) |
| ax.set_xticklabels(names[::step], rotation=90, fontsize=5.5, color="#8b949e") |
| plt.tight_layout(); return fig |
|
|
|
|
| |
|
|
| def _rbox(ax, cx, cy, w, h, label, fc, ec="#ffffff20", fs: int = 9, fw="bold", sub=None, tc="#ffffff"): |
| from matplotlib.patches import FancyBboxPatch |
| p = FancyBboxPatch((cx - w/2, cy - h/2), w, h, |
| boxstyle="round,pad=0.07", fc=fc, ec=ec, lw=0.9, zorder=3, clip_on=False) |
| ax.add_patch(p) |
| if sub: |
| ax.text(cx, cy + 0.12, label, ha="center", va="center", |
| color=tc, fontsize=fs, fontweight=fw, zorder=4) |
| ax.text(cx, cy - 0.22, sub, ha="center", va="center", |
| color="#ffffffaa", fontsize=max(int(fs) - 2, 6), zorder=4) |
| else: |
| ax.text(cx, cy, label, ha="center", va="center", |
| color=tc, fontsize=fs, fontweight=fw, zorder=4) |
|
|
|
|
| def _arr(ax, x1, y1, x2, y2, color="#58a6ff", lw=1.5): |
| ax.annotate("", xy=(x2, y2), xytext=(x1, y1), |
| arrowprops=dict(arrowstyle="-|>", color=color, lw=lw, mutation_scale=12), zorder=2) |
|
|
|
|
| def _draw_transformer_overview(arch, n_heads, n_layers, hidden, ff_dim, |
| vocab, head_dim, max_pos, gqa, n_kv, total_p): |
| show_h = min(max(n_heads, 1), 4) |
|
|
| fig = plt.figure(figsize=(18, 11)) |
| fig.patch.set_facecolor("#0d1117") |
|
|
| ax = fig.add_axes((0.01, 0.06, 0.70, 0.92)) |
| ax_h = fig.add_axes((0.73, 0.58, 0.25, 0.36)) |
| ax_c = fig.add_axes((0.73, 0.06, 0.25, 0.48)) |
|
|
| for a in [ax, ax_c]: |
| a.set_facecolor("#0d1117"); a.axis("off") |
| ax.set_xlim(0, 16); ax.set_ylim(0, 11) |
| ax_c.set_xlim(0, 1); ax_c.set_ylim(0, 1) |
| ax_h.set_facecolor("#161b22") |
| for sp in ax_h.spines.values(): sp.set_color("#30363d") |
| ax_h.tick_params(colors="#8b949e", labelsize=6) |
|
|
| CQ = "#e67e22"; CK = "#c0392b"; CV = "#2980b9" |
| CH = "#16a085"; CO = "#8e44ad"; CA = "#27ae60" |
|
|
| from matplotlib.patches import FancyBboxPatch |
| |
| sx = 0.95 |
| outer = FancyBboxPatch((0.05, 0.85), 1.8, 9.3, |
| boxstyle="round,pad=0.1", fc="#161b22", ec="#30363d", lw=1.5, zorder=1) |
| ax.add_patch(outer) |
| ax.text(sx, 10.4, f"Transformer Block Γ{n_layers}", |
| ha="center", va="center", color="#8b949e", fontsize=7.5, style="italic", zorder=4) |
|
|
| sb = [ |
| (9.5, "Add &\nNorm", "#2d3748"), |
| (8.2, "Feed\nForward", "#1a4f7a"), |
| (6.9, "Add &\nNorm", "#2d3748"), |
| (5.3, "Multi-Head\nAttention", "#1e3a5f"), |
| (3.5, "Pos\nEncoding", "#1a3a2a"), |
| (2.1, "Token\nEmbedding", "#2d3748"), |
| (1.0, "Input", "#1a2332"), |
| ] |
| for sy, slbl, sc in sb: |
| _rbox(ax, sx, sy, 1.55, 0.75, slbl, sc, fs=8, fw="bold") |
| for i in range(len(sb) - 1): |
| _arr(ax, sx, sb[i+1][0]+0.38, sx, sb[i][0]-0.38, color="#58a6ff60", lw=1.0) |
|
|
| |
| ax.annotate("", xy=(0.15, 6.9+0.38), xytext=(0.15, 5.3-0.38), |
| arrowprops=dict(arrowstyle="-|>", color="#bc8cff", lw=1.2, |
| connectionstyle="arc3,rad=-0.4"), zorder=2) |
| ax.text(0.07, 6.1, "+", ha="center", va="center", |
| color="#bc8cff", fontsize=11, fontweight="bold") |
|
|
| |
| rows = [ |
| (8.8, "Q", CQ), |
| (5.8, "K", CK), |
| (2.8, "V", CV), |
| ] |
| XIN = 2.8; XMUL = 3.65; XW = 4.5; XEQ = 5.35; XPRI = 6.15 |
| HHSTART = 7.3; HGAP = 0.82 |
|
|
| for row_y, lbl, color in rows: |
| _rbox(ax, XIN, row_y, 0.95, 0.68, lbl, color, |
| sub=f"(seq,{hidden})", fs=12, fw="bold") |
| ax.text(XMUL, row_y, "Γ", ha="center", va="center", |
| color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4) |
| _rbox(ax, XW, row_y, 1.1, 0.68, f"W$^{{{lbl}}}$", color, |
| sub=f"({hidden},{hidden})", fs=11, fw="bold") |
| ax.text(XEQ, row_y, "=", ha="center", va="center", |
| color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4) |
| _rbox(ax, XPRI, row_y, 0.95, 0.68, f"{lbl}'", color, |
| sub=f"(seq,{hidden})", fs=11, fw="bold") |
| _arr(ax, XPRI+0.48, row_y, HHSTART-0.20, row_y, color=color, lw=1.5) |
| for hi in range(show_h): |
| hx = HHSTART + hi * HGAP |
| _rbox(ax, hx, row_y, 0.65, 0.62, f"{lbl}{hi+1}", color, fs=9) |
| if n_heads > show_h: |
| ax.text(HHSTART + show_h * HGAP - 0.05, row_y, "β¦", |
| ha="center", va="center", color=color, fontsize=12, zorder=4) |
|
|
| |
| dk_x = HHSTART + (show_h - 1) * HGAP / 2 |
| ax.text(dk_x, rows[0][0] + 0.68, f"d_k = {head_dim}", |
| ha="center", va="bottom", color="#f778ba", fontsize=8, zorder=4, |
| bbox=dict(boxstyle="round,pad=0.2", fc="#0d1117", ec="#f778ba60", lw=0.8)) |
|
|
| |
| hend_x = HHSTART + (show_h - 1) * HGAP |
| ax.annotate("", xy=(hend_x+0.33, rows[2][0]-0.68), |
| xytext=(HHSTART-0.33, rows[2][0]-0.68), |
| arrowprops=dict(arrowstyle="<->", color="#8b949e", lw=0.8)) |
| ax.text(dk_x, rows[2][0]-0.84, f"d_model = {hidden}", |
| ha="center", va="top", color="#8b949e", fontsize=7.5) |
|
|
| |
| for hi in range(show_h): |
| hx = HHSTART + hi * HGAP |
| ax.annotate("", xy=(hx, rows[1][0]+0.31), xytext=(hx, rows[0][0]-0.31), |
| arrowprops=dict(arrowstyle="-|>", color="#ffffff22", lw=0.8), zorder=2) |
| ax.annotate("", xy=(hx, rows[2][0]+0.31), xytext=(hx, rows[1][0]-0.31), |
| arrowprops=dict(arrowstyle="-|>", color="#ffffff22", lw=0.8), zorder=2) |
|
|
| |
| fml_x = HHSTART + (show_h - 1) * HGAP / 2 |
| fml_y = (rows[0][0] + rows[2][0]) / 2 |
| ax.text(fml_x, fml_y, f"softmax(QKα΅ / β{head_dim})", |
| ha="center", va="center", color="#f0e68c", fontsize=9, style="italic", zorder=4, |
| bbox=dict(boxstyle="round,pad=0.3", fc="#14101e", ec="#f0e68c50", lw=0.8)) |
|
|
| |
| out_y = 1.3 |
| H_x = HHSTART + (show_h - 1) * HGAP / 2 |
| for hi in range(show_h): |
| hx = HHSTART + hi * HGAP |
| ax.annotate("", xy=(H_x + (hi - show_h/2 + 0.5) * 0.22, out_y+0.40), |
| xytext=(hx, rows[2][0]-0.31), |
| arrowprops=dict(arrowstyle="-|>", color=CV, lw=0.9), zorder=2) |
|
|
| WO_x = H_x + 1.75 |
| MHA_x = WO_x + 1.85 |
| _rbox(ax, H_x, out_y, 1.1, 0.65, "H", CH, sub=f"(seq,{hidden})", fs=12, fw="bold") |
| ax.text(WO_x-0.68, out_y, "Γ", ha="center", va="center", |
| color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4) |
| _rbox(ax, WO_x, out_y, 1.0, 0.65, "W$^O$", CO, sub=f"({hidden},{hidden})", fs=11, fw="bold") |
| ax.text(WO_x+0.73, out_y, "=", ha="center", va="center", |
| color="#c9d1d9", fontsize=14, fontweight="bold", zorder=4) |
| _rbox(ax, MHA_x, out_y, 1.35, 0.65, "MH-A", CA, sub=f"(seq,{hidden})", fs=11, fw="bold") |
| _arr(ax, H_x+0.55, out_y, WO_x-0.50, out_y, color=CH, lw=1.5) |
| _arr(ax, WO_x+0.50, out_y, MHA_x-0.68, out_y, color=CO, lw=1.5) |
|
|
| |
| ax.text(7.8, 0.60, "Attention(Q,K,V) = softmax(QKα΅/βd_k) Β· V", |
| ha="center", va="center", color="#c9d1d9", fontsize=9, style="italic", zorder=4) |
| ax.text(7.8, 0.22, f"MultiHead = Concat(headβ β¦ head_{n_heads}) Β· W\u1D3C", |
| ha="center", va="center", color="#c9d1d9", fontsize=9, style="italic", zorder=4) |
|
|
| |
| ax.text(8.0, 10.75, arch, ha="center", va="center", |
| color="#58a6ff", fontsize=13, fontweight="bold", zorder=4) |
| ax.text(8.0, 10.35, f"{_fmt(total_p)} params", |
| ha="center", va="center", color="#8b949e", fontsize=8.5, zorder=4) |
|
|
| |
| seq_d = 8 |
| np.random.seed(7) |
| attn = np.random.dirichlet([1.5] * seq_d, seq_d) |
| attn = 0.45 * np.eye(seq_d) + 0.55 * attn |
| attn /= attn.sum(axis=1, keepdims=True) |
| ax_h.imshow(attn, cmap="YlOrRd", aspect="auto", interpolation="bilinear", vmin=0) |
| ax_h.set_title("Attention Visualization", color="#c9d1d9", fontsize=9, pad=4) |
|
|
| |
| ax_c.text(0.5, 0.97, "β Model Config", ha="center", va="top", |
| color="#58a6ff", fontsize=10, fontweight="bold") |
| ax_c.axhline(y=0.93, color="#30363d", lw=0.8) |
| info = [ |
| ("Architecture", arch.split("(")[0].strip()), |
| ("Parameters", _fmt(total_p)), |
| ("Layers", str(n_layers)), |
| ("Hidden size", str(hidden)), |
| ("Attn heads", f"{n_heads}" + (" (GQA)" if gqa else "")), |
| ("Head dim", str(head_dim)), |
| ("FF dim", _fmt(ff_dim) if ff_dim else "β"), |
| ("Vocab size", f"{vocab:,}" if vocab else "β"), |
| ("Max position", str(max_pos) if max_pos else "β"), |
| ] |
| if gqa and n_kv: |
| info.insert(5, ("KV heads", str(n_kv))) |
| for i, (k, v) in enumerate(info[:10]): |
| y = 0.89 - i * 0.087 |
| ax_c.text(0.03, y, k + ":", ha="left", va="center", |
| color="#8b949e", fontsize=8.5) |
| ax_c.text(0.97, y, v, ha="right", va="center", |
| color="#e6edf3", fontsize=8.5, fontweight="bold") |
|
|
| plt.tight_layout(pad=0.2) |
| return fig |
|
|
|
|
| def _draw_cnn_overview(state_dict, arch, config, total_p): |
| fig = plt.figure(figsize=(18, 8)) |
| fig.patch.set_facecolor("#0d1117") |
| ax = fig.add_axes((0.02, 0.14, 0.96, 0.80)) |
| ax.set_facecolor("#0d1117"); ax.axis("off") |
| ax.set_xlim(0, 18); ax.set_ylim(0, 8) |
|
|
| |
| top_mods = {} |
| for k, t in state_dict.items(): |
| top = k.split(".")[0] |
| top_mods[top] = top_mods.get(top, 0) + t.numel() |
| blocks = list(top_mods.items())[:12] |
| if not blocks: |
| blocks = [("conv1", 0), ("layer1", 0), ("fc", 0)] |
|
|
| colors_map = {"conv": "#2980b9", "layer": "#1a5276", "block": "#1a5276", |
| "stage": "#1a5276", "res": "#1a5276", "pool": "#27ae60", |
| "fc": "#8e44ad", "head": "#8e44ad", "classifier": "#8e44ad", |
| "norm": "#d68910", "bn": "#d68910"} |
|
|
| def _block_color(name): |
| nl = name.lower() |
| for key, col in colors_map.items(): |
| if key in nl: |
| return col |
| return "#2d3748" |
|
|
| ax.text(9, 7.6, arch, ha="center", va="center", |
| color="#58a6ff", fontsize=14, fontweight="bold") |
| ax.text(9, 7.1, f"{_fmt(total_p)} parameters β’ CNN / Vision Architecture", |
| ha="center", va="center", color="#8b949e", fontsize=9) |
|
|
| n = len(blocks) |
| spacing = 15.0 / max(n, 1) |
| _rbox(ax, 0.8, 4.5, 0.85, 0.70, "Input\nImage", "#1a3a2a", fs=8) |
| prev_x = 1.23 |
| for i, (name, params) in enumerate(blocks): |
| bx = 1.5 + i * spacing + spacing * 0.5 |
| color = _block_color(name) |
| _rbox(ax, bx, 4.5, max(spacing * 0.82, 0.9), 0.72, name, color, |
| sub=_fmt(params) if params else "", fs=9, fw="bold") |
| _arr(ax, prev_x, 4.5, bx - spacing * 0.41 - 0.02, 4.5, color="#58a6ff", lw=1.5) |
| ax.text(bx, 3.9, _block_color(name) and |
| next((v for k, v in [ |
| ("conv","Conv"), ("layer","Block"), ("pool","Pool"), |
| ("fc","Linear"), ("head","Head"), ("norm","Norm"), |
| ("res","ResBlock"), ("stage","Stage")] if k in name.lower()), "Module"), |
| ha="center", va="center", color="#8b949e", fontsize=7) |
| prev_x = bx + spacing * 0.41 |
| _rbox(ax, prev_x + 0.6, 4.5, 1.0, 0.70, "Output", "#27ae60", fs=9) |
| _arr(ax, prev_x, 4.5, prev_x + 0.1, 4.5, color="#58a6ff", lw=1.5) |
|
|
| |
| legend = [("Conv/Feature", "#2980b9"), ("Layer/Block", "#1a5276"), |
| ("Pool/Down", "#27ae60"), ("FC/Head", "#8e44ad"), ("Norm", "#d68910")] |
| for i, (lbl, lc) in enumerate(legend): |
| _rbox(ax, 2.0 + i * 3.3, 2.2, 2.4, 0.50, lbl, lc, fs=8) |
| ax.text(9, 1.5, "Layer type legend", ha="center", va="center", |
| color="#8b949e", fontsize=8) |
| plt.tight_layout(pad=0.2) |
| return fig |
|
|
|
|
| def _draw_generic_flow(state_dict, arch, config, total_p): |
| fig = plt.figure(figsize=(18, 9)) |
| fig.patch.set_facecolor("#0d1117") |
| ax = fig.add_axes((0.02, 0.08, 0.96, 0.85)) |
| ax.set_facecolor("#0d1117"); ax.axis("off") |
| ax.set_xlim(0, 18); ax.set_ylim(0, 9) |
|
|
| top_mods = {} |
| for k, t in state_dict.items(): |
| top = k.split(".")[0] |
| top_mods[top] = top_mods.get(top, 0) + t.numel() |
| sorted_mods = sorted(top_mods.items(), key=lambda x: -x[1])[:10] |
|
|
| ax.text(9, 8.5, arch, ha="center", va="center", |
| color="#58a6ff", fontsize=14, fontweight="bold") |
| ax.text(9, 8.0, f"{_fmt(total_p)} parameters", |
| ha="center", va="center", color="#8b949e", fontsize=9) |
|
|
| pal = ["#2980b9","#e67e22","#27ae60","#8e44ad","#c0392b", |
| "#1abc9c","#d68910","#2ecc71","#3498db","#9b59b6"] |
| total_n = max(sum(p for _, p in sorted_mods), 1) |
| x = 0.5 |
| for i, (name, params) in enumerate(sorted_mods): |
| w = max(0.9, (params / total_n) * 15) |
| cx = x + w / 2 |
| _rbox(ax, cx, 5.5, w - 0.12, 1.1, name, pal[i % len(pal)], |
| sub=_fmt(params), fs=9, fw="bold") |
| if i > 0: |
| _arr(ax, x, 5.5, x + 0.06, 5.5, color="#58a6ff", lw=1.5) |
| x += w |
|
|
| top_t = sorted(state_dict.items(), key=lambda x: -x[1].numel())[:5] |
| ax.text(9, 3.8, "Largest tensors:", ha="center", va="center", |
| color="#58a6ff", fontsize=9, fontweight="bold") |
| for i, (k, t) in enumerate(top_t): |
| sh = "Γ".join(map(str, t.shape)) |
| short = k if len(k) < 48 else "β¦" + k[-46:] |
| ax.text(9, 3.3 - i * 0.48, f"{short} [{sh}] {_fmt(t.numel())}", |
| ha="center", va="center", color="#c9d1d9", fontsize=8) |
|
|
| plt.tight_layout(pad=0.2) |
| return fig |
|
|
|
|
| def plot_architecture_overview(state_dict, detections, config): |
| arch = detections[0]["family"] if detections else "Unknown Model" |
| category = detections[0]["category"] if detections else "" |
| hidden = config.get("likely_hidden_size", 768) |
| n_layers = config.get("num_layers", 12) |
| n_heads = config.get("num_attention_heads", 8) |
| head_dim = config.get("head_dim", hidden // max(n_heads, 1)) |
| ff_dim = config.get("intermediate_size", hidden * 4) |
| vocab = config.get("vocab_size", 0) |
| max_pos = config.get("max_position_embeddings", 0) |
| gqa = config.get("grouped_query_attention", False) |
| n_kv = config.get("num_kv_heads", n_heads) |
| total_p = sum(t.numel() for t in state_dict.values()) |
|
|
| al = (arch + " " + category).lower() |
| is_cnn = any(x in al for x in ["resnet","convnext","efficientnet","mobilenet","vgg", |
| "densenet","yolo","convnet","regnet","detr","convolution"]) |
| is_diff = any(x in al for x in ["stable diffusion","unet","sdxl","flux","dit", |
| "pixart","vqgan","stylegan","esrgan"]) |
|
|
| if is_cnn: |
| return _draw_cnn_overview(state_dict, arch, config, total_p) |
| elif is_diff: |
| return _draw_generic_flow(state_dict, arch, config, total_p) |
| else: |
| return _draw_transformer_overview(arch, n_heads, n_layers, hidden, ff_dim, |
| vocab, head_dim, max_pos, gqa, n_kv, total_p) |
|
|
|
|
| |
|
|
| def build_architecture_html(state_dict, detections, config): |
| """ |
| Generate an interactive SVG/HTML architecture graph. |
| Connections are inferred purely from weight dimension matching β |
| no hardcoded architecture checks. Works for any model. |
| """ |
| import html as _hl |
| arch = detections[0]["family"] if detections else "Unknown" |
| total_p = sum(t.numel() for t in state_dict.values()) |
|
|
| |
| TC = {"Embedding":"#1b4332","Q Proj":"#7c3400","K Proj":"#7c1010", |
| "V Proj":"#0d3b6e","Attn Out":"#0a2a5e","FFN Up":"#2d0f57", |
| "FFN Down":"#57103a","Norm":"#2a3441","Conv2d":"#004d5e", |
| "Conv1d":"#003d4a","Linear":"#1b3047","Output":"#14451f","Other":"#1c2328"} |
| BC = {"Embedding":"#40c070","Q Proj":"#fb8c00","K Proj":"#ef5350", |
| "V Proj":"#42a5f5","Attn Out":"#1976d2","FFN Up":"#ab47bc", |
| "FFN Down":"#ec407a","Norm":"#78909c","Conv2d":"#00bcd4", |
| "Conv1d":"#009688","Linear":"#546e7a","Output":"#43a047","Other":"#546e7a"} |
|
|
| def _infer(path, tensor): |
| lk, s, nd = path.lower(), tensor.shape, tensor.ndim |
| if nd == 4: return "Conv2d", int(s[1]), int(s[0]), f"k{s[2]}Γ{s[3]}" |
| if nd == 3: return "Conv1d", int(s[1]), int(s[0]), f"k{s[2]}" |
| if nd == 2: |
| i, o = int(s[1]), int(s[0]) |
| end = path.rsplit(".", 1)[-1].lower() |
| if any(x in lk for x in ["embed","token","wte","wpe","word_embed","tok_embed"]): return "Embedding", i, o, "" |
| if "q_proj" in lk or ("query" in lk and end == "weight"): return "Q Proj", i, o, "" |
| if "k_proj" in lk or ("key" in lk and end == "weight"): return "K Proj", i, o, "" |
| if "v_proj" in lk or ("value" in lk and end == "weight"): return "V Proj", i, o, "" |
| if ("o_proj" in lk or "out_proj" in lk) and any(x in lk for x in ["attn","self","mha"]): return "Attn Out", i, o, "" |
| if any(x in lk for x in ["gate_proj","up_proj","fc1","wi_0","c_fc","intermediate"]): return "FFN Up", i, o, "" |
| if any(x in lk for x in ["down_proj","fc2","c_proj"]) and "attn" not in lk: return "FFN Down", i, o, "" |
| if any(x in lk for x in ["lm_head","classifier","cls","score","head.weight"]): return "Output", i, o, "" |
| if any(x in lk for x in ["norm","ln_","rms","layer_norm","ln_f"]): return "Norm", i, o, "" |
| return "Linear", i, o, "" |
| if nd == 1 and any(x in path.lower() for x in ["norm","bn","ln","rms"]): |
| return "Norm", int(s[0]), int(s[0]), "" |
| return None, None, None, None |
|
|
| |
| mods, seen = [], set() |
| for key, tensor in state_dict.items(): |
| parts = key.rsplit(".", 1) |
| path = parts[0] if len(parts) == 2 else key |
| param = parts[-1] if len(parts) == 2 else key |
| if param not in ("weight", "W", "kernel") or path in seen: |
| continue |
| seen.add(path) |
| ltype, ind, outd, det = _infer(path, tensor) |
| if ltype is None: |
| continue |
| pparts = path.split(".") |
| grp = ".".join(pparts[:-1]) if len(pparts) > 1 else "__root__" |
| mods.append({"path": path, "name": pparts[-1], "group": grp, |
| "type": ltype, "in_dim": ind, "out_dim": outd, |
| "params": int(tensor.numel()), "det": det}) |
|
|
| if not mods: |
| return "<div style='color:#8b949e;padding:20px;font-family:monospace'>No weight modules found.</div>" |
|
|
| |
| groups = OrderedDict() |
| for m in mods: |
| groups.setdefault(m["group"], []).append(m) |
|
|
| |
| MAX_G = 22 |
| if len(groups) > MAX_G: |
| gl = list(groups.items()) |
| keep = list(range(min(3, len(gl)))) + list(range(max(len(gl)-3, 3), len(gl))) |
| step = max(1, (len(gl) - 6) // (MAX_G - 6)) |
| for i in range(3, len(gl) - 3, step): |
| keep.append(i) |
| keep = sorted(set(keep))[:MAX_G] |
| groups = OrderedDict(gl[i] for i in keep) |
|
|
| |
| flat = [] |
| for gmods in groups.values(): |
| for m in gmods: |
| m["id"] = len(flat) |
| flat.append(m) |
|
|
| n_total = len(mods) |
|
|
| |
| edges, seen_e = [], set() |
| for i in range(len(flat)): |
| a = flat[i] |
| found_direct = False |
| for j in range(i + 1, min(i + 30, len(flat))): |
| b = flat[j] |
| if (i, j) in seen_e or a["out_dim"] <= 0: |
| continue |
| if a["out_dim"] == b["in_dim"]: |
| is_direct = (j == i + 1 and not found_direct) |
| etype = "direct" if is_direct else "skip" |
| edges.append({"from": i, "to": j, "type": etype, "dim": a["out_dim"]}) |
| seen_e.add((i, j)) |
| if is_direct: |
| found_direct = True |
| if j - i > 10 and etype == "skip": |
| break |
|
|
| |
| NW, NH = 192, 52 |
| HGAP, GPX, GPY, GHH, GVGAP = 10, 12, 10, 24, 14 |
| cy = 8 |
| grects = {} |
| for g, gmods in groups.items(): |
| nm = len(gmods) |
| gw = nm * (NW + HGAP) - HGAP + GPX * 2 |
| gh = GHH + GPY + NH + GPY |
| for j, m in enumerate(gmods): |
| m["ax"] = GPX + j * (NW + HGAP) + NW // 2 |
| m["ay"] = cy + GHH + GPY + NH // 2 |
| gl = g.rsplit(".", 1)[-1] if "." in g else g |
| grects[g] = {"y": cy, "w": gw, "h": gh, |
| "label": (gl if len(gl) <= 38 else "β¦" + gl[-36:]), |
| "full": g} |
| cy += gh + GVGAP |
|
|
| svg_w = max((r["w"] for r in grects.values()), default=600) + 60 |
| svg_h = cy + 10 |
|
|
| |
| def e(s): return _hl.escape(str(s)) |
|
|
| parts = ['''<defs> |
| <marker id="ma" markerWidth="7" markerHeight="7" refX="6" refY="3.5" orient="auto"> |
| <path d="M0,0 L7,3.5 L0,7 z" fill="#58a6ff" opacity="0.9"/> |
| </marker> |
| <marker id="ms" markerWidth="7" markerHeight="7" refX="6" refY="3.5" orient="auto"> |
| <path d="M0,0 L7,3.5 L0,7 z" fill="#bc8cff" opacity="0.8"/> |
| </marker> |
| </defs>'''] |
|
|
| |
| for g, r in grects.items(): |
| parts.append( |
| f'<rect x="2" y="{r["y"]+2}" width="{r["w"]}" height="{r["h"]}" rx="10" fill="#000" opacity="0.35"/>' |
| f'<rect x="0" y="{r["y"]}" width="{r["w"]}" height="{r["h"]}" rx="10" fill="#161b22" stroke="#21262d" stroke-width="1.2"/>' |
| f'<text x="10" y="{r["y"]+16}" font-family="monospace" font-size="11" fill="#8b949e">{e(r["label"])}</text>') |
|
|
| |
| for ed in edges: |
| a, b = flat[ed["from"]], flat[ed["to"]] |
| col = "#58a6ff" if ed["type"] == "direct" else "#bc8cff" |
| mk = "ma" if ed["type"] == "direct" else "ms" |
| op = "0.75" if ed["type"] == "direct" else "0.45" |
| dash = "" if ed["type"] == "direct" else 'stroke-dasharray="5,3"' |
| tip = e(f"dim={ed['dim']}") |
| same_g = a["group"] == b["group"] |
| if same_g: |
| x1, y1 = a["ax"] + NW // 2, a["ay"] |
| x2, y2 = b["ax"] - NW // 2, b["ay"] |
| mx = (x1 + x2) // 2 |
| d = f"M{x1},{y1} C{mx},{y1} {mx},{y2} {x2},{y2}" |
| else: |
| x1, y1 = a["ax"], a["ay"] + NH // 2 |
| x2, y2 = b["ax"], b["ay"] - NH // 2 |
| dy = y2 - y1 |
| d = f"M{x1},{y1} C{x1},{int(y1+dy*0.4)} {x2},{int(y1+dy*0.6)} {x2},{y2}" |
| parts.append(f'<path d="{d}" fill="none" stroke="{col}" stroke-width="1.6" stroke-opacity="{op}" {dash} marker-end="url(#{mk})"><title>{tip}</title></path>') |
|
|
| |
| for m in flat: |
| x0, y0 = m["ax"] - NW // 2, m["ay"] - NH // 2 |
| fc = TC.get(m["type"], TC["Other"]) |
| bc = BC.get(m["type"], BC["Other"]) |
| nm_label = e(m["name"]) |
| type_lbl = e(m["type"]) |
| dim_str = f"{m['in_dim']}β{m['out_dim']}" if m["in_dim"] != m["out_dim"] else str(m["in_dim"]) |
| tip = e(f"{m['path']}\nType: {m['type']}\nIn: {m['in_dim']} Out: {m['out_dim']}\nParams: {m['params']:,}") |
| parts.append( |
| f'<g transform="translate({x0},{y0})" style="cursor:default">' |
| f'<rect width="{NW}" height="{NH}" rx="7" fill="{fc}" stroke="{bc}" stroke-width="1.5"/>' |
| f'<rect width="{NW}" height="17" rx="7" fill="{bc}" fill-opacity="0.30"/>' |
| f'<rect y="10" width="{NW}" height="7" fill="{bc}" fill-opacity="0.30"/>' |
| f'<text x="7" y="13" font-family="monospace" font-size="9" fill="#ffffffbb" font-weight="bold">{type_lbl}</text>' |
| f'<text x="{NW//2}" y="33" font-family="monospace" font-size="12" fill="#e6edf3" font-weight="bold" text-anchor="middle">{nm_label}</text>' |
| f'<text x="{NW//2}" y="47" font-family="monospace" font-size="10" fill="{bc}" text-anchor="middle">{dim_str}</text>' |
| f'<title>{tip}</title></g>') |
|
|
| svg_body = "\n".join(parts) |
|
|
| |
| seen_types = list(dict.fromkeys(m["type"] for m in flat)) |
| legend = "".join( |
| f'<span style="display:inline-flex;align-items:center;gap:4px;font-size:10px;' |
| f'color:#8b949e;font-family:monospace;margin-right:6px;">' |
| f'<span style="width:11px;height:11px;border-radius:3px;background:{BC.get(t,"#546e7a")};' |
| f'display:inline-block;border:1px solid {BC.get(t,"#546e7a")}40"></span>{e(t)}</span>' |
| for t in seen_types) |
|
|
| sampled_note = f" (showing {len(flat)} of {n_total})" if n_total > len(flat) else "" |
|
|
| return f'''<div style="background:#0d1117;border-radius:12px;padding:16px;border:1px solid #30363d;"> |
| <div style="display:flex;justify-content:space-between;align-items:flex-start;margin-bottom:10px;"> |
| <div> |
| <div style="color:#58a6ff;font-size:14px;font-weight:bold;font-family:monospace;">⬑ {e(arch)} β Architecture Graph</div> |
| <div style="color:#8b949e;font-size:11px;margin-top:4px;">{_fmt(total_p)} params Β· {len(flat)} modules Β· {len(edges)} connections{sampled_note}</div> |
| </div> |
| <div style="color:#8b949e;font-size:10px;text-align:right;line-height:1.6;"> |
| Edges drawn where out_dim = in_dim<br> |
| βββ direct - - - skip/residual |
| </div> |
| </div> |
| <div style="overflow:auto;border:1px solid #21262d;border-radius:8px;max-height:660px;background:#0a0e14;"> |
| <svg width="{svg_w}" height="{svg_h}" style="display:block;">{svg_body}</svg> |
| </div> |
| <div style="margin-top:10px;display:flex;gap:4px;flex-wrap:wrap;">{legend}</div> |
| </div>''' |
|
|
|
|
| |
|
|
| def analyze_model(file): |
| empty = ("", None, "", None, None, None, None, None, "", "", "") |
| if file is None: |
| return ("", None, "Please upload a model file.") + empty[3:] |
|
|
| try: |
| file_path = file.name if hasattr(file, "name") else file |
| state_dict, metadata = load_state_dict_safe(file_path) |
|
|
| if not state_dict: |
| return ("", None, "No tensors found.") + empty[3:] |
|
|
| detections = detect_architectures(state_dict) |
| config = infer_model_config(state_dict, detections) |
|
|
| summary = build_summary(state_dict, metadata, detections, config) |
| tree_text = build_layer_tree(state_dict) |
| types_text = infer_all_layers(state_dict) |
| stats, stats_text = compute_weight_stats(state_dict) |
|
|
| arch_html = build_architecture_html(state_dict, detections, config); gc.collect() |
| fig_overview = plot_architecture_overview(state_dict, detections, config); gc.collect() |
| fig_dist = plot_distributions(state_dict); gc.collect() |
| fig_sizes = plot_module_sizes(state_dict); gc.collect() |
| fig_heat = plot_heatmap(stats); gc.collect() |
| fig_mem = plot_memory(state_dict); gc.collect() |
| fig_depth = plot_depth(state_dict); gc.collect() |
|
|
| return (arch_html, fig_overview, summary, fig_dist, fig_sizes, fig_heat, fig_mem, |
| fig_depth, tree_text, types_text, stats_text) |
|
|
| except Exception as e: |
| err = f"Error:\n{str(e)}\n\n{traceback.format_exc()}" |
| return ("", None, err) + empty[3:] |
|
|
|
|
| def build_app(): |
| theme = gr.themes.Base( |
| primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.purple, |
| neutral_hue=gr.themes.colors.gray, |
| font=gr.themes.GoogleFont("IBM Plex Sans"), |
| font_mono=gr.themes.GoogleFont("IBM Plex Mono"), |
| ).set( |
| body_background_fill="#0d1117", body_background_fill_dark="#0d1117", |
| block_background_fill="#161b22", block_background_fill_dark="#161b22", |
| block_border_color="#30363d", block_label_text_color="#c9d1d9", |
| input_background_fill="#0d1117", input_background_fill_dark="#0d1117", |
| button_primary_background_fill="linear-gradient(135deg, #58a6ff 0%, #bc8cff 100%)", |
| button_primary_text_color="#ffffff", |
| ) |
|
|
| with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="Neural Model Analyzer") as app: |
| gr.HTML(""" |
| <div class="main-header"> |
| <h1>\U0001f52c Neural Model Analyzer</h1> |
| <p>Deep introspection for <b>every</b> model on HuggingFace — 150+ architectures · Free CPU</p> |
| </div>""") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| file_input = gr.File(label="Upload Model", file_types=[".pth",".pt",".bin",".safetensors",".onnx"], type="filepath") |
| analyze_btn = gr.Button("Analyze Model", variant="primary", size="lg") |
| gr.HTML("""<div class="info-panel"> |
| <b>All HF model families supported:</b> BERT, RoBERTa, DeBERTa, ALBERT, ELECTRA, |
| GPT-2, GPT-Neo, LLaMA, Mistral, Mixtral, Phi, Qwen, Falcon, BLOOM, OPT, Mamba, RWKV, |
| T5, BART, Pegasus, MarianMT, NLLB, ViT, DeiT, BEiT, Swin, ConvNeXt, ResNet, |
| EfficientNet, DINOv2, DETR, YOLO, SAM, SegFormer, Mask2Former, CLIP, BLIP, LLaVA, |
| Whisper, Wav2Vec2, HuBERT, MusicGen, Bark, Stable Diffusion, SDXL, FLUX, DiT, |
| ControlNet, LoRA, ESM, PatchTST, GNN, <b>and 100+ more</b> |
| </div>""") |
|
|
| with gr.Tabs(): |
| with gr.Tab("πΈ Architecture Graph"): |
| arch_html_out = gr.HTML(label="Architecture Graph") |
| with gr.Tab("π Model Overview"): |
| overview_plot = gr.Plot(label="Architecture Overview") |
| with gr.Tab("Summary & Architecture"): |
| summary_out = gr.Textbox(label="Full Analysis", lines=40, max_lines=5000, interactive=False) |
| with gr.Tab("Layer Tree"): |
| tree_out = gr.Textbox(label="Hierarchy", lines=40, max_lines=5000, interactive=False) |
| with gr.Tab("Types & Connections"): |
| types_out = gr.Textbox(label="Types + Connections", lines=40, max_lines=5000, interactive=False) |
| with gr.Tab("Weight Stats"): |
| stats_out = gr.Textbox(label="Statistics", lines=40, max_lines=5000, interactive=False) |
| with gr.Tab("Distributions"): |
| dist_plot = gr.Plot(label="Weight Distributions") |
| with gr.Tab("Module Sizes"): |
| sizes_plot = gr.Plot(label="Module Parameters") |
| with gr.Tab("Heatmap"): |
| heat_plot = gr.Plot(label="Stats Heatmap") |
| with gr.Tab("Memory"): |
| mem_plot = gr.Plot(label="Memory") |
| with gr.Tab("Depth Profile"): |
| depth_plot = gr.Plot(label="Depth") |
|
|
| analyze_btn.click(fn=analyze_model, inputs=[file_input], |
| outputs=[arch_html_out, overview_plot, summary_out, dist_plot, sizes_plot, |
| heat_plot, mem_plot, depth_plot, tree_out, types_out, stats_out]) |
|
|
| return app |
|
|
|
|
| if __name__ == "__main__": |
| app = build_app() |
| app.launch() |
|
|