Spaces:
Sleeping
Sleeping
| // Parse the hierarchical scope embedded in ONNX node names. | |
| // | |
| // Optimum/transformers ONNX exports preserve the PyTorch module path as | |
| // slash-separated names, e.g.: | |
| // /distilbert/transformer/layer.0/attention/q_lin/MatMul | |
| // /model/layers.5/self_attn/q_proj/MatMul | |
| // | |
| // We turn that into a scope array (["distilbert","transformer","layer.0",...]) | |
| // that drives the granularity slider. | |
| import type { IRGraph } from "../types"; | |
| import { SYNTHETIC_INPUT_ID, SYNTHETIC_OUTPUT_ID } from "./graph"; | |
| export type Scope = string[]; | |
| const TRAILING_DUP_SUFFIX = /_\d+$/; | |
| /** | |
| * Extract the scope array from a node name. | |
| * | |
| * Rules: | |
| * - Leading slash is stripped. | |
| * - The last segment is the op identifier (e.g. "MatMul", "Constant_3") and | |
| * is intentionally *excluded* from the scope so it does not pollute the | |
| * cluster keys. | |
| * - For dotted-only names (e.g. `distilbert.embeddings.word_embeddings.weight`), | |
| * we fall back to splitting on dots. | |
| * - Returns [] for unscoped names. | |
| */ | |
| export function parseScope(name: string | undefined): Scope { | |
| if (!name) return []; | |
| const trimmed = name.replace(/^\/+/, ""); | |
| if (trimmed.includes("/")) { | |
| const parts = trimmed.split("/").filter(Boolean); | |
| // Drop the last segment (the op name itself). | |
| return parts.slice(0, Math.max(0, parts.length - 1)); | |
| } | |
| if (trimmed.includes(".")) { | |
| const parts = trimmed.split(".").filter(Boolean); | |
| // For dotted weight-style names we also drop the last segment. | |
| return parts.slice(0, Math.max(0, parts.length - 1)); | |
| } | |
| return []; | |
| } | |
| /** | |
| * Compute a scope for every node in the IR. Synthetic input/output nodes get | |
| * a sentinel scope so they never get grouped with anything else. Nodes that | |
| * have no scope of their own (e.g. anonymous Constants) "borrow" the scope of | |
| * their first scoped consumer to avoid polluting the root level. | |
| */ | |
| export function computeScopes(ir: IRGraph): Map<string, Scope> { | |
| const scopes = new Map<string, Scope>(); | |
| // First pass: direct parsing. | |
| for (const n of ir.nodes) { | |
| if (n.id === SYNTHETIC_INPUT_ID) { | |
| scopes.set(n.id, ["__io__", "input"]); | |
| continue; | |
| } | |
| if (n.id === SYNTHETIC_OUTPUT_ID) { | |
| scopes.set(n.id, ["__io__", "output"]); | |
| continue; | |
| } | |
| scopes.set(n.id, parseScope(n.name)); | |
| } | |
| // Second pass: orphan nodes (empty scope) borrow from their first scoped | |
| // consumer. Build nodeId → consumer nodeIds index from the edges. | |
| const consumersByNode = new Map<string, string[]>(); | |
| for (const e of ir.edges) { | |
| const arr = consumersByNode.get(e.source) ?? []; | |
| arr.push(e.target); | |
| consumersByNode.set(e.source, arr); | |
| } | |
| for (const n of ir.nodes) { | |
| if (scopes.get(n.id)?.length) continue; | |
| if (n.id === SYNTHETIC_INPUT_ID || n.id === SYNTHETIC_OUTPUT_ID) continue; | |
| // BFS over consumers, take first scoped one. | |
| const seen = new Set<string>([n.id]); | |
| const queue: string[] = consumersByNode.get(n.id) ?? []; | |
| let found: Scope | null = null; | |
| while (queue.length > 0) { | |
| const next = queue.shift()!; | |
| if (seen.has(next)) continue; | |
| seen.add(next); | |
| const s = scopes.get(next); | |
| if (s && s.length > 0) { | |
| found = s; | |
| break; | |
| } | |
| for (const c of consumersByNode.get(next) ?? []) queue.push(c); | |
| } | |
| if (found) scopes.set(n.id, found); | |
| } | |
| return scopes; | |
| } | |
| /** Max depth across all real (non-IO) nodes. */ | |
| export function maxScopeDepth(scopes: Map<string, Scope>): number { | |
| let max = 0; | |
| for (const [id, s] of scopes) { | |
| if (id === SYNTHETIC_INPUT_ID || id === SYNTHETIC_OUTPUT_ID) continue; | |
| if (s.length > max) max = s.length; | |
| } | |
| return max; | |
| } | |
| /** | |
| * Strip generated numeric suffixes from a name fragment so that "Constant_3" | |
| * and "Constant_4" both map to "Constant". Used for cluster naming display | |
| * only — never for cluster identity. | |
| */ | |
| export function stripDupSuffix(fragment: string): string { | |
| return fragment.replace(TRAILING_DUP_SUFFIX, ""); | |
| } | |
| /** | |
| * Detect a friendly semantic label for a scope segment. | |
| * E.g. `attention` -> "Attention", `layer.5` -> "TransformerLayer 5". | |
| */ | |
| export function semanticLabel(scope: Scope): { | |
| opType: string; | |
| label: string; | |
| } { | |
| if (scope.length === 0) return { opType: "Module", label: "model" }; | |
| const last = scope[scope.length - 1] ?? ""; | |
| const lower = last.toLowerCase(); | |
| // Layer index pattern: "layer.5", "layers.12", "h.3" | |
| const layerMatch = last.match(/^(layer|layers|h|block|blocks)\.(\d+)$/i); | |
| if (layerMatch) { | |
| return { | |
| opType: "TransformerLayer", | |
| label: `Layer ${layerMatch[2]}`, | |
| }; | |
| } | |
| if (/^(self_attn|attention|attn|self_attention)$/.test(lower)) { | |
| return { opType: "Attention", label: "Attention" }; | |
| } | |
| if (/^(mlp|ffn|feed_forward|feedforward|intermediate)$/.test(lower)) { | |
| return { opType: "MLP", label: "MLP" }; | |
| } | |
| if (/^(embeddings?|word_embeddings?|token_embeddings?|position_embeddings?)$/.test(lower)) { | |
| return { opType: "Embedding", label: last }; | |
| } | |
| if (/(norm|layernorm|layer_norm|rmsnorm|rms_norm)$/.test(lower)) { | |
| return { opType: "Norm", label: last }; | |
| } | |
| if (/^(q_proj|k_proj|v_proj|out_proj|o_proj|q_lin|k_lin|v_lin|out_lin)$/.test(lower)) { | |
| return { opType: "Linear", label: last }; | |
| } | |
| if (/^(rotary|rope|rotary_emb)$/.test(lower)) { | |
| return { opType: "RoPE", label: last }; | |
| } | |
| if (lower.includes("router") || lower.includes("gate")) { | |
| return { opType: "Router", label: last }; | |
| } | |
| if (/expert/.test(lower)) { | |
| return { opType: "Expert", label: last }; | |
| } | |
| return { opType: "Module", label: last }; | |
| } | |