// Parse the hierarchical scope embedded in ONNX node names. // // Optimum/transformers ONNX exports preserve the PyTorch module path as // slash-separated names, e.g.: // /distilbert/transformer/layer.0/attention/q_lin/MatMul // /model/layers.5/self_attn/q_proj/MatMul // // We turn that into a scope array (["distilbert","transformer","layer.0",...]) // that drives the granularity slider. import type { IRGraph } from "../types"; import { SYNTHETIC_INPUT_ID, SYNTHETIC_OUTPUT_ID } from "./graph"; export type Scope = string[]; const TRAILING_DUP_SUFFIX = /_\d+$/; /** * Extract the scope array from a node name. * * Rules: * - Leading slash is stripped. * - The last segment is the op identifier (e.g. "MatMul", "Constant_3") and * is intentionally *excluded* from the scope so it does not pollute the * cluster keys. * - For dotted-only names (e.g. `distilbert.embeddings.word_embeddings.weight`), * we fall back to splitting on dots. * - Returns [] for unscoped names. */ export function parseScope(name: string | undefined): Scope { if (!name) return []; const trimmed = name.replace(/^\/+/, ""); if (trimmed.includes("/")) { const parts = trimmed.split("/").filter(Boolean); // Drop the last segment (the op name itself). return parts.slice(0, Math.max(0, parts.length - 1)); } if (trimmed.includes(".")) { const parts = trimmed.split(".").filter(Boolean); // For dotted weight-style names we also drop the last segment. return parts.slice(0, Math.max(0, parts.length - 1)); } return []; } /** * Compute a scope for every node in the IR. Synthetic input/output nodes get * a sentinel scope so they never get grouped with anything else. Nodes that * have no scope of their own (e.g. anonymous Constants) "borrow" the scope of * their first scoped consumer to avoid polluting the root level. */ export function computeScopes(ir: IRGraph): Map { const scopes = new Map(); // First pass: direct parsing. for (const n of ir.nodes) { if (n.id === SYNTHETIC_INPUT_ID) { scopes.set(n.id, ["__io__", "input"]); continue; } if (n.id === SYNTHETIC_OUTPUT_ID) { scopes.set(n.id, ["__io__", "output"]); continue; } scopes.set(n.id, parseScope(n.name)); } // Second pass: orphan nodes (empty scope) borrow from their first scoped // consumer. Build nodeId → consumer nodeIds index from the edges. const consumersByNode = new Map(); for (const e of ir.edges) { const arr = consumersByNode.get(e.source) ?? []; arr.push(e.target); consumersByNode.set(e.source, arr); } for (const n of ir.nodes) { if (scopes.get(n.id)?.length) continue; if (n.id === SYNTHETIC_INPUT_ID || n.id === SYNTHETIC_OUTPUT_ID) continue; // BFS over consumers, take first scoped one. const seen = new Set([n.id]); const queue: string[] = consumersByNode.get(n.id) ?? []; let found: Scope | null = null; while (queue.length > 0) { const next = queue.shift()!; if (seen.has(next)) continue; seen.add(next); const s = scopes.get(next); if (s && s.length > 0) { found = s; break; } for (const c of consumersByNode.get(next) ?? []) queue.push(c); } if (found) scopes.set(n.id, found); } return scopes; } /** Max depth across all real (non-IO) nodes. */ export function maxScopeDepth(scopes: Map): number { let max = 0; for (const [id, s] of scopes) { if (id === SYNTHETIC_INPUT_ID || id === SYNTHETIC_OUTPUT_ID) continue; if (s.length > max) max = s.length; } return max; } /** * Strip generated numeric suffixes from a name fragment so that "Constant_3" * and "Constant_4" both map to "Constant". Used for cluster naming display * only — never for cluster identity. */ export function stripDupSuffix(fragment: string): string { return fragment.replace(TRAILING_DUP_SUFFIX, ""); } /** * Detect a friendly semantic label for a scope segment. * E.g. `attention` -> "Attention", `layer.5` -> "TransformerLayer 5". */ export function semanticLabel(scope: Scope): { opType: string; label: string; } { if (scope.length === 0) return { opType: "Module", label: "model" }; const last = scope[scope.length - 1] ?? ""; const lower = last.toLowerCase(); // Layer index pattern: "layer.5", "layers.12", "h.3" const layerMatch = last.match(/^(layer|layers|h|block|blocks)\.(\d+)$/i); if (layerMatch) { return { opType: "TransformerLayer", label: `Layer ${layerMatch[2]}`, }; } if (/^(self_attn|attention|attn|self_attention)$/.test(lower)) { return { opType: "Attention", label: "Attention" }; } if (/^(mlp|ffn|feed_forward|feedforward|intermediate)$/.test(lower)) { return { opType: "MLP", label: "MLP" }; } if (/^(embeddings?|word_embeddings?|token_embeddings?|position_embeddings?)$/.test(lower)) { return { opType: "Embedding", label: last }; } if (/(norm|layernorm|layer_norm|rmsnorm|rms_norm)$/.test(lower)) { return { opType: "Norm", label: last }; } if (/^(q_proj|k_proj|v_proj|out_proj|o_proj|q_lin|k_lin|v_lin|out_lin)$/.test(lower)) { return { opType: "Linear", label: last }; } if (/^(rotary|rope|rotary_emb)$/.test(lower)) { return { opType: "RoPE", label: last }; } if (lower.includes("router") || lower.includes("gate")) { return { opType: "Router", label: last }; } if (/expert/.test(lower)) { return { opType: "Expert", label: last }; } return { opType: "Module", label: last }; }