Spaces:
Sleeping
Sleeping
| // Hierarchical graph layout using dagre. We feed it the IR and produce | |
| // React Flow-compatible Node[] and Edge[] arrays. | |
| // | |
| // Subflows: modules with descendants are rendered as React Flow *group* | |
| // containers that visually wrap their children. The hierarchy is derived | |
| // from the per-node `attrs.scope` string (slash-separated): any scope that | |
| // has at least one descendant in the IR becomes a synthetic container. | |
| // Wrapper IRNodes (`opType: "Module"` with no weights) are absorbed into | |
| // their container so we don't render both `model` and "group:model" side by | |
| // side. | |
| import dagre from "@dagrejs/dagre"; | |
| import { Position, type Node, type Edge } from "reactflow"; | |
| import type { IRGraph, IRNode } from "../types"; | |
| // Node footprints, calibrated to match the visual density of hfviewer.com: | |
| // short pills with a 10-11px label, no decoration. Bigger numbers blow up | |
| // the canvas; smaller ones make the text unreadable. | |
| const NODE_WIDTH = 110; | |
| const NODE_HEIGHT = 22; | |
| const CLUSTER_NODE_WIDTH = 130; | |
| const CLUSTER_NODE_HEIGHT = 26; | |
| /** Synthetic ops (Add, Mul, MatMul, Softmax, RoPE, ...) carry no params, | |
| * show a single short label, and are best kept extra compact so a fully | |
| * expanded transformer block doesn't spread over dozens of bulky rows. */ | |
| const SYNTHETIC_NODE_WIDTH = 70; | |
| const SYNTHETIC_NODE_HEIGHT = 20; | |
| /** Identity/boundary ops are visually minimal — they're "data flow stubs". */ | |
| const IDENTITY_NODE_WIDTH = 44; | |
| const IDENTITY_NODE_HEIGHT = 16; | |
| /** Extra space dagre reserves around a compound parent so its label band | |
| * has room and the children don't kiss the border. */ | |
| const GROUP_PADDING_X = 8; | |
| const GROUP_PADDING_TOP = 12; | |
| const GROUP_PADDING_BOTTOM = 6; | |
| const GROUP_ID_PREFIX = "group:"; | |
| function colorForOp(opType: string): string { | |
| // Semantic super-node types (from grouping.ts) come first. | |
| if (opType === "TransformerLayer") return "#5d4eff"; | |
| if (opType === "Attention") return "#8a5cff"; | |
| if (opType === "MLP") return "#3c8ce6"; | |
| if (opType === "Embedding") return "#f0925b"; | |
| if (opType === "Norm") return "#a06bd6"; | |
| if (opType === "Linear") return "#4f78e0"; | |
| if (opType === "RoPE") return "#c25fe6"; | |
| if (opType === "Router") return "#ffae5b"; | |
| if (opType === "Expert") return "#5e9efb"; | |
| if (opType === "Module") return "#506280"; | |
| // ONNX op-level fallback. | |
| if (/MatMul|Gemm|Conv/.test(opType)) return "#5b8def"; | |
| if (/Attention|Softmax/.test(opType)) return "#7c5cff"; | |
| if (/Norm|LayerNormalization|RMSNormalization/.test(opType)) return "#9b6bd1"; | |
| if (/Add|Mul|Sub|Div|Pow/.test(opType)) return "#3aa6c9"; | |
| if (/Gelu|Relu|Silu|Swish|Tanh|Sigmoid/.test(opType)) return "#22c1a0"; | |
| if (/Gather|Slice|Reshape|Transpose|Squeeze|Unsqueeze|Concat/.test(opType)) | |
| return "#7a8aa3"; | |
| if (/Cast|Identity|Constant|Quantize|DequantizeLinear/.test(opType)) | |
| return "#4a5263"; | |
| if (opType === "Input") return "#f6a96b"; | |
| if (opType === "Output") return "#e25fa3"; | |
| return "#6477a8"; | |
| } | |
| function isClusterNode(n: IRNode): boolean { | |
| return n.id.startsWith("cluster:"); | |
| } | |
| // ─── Human-friendly label formatting ───────────────────────────────────── | |
| // | |
| // Raw module names like `input_layernorm`, `q_proj`, `embed_tokens`, | |
| // `post_attention_layernorm`, ... read fine in code but are noisy in a | |
| // graph node. The helpers below map the most common transformer module | |
| // suffixes to short, semantically meaningful labels. | |
| function humanizeNorm(tail: string): string { | |
| if (tail === "input_layernorm") return "pre-attn"; | |
| if (tail === "post_attention_layernorm") return "post-attn"; | |
| if (tail === "pre_feedforward_layernorm") return "pre-ffn"; | |
| if (tail === "post_feedforward_layernorm") return "post-ffn"; | |
| if (tail === "norm" || tail === "final_layernorm" || tail === "ln_f") | |
| return "final"; | |
| if (tail.endsWith("_layernorm")) { | |
| return tail.replace("_layernorm", "").replace(/_/g, "-"); | |
| } | |
| if (tail.endsWith("_norm")) { | |
| return tail.replace("_norm", "").replace(/_/g, "-"); | |
| } | |
| if (tail === "ln1") return "pre-attn"; | |
| if (tail === "ln2") return "post-attn"; | |
| return ""; | |
| } | |
| function humanizeLinear(tail: string): string { | |
| if (tail === "q_proj") return "Q proj"; | |
| if (tail === "k_proj") return "K proj"; | |
| if (tail === "v_proj") return "V proj"; | |
| if (tail === "o_proj") return "O proj"; | |
| if (tail === "qkv_proj" || tail === "c_attn") return "QKV"; | |
| if (tail === "gate_proj" || tail === "w1") return "gate"; | |
| if (tail === "up_proj" || tail === "w3") return "up"; | |
| if (tail === "down_proj" || tail === "w2") return "down"; | |
| if (tail === "c_proj" || tail === "c_fc") return tail.replace(/_/g, " "); | |
| if (tail.endsWith("_proj")) { | |
| return tail.replace("_proj", "").replace(/_/g, " "); | |
| } | |
| return ""; | |
| } | |
| function humanizeEmbedding(tail: string): string { | |
| if (tail === "embed_tokens" || tail === "wte" || tail === "word_embeddings") | |
| return "tokens"; | |
| if (tail === "wpe" || tail === "embed_positions" || tail === "position_embeddings") | |
| return "positions"; | |
| if (tail === "token_type_embeddings") return "segments"; | |
| return tail.replace(/_/g, " "); | |
| } | |
| function humanizeGeneric(tail: string): string { | |
| if (!tail) return ""; | |
| if (/^\d+$/.test(tail)) return `#${tail}`; | |
| return tail.replace(/_/g, " "); | |
| } | |
| /** | |
| * Build the Python class name that most likely backs a given GROUP scope, | |
| * given the HF `model_type` from config.json. Mirrors the transformers | |
| * naming convention so a `model/layers/0` group under modelType=llama | |
| * renders as `LlamaDecoderLayer`, matching the hfviewer.com convention. | |
| */ | |
| function classNameForScope(scope: string, modelType?: string): string | null { | |
| if (!modelType) return null; | |
| const cap = modelType.charAt(0).toUpperCase() + modelType.slice(1).toLowerCase(); | |
| const segs = scope.split("/").filter(Boolean); | |
| const tail = segs[segs.length - 1] ?? ""; | |
| const parent = segs[segs.length - 2] ?? ""; | |
| if ( | |
| /^\d+$/.test(tail) && | |
| (parent === "layers" || parent === "h" || parent === "block") | |
| ) { | |
| return `${cap}DecoderLayer`; | |
| } | |
| if (tail === "self_attn" || tail === "attn" || tail === "attention") | |
| return `${cap}Attention`; | |
| if (tail === "mlp" || tail === "feed_forward" || tail === "ffn") | |
| return `${cap}MLP`; | |
| if (tail === "encoder") return `${cap}Encoder`; | |
| if (tail === "decoder") return `${cap}Decoder`; | |
| return null; | |
| } | |
| /** | |
| * Labels for the group/subflow header band. We aim for short, recognisable | |
| * names (the visual container already conveys "this is a sub-block", so we | |
| * just need to name what it is). | |
| */ | |
| function humanizeGroupLabel(tail: string): string { | |
| if (!tail) return ""; | |
| if (tail === "model") return "model"; | |
| if (tail === "transformer") return "transformer"; | |
| if (tail === "encoder") return "encoder"; | |
| if (tail === "decoder") return "decoder"; | |
| if (tail === "layers" || tail === "h" || tail === "block") return "layers"; | |
| if (tail === "self_attn" || tail === "attn" || tail === "attention") | |
| return "attention"; | |
| if (tail === "mlp" || tail === "feed_forward" || tail === "ffn") return "mlp"; | |
| if (tail === "experts") return "experts"; | |
| if (/^\d+$/.test(tail)) return `layer #${tail}`; | |
| return tail.replace(/_/g, " "); | |
| } | |
| interface FormattedLabel { | |
| label: string; | |
| sublabel: string; | |
| } | |
| /** | |
| * Build the Python class name that most likely backs a given node, using | |
| * the HF `model_type` from config.json. We mirror the transformers naming | |
| * convention: `Llama` + `DecoderLayer` = `LlamaDecoderLayer`. This is what | |
| * hfviewer.com shows on its main node labels. | |
| */ | |
| function classNameForNode(n: IRNode, modelType?: string): string | null { | |
| if (!modelType) return null; | |
| const cap = modelType.charAt(0).toUpperCase() + modelType.slice(1).toLowerCase(); | |
| const scope = scopeOf(n); | |
| const tail = scope.split("/").filter(Boolean).pop() ?? ""; | |
| const ot = n.opType; | |
| // A numeric tail under a `layers`/`h`/`block` parent is a transformer layer. | |
| if (/^\d+$/.test(tail)) { | |
| const parentTail = scope.split("/").filter(Boolean).slice(-2, -1)[0] ?? ""; | |
| if (parentTail === "layers" || parentTail === "h" || parentTail === "block") { | |
| return `${cap}DecoderLayer`; | |
| } | |
| } | |
| if (ot === "TransformerLayer") return `${cap}DecoderLayer`; | |
| if (ot === "Attention") return `${cap}Attention`; | |
| if (ot === "MLP") return `${cap}MLP`; | |
| if (ot === "Norm") { | |
| // Common naming pattern in modern Llama-family models: LlamaRMSNorm. | |
| if (modelType === "llama" || modelType === "qwen2" || modelType === "qwen3" || modelType === "mistral") { | |
| return `${cap}RMSNorm`; | |
| } | |
| } | |
| return null; | |
| } | |
| function formatNodeLabel(n: IRNode, modelType?: string): FormattedLabel { | |
| const synLabel = n.attrs.label as string | undefined; | |
| // Synthetic ops carry an intentional, concise label (e.g. "+", "h_in"): | |
| // it's already optimised for display. | |
| if (synLabel && synLabel.length > 0 && synLabel.length <= 24) { | |
| return { label: synLabel, sublabel: "" }; | |
| } | |
| const opType = n.opType; | |
| const scope = scopeOf(n); | |
| const tail = scope.split("/").filter(Boolean).pop() ?? ""; | |
| // Special: a top-level Linear named lm_head is the model's classification | |
| // head — treat it as a first-class concept, not "Linear / lm_head". | |
| if (opType === "Linear" && tail === "lm_head") { | |
| return { label: "LM head", sublabel: "" }; | |
| } | |
| // Folded transformer layer wrapper (numeric tail under `layers`/`h`/`block`): | |
| // surface the Python class name (e.g. `LlamaDecoderLayer`) as primary | |
| // label — matches the hfviewer.com convention and is more informative | |
| // than the bare layer index. | |
| if (/^\d+$/.test(tail) && opType === "Module") { | |
| const cn = classNameForNode(n, modelType); | |
| if (cn) return { label: cn, sublabel: "" }; | |
| return { label: `Layer #${tail}`, sublabel: "" }; | |
| } | |
| switch (opType) { | |
| case "Embedding": | |
| return { label: "Embedding", sublabel: humanizeEmbedding(tail) }; | |
| case "Attention": { | |
| const cn = classNameForNode(n, modelType); | |
| return cn | |
| ? { label: cn, sublabel: "" } | |
| : { label: "Attention", sublabel: "" }; | |
| } | |
| case "MLP": { | |
| const cn = classNameForNode(n, modelType); | |
| return cn ? { label: cn, sublabel: "" } : { label: "MLP", sublabel: "" }; | |
| } | |
| case "TransformerLayer": { | |
| const cn = classNameForNode(n, modelType); | |
| return cn | |
| ? { label: cn, sublabel: "" } | |
| : { label: "Decoder layer", sublabel: "" }; | |
| } | |
| case "Norm": { | |
| const cn = classNameForNode(n, modelType); | |
| return cn | |
| ? { label: cn, sublabel: humanizeNorm(tail) } | |
| : { label: "Norm", sublabel: humanizeNorm(tail) }; | |
| } | |
| case "Linear": | |
| return { label: "Linear", sublabel: humanizeLinear(tail) || tail }; | |
| case "RoPE": | |
| return { label: "RoPE", sublabel: "" }; | |
| case "Router": | |
| return { label: "Router", sublabel: "" }; | |
| case "Expert": | |
| return { label: "Expert", sublabel: humanizeGeneric(tail) }; | |
| case "Module": { | |
| // A generic Module wrapper. The scope tail is usually a Python | |
| // identifier; humanise it. | |
| const pretty = humanizeGeneric(tail); | |
| return { label: pretty || "Module", sublabel: "" }; | |
| } | |
| default: { | |
| const pretty = humanizeGeneric(tail); | |
| const sublabel = pretty && pretty.toLowerCase() !== opType.toLowerCase() ? pretty : ""; | |
| return { label: opType, sublabel }; | |
| } | |
| } | |
| } | |
| function isIoScope(scope: string): boolean { | |
| return scope === "" || scope.startsWith("__io__"); | |
| } | |
| function scopeOf(n: IRNode): string { | |
| return ((n.attrs.scope as string) ?? "").trim(); | |
| } | |
| function ancestorScopes(scope: string): string[] { | |
| const segs = scope.split("/").filter(Boolean); | |
| const out: string[] = []; | |
| for (let i = 1; i < segs.length; i++) out.push(segs.slice(0, i).join("/")); | |
| return out; | |
| } | |
| function parentScope(scope: string): string { | |
| const segs = scope.split("/").filter(Boolean); | |
| if (segs.length <= 1) return ""; | |
| return segs.slice(0, -1).join("/"); | |
| } | |
| function lastSeg(scope: string): string { | |
| const segs = scope.split("/").filter(Boolean); | |
| return segs[segs.length - 1] ?? ""; | |
| } | |
| interface Hierarchy { | |
| /** scope path → group id (only for scopes that became real groups). */ | |
| groupForScope: Map<string, string>; | |
| /** group id → direct child node ids (mix of IR ids and other group ids). | |
| * Children come in IR-order (matching `ir.nodes` traversal) so dagre | |
| * has a stable input. */ | |
| childrenOf: Map<string, string[]>; | |
| /** node id → parent group id (only for nodes that have a parent group). */ | |
| parentOf: Map<string, string>; | |
| /** group id → label to render in the container header. */ | |
| groupLabel: Map<string, string>; | |
| /** IR node ids that are absorbed into their containing group (wrapper | |
| * IRNodes whose scope is exactly the group's scope). They must not be | |
| * emitted as separate React Flow nodes. */ | |
| absorbed: Set<string>; | |
| /** Top-level node ids (in IR order, plus IO pills first/last). */ | |
| topLevel: string[]; | |
| } | |
| /** | |
| * Build the visual hierarchy: | |
| * - A scope P becomes a group iff at least one IR node has scope deeper | |
| * than P (i.e. P has descendants in the IR). | |
| * - "Trivial" groups (no IR children of their own AND a single sub-group) | |
| * are collapsed: their only sub-group is promoted to the parent. | |
| * - A wrapper IRNode (opType="Module", no weights) at the exact scope of | |
| * a group is absorbed into the group (its label and color are kept). | |
| * - IO pills (__graph_input__, __graph_output__) stay top-level. | |
| */ | |
| function buildHierarchy(ir: IRGraph): Hierarchy { | |
| // 1. Bucket IR nodes by scope, ignoring IO. | |
| const byScope = new Map<string, IRNode[]>(); | |
| const allCandidateScopes = new Set<string>(); | |
| for (const n of ir.nodes) { | |
| const s = scopeOf(n); | |
| if (isIoScope(s)) continue; | |
| if (!byScope.has(s)) byScope.set(s, []); | |
| byScope.get(s)!.push(n); | |
| allCandidateScopes.add(s); | |
| for (const a of ancestorScopes(s)) allCandidateScopes.add(a); | |
| } | |
| // 2. Mark which scopes have descendants (strictly deeper IR nodes). | |
| const hasDescendants = new Set<string>(); | |
| for (const s of byScope.keys()) { | |
| for (const a of ancestorScopes(s)) hasDescendants.add(a); | |
| } | |
| // 3. Initial group decision: a scope is a group if it has descendants. | |
| // (Pure leaf scopes with no children don't need a container.) | |
| // EXCEPT for "structural wrapper" scopes — the bare `model`, the | |
| // ModuleList container `layers` (a.k.a. `h` or `block` in HF code), or | |
| // the global `transformer` root. These carry no semantic information | |
| // that the user cares about, they just add a layer of nesting around | |
| // every node. hfviewer.com flattens them by default, and so do we. | |
| // The semantic class names (`LlamaDecoderLayer`, `LlamaAttention`, | |
| // `LlamaMLP`) come from scopes one level deeper (`layers/0`, `self_attn`, | |
| // `mlp`), which DO become groups. | |
| const STRUCTURAL_WRAPPERS = new Set<string>([ | |
| "model", | |
| "transformer", | |
| "layers", | |
| "h", | |
| "block", | |
| ]); | |
| const isStructuralWrapperScope = (scope: string): boolean => { | |
| const tail = lastSeg(scope); | |
| return STRUCTURAL_WRAPPERS.has(tail); | |
| }; | |
| const isGroup = new Set<string>(); | |
| for (const s of allCandidateScopes) { | |
| if (hasDescendants.has(s) && !isStructuralWrapperScope(s)) isGroup.add(s); | |
| } | |
| // 4. Compute the direct sub-group of every scope (bottom-up traversal). | |
| // A sub-group of P is a child scope C such that C (or some ancestor of | |
| // C strictly below P) is a group. | |
| // Easier formulation: for each group G, its parent group is the | |
| // longest proper ancestor scope of G that is also a group. | |
| const parentScopeOfGroup = new Map<string, string | null>(); | |
| for (const g of isGroup) { | |
| let cur = parentScope(g); | |
| let parent: string | null = null; | |
| while (cur !== "") { | |
| if (isGroup.has(cur)) { | |
| parent = cur; | |
| break; | |
| } | |
| cur = parentScope(cur); | |
| } | |
| parentScopeOfGroup.set(g, parent); | |
| } | |
| // 5. Collapse "trivial" groups bottom-up: a group whose only purpose is | |
| // to wrap a single child (one sub-group OR one IR node assigned to it, | |
| // and no own IR node) is visual noise. We drop it and re-parent that | |
| // child to the trivial group's parent. This is what makes the | |
| // ModuleList wrapper `layers` disappear when it folds down to a single | |
| // `×N` cluster — the user sees the badge on the rep node directly | |
| // instead of through an empty container. | |
| const closestContainingGroup = (scope: string): string | null => { | |
| let cur: string = scope; | |
| while (cur !== "") { | |
| if (isGroup.has(cur)) return cur; | |
| cur = parentScope(cur); | |
| } | |
| return null; | |
| }; | |
| const sortedDeepFirst = [...isGroup].sort( | |
| (a, b) => b.split("/").length - a.split("/").length, | |
| ); | |
| let changed = true; | |
| while (changed) { | |
| changed = false; | |
| // Count direct IR children per group: an IR node belongs to its deepest | |
| // containing group. Sub-groups don't count here. | |
| const irChildCount = new Map<string, number>(); | |
| for (const n of ir.nodes) { | |
| const s = scopeOf(n); | |
| if (isIoScope(s)) continue; | |
| const g = closestContainingGroup(s); | |
| if (!g) continue; | |
| irChildCount.set(g, (irChildCount.get(g) ?? 0) + 1); | |
| } | |
| // Count sub-groups per group based on the live parentScopeOfGroup. | |
| const subGroupCount = new Map<string, number>(); | |
| const subGroupsOf = new Map<string, string[]>(); | |
| for (const other of isGroup) { | |
| const p = parentScopeOfGroup.get(other); | |
| if (p && isGroup.has(p)) { | |
| subGroupCount.set(p, (subGroupCount.get(p) ?? 0) + 1); | |
| const arr = subGroupsOf.get(p) ?? []; | |
| arr.push(other); | |
| subGroupsOf.set(p, arr); | |
| } | |
| } | |
| for (const g of sortedDeepFirst) { | |
| if (!isGroup.has(g)) continue; | |
| const ownIR = byScope.get(g)?.length ?? 0; | |
| if (ownIR > 0) continue; | |
| const ir1 = irChildCount.get(g) ?? 0; | |
| const subs = subGroupCount.get(g) ?? 0; | |
| if (ir1 + subs <= 1) { | |
| // Trivial: drop g, re-parent any sub-group of g to g's parent. | |
| const myParent = parentScopeOfGroup.get(g) ?? null; | |
| for (const sub of subGroupsOf.get(g) ?? []) { | |
| parentScopeOfGroup.set(sub, myParent); | |
| } | |
| isGroup.delete(g); | |
| parentScopeOfGroup.delete(g); | |
| changed = true; | |
| } | |
| } | |
| } | |
| // 6. Build the result. | |
| const groupForScope = new Map<string, string>(); | |
| const groupLabel = new Map<string, string>(); | |
| for (const g of isGroup) { | |
| groupForScope.set(g, GROUP_ID_PREFIX + g); | |
| groupLabel.set(GROUP_ID_PREFIX + g, humanizeGroupLabel(lastSeg(g))); | |
| } | |
| // For each IR node, find its containing group: the deepest group scope | |
| // that is a (non-strict) ancestor of the node's scope. If the node's | |
| // own scope is a group, we still want to assign the node TO that group | |
| // (so e.g. synthetic ops at scope "model/layers/0" go inside "group:0"). | |
| const parentOf = new Map<string, string>(); | |
| const absorbed = new Set<string>(); | |
| const childrenOf = new Map<string, string[]>(); | |
| const ensureChildren = (id: string): string[] => { | |
| let arr = childrenOf.get(id); | |
| if (!arr) { | |
| arr = []; | |
| childrenOf.set(id, arr); | |
| } | |
| return arr; | |
| }; | |
| const findContainingGroup = (scope: string): string | null => { | |
| let cur: string | null = scope; | |
| while (cur !== null) { | |
| if (cur === "") return null; | |
| if (isGroup.has(cur)) return groupForScope.get(cur) ?? null; | |
| const p = parentScope(cur); | |
| cur = p === "" ? "" : p; | |
| if (cur === "") return null; | |
| } | |
| return null; | |
| }; | |
| // Order IR nodes so that within the same group siblings keep a stable | |
| // sort. We rely on `ir.nodes` order being stable from buildIR. | |
| for (const n of ir.nodes) { | |
| const s = scopeOf(n); | |
| if (isIoScope(s)) continue; | |
| // Wrapper IRNodes that live exactly at a flattened structural scope | |
| // (e.g. the bare `model` module wrapper) carry no meaningful info on | |
| // their own — they were the implicit container above the real | |
| // sub-blocks. Now that we no longer render that container, the | |
| // wrapper would otherwise float as a stray "model" pill. Drop it. | |
| if ( | |
| isStructuralWrapperScope(s) && | |
| n.opType === "Module" && | |
| n.attrs.synthetic !== true && | |
| n.weights.length === 0 | |
| ) { | |
| absorbed.add(n.id); | |
| continue; | |
| } | |
| const gid = findContainingGroup(s); | |
| if (!gid) continue; | |
| // Wrapper absorption: if this IR node is the canonical "Module wrapper" | |
| // at the EXACT scope of the group, absorb it. CRITICAL: a cluster that | |
| // merges multiple members (numChildren > 1) typically carries connectivity | |
| // (e.g. residual adds inside a layer scope). Absorbing such a cluster | |
| // would silently drop edges that point at it, so we only absorb pure | |
| // single-member wrappers. | |
| const gscope = gid.slice(GROUP_ID_PREFIX.length); | |
| const numChildren = (n.attrs.numChildren as number | undefined) ?? 1; | |
| const isWrapperAtGroupScope = | |
| s === gscope && | |
| n.opType === "Module" && | |
| n.attrs.synthetic !== true && | |
| n.weights.length === 0 && | |
| numChildren <= 1; | |
| if (isWrapperAtGroupScope) { | |
| absorbed.add(n.id); | |
| // Promote the wrapper's label to the group (it's usually more | |
| // informative than the bare last segment). | |
| const sem = n.attrs.semanticLabel as string | undefined; | |
| const raw = sem ?? lastSeg(s); | |
| groupLabel.set(gid, humanizeGroupLabel(raw)); | |
| continue; | |
| } | |
| parentOf.set(n.id, gid); | |
| ensureChildren(gid).push(n.id); | |
| } | |
| // 7. Sub-group containment: each group's parent group (if any). | |
| for (const g of isGroup) { | |
| const gid = GROUP_ID_PREFIX + g; | |
| const parentG = parentScopeOfGroup.get(g); | |
| if (parentG && isGroup.has(parentG)) { | |
| const parentGid = GROUP_ID_PREFIX + parentG; | |
| parentOf.set(gid, parentGid); | |
| ensureChildren(parentGid).push(gid); | |
| } | |
| } | |
| // 8. Compute top-level: nodes with no parent group, in original IR order, | |
| // with IO pills explicitly emitted (input first, output last). | |
| const topLevel: string[] = []; | |
| const seenTop = new Set<string>(); | |
| const pushTop = (id: string): void => { | |
| if (seenTop.has(id)) return; | |
| seenTop.add(id); | |
| topLevel.push(id); | |
| }; | |
| for (const n of ir.nodes) { | |
| if (absorbed.has(n.id)) continue; | |
| if (parentOf.has(n.id)) continue; | |
| pushTop(n.id); | |
| } | |
| for (const g of isGroup) { | |
| const gid = GROUP_ID_PREFIX + g; | |
| if (parentOf.has(gid)) continue; | |
| pushTop(gid); | |
| } | |
| return { | |
| groupForScope, | |
| childrenOf, | |
| parentOf, | |
| groupLabel, | |
| absorbed, | |
| topLevel, | |
| }; | |
| } | |
| export interface LayoutResult { | |
| nodes: Node[]; | |
| edges: Edge[]; | |
| } | |
| export function layoutGraph(ir: IRGraph): LayoutResult { | |
| const hier = buildHierarchy(ir); | |
| const modelType = ir.meta.modelType; | |
| const isIdentity = (n: IRNode) => n.opType === "Identity"; | |
| const dimsOf = (n: IRNode): { w: number; h: number } => { | |
| if (isClusterNode(n)) return { w: CLUSTER_NODE_WIDTH, h: CLUSTER_NODE_HEIGHT }; | |
| if (isIdentity(n)) return { w: IDENTITY_NODE_WIDTH, h: IDENTITY_NODE_HEIGHT }; | |
| // Synthetic ops (no weights, marked `synthetic: true`) are kept slim | |
| // so a fully expanded transformer block stays vertically compact. | |
| if (n.attrs.synthetic === true) { | |
| return { w: SYNTHETIC_NODE_WIDTH, h: SYNTHETIC_NODE_HEIGHT }; | |
| } | |
| return { w: NODE_WIDTH, h: NODE_HEIGHT }; | |
| }; | |
| // Build a dagre compound graph. Each parent group is a dagre node with no | |
| // intrinsic dimensions; dagre computes a bounding box from the children | |
| // it contains, padded by GROUP_PADDING_*. | |
| // Top-to-bottom pipeline layout. This is the convention used by every | |
| // model graph on hfviewer.com (TinyLlama, GPT-2, ViT, T5, Qwen, ...): | |
| // input on top, output on the bottom, decoder/transformer layers as | |
| // single nodes with an ×N badge to the right when loop-folded. | |
| // | |
| // Vertical rank spacing is adaptive: we already keep nodes compact | |
| // (height 38-26 px), so a generous gap would dominate the canvas and | |
| // make every block feel "too tall". At the Op-level granularity a | |
| // transformer can expose 30-50 nodes in a column, so we shrink the | |
| // gap further still. The horizontal gap (`nodesep`) is kept constant | |
| // since siblings rarely share a rank at large counts. | |
| const visibleCount = ir.nodes.length; | |
| const ranksep = Math.round( | |
| Math.max(4, Math.min(12, 12 - (visibleCount - 8) * 0.35)), | |
| ); | |
| const g = new dagre.graphlib.Graph({ compound: true }); | |
| g.setGraph({ | |
| rankdir: "TB", | |
| nodesep: 26, | |
| ranksep, | |
| marginx: 24, | |
| marginy: 24, | |
| }); | |
| g.setDefaultEdgeLabel(() => ({})); | |
| const irById = new Map<string, IRNode>(); | |
| for (const n of ir.nodes) irById.set(n.id, n); | |
| // Emit IR nodes (those not absorbed) as dagre leaves. | |
| for (const n of ir.nodes) { | |
| if (hier.absorbed.has(n.id)) continue; | |
| const { w, h } = dimsOf(n); | |
| g.setNode(n.id, { width: w, height: h }); | |
| } | |
| // Emit groups as compound parents. | |
| for (const gid of hier.groupForScope.values()) { | |
| g.setNode(gid, { | |
| paddingTop: GROUP_PADDING_TOP, | |
| paddingBottom: GROUP_PADDING_BOTTOM, | |
| paddingLeft: GROUP_PADDING_X, | |
| paddingRight: GROUP_PADDING_X, | |
| }); | |
| } | |
| // Wire parent → child relationships. | |
| for (const [childId, parentGid] of hier.parentOf.entries()) { | |
| g.setParent(childId, parentGid); | |
| } | |
| // Forward/residual edges drive the layout. Tree edges are visually | |
| // redundant once we have group containers; we drop them entirely. | |
| // Edges whose endpoint is an absorbed wrapper get rewritten to the | |
| // wrapper's containing group (so dagre still has something to anchor | |
| // the rank to). | |
| const layoutSource = (id: string): string => { | |
| if (hier.absorbed.has(id)) { | |
| const gid = hier.parentOf.get(id); | |
| if (gid) return gid; | |
| } | |
| return id; | |
| }; | |
| for (const e of ir.edges) { | |
| if (e.kind === "tree") continue; | |
| const s = layoutSource(e.source); | |
| const t = layoutSource(e.target); | |
| if (s === t) continue; | |
| g.setEdge(s, t); | |
| } | |
| dagre.layout(g); | |
| // Per-node param shading (heatmap). | |
| const paramsOf = (n: IRNode): number => { | |
| const fromAttr = n.attrs.totalParams; | |
| if (typeof fromAttr === "number" && fromAttr > 0) return fromAttr; | |
| return n.weights.reduce((a, w) => a + w.numParams, 0); | |
| }; | |
| const paramsPerNode = new Map<string, number>(); | |
| for (const n of ir.nodes) paramsPerNode.set(n.id, paramsOf(n)); | |
| let maxParams = 0; | |
| for (const p of paramsPerNode.values()) if (p > maxParams) maxParams = p; | |
| // Build React Flow nodes. IMPORTANT: parents must appear before children | |
| // in the output array (React Flow requirement). | |
| const reactNodes: Node[] = []; | |
| const positionOf = (id: string): { x: number; y: number } | null => { | |
| const dn = g.node(id); | |
| if (!dn) return null; | |
| const parentGid = hier.parentOf.get(id); | |
| if (parentGid) { | |
| const parent = g.node(parentGid); | |
| if (!parent) return { x: dn.x - dn.width / 2, y: dn.y - dn.height / 2 }; | |
| const parentTopLeftX = parent.x - parent.width / 2; | |
| const parentTopLeftY = parent.y - parent.height / 2; | |
| return { | |
| x: dn.x - dn.width / 2 - parentTopLeftX, | |
| y: dn.y - dn.height / 2 - parentTopLeftY, | |
| }; | |
| } | |
| return { x: dn.x - dn.width / 2, y: dn.y - dn.height / 2 }; | |
| }; | |
| // 1) Emit groups (parents) in shallow-first order so a child group's | |
| // parentNode reference is always already present. | |
| const groupIds = [...hier.groupForScope.values()].sort((a, b) => { | |
| const da = a.split("/").length; | |
| const db = b.split("/").length; | |
| return da - db; | |
| }); | |
| // Index IR nodes by exact scope so we can pull rich metadata (semantic | |
| // class name, repeatCount) for the container that sits at that scope. | |
| const irByExactScope = new Map<string, IRNode>(); | |
| for (const n of ir.nodes) { | |
| const s = scopeOf(n); | |
| if (!irByExactScope.has(s)) irByExactScope.set(s, n); | |
| } | |
| for (const gid of groupIds) { | |
| const dn = g.node(gid); | |
| if (!dn) continue; | |
| const pos = positionOf(gid)!; | |
| const parentGid = hier.parentOf.get(gid); | |
| const scope = gid.slice(GROUP_ID_PREFIX.length); | |
| const scopeTail = lastSeg(scope); | |
| const segs = scope.split("/").filter(Boolean); | |
| const parentTail = segs[segs.length - 2] ?? ""; | |
| // Pull semantic info from the IR node at this exact scope (typically a | |
| // Module wrapper or a folded-loop cluster). Its repeatCount lets us | |
| // show the ×N badge on the container itself (matching hfviewer's loop | |
| // visualization). The label comes from `classNameForScope` so a layer | |
| // group under model_type=llama reads as `LlamaDecoderLayer`. | |
| const scopeNode = irByExactScope.get(scope); | |
| const className = classNameForScope(scope, modelType); | |
| const baseLabel = hier.groupLabel.get(gid) ?? ""; | |
| const label = | |
| className && className.length <= 36 ? className : baseLabel; | |
| // The ×N badge only makes sense for groups that wrap a folded loop | |
| // (i.e. a numeric tail under `layers`/`h`/`block`). Otherwise the | |
| // repeat count is inherited from a parent loop and would be misleading | |
| // on attention / mlp / experts children. | |
| const isLoopGroup = | |
| /^\d+$/.test(scopeTail) && | |
| (parentTail === "layers" || parentTail === "h" || parentTail === "block"); | |
| const repeatCount = isLoopGroup | |
| ? (scopeNode?.attrs.repeatCount as number | undefined) | |
| : undefined; | |
| // Color the container border by its semantic role, derived from the | |
| // bare scope tail. Falls back to a neutral slate for unrecognised | |
| // wrappers (and for non-transformer architectures). | |
| const groupColor = ((tail: string): string => { | |
| if (/^\d+$/.test(tail)) return colorForOp("TransformerLayer"); | |
| if (tail === "self_attn" || tail === "attn" || tail === "attention") | |
| return colorForOp("Attention"); | |
| if (tail === "mlp" || tail === "feed_forward" || tail === "ffn") | |
| return colorForOp("MLP"); | |
| if (tail === "experts") return colorForOp("Expert"); | |
| if (tail === "encoder" || tail === "decoder") | |
| return colorForOp("TransformerLayer"); | |
| return "#7a8aa3"; | |
| })(scopeTail); | |
| reactNodes.push({ | |
| id: gid, | |
| type: "group", | |
| ...(parentGid ? { parentNode: parentGid } : {}), | |
| position: pos, | |
| style: { width: dn.width, height: dn.height }, | |
| data: { | |
| label, | |
| sublabel: "", | |
| color: groupColor, | |
| isCluster: true, | |
| intensity: 0, | |
| isGroupContainer: true, | |
| repeatCount, | |
| }, | |
| selectable: false, | |
| draggable: false, | |
| }); | |
| } | |
| // Identify groups that wrap a folded loop (and which already show a ×N | |
| // badge on their container). Nodes inside such a group should NOT carry | |
| // their own ×N badge — the container already conveys it, and showing it | |
| // on every child would clutter the visual. | |
| const loopGroupIds = new Set<string>(); | |
| for (const [scope, gid] of hier.groupForScope.entries()) { | |
| const segs = scope.split("/").filter(Boolean); | |
| const t = segs[segs.length - 1] ?? ""; | |
| const p = segs[segs.length - 2] ?? ""; | |
| if (/^\d+$/.test(t) && (p === "layers" || p === "h" || p === "block")) { | |
| loopGroupIds.add(gid); | |
| } | |
| } | |
| const isUnderLoopGroup = (startGid: string | undefined): boolean => { | |
| let cur = startGid; | |
| while (cur) { | |
| if (loopGroupIds.has(cur)) return true; | |
| cur = hier.parentOf.get(cur); | |
| } | |
| return false; | |
| }; | |
| // 2) Emit leaf IR nodes. | |
| for (const n of ir.nodes) { | |
| if (hier.absorbed.has(n.id)) continue; | |
| const dn = g.node(n.id); | |
| if (!dn) continue; | |
| const pos = positionOf(n.id)!; | |
| const parentGid = hier.parentOf.get(n.id); | |
| const cluster = isClusterNode(n); | |
| const { w, h } = dimsOf(n); | |
| const color = colorForOp(n.opType); | |
| const params = paramsPerNode.get(n.id) ?? 0; | |
| const intensity = maxParams > 0 ? Math.sqrt(params / maxParams) : 0; | |
| const { label, sublabel } = formatNodeLabel(n, modelType); | |
| const rawRepeatCount = n.attrs.repeatCount as number | undefined; | |
| const repeatCount = isUnderLoopGroup(parentGid) | |
| ? undefined | |
| : rawRepeatCount; | |
| reactNodes.push({ | |
| id: n.id, | |
| type: cluster ? "cluster" : "op", | |
| ...(parentGid ? { parentNode: parentGid } : {}), | |
| position: pos, | |
| targetPosition: Position.Top, | |
| sourcePosition: Position.Bottom, | |
| data: { | |
| label, | |
| sublabel, | |
| irNode: n, | |
| color, | |
| isCluster: cluster, | |
| intensity, | |
| repeatCount, | |
| }, | |
| style: { width: w, height: h }, | |
| }); | |
| } | |
| // Edges: drop tree-kind (the container conveys that). Also rewrite any | |
| // endpoint that points to an absorbed wrapper. | |
| // | |
| // Routing strategy: | |
| // - All edges use the `smoothstep` type (orthogonal routing with rounded | |
| // corners). This is much better than the default bezier for densely | |
| // packed columns: a bezier from B → D drawn through C visually crosses | |
| // C; a smoothstep one bends around it. | |
| // - Forward edges connect the default top/bottom handles ("t" / "s") so | |
| // they trace the main flow of the column. | |
| // - Residual (skip-connection) edges connect the *right-side* handles | |
| // ("s-r" / "t-r") instead, so they run alongside the column rather | |
| // than overlapping the forward chain. The combination of side-handles | |
| // + smoothstep produces the familiar "side-arc" look used by tools | |
| // like hfviewer.com. | |
| const reactEdges: Edge[] = []; | |
| for (const e of ir.edges) { | |
| if (e.kind === "tree") continue; | |
| const src = hier.absorbed.has(e.source) | |
| ? (hier.parentOf.get(e.source) ?? e.source) | |
| : e.source; | |
| const tgt = hier.absorbed.has(e.target) | |
| ? (hier.parentOf.get(e.target) ?? e.target) | |
| : e.target; | |
| if (src === tgt) continue; | |
| const kind = e.kind ?? "forward"; | |
| const isResidual = kind === "residual"; | |
| // hfviewer styling: thin neutral lines for the forward chain, animated | |
| // solid orange lines for skip-connections. No arrowheads, no per-edge | |
| // tensor labels — the direction is implied by the top→bottom layout, | |
| // and shape annotations would just add noise. | |
| const style = isResidual | |
| ? { | |
| stroke: "rgba(255,170,90,0.95)", | |
| strokeWidth: 1.4, | |
| } | |
| : { stroke: "rgba(200,210,235,0.55)", strokeWidth: 1.2 }; | |
| reactEdges.push({ | |
| id: e.id, | |
| source: src, | |
| target: tgt, | |
| sourceHandle: isResidual ? "s-r" : "s", | |
| targetHandle: isResidual ? "t-r" : "t", | |
| style, | |
| type: "smoothstep", | |
| pathOptions: { borderRadius: 10, offset: isResidual ? 18 : 8 }, | |
| animated: isResidual, | |
| // Make sure edges crossing group boundaries are not clipped. | |
| zIndex: 1000, | |
| }); | |
| } | |
| return { nodes: reactNodes, edges: reactEdges }; | |
| } | |