hf-model-viewer / src /lib /layout.ts
tfrere's picture
tfrere HF Staff
Deploy hf-model-viewer 2026-05-22T16:59:58Z
fc01079 verified
// Hierarchical graph layout using dagre. We feed it the IR and produce
// React Flow-compatible Node[] and Edge[] arrays.
//
// Subflows: modules with descendants are rendered as React Flow *group*
// containers that visually wrap their children. The hierarchy is derived
// from the per-node `attrs.scope` string (slash-separated): any scope that
// has at least one descendant in the IR becomes a synthetic container.
// Wrapper IRNodes (`opType: "Module"` with no weights) are absorbed into
// their container so we don't render both `model` and "group:model" side by
// side.
import dagre from "@dagrejs/dagre";
import { Position, type Node, type Edge } from "reactflow";
import type { IRGraph, IRNode } from "../types";
// Node footprints, calibrated to match the visual density of hfviewer.com:
// short pills with a 10-11px label, no decoration. Bigger numbers blow up
// the canvas; smaller ones make the text unreadable.
const NODE_WIDTH = 110;
const NODE_HEIGHT = 22;
const CLUSTER_NODE_WIDTH = 130;
const CLUSTER_NODE_HEIGHT = 26;
/** Synthetic ops (Add, Mul, MatMul, Softmax, RoPE, ...) carry no params,
* show a single short label, and are best kept extra compact so a fully
* expanded transformer block doesn't spread over dozens of bulky rows. */
const SYNTHETIC_NODE_WIDTH = 70;
const SYNTHETIC_NODE_HEIGHT = 20;
/** Identity/boundary ops are visually minimal — they're "data flow stubs". */
const IDENTITY_NODE_WIDTH = 44;
const IDENTITY_NODE_HEIGHT = 16;
/** Extra space dagre reserves around a compound parent so its label band
* has room and the children don't kiss the border. */
const GROUP_PADDING_X = 8;
const GROUP_PADDING_TOP = 12;
const GROUP_PADDING_BOTTOM = 6;
const GROUP_ID_PREFIX = "group:";
function colorForOp(opType: string): string {
// Semantic super-node types (from grouping.ts) come first.
if (opType === "TransformerLayer") return "#5d4eff";
if (opType === "Attention") return "#8a5cff";
if (opType === "MLP") return "#3c8ce6";
if (opType === "Embedding") return "#f0925b";
if (opType === "Norm") return "#a06bd6";
if (opType === "Linear") return "#4f78e0";
if (opType === "RoPE") return "#c25fe6";
if (opType === "Router") return "#ffae5b";
if (opType === "Expert") return "#5e9efb";
if (opType === "Module") return "#506280";
// ONNX op-level fallback.
if (/MatMul|Gemm|Conv/.test(opType)) return "#5b8def";
if (/Attention|Softmax/.test(opType)) return "#7c5cff";
if (/Norm|LayerNormalization|RMSNormalization/.test(opType)) return "#9b6bd1";
if (/Add|Mul|Sub|Div|Pow/.test(opType)) return "#3aa6c9";
if (/Gelu|Relu|Silu|Swish|Tanh|Sigmoid/.test(opType)) return "#22c1a0";
if (/Gather|Slice|Reshape|Transpose|Squeeze|Unsqueeze|Concat/.test(opType))
return "#7a8aa3";
if (/Cast|Identity|Constant|Quantize|DequantizeLinear/.test(opType))
return "#4a5263";
if (opType === "Input") return "#f6a96b";
if (opType === "Output") return "#e25fa3";
return "#6477a8";
}
function isClusterNode(n: IRNode): boolean {
return n.id.startsWith("cluster:");
}
// ─── Human-friendly label formatting ─────────────────────────────────────
//
// Raw module names like `input_layernorm`, `q_proj`, `embed_tokens`,
// `post_attention_layernorm`, ... read fine in code but are noisy in a
// graph node. The helpers below map the most common transformer module
// suffixes to short, semantically meaningful labels.
function humanizeNorm(tail: string): string {
if (tail === "input_layernorm") return "pre-attn";
if (tail === "post_attention_layernorm") return "post-attn";
if (tail === "pre_feedforward_layernorm") return "pre-ffn";
if (tail === "post_feedforward_layernorm") return "post-ffn";
if (tail === "norm" || tail === "final_layernorm" || tail === "ln_f")
return "final";
if (tail.endsWith("_layernorm")) {
return tail.replace("_layernorm", "").replace(/_/g, "-");
}
if (tail.endsWith("_norm")) {
return tail.replace("_norm", "").replace(/_/g, "-");
}
if (tail === "ln1") return "pre-attn";
if (tail === "ln2") return "post-attn";
return "";
}
function humanizeLinear(tail: string): string {
if (tail === "q_proj") return "Q proj";
if (tail === "k_proj") return "K proj";
if (tail === "v_proj") return "V proj";
if (tail === "o_proj") return "O proj";
if (tail === "qkv_proj" || tail === "c_attn") return "QKV";
if (tail === "gate_proj" || tail === "w1") return "gate";
if (tail === "up_proj" || tail === "w3") return "up";
if (tail === "down_proj" || tail === "w2") return "down";
if (tail === "c_proj" || tail === "c_fc") return tail.replace(/_/g, " ");
if (tail.endsWith("_proj")) {
return tail.replace("_proj", "").replace(/_/g, " ");
}
return "";
}
function humanizeEmbedding(tail: string): string {
if (tail === "embed_tokens" || tail === "wte" || tail === "word_embeddings")
return "tokens";
if (tail === "wpe" || tail === "embed_positions" || tail === "position_embeddings")
return "positions";
if (tail === "token_type_embeddings") return "segments";
return tail.replace(/_/g, " ");
}
function humanizeGeneric(tail: string): string {
if (!tail) return "";
if (/^\d+$/.test(tail)) return `#${tail}`;
return tail.replace(/_/g, " ");
}
/**
* Build the Python class name that most likely backs a given GROUP scope,
* given the HF `model_type` from config.json. Mirrors the transformers
* naming convention so a `model/layers/0` group under modelType=llama
* renders as `LlamaDecoderLayer`, matching the hfviewer.com convention.
*/
function classNameForScope(scope: string, modelType?: string): string | null {
if (!modelType) return null;
const cap = modelType.charAt(0).toUpperCase() + modelType.slice(1).toLowerCase();
const segs = scope.split("/").filter(Boolean);
const tail = segs[segs.length - 1] ?? "";
const parent = segs[segs.length - 2] ?? "";
if (
/^\d+$/.test(tail) &&
(parent === "layers" || parent === "h" || parent === "block")
) {
return `${cap}DecoderLayer`;
}
if (tail === "self_attn" || tail === "attn" || tail === "attention")
return `${cap}Attention`;
if (tail === "mlp" || tail === "feed_forward" || tail === "ffn")
return `${cap}MLP`;
if (tail === "encoder") return `${cap}Encoder`;
if (tail === "decoder") return `${cap}Decoder`;
return null;
}
/**
* Labels for the group/subflow header band. We aim for short, recognisable
* names (the visual container already conveys "this is a sub-block", so we
* just need to name what it is).
*/
function humanizeGroupLabel(tail: string): string {
if (!tail) return "";
if (tail === "model") return "model";
if (tail === "transformer") return "transformer";
if (tail === "encoder") return "encoder";
if (tail === "decoder") return "decoder";
if (tail === "layers" || tail === "h" || tail === "block") return "layers";
if (tail === "self_attn" || tail === "attn" || tail === "attention")
return "attention";
if (tail === "mlp" || tail === "feed_forward" || tail === "ffn") return "mlp";
if (tail === "experts") return "experts";
if (/^\d+$/.test(tail)) return `layer #${tail}`;
return tail.replace(/_/g, " ");
}
interface FormattedLabel {
label: string;
sublabel: string;
}
/**
* Build the Python class name that most likely backs a given node, using
* the HF `model_type` from config.json. We mirror the transformers naming
* convention: `Llama` + `DecoderLayer` = `LlamaDecoderLayer`. This is what
* hfviewer.com shows on its main node labels.
*/
function classNameForNode(n: IRNode, modelType?: string): string | null {
if (!modelType) return null;
const cap = modelType.charAt(0).toUpperCase() + modelType.slice(1).toLowerCase();
const scope = scopeOf(n);
const tail = scope.split("/").filter(Boolean).pop() ?? "";
const ot = n.opType;
// A numeric tail under a `layers`/`h`/`block` parent is a transformer layer.
if (/^\d+$/.test(tail)) {
const parentTail = scope.split("/").filter(Boolean).slice(-2, -1)[0] ?? "";
if (parentTail === "layers" || parentTail === "h" || parentTail === "block") {
return `${cap}DecoderLayer`;
}
}
if (ot === "TransformerLayer") return `${cap}DecoderLayer`;
if (ot === "Attention") return `${cap}Attention`;
if (ot === "MLP") return `${cap}MLP`;
if (ot === "Norm") {
// Common naming pattern in modern Llama-family models: LlamaRMSNorm.
if (modelType === "llama" || modelType === "qwen2" || modelType === "qwen3" || modelType === "mistral") {
return `${cap}RMSNorm`;
}
}
return null;
}
function formatNodeLabel(n: IRNode, modelType?: string): FormattedLabel {
const synLabel = n.attrs.label as string | undefined;
// Synthetic ops carry an intentional, concise label (e.g. "+", "h_in"):
// it's already optimised for display.
if (synLabel && synLabel.length > 0 && synLabel.length <= 24) {
return { label: synLabel, sublabel: "" };
}
const opType = n.opType;
const scope = scopeOf(n);
const tail = scope.split("/").filter(Boolean).pop() ?? "";
// Special: a top-level Linear named lm_head is the model's classification
// head — treat it as a first-class concept, not "Linear / lm_head".
if (opType === "Linear" && tail === "lm_head") {
return { label: "LM head", sublabel: "" };
}
// Folded transformer layer wrapper (numeric tail under `layers`/`h`/`block`):
// surface the Python class name (e.g. `LlamaDecoderLayer`) as primary
// label — matches the hfviewer.com convention and is more informative
// than the bare layer index.
if (/^\d+$/.test(tail) && opType === "Module") {
const cn = classNameForNode(n, modelType);
if (cn) return { label: cn, sublabel: "" };
return { label: `Layer #${tail}`, sublabel: "" };
}
switch (opType) {
case "Embedding":
return { label: "Embedding", sublabel: humanizeEmbedding(tail) };
case "Attention": {
const cn = classNameForNode(n, modelType);
return cn
? { label: cn, sublabel: "" }
: { label: "Attention", sublabel: "" };
}
case "MLP": {
const cn = classNameForNode(n, modelType);
return cn ? { label: cn, sublabel: "" } : { label: "MLP", sublabel: "" };
}
case "TransformerLayer": {
const cn = classNameForNode(n, modelType);
return cn
? { label: cn, sublabel: "" }
: { label: "Decoder layer", sublabel: "" };
}
case "Norm": {
const cn = classNameForNode(n, modelType);
return cn
? { label: cn, sublabel: humanizeNorm(tail) }
: { label: "Norm", sublabel: humanizeNorm(tail) };
}
case "Linear":
return { label: "Linear", sublabel: humanizeLinear(tail) || tail };
case "RoPE":
return { label: "RoPE", sublabel: "" };
case "Router":
return { label: "Router", sublabel: "" };
case "Expert":
return { label: "Expert", sublabel: humanizeGeneric(tail) };
case "Module": {
// A generic Module wrapper. The scope tail is usually a Python
// identifier; humanise it.
const pretty = humanizeGeneric(tail);
return { label: pretty || "Module", sublabel: "" };
}
default: {
const pretty = humanizeGeneric(tail);
const sublabel = pretty && pretty.toLowerCase() !== opType.toLowerCase() ? pretty : "";
return { label: opType, sublabel };
}
}
}
function isIoScope(scope: string): boolean {
return scope === "" || scope.startsWith("__io__");
}
function scopeOf(n: IRNode): string {
return ((n.attrs.scope as string) ?? "").trim();
}
function ancestorScopes(scope: string): string[] {
const segs = scope.split("/").filter(Boolean);
const out: string[] = [];
for (let i = 1; i < segs.length; i++) out.push(segs.slice(0, i).join("/"));
return out;
}
function parentScope(scope: string): string {
const segs = scope.split("/").filter(Boolean);
if (segs.length <= 1) return "";
return segs.slice(0, -1).join("/");
}
function lastSeg(scope: string): string {
const segs = scope.split("/").filter(Boolean);
return segs[segs.length - 1] ?? "";
}
interface Hierarchy {
/** scope path → group id (only for scopes that became real groups). */
groupForScope: Map<string, string>;
/** group id → direct child node ids (mix of IR ids and other group ids).
* Children come in IR-order (matching `ir.nodes` traversal) so dagre
* has a stable input. */
childrenOf: Map<string, string[]>;
/** node id → parent group id (only for nodes that have a parent group). */
parentOf: Map<string, string>;
/** group id → label to render in the container header. */
groupLabel: Map<string, string>;
/** IR node ids that are absorbed into their containing group (wrapper
* IRNodes whose scope is exactly the group's scope). They must not be
* emitted as separate React Flow nodes. */
absorbed: Set<string>;
/** Top-level node ids (in IR order, plus IO pills first/last). */
topLevel: string[];
}
/**
* Build the visual hierarchy:
* - A scope P becomes a group iff at least one IR node has scope deeper
* than P (i.e. P has descendants in the IR).
* - "Trivial" groups (no IR children of their own AND a single sub-group)
* are collapsed: their only sub-group is promoted to the parent.
* - A wrapper IRNode (opType="Module", no weights) at the exact scope of
* a group is absorbed into the group (its label and color are kept).
* - IO pills (__graph_input__, __graph_output__) stay top-level.
*/
function buildHierarchy(ir: IRGraph): Hierarchy {
// 1. Bucket IR nodes by scope, ignoring IO.
const byScope = new Map<string, IRNode[]>();
const allCandidateScopes = new Set<string>();
for (const n of ir.nodes) {
const s = scopeOf(n);
if (isIoScope(s)) continue;
if (!byScope.has(s)) byScope.set(s, []);
byScope.get(s)!.push(n);
allCandidateScopes.add(s);
for (const a of ancestorScopes(s)) allCandidateScopes.add(a);
}
// 2. Mark which scopes have descendants (strictly deeper IR nodes).
const hasDescendants = new Set<string>();
for (const s of byScope.keys()) {
for (const a of ancestorScopes(s)) hasDescendants.add(a);
}
// 3. Initial group decision: a scope is a group if it has descendants.
// (Pure leaf scopes with no children don't need a container.)
// EXCEPT for "structural wrapper" scopes — the bare `model`, the
// ModuleList container `layers` (a.k.a. `h` or `block` in HF code), or
// the global `transformer` root. These carry no semantic information
// that the user cares about, they just add a layer of nesting around
// every node. hfviewer.com flattens them by default, and so do we.
// The semantic class names (`LlamaDecoderLayer`, `LlamaAttention`,
// `LlamaMLP`) come from scopes one level deeper (`layers/0`, `self_attn`,
// `mlp`), which DO become groups.
const STRUCTURAL_WRAPPERS = new Set<string>([
"model",
"transformer",
"layers",
"h",
"block",
]);
const isStructuralWrapperScope = (scope: string): boolean => {
const tail = lastSeg(scope);
return STRUCTURAL_WRAPPERS.has(tail);
};
const isGroup = new Set<string>();
for (const s of allCandidateScopes) {
if (hasDescendants.has(s) && !isStructuralWrapperScope(s)) isGroup.add(s);
}
// 4. Compute the direct sub-group of every scope (bottom-up traversal).
// A sub-group of P is a child scope C such that C (or some ancestor of
// C strictly below P) is a group.
// Easier formulation: for each group G, its parent group is the
// longest proper ancestor scope of G that is also a group.
const parentScopeOfGroup = new Map<string, string | null>();
for (const g of isGroup) {
let cur = parentScope(g);
let parent: string | null = null;
while (cur !== "") {
if (isGroup.has(cur)) {
parent = cur;
break;
}
cur = parentScope(cur);
}
parentScopeOfGroup.set(g, parent);
}
// 5. Collapse "trivial" groups bottom-up: a group whose only purpose is
// to wrap a single child (one sub-group OR one IR node assigned to it,
// and no own IR node) is visual noise. We drop it and re-parent that
// child to the trivial group's parent. This is what makes the
// ModuleList wrapper `layers` disappear when it folds down to a single
// `×N` cluster — the user sees the badge on the rep node directly
// instead of through an empty container.
const closestContainingGroup = (scope: string): string | null => {
let cur: string = scope;
while (cur !== "") {
if (isGroup.has(cur)) return cur;
cur = parentScope(cur);
}
return null;
};
const sortedDeepFirst = [...isGroup].sort(
(a, b) => b.split("/").length - a.split("/").length,
);
let changed = true;
while (changed) {
changed = false;
// Count direct IR children per group: an IR node belongs to its deepest
// containing group. Sub-groups don't count here.
const irChildCount = new Map<string, number>();
for (const n of ir.nodes) {
const s = scopeOf(n);
if (isIoScope(s)) continue;
const g = closestContainingGroup(s);
if (!g) continue;
irChildCount.set(g, (irChildCount.get(g) ?? 0) + 1);
}
// Count sub-groups per group based on the live parentScopeOfGroup.
const subGroupCount = new Map<string, number>();
const subGroupsOf = new Map<string, string[]>();
for (const other of isGroup) {
const p = parentScopeOfGroup.get(other);
if (p && isGroup.has(p)) {
subGroupCount.set(p, (subGroupCount.get(p) ?? 0) + 1);
const arr = subGroupsOf.get(p) ?? [];
arr.push(other);
subGroupsOf.set(p, arr);
}
}
for (const g of sortedDeepFirst) {
if (!isGroup.has(g)) continue;
const ownIR = byScope.get(g)?.length ?? 0;
if (ownIR > 0) continue;
const ir1 = irChildCount.get(g) ?? 0;
const subs = subGroupCount.get(g) ?? 0;
if (ir1 + subs <= 1) {
// Trivial: drop g, re-parent any sub-group of g to g's parent.
const myParent = parentScopeOfGroup.get(g) ?? null;
for (const sub of subGroupsOf.get(g) ?? []) {
parentScopeOfGroup.set(sub, myParent);
}
isGroup.delete(g);
parentScopeOfGroup.delete(g);
changed = true;
}
}
}
// 6. Build the result.
const groupForScope = new Map<string, string>();
const groupLabel = new Map<string, string>();
for (const g of isGroup) {
groupForScope.set(g, GROUP_ID_PREFIX + g);
groupLabel.set(GROUP_ID_PREFIX + g, humanizeGroupLabel(lastSeg(g)));
}
// For each IR node, find its containing group: the deepest group scope
// that is a (non-strict) ancestor of the node's scope. If the node's
// own scope is a group, we still want to assign the node TO that group
// (so e.g. synthetic ops at scope "model/layers/0" go inside "group:0").
const parentOf = new Map<string, string>();
const absorbed = new Set<string>();
const childrenOf = new Map<string, string[]>();
const ensureChildren = (id: string): string[] => {
let arr = childrenOf.get(id);
if (!arr) {
arr = [];
childrenOf.set(id, arr);
}
return arr;
};
const findContainingGroup = (scope: string): string | null => {
let cur: string | null = scope;
while (cur !== null) {
if (cur === "") return null;
if (isGroup.has(cur)) return groupForScope.get(cur) ?? null;
const p = parentScope(cur);
cur = p === "" ? "" : p;
if (cur === "") return null;
}
return null;
};
// Order IR nodes so that within the same group siblings keep a stable
// sort. We rely on `ir.nodes` order being stable from buildIR.
for (const n of ir.nodes) {
const s = scopeOf(n);
if (isIoScope(s)) continue;
// Wrapper IRNodes that live exactly at a flattened structural scope
// (e.g. the bare `model` module wrapper) carry no meaningful info on
// their own — they were the implicit container above the real
// sub-blocks. Now that we no longer render that container, the
// wrapper would otherwise float as a stray "model" pill. Drop it.
if (
isStructuralWrapperScope(s) &&
n.opType === "Module" &&
n.attrs.synthetic !== true &&
n.weights.length === 0
) {
absorbed.add(n.id);
continue;
}
const gid = findContainingGroup(s);
if (!gid) continue;
// Wrapper absorption: if this IR node is the canonical "Module wrapper"
// at the EXACT scope of the group, absorb it. CRITICAL: a cluster that
// merges multiple members (numChildren > 1) typically carries connectivity
// (e.g. residual adds inside a layer scope). Absorbing such a cluster
// would silently drop edges that point at it, so we only absorb pure
// single-member wrappers.
const gscope = gid.slice(GROUP_ID_PREFIX.length);
const numChildren = (n.attrs.numChildren as number | undefined) ?? 1;
const isWrapperAtGroupScope =
s === gscope &&
n.opType === "Module" &&
n.attrs.synthetic !== true &&
n.weights.length === 0 &&
numChildren <= 1;
if (isWrapperAtGroupScope) {
absorbed.add(n.id);
// Promote the wrapper's label to the group (it's usually more
// informative than the bare last segment).
const sem = n.attrs.semanticLabel as string | undefined;
const raw = sem ?? lastSeg(s);
groupLabel.set(gid, humanizeGroupLabel(raw));
continue;
}
parentOf.set(n.id, gid);
ensureChildren(gid).push(n.id);
}
// 7. Sub-group containment: each group's parent group (if any).
for (const g of isGroup) {
const gid = GROUP_ID_PREFIX + g;
const parentG = parentScopeOfGroup.get(g);
if (parentG && isGroup.has(parentG)) {
const parentGid = GROUP_ID_PREFIX + parentG;
parentOf.set(gid, parentGid);
ensureChildren(parentGid).push(gid);
}
}
// 8. Compute top-level: nodes with no parent group, in original IR order,
// with IO pills explicitly emitted (input first, output last).
const topLevel: string[] = [];
const seenTop = new Set<string>();
const pushTop = (id: string): void => {
if (seenTop.has(id)) return;
seenTop.add(id);
topLevel.push(id);
};
for (const n of ir.nodes) {
if (absorbed.has(n.id)) continue;
if (parentOf.has(n.id)) continue;
pushTop(n.id);
}
for (const g of isGroup) {
const gid = GROUP_ID_PREFIX + g;
if (parentOf.has(gid)) continue;
pushTop(gid);
}
return {
groupForScope,
childrenOf,
parentOf,
groupLabel,
absorbed,
topLevel,
};
}
export interface LayoutResult {
nodes: Node[];
edges: Edge[];
}
export function layoutGraph(ir: IRGraph): LayoutResult {
const hier = buildHierarchy(ir);
const modelType = ir.meta.modelType;
const isIdentity = (n: IRNode) => n.opType === "Identity";
const dimsOf = (n: IRNode): { w: number; h: number } => {
if (isClusterNode(n)) return { w: CLUSTER_NODE_WIDTH, h: CLUSTER_NODE_HEIGHT };
if (isIdentity(n)) return { w: IDENTITY_NODE_WIDTH, h: IDENTITY_NODE_HEIGHT };
// Synthetic ops (no weights, marked `synthetic: true`) are kept slim
// so a fully expanded transformer block stays vertically compact.
if (n.attrs.synthetic === true) {
return { w: SYNTHETIC_NODE_WIDTH, h: SYNTHETIC_NODE_HEIGHT };
}
return { w: NODE_WIDTH, h: NODE_HEIGHT };
};
// Build a dagre compound graph. Each parent group is a dagre node with no
// intrinsic dimensions; dagre computes a bounding box from the children
// it contains, padded by GROUP_PADDING_*.
// Top-to-bottom pipeline layout. This is the convention used by every
// model graph on hfviewer.com (TinyLlama, GPT-2, ViT, T5, Qwen, ...):
// input on top, output on the bottom, decoder/transformer layers as
// single nodes with an ×N badge to the right when loop-folded.
//
// Vertical rank spacing is adaptive: we already keep nodes compact
// (height 38-26 px), so a generous gap would dominate the canvas and
// make every block feel "too tall". At the Op-level granularity a
// transformer can expose 30-50 nodes in a column, so we shrink the
// gap further still. The horizontal gap (`nodesep`) is kept constant
// since siblings rarely share a rank at large counts.
const visibleCount = ir.nodes.length;
const ranksep = Math.round(
Math.max(4, Math.min(12, 12 - (visibleCount - 8) * 0.35)),
);
const g = new dagre.graphlib.Graph({ compound: true });
g.setGraph({
rankdir: "TB",
nodesep: 26,
ranksep,
marginx: 24,
marginy: 24,
});
g.setDefaultEdgeLabel(() => ({}));
const irById = new Map<string, IRNode>();
for (const n of ir.nodes) irById.set(n.id, n);
// Emit IR nodes (those not absorbed) as dagre leaves.
for (const n of ir.nodes) {
if (hier.absorbed.has(n.id)) continue;
const { w, h } = dimsOf(n);
g.setNode(n.id, { width: w, height: h });
}
// Emit groups as compound parents.
for (const gid of hier.groupForScope.values()) {
g.setNode(gid, {
paddingTop: GROUP_PADDING_TOP,
paddingBottom: GROUP_PADDING_BOTTOM,
paddingLeft: GROUP_PADDING_X,
paddingRight: GROUP_PADDING_X,
});
}
// Wire parent → child relationships.
for (const [childId, parentGid] of hier.parentOf.entries()) {
g.setParent(childId, parentGid);
}
// Forward/residual edges drive the layout. Tree edges are visually
// redundant once we have group containers; we drop them entirely.
// Edges whose endpoint is an absorbed wrapper get rewritten to the
// wrapper's containing group (so dagre still has something to anchor
// the rank to).
const layoutSource = (id: string): string => {
if (hier.absorbed.has(id)) {
const gid = hier.parentOf.get(id);
if (gid) return gid;
}
return id;
};
for (const e of ir.edges) {
if (e.kind === "tree") continue;
const s = layoutSource(e.source);
const t = layoutSource(e.target);
if (s === t) continue;
g.setEdge(s, t);
}
dagre.layout(g);
// Per-node param shading (heatmap).
const paramsOf = (n: IRNode): number => {
const fromAttr = n.attrs.totalParams;
if (typeof fromAttr === "number" && fromAttr > 0) return fromAttr;
return n.weights.reduce((a, w) => a + w.numParams, 0);
};
const paramsPerNode = new Map<string, number>();
for (const n of ir.nodes) paramsPerNode.set(n.id, paramsOf(n));
let maxParams = 0;
for (const p of paramsPerNode.values()) if (p > maxParams) maxParams = p;
// Build React Flow nodes. IMPORTANT: parents must appear before children
// in the output array (React Flow requirement).
const reactNodes: Node[] = [];
const positionOf = (id: string): { x: number; y: number } | null => {
const dn = g.node(id);
if (!dn) return null;
const parentGid = hier.parentOf.get(id);
if (parentGid) {
const parent = g.node(parentGid);
if (!parent) return { x: dn.x - dn.width / 2, y: dn.y - dn.height / 2 };
const parentTopLeftX = parent.x - parent.width / 2;
const parentTopLeftY = parent.y - parent.height / 2;
return {
x: dn.x - dn.width / 2 - parentTopLeftX,
y: dn.y - dn.height / 2 - parentTopLeftY,
};
}
return { x: dn.x - dn.width / 2, y: dn.y - dn.height / 2 };
};
// 1) Emit groups (parents) in shallow-first order so a child group's
// parentNode reference is always already present.
const groupIds = [...hier.groupForScope.values()].sort((a, b) => {
const da = a.split("/").length;
const db = b.split("/").length;
return da - db;
});
// Index IR nodes by exact scope so we can pull rich metadata (semantic
// class name, repeatCount) for the container that sits at that scope.
const irByExactScope = new Map<string, IRNode>();
for (const n of ir.nodes) {
const s = scopeOf(n);
if (!irByExactScope.has(s)) irByExactScope.set(s, n);
}
for (const gid of groupIds) {
const dn = g.node(gid);
if (!dn) continue;
const pos = positionOf(gid)!;
const parentGid = hier.parentOf.get(gid);
const scope = gid.slice(GROUP_ID_PREFIX.length);
const scopeTail = lastSeg(scope);
const segs = scope.split("/").filter(Boolean);
const parentTail = segs[segs.length - 2] ?? "";
// Pull semantic info from the IR node at this exact scope (typically a
// Module wrapper or a folded-loop cluster). Its repeatCount lets us
// show the ×N badge on the container itself (matching hfviewer's loop
// visualization). The label comes from `classNameForScope` so a layer
// group under model_type=llama reads as `LlamaDecoderLayer`.
const scopeNode = irByExactScope.get(scope);
const className = classNameForScope(scope, modelType);
const baseLabel = hier.groupLabel.get(gid) ?? "";
const label =
className && className.length <= 36 ? className : baseLabel;
// The ×N badge only makes sense for groups that wrap a folded loop
// (i.e. a numeric tail under `layers`/`h`/`block`). Otherwise the
// repeat count is inherited from a parent loop and would be misleading
// on attention / mlp / experts children.
const isLoopGroup =
/^\d+$/.test(scopeTail) &&
(parentTail === "layers" || parentTail === "h" || parentTail === "block");
const repeatCount = isLoopGroup
? (scopeNode?.attrs.repeatCount as number | undefined)
: undefined;
// Color the container border by its semantic role, derived from the
// bare scope tail. Falls back to a neutral slate for unrecognised
// wrappers (and for non-transformer architectures).
const groupColor = ((tail: string): string => {
if (/^\d+$/.test(tail)) return colorForOp("TransformerLayer");
if (tail === "self_attn" || tail === "attn" || tail === "attention")
return colorForOp("Attention");
if (tail === "mlp" || tail === "feed_forward" || tail === "ffn")
return colorForOp("MLP");
if (tail === "experts") return colorForOp("Expert");
if (tail === "encoder" || tail === "decoder")
return colorForOp("TransformerLayer");
return "#7a8aa3";
})(scopeTail);
reactNodes.push({
id: gid,
type: "group",
...(parentGid ? { parentNode: parentGid } : {}),
position: pos,
style: { width: dn.width, height: dn.height },
data: {
label,
sublabel: "",
color: groupColor,
isCluster: true,
intensity: 0,
isGroupContainer: true,
repeatCount,
},
selectable: false,
draggable: false,
});
}
// Identify groups that wrap a folded loop (and which already show a ×N
// badge on their container). Nodes inside such a group should NOT carry
// their own ×N badge — the container already conveys it, and showing it
// on every child would clutter the visual.
const loopGroupIds = new Set<string>();
for (const [scope, gid] of hier.groupForScope.entries()) {
const segs = scope.split("/").filter(Boolean);
const t = segs[segs.length - 1] ?? "";
const p = segs[segs.length - 2] ?? "";
if (/^\d+$/.test(t) && (p === "layers" || p === "h" || p === "block")) {
loopGroupIds.add(gid);
}
}
const isUnderLoopGroup = (startGid: string | undefined): boolean => {
let cur = startGid;
while (cur) {
if (loopGroupIds.has(cur)) return true;
cur = hier.parentOf.get(cur);
}
return false;
};
// 2) Emit leaf IR nodes.
for (const n of ir.nodes) {
if (hier.absorbed.has(n.id)) continue;
const dn = g.node(n.id);
if (!dn) continue;
const pos = positionOf(n.id)!;
const parentGid = hier.parentOf.get(n.id);
const cluster = isClusterNode(n);
const { w, h } = dimsOf(n);
const color = colorForOp(n.opType);
const params = paramsPerNode.get(n.id) ?? 0;
const intensity = maxParams > 0 ? Math.sqrt(params / maxParams) : 0;
const { label, sublabel } = formatNodeLabel(n, modelType);
const rawRepeatCount = n.attrs.repeatCount as number | undefined;
const repeatCount = isUnderLoopGroup(parentGid)
? undefined
: rawRepeatCount;
reactNodes.push({
id: n.id,
type: cluster ? "cluster" : "op",
...(parentGid ? { parentNode: parentGid } : {}),
position: pos,
targetPosition: Position.Top,
sourcePosition: Position.Bottom,
data: {
label,
sublabel,
irNode: n,
color,
isCluster: cluster,
intensity,
repeatCount,
},
style: { width: w, height: h },
});
}
// Edges: drop tree-kind (the container conveys that). Also rewrite any
// endpoint that points to an absorbed wrapper.
//
// Routing strategy:
// - All edges use the `smoothstep` type (orthogonal routing with rounded
// corners). This is much better than the default bezier for densely
// packed columns: a bezier from B → D drawn through C visually crosses
// C; a smoothstep one bends around it.
// - Forward edges connect the default top/bottom handles ("t" / "s") so
// they trace the main flow of the column.
// - Residual (skip-connection) edges connect the *right-side* handles
// ("s-r" / "t-r") instead, so they run alongside the column rather
// than overlapping the forward chain. The combination of side-handles
// + smoothstep produces the familiar "side-arc" look used by tools
// like hfviewer.com.
const reactEdges: Edge[] = [];
for (const e of ir.edges) {
if (e.kind === "tree") continue;
const src = hier.absorbed.has(e.source)
? (hier.parentOf.get(e.source) ?? e.source)
: e.source;
const tgt = hier.absorbed.has(e.target)
? (hier.parentOf.get(e.target) ?? e.target)
: e.target;
if (src === tgt) continue;
const kind = e.kind ?? "forward";
const isResidual = kind === "residual";
// hfviewer styling: thin neutral lines for the forward chain, animated
// solid orange lines for skip-connections. No arrowheads, no per-edge
// tensor labels — the direction is implied by the top→bottom layout,
// and shape annotations would just add noise.
const style = isResidual
? {
stroke: "rgba(255,170,90,0.95)",
strokeWidth: 1.4,
}
: { stroke: "rgba(200,210,235,0.55)", strokeWidth: 1.2 };
reactEdges.push({
id: e.id,
source: src,
target: tgt,
sourceHandle: isResidual ? "s-r" : "s",
targetHandle: isResidual ? "t-r" : "t",
style,
type: "smoothstep",
pathOptions: { borderRadius: 10, offset: isResidual ? 18 : 8 },
animated: isResidual,
// Make sure edges crossing group boundaries are not clipped.
zIndex: 1000,
});
}
return { nodes: reactNodes, edges: reactEdges };
}