// Hierarchical graph layout using dagre. We feed it the IR and produce // React Flow-compatible Node[] and Edge[] arrays. // // Subflows: modules with descendants are rendered as React Flow *group* // containers that visually wrap their children. The hierarchy is derived // from the per-node `attrs.scope` string (slash-separated): any scope that // has at least one descendant in the IR becomes a synthetic container. // Wrapper IRNodes (`opType: "Module"` with no weights) are absorbed into // their container so we don't render both `model` and "group:model" side by // side. import dagre from "@dagrejs/dagre"; import { Position, type Node, type Edge } from "reactflow"; import type { IRGraph, IRNode } from "../types"; // Node footprints, calibrated to match the visual density of hfviewer.com: // short pills with a 10-11px label, no decoration. Bigger numbers blow up // the canvas; smaller ones make the text unreadable. const NODE_WIDTH = 110; const NODE_HEIGHT = 22; const CLUSTER_NODE_WIDTH = 130; const CLUSTER_NODE_HEIGHT = 26; /** Synthetic ops (Add, Mul, MatMul, Softmax, RoPE, ...) carry no params, * show a single short label, and are best kept extra compact so a fully * expanded transformer block doesn't spread over dozens of bulky rows. */ const SYNTHETIC_NODE_WIDTH = 70; const SYNTHETIC_NODE_HEIGHT = 20; /** Identity/boundary ops are visually minimal — they're "data flow stubs". */ const IDENTITY_NODE_WIDTH = 44; const IDENTITY_NODE_HEIGHT = 16; /** Extra space dagre reserves around a compound parent so its label band * has room and the children don't kiss the border. */ const GROUP_PADDING_X = 8; const GROUP_PADDING_TOP = 12; const GROUP_PADDING_BOTTOM = 6; const GROUP_ID_PREFIX = "group:"; function colorForOp(opType: string): string { // Semantic super-node types (from grouping.ts) come first. if (opType === "TransformerLayer") return "#5d4eff"; if (opType === "Attention") return "#8a5cff"; if (opType === "MLP") return "#3c8ce6"; if (opType === "Embedding") return "#f0925b"; if (opType === "Norm") return "#a06bd6"; if (opType === "Linear") return "#4f78e0"; if (opType === "RoPE") return "#c25fe6"; if (opType === "Router") return "#ffae5b"; if (opType === "Expert") return "#5e9efb"; if (opType === "Module") return "#506280"; // ONNX op-level fallback. if (/MatMul|Gemm|Conv/.test(opType)) return "#5b8def"; if (/Attention|Softmax/.test(opType)) return "#7c5cff"; if (/Norm|LayerNormalization|RMSNormalization/.test(opType)) return "#9b6bd1"; if (/Add|Mul|Sub|Div|Pow/.test(opType)) return "#3aa6c9"; if (/Gelu|Relu|Silu|Swish|Tanh|Sigmoid/.test(opType)) return "#22c1a0"; if (/Gather|Slice|Reshape|Transpose|Squeeze|Unsqueeze|Concat/.test(opType)) return "#7a8aa3"; if (/Cast|Identity|Constant|Quantize|DequantizeLinear/.test(opType)) return "#4a5263"; if (opType === "Input") return "#f6a96b"; if (opType === "Output") return "#e25fa3"; return "#6477a8"; } function isClusterNode(n: IRNode): boolean { return n.id.startsWith("cluster:"); } // ─── Human-friendly label formatting ───────────────────────────────────── // // Raw module names like `input_layernorm`, `q_proj`, `embed_tokens`, // `post_attention_layernorm`, ... read fine in code but are noisy in a // graph node. The helpers below map the most common transformer module // suffixes to short, semantically meaningful labels. function humanizeNorm(tail: string): string { if (tail === "input_layernorm") return "pre-attn"; if (tail === "post_attention_layernorm") return "post-attn"; if (tail === "pre_feedforward_layernorm") return "pre-ffn"; if (tail === "post_feedforward_layernorm") return "post-ffn"; if (tail === "norm" || tail === "final_layernorm" || tail === "ln_f") return "final"; if (tail.endsWith("_layernorm")) { return tail.replace("_layernorm", "").replace(/_/g, "-"); } if (tail.endsWith("_norm")) { return tail.replace("_norm", "").replace(/_/g, "-"); } if (tail === "ln1") return "pre-attn"; if (tail === "ln2") return "post-attn"; return ""; } function humanizeLinear(tail: string): string { if (tail === "q_proj") return "Q proj"; if (tail === "k_proj") return "K proj"; if (tail === "v_proj") return "V proj"; if (tail === "o_proj") return "O proj"; if (tail === "qkv_proj" || tail === "c_attn") return "QKV"; if (tail === "gate_proj" || tail === "w1") return "gate"; if (tail === "up_proj" || tail === "w3") return "up"; if (tail === "down_proj" || tail === "w2") return "down"; if (tail === "c_proj" || tail === "c_fc") return tail.replace(/_/g, " "); if (tail.endsWith("_proj")) { return tail.replace("_proj", "").replace(/_/g, " "); } return ""; } function humanizeEmbedding(tail: string): string { if (tail === "embed_tokens" || tail === "wte" || tail === "word_embeddings") return "tokens"; if (tail === "wpe" || tail === "embed_positions" || tail === "position_embeddings") return "positions"; if (tail === "token_type_embeddings") return "segments"; return tail.replace(/_/g, " "); } function humanizeGeneric(tail: string): string { if (!tail) return ""; if (/^\d+$/.test(tail)) return `#${tail}`; return tail.replace(/_/g, " "); } /** * Build the Python class name that most likely backs a given GROUP scope, * given the HF `model_type` from config.json. Mirrors the transformers * naming convention so a `model/layers/0` group under modelType=llama * renders as `LlamaDecoderLayer`, matching the hfviewer.com convention. */ function classNameForScope(scope: string, modelType?: string): string | null { if (!modelType) return null; const cap = modelType.charAt(0).toUpperCase() + modelType.slice(1).toLowerCase(); const segs = scope.split("/").filter(Boolean); const tail = segs[segs.length - 1] ?? ""; const parent = segs[segs.length - 2] ?? ""; if ( /^\d+$/.test(tail) && (parent === "layers" || parent === "h" || parent === "block") ) { return `${cap}DecoderLayer`; } if (tail === "self_attn" || tail === "attn" || tail === "attention") return `${cap}Attention`; if (tail === "mlp" || tail === "feed_forward" || tail === "ffn") return `${cap}MLP`; if (tail === "encoder") return `${cap}Encoder`; if (tail === "decoder") return `${cap}Decoder`; return null; } /** * Labels for the group/subflow header band. We aim for short, recognisable * names (the visual container already conveys "this is a sub-block", so we * just need to name what it is). */ function humanizeGroupLabel(tail: string): string { if (!tail) return ""; if (tail === "model") return "model"; if (tail === "transformer") return "transformer"; if (tail === "encoder") return "encoder"; if (tail === "decoder") return "decoder"; if (tail === "layers" || tail === "h" || tail === "block") return "layers"; if (tail === "self_attn" || tail === "attn" || tail === "attention") return "attention"; if (tail === "mlp" || tail === "feed_forward" || tail === "ffn") return "mlp"; if (tail === "experts") return "experts"; if (/^\d+$/.test(tail)) return `layer #${tail}`; return tail.replace(/_/g, " "); } interface FormattedLabel { label: string; sublabel: string; } /** * Build the Python class name that most likely backs a given node, using * the HF `model_type` from config.json. We mirror the transformers naming * convention: `Llama` + `DecoderLayer` = `LlamaDecoderLayer`. This is what * hfviewer.com shows on its main node labels. */ function classNameForNode(n: IRNode, modelType?: string): string | null { if (!modelType) return null; const cap = modelType.charAt(0).toUpperCase() + modelType.slice(1).toLowerCase(); const scope = scopeOf(n); const tail = scope.split("/").filter(Boolean).pop() ?? ""; const ot = n.opType; // A numeric tail under a `layers`/`h`/`block` parent is a transformer layer. if (/^\d+$/.test(tail)) { const parentTail = scope.split("/").filter(Boolean).slice(-2, -1)[0] ?? ""; if (parentTail === "layers" || parentTail === "h" || parentTail === "block") { return `${cap}DecoderLayer`; } } if (ot === "TransformerLayer") return `${cap}DecoderLayer`; if (ot === "Attention") return `${cap}Attention`; if (ot === "MLP") return `${cap}MLP`; if (ot === "Norm") { // Common naming pattern in modern Llama-family models: LlamaRMSNorm. if (modelType === "llama" || modelType === "qwen2" || modelType === "qwen3" || modelType === "mistral") { return `${cap}RMSNorm`; } } return null; } function formatNodeLabel(n: IRNode, modelType?: string): FormattedLabel { const synLabel = n.attrs.label as string | undefined; // Synthetic ops carry an intentional, concise label (e.g. "+", "h_in"): // it's already optimised for display. if (synLabel && synLabel.length > 0 && synLabel.length <= 24) { return { label: synLabel, sublabel: "" }; } const opType = n.opType; const scope = scopeOf(n); const tail = scope.split("/").filter(Boolean).pop() ?? ""; // Special: a top-level Linear named lm_head is the model's classification // head — treat it as a first-class concept, not "Linear / lm_head". if (opType === "Linear" && tail === "lm_head") { return { label: "LM head", sublabel: "" }; } // Folded transformer layer wrapper (numeric tail under `layers`/`h`/`block`): // surface the Python class name (e.g. `LlamaDecoderLayer`) as primary // label — matches the hfviewer.com convention and is more informative // than the bare layer index. if (/^\d+$/.test(tail) && opType === "Module") { const cn = classNameForNode(n, modelType); if (cn) return { label: cn, sublabel: "" }; return { label: `Layer #${tail}`, sublabel: "" }; } switch (opType) { case "Embedding": return { label: "Embedding", sublabel: humanizeEmbedding(tail) }; case "Attention": { const cn = classNameForNode(n, modelType); return cn ? { label: cn, sublabel: "" } : { label: "Attention", sublabel: "" }; } case "MLP": { const cn = classNameForNode(n, modelType); return cn ? { label: cn, sublabel: "" } : { label: "MLP", sublabel: "" }; } case "TransformerLayer": { const cn = classNameForNode(n, modelType); return cn ? { label: cn, sublabel: "" } : { label: "Decoder layer", sublabel: "" }; } case "Norm": { const cn = classNameForNode(n, modelType); return cn ? { label: cn, sublabel: humanizeNorm(tail) } : { label: "Norm", sublabel: humanizeNorm(tail) }; } case "Linear": return { label: "Linear", sublabel: humanizeLinear(tail) || tail }; case "RoPE": return { label: "RoPE", sublabel: "" }; case "Router": return { label: "Router", sublabel: "" }; case "Expert": return { label: "Expert", sublabel: humanizeGeneric(tail) }; case "Module": { // A generic Module wrapper. The scope tail is usually a Python // identifier; humanise it. const pretty = humanizeGeneric(tail); return { label: pretty || "Module", sublabel: "" }; } default: { const pretty = humanizeGeneric(tail); const sublabel = pretty && pretty.toLowerCase() !== opType.toLowerCase() ? pretty : ""; return { label: opType, sublabel }; } } } function isIoScope(scope: string): boolean { return scope === "" || scope.startsWith("__io__"); } function scopeOf(n: IRNode): string { return ((n.attrs.scope as string) ?? "").trim(); } function ancestorScopes(scope: string): string[] { const segs = scope.split("/").filter(Boolean); const out: string[] = []; for (let i = 1; i < segs.length; i++) out.push(segs.slice(0, i).join("/")); return out; } function parentScope(scope: string): string { const segs = scope.split("/").filter(Boolean); if (segs.length <= 1) return ""; return segs.slice(0, -1).join("/"); } function lastSeg(scope: string): string { const segs = scope.split("/").filter(Boolean); return segs[segs.length - 1] ?? ""; } interface Hierarchy { /** scope path → group id (only for scopes that became real groups). */ groupForScope: Map; /** group id → direct child node ids (mix of IR ids and other group ids). * Children come in IR-order (matching `ir.nodes` traversal) so dagre * has a stable input. */ childrenOf: Map; /** node id → parent group id (only for nodes that have a parent group). */ parentOf: Map; /** group id → label to render in the container header. */ groupLabel: Map; /** IR node ids that are absorbed into their containing group (wrapper * IRNodes whose scope is exactly the group's scope). They must not be * emitted as separate React Flow nodes. */ absorbed: Set; /** Top-level node ids (in IR order, plus IO pills first/last). */ topLevel: string[]; } /** * Build the visual hierarchy: * - A scope P becomes a group iff at least one IR node has scope deeper * than P (i.e. P has descendants in the IR). * - "Trivial" groups (no IR children of their own AND a single sub-group) * are collapsed: their only sub-group is promoted to the parent. * - A wrapper IRNode (opType="Module", no weights) at the exact scope of * a group is absorbed into the group (its label and color are kept). * - IO pills (__graph_input__, __graph_output__) stay top-level. */ function buildHierarchy(ir: IRGraph): Hierarchy { // 1. Bucket IR nodes by scope, ignoring IO. const byScope = new Map(); const allCandidateScopes = new Set(); for (const n of ir.nodes) { const s = scopeOf(n); if (isIoScope(s)) continue; if (!byScope.has(s)) byScope.set(s, []); byScope.get(s)!.push(n); allCandidateScopes.add(s); for (const a of ancestorScopes(s)) allCandidateScopes.add(a); } // 2. Mark which scopes have descendants (strictly deeper IR nodes). const hasDescendants = new Set(); for (const s of byScope.keys()) { for (const a of ancestorScopes(s)) hasDescendants.add(a); } // 3. Initial group decision: a scope is a group if it has descendants. // (Pure leaf scopes with no children don't need a container.) // EXCEPT for "structural wrapper" scopes — the bare `model`, the // ModuleList container `layers` (a.k.a. `h` or `block` in HF code), or // the global `transformer` root. These carry no semantic information // that the user cares about, they just add a layer of nesting around // every node. hfviewer.com flattens them by default, and so do we. // The semantic class names (`LlamaDecoderLayer`, `LlamaAttention`, // `LlamaMLP`) come from scopes one level deeper (`layers/0`, `self_attn`, // `mlp`), which DO become groups. const STRUCTURAL_WRAPPERS = new Set([ "model", "transformer", "layers", "h", "block", ]); const isStructuralWrapperScope = (scope: string): boolean => { const tail = lastSeg(scope); return STRUCTURAL_WRAPPERS.has(tail); }; const isGroup = new Set(); for (const s of allCandidateScopes) { if (hasDescendants.has(s) && !isStructuralWrapperScope(s)) isGroup.add(s); } // 4. Compute the direct sub-group of every scope (bottom-up traversal). // A sub-group of P is a child scope C such that C (or some ancestor of // C strictly below P) is a group. // Easier formulation: for each group G, its parent group is the // longest proper ancestor scope of G that is also a group. const parentScopeOfGroup = new Map(); for (const g of isGroup) { let cur = parentScope(g); let parent: string | null = null; while (cur !== "") { if (isGroup.has(cur)) { parent = cur; break; } cur = parentScope(cur); } parentScopeOfGroup.set(g, parent); } // 5. Collapse "trivial" groups bottom-up: a group whose only purpose is // to wrap a single child (one sub-group OR one IR node assigned to it, // and no own IR node) is visual noise. We drop it and re-parent that // child to the trivial group's parent. This is what makes the // ModuleList wrapper `layers` disappear when it folds down to a single // `×N` cluster — the user sees the badge on the rep node directly // instead of through an empty container. const closestContainingGroup = (scope: string): string | null => { let cur: string = scope; while (cur !== "") { if (isGroup.has(cur)) return cur; cur = parentScope(cur); } return null; }; const sortedDeepFirst = [...isGroup].sort( (a, b) => b.split("/").length - a.split("/").length, ); let changed = true; while (changed) { changed = false; // Count direct IR children per group: an IR node belongs to its deepest // containing group. Sub-groups don't count here. const irChildCount = new Map(); for (const n of ir.nodes) { const s = scopeOf(n); if (isIoScope(s)) continue; const g = closestContainingGroup(s); if (!g) continue; irChildCount.set(g, (irChildCount.get(g) ?? 0) + 1); } // Count sub-groups per group based on the live parentScopeOfGroup. const subGroupCount = new Map(); const subGroupsOf = new Map(); for (const other of isGroup) { const p = parentScopeOfGroup.get(other); if (p && isGroup.has(p)) { subGroupCount.set(p, (subGroupCount.get(p) ?? 0) + 1); const arr = subGroupsOf.get(p) ?? []; arr.push(other); subGroupsOf.set(p, arr); } } for (const g of sortedDeepFirst) { if (!isGroup.has(g)) continue; const ownIR = byScope.get(g)?.length ?? 0; if (ownIR > 0) continue; const ir1 = irChildCount.get(g) ?? 0; const subs = subGroupCount.get(g) ?? 0; if (ir1 + subs <= 1) { // Trivial: drop g, re-parent any sub-group of g to g's parent. const myParent = parentScopeOfGroup.get(g) ?? null; for (const sub of subGroupsOf.get(g) ?? []) { parentScopeOfGroup.set(sub, myParent); } isGroup.delete(g); parentScopeOfGroup.delete(g); changed = true; } } } // 6. Build the result. const groupForScope = new Map(); const groupLabel = new Map(); for (const g of isGroup) { groupForScope.set(g, GROUP_ID_PREFIX + g); groupLabel.set(GROUP_ID_PREFIX + g, humanizeGroupLabel(lastSeg(g))); } // For each IR node, find its containing group: the deepest group scope // that is a (non-strict) ancestor of the node's scope. If the node's // own scope is a group, we still want to assign the node TO that group // (so e.g. synthetic ops at scope "model/layers/0" go inside "group:0"). const parentOf = new Map(); const absorbed = new Set(); const childrenOf = new Map(); const ensureChildren = (id: string): string[] => { let arr = childrenOf.get(id); if (!arr) { arr = []; childrenOf.set(id, arr); } return arr; }; const findContainingGroup = (scope: string): string | null => { let cur: string | null = scope; while (cur !== null) { if (cur === "") return null; if (isGroup.has(cur)) return groupForScope.get(cur) ?? null; const p = parentScope(cur); cur = p === "" ? "" : p; if (cur === "") return null; } return null; }; // Order IR nodes so that within the same group siblings keep a stable // sort. We rely on `ir.nodes` order being stable from buildIR. for (const n of ir.nodes) { const s = scopeOf(n); if (isIoScope(s)) continue; // Wrapper IRNodes that live exactly at a flattened structural scope // (e.g. the bare `model` module wrapper) carry no meaningful info on // their own — they were the implicit container above the real // sub-blocks. Now that we no longer render that container, the // wrapper would otherwise float as a stray "model" pill. Drop it. if ( isStructuralWrapperScope(s) && n.opType === "Module" && n.attrs.synthetic !== true && n.weights.length === 0 ) { absorbed.add(n.id); continue; } const gid = findContainingGroup(s); if (!gid) continue; // Wrapper absorption: if this IR node is the canonical "Module wrapper" // at the EXACT scope of the group, absorb it. CRITICAL: a cluster that // merges multiple members (numChildren > 1) typically carries connectivity // (e.g. residual adds inside a layer scope). Absorbing such a cluster // would silently drop edges that point at it, so we only absorb pure // single-member wrappers. const gscope = gid.slice(GROUP_ID_PREFIX.length); const numChildren = (n.attrs.numChildren as number | undefined) ?? 1; const isWrapperAtGroupScope = s === gscope && n.opType === "Module" && n.attrs.synthetic !== true && n.weights.length === 0 && numChildren <= 1; if (isWrapperAtGroupScope) { absorbed.add(n.id); // Promote the wrapper's label to the group (it's usually more // informative than the bare last segment). const sem = n.attrs.semanticLabel as string | undefined; const raw = sem ?? lastSeg(s); groupLabel.set(gid, humanizeGroupLabel(raw)); continue; } parentOf.set(n.id, gid); ensureChildren(gid).push(n.id); } // 7. Sub-group containment: each group's parent group (if any). for (const g of isGroup) { const gid = GROUP_ID_PREFIX + g; const parentG = parentScopeOfGroup.get(g); if (parentG && isGroup.has(parentG)) { const parentGid = GROUP_ID_PREFIX + parentG; parentOf.set(gid, parentGid); ensureChildren(parentGid).push(gid); } } // 8. Compute top-level: nodes with no parent group, in original IR order, // with IO pills explicitly emitted (input first, output last). const topLevel: string[] = []; const seenTop = new Set(); const pushTop = (id: string): void => { if (seenTop.has(id)) return; seenTop.add(id); topLevel.push(id); }; for (const n of ir.nodes) { if (absorbed.has(n.id)) continue; if (parentOf.has(n.id)) continue; pushTop(n.id); } for (const g of isGroup) { const gid = GROUP_ID_PREFIX + g; if (parentOf.has(gid)) continue; pushTop(gid); } return { groupForScope, childrenOf, parentOf, groupLabel, absorbed, topLevel, }; } export interface LayoutResult { nodes: Node[]; edges: Edge[]; } export function layoutGraph(ir: IRGraph): LayoutResult { const hier = buildHierarchy(ir); const modelType = ir.meta.modelType; const isIdentity = (n: IRNode) => n.opType === "Identity"; const dimsOf = (n: IRNode): { w: number; h: number } => { if (isClusterNode(n)) return { w: CLUSTER_NODE_WIDTH, h: CLUSTER_NODE_HEIGHT }; if (isIdentity(n)) return { w: IDENTITY_NODE_WIDTH, h: IDENTITY_NODE_HEIGHT }; // Synthetic ops (no weights, marked `synthetic: true`) are kept slim // so a fully expanded transformer block stays vertically compact. if (n.attrs.synthetic === true) { return { w: SYNTHETIC_NODE_WIDTH, h: SYNTHETIC_NODE_HEIGHT }; } return { w: NODE_WIDTH, h: NODE_HEIGHT }; }; // Build a dagre compound graph. Each parent group is a dagre node with no // intrinsic dimensions; dagre computes a bounding box from the children // it contains, padded by GROUP_PADDING_*. // Top-to-bottom pipeline layout. This is the convention used by every // model graph on hfviewer.com (TinyLlama, GPT-2, ViT, T5, Qwen, ...): // input on top, output on the bottom, decoder/transformer layers as // single nodes with an ×N badge to the right when loop-folded. // // Vertical rank spacing is adaptive: we already keep nodes compact // (height 38-26 px), so a generous gap would dominate the canvas and // make every block feel "too tall". At the Op-level granularity a // transformer can expose 30-50 nodes in a column, so we shrink the // gap further still. The horizontal gap (`nodesep`) is kept constant // since siblings rarely share a rank at large counts. const visibleCount = ir.nodes.length; const ranksep = Math.round( Math.max(4, Math.min(12, 12 - (visibleCount - 8) * 0.35)), ); const g = new dagre.graphlib.Graph({ compound: true }); g.setGraph({ rankdir: "TB", nodesep: 26, ranksep, marginx: 24, marginy: 24, }); g.setDefaultEdgeLabel(() => ({})); const irById = new Map(); for (const n of ir.nodes) irById.set(n.id, n); // Emit IR nodes (those not absorbed) as dagre leaves. for (const n of ir.nodes) { if (hier.absorbed.has(n.id)) continue; const { w, h } = dimsOf(n); g.setNode(n.id, { width: w, height: h }); } // Emit groups as compound parents. for (const gid of hier.groupForScope.values()) { g.setNode(gid, { paddingTop: GROUP_PADDING_TOP, paddingBottom: GROUP_PADDING_BOTTOM, paddingLeft: GROUP_PADDING_X, paddingRight: GROUP_PADDING_X, }); } // Wire parent → child relationships. for (const [childId, parentGid] of hier.parentOf.entries()) { g.setParent(childId, parentGid); } // Forward/residual edges drive the layout. Tree edges are visually // redundant once we have group containers; we drop them entirely. // Edges whose endpoint is an absorbed wrapper get rewritten to the // wrapper's containing group (so dagre still has something to anchor // the rank to). const layoutSource = (id: string): string => { if (hier.absorbed.has(id)) { const gid = hier.parentOf.get(id); if (gid) return gid; } return id; }; for (const e of ir.edges) { if (e.kind === "tree") continue; const s = layoutSource(e.source); const t = layoutSource(e.target); if (s === t) continue; g.setEdge(s, t); } dagre.layout(g); // Per-node param shading (heatmap). const paramsOf = (n: IRNode): number => { const fromAttr = n.attrs.totalParams; if (typeof fromAttr === "number" && fromAttr > 0) return fromAttr; return n.weights.reduce((a, w) => a + w.numParams, 0); }; const paramsPerNode = new Map(); for (const n of ir.nodes) paramsPerNode.set(n.id, paramsOf(n)); let maxParams = 0; for (const p of paramsPerNode.values()) if (p > maxParams) maxParams = p; // Build React Flow nodes. IMPORTANT: parents must appear before children // in the output array (React Flow requirement). const reactNodes: Node[] = []; const positionOf = (id: string): { x: number; y: number } | null => { const dn = g.node(id); if (!dn) return null; const parentGid = hier.parentOf.get(id); if (parentGid) { const parent = g.node(parentGid); if (!parent) return { x: dn.x - dn.width / 2, y: dn.y - dn.height / 2 }; const parentTopLeftX = parent.x - parent.width / 2; const parentTopLeftY = parent.y - parent.height / 2; return { x: dn.x - dn.width / 2 - parentTopLeftX, y: dn.y - dn.height / 2 - parentTopLeftY, }; } return { x: dn.x - dn.width / 2, y: dn.y - dn.height / 2 }; }; // 1) Emit groups (parents) in shallow-first order so a child group's // parentNode reference is always already present. const groupIds = [...hier.groupForScope.values()].sort((a, b) => { const da = a.split("/").length; const db = b.split("/").length; return da - db; }); // Index IR nodes by exact scope so we can pull rich metadata (semantic // class name, repeatCount) for the container that sits at that scope. const irByExactScope = new Map(); for (const n of ir.nodes) { const s = scopeOf(n); if (!irByExactScope.has(s)) irByExactScope.set(s, n); } for (const gid of groupIds) { const dn = g.node(gid); if (!dn) continue; const pos = positionOf(gid)!; const parentGid = hier.parentOf.get(gid); const scope = gid.slice(GROUP_ID_PREFIX.length); const scopeTail = lastSeg(scope); const segs = scope.split("/").filter(Boolean); const parentTail = segs[segs.length - 2] ?? ""; // Pull semantic info from the IR node at this exact scope (typically a // Module wrapper or a folded-loop cluster). Its repeatCount lets us // show the ×N badge on the container itself (matching hfviewer's loop // visualization). The label comes from `classNameForScope` so a layer // group under model_type=llama reads as `LlamaDecoderLayer`. const scopeNode = irByExactScope.get(scope); const className = classNameForScope(scope, modelType); const baseLabel = hier.groupLabel.get(gid) ?? ""; const label = className && className.length <= 36 ? className : baseLabel; // The ×N badge only makes sense for groups that wrap a folded loop // (i.e. a numeric tail under `layers`/`h`/`block`). Otherwise the // repeat count is inherited from a parent loop and would be misleading // on attention / mlp / experts children. const isLoopGroup = /^\d+$/.test(scopeTail) && (parentTail === "layers" || parentTail === "h" || parentTail === "block"); const repeatCount = isLoopGroup ? (scopeNode?.attrs.repeatCount as number | undefined) : undefined; // Color the container border by its semantic role, derived from the // bare scope tail. Falls back to a neutral slate for unrecognised // wrappers (and for non-transformer architectures). const groupColor = ((tail: string): string => { if (/^\d+$/.test(tail)) return colorForOp("TransformerLayer"); if (tail === "self_attn" || tail === "attn" || tail === "attention") return colorForOp("Attention"); if (tail === "mlp" || tail === "feed_forward" || tail === "ffn") return colorForOp("MLP"); if (tail === "experts") return colorForOp("Expert"); if (tail === "encoder" || tail === "decoder") return colorForOp("TransformerLayer"); return "#7a8aa3"; })(scopeTail); reactNodes.push({ id: gid, type: "group", ...(parentGid ? { parentNode: parentGid } : {}), position: pos, style: { width: dn.width, height: dn.height }, data: { label, sublabel: "", color: groupColor, isCluster: true, intensity: 0, isGroupContainer: true, repeatCount, }, selectable: false, draggable: false, }); } // Identify groups that wrap a folded loop (and which already show a ×N // badge on their container). Nodes inside such a group should NOT carry // their own ×N badge — the container already conveys it, and showing it // on every child would clutter the visual. const loopGroupIds = new Set(); for (const [scope, gid] of hier.groupForScope.entries()) { const segs = scope.split("/").filter(Boolean); const t = segs[segs.length - 1] ?? ""; const p = segs[segs.length - 2] ?? ""; if (/^\d+$/.test(t) && (p === "layers" || p === "h" || p === "block")) { loopGroupIds.add(gid); } } const isUnderLoopGroup = (startGid: string | undefined): boolean => { let cur = startGid; while (cur) { if (loopGroupIds.has(cur)) return true; cur = hier.parentOf.get(cur); } return false; }; // 2) Emit leaf IR nodes. for (const n of ir.nodes) { if (hier.absorbed.has(n.id)) continue; const dn = g.node(n.id); if (!dn) continue; const pos = positionOf(n.id)!; const parentGid = hier.parentOf.get(n.id); const cluster = isClusterNode(n); const { w, h } = dimsOf(n); const color = colorForOp(n.opType); const params = paramsPerNode.get(n.id) ?? 0; const intensity = maxParams > 0 ? Math.sqrt(params / maxParams) : 0; const { label, sublabel } = formatNodeLabel(n, modelType); const rawRepeatCount = n.attrs.repeatCount as number | undefined; const repeatCount = isUnderLoopGroup(parentGid) ? undefined : rawRepeatCount; reactNodes.push({ id: n.id, type: cluster ? "cluster" : "op", ...(parentGid ? { parentNode: parentGid } : {}), position: pos, targetPosition: Position.Top, sourcePosition: Position.Bottom, data: { label, sublabel, irNode: n, color, isCluster: cluster, intensity, repeatCount, }, style: { width: w, height: h }, }); } // Edges: drop tree-kind (the container conveys that). Also rewrite any // endpoint that points to an absorbed wrapper. // // Routing strategy: // - All edges use the `smoothstep` type (orthogonal routing with rounded // corners). This is much better than the default bezier for densely // packed columns: a bezier from B → D drawn through C visually crosses // C; a smoothstep one bends around it. // - Forward edges connect the default top/bottom handles ("t" / "s") so // they trace the main flow of the column. // - Residual (skip-connection) edges connect the *right-side* handles // ("s-r" / "t-r") instead, so they run alongside the column rather // than overlapping the forward chain. The combination of side-handles // + smoothstep produces the familiar "side-arc" look used by tools // like hfviewer.com. const reactEdges: Edge[] = []; for (const e of ir.edges) { if (e.kind === "tree") continue; const src = hier.absorbed.has(e.source) ? (hier.parentOf.get(e.source) ?? e.source) : e.source; const tgt = hier.absorbed.has(e.target) ? (hier.parentOf.get(e.target) ?? e.target) : e.target; if (src === tgt) continue; const kind = e.kind ?? "forward"; const isResidual = kind === "residual"; // hfviewer styling: thin neutral lines for the forward chain, animated // solid orange lines for skip-connections. No arrowheads, no per-edge // tensor labels — the direction is implied by the top→bottom layout, // and shape annotations would just add noise. const style = isResidual ? { stroke: "rgba(255,170,90,0.95)", strokeWidth: 1.4, } : { stroke: "rgba(200,210,235,0.55)", strokeWidth: 1.2 }; reactEdges.push({ id: e.id, source: src, target: tgt, sourceHandle: isResidual ? "s-r" : "s", targetHandle: isResidual ? "t-r" : "t", style, type: "smoothstep", pathOptions: { borderRadius: 10, offset: isResidual ? 18 : 8 }, animated: isResidual, // Make sure edges crossing group boundaries are not clipped. zIndex: 1000, }); } return { nodes: reactNodes, edges: reactEdges }; }