hf-model-viewer / src /lib /scope.ts
tfrere's picture
tfrere HF Staff
Deploy hf-model-viewer 2026-05-22T16:59:58Z
fc01079 verified
// Parse the hierarchical scope embedded in ONNX node names.
//
// Optimum/transformers ONNX exports preserve the PyTorch module path as
// slash-separated names, e.g.:
// /distilbert/transformer/layer.0/attention/q_lin/MatMul
// /model/layers.5/self_attn/q_proj/MatMul
//
// We turn that into a scope array (["distilbert","transformer","layer.0",...])
// that drives the granularity slider.
import type { IRGraph } from "../types";
import { SYNTHETIC_INPUT_ID, SYNTHETIC_OUTPUT_ID } from "./graph";
export type Scope = string[];
const TRAILING_DUP_SUFFIX = /_\d+$/;
/**
* Extract the scope array from a node name.
*
* Rules:
* - Leading slash is stripped.
* - The last segment is the op identifier (e.g. "MatMul", "Constant_3") and
* is intentionally *excluded* from the scope so it does not pollute the
* cluster keys.
* - For dotted-only names (e.g. `distilbert.embeddings.word_embeddings.weight`),
* we fall back to splitting on dots.
* - Returns [] for unscoped names.
*/
export function parseScope(name: string | undefined): Scope {
if (!name) return [];
const trimmed = name.replace(/^\/+/, "");
if (trimmed.includes("/")) {
const parts = trimmed.split("/").filter(Boolean);
// Drop the last segment (the op name itself).
return parts.slice(0, Math.max(0, parts.length - 1));
}
if (trimmed.includes(".")) {
const parts = trimmed.split(".").filter(Boolean);
// For dotted weight-style names we also drop the last segment.
return parts.slice(0, Math.max(0, parts.length - 1));
}
return [];
}
/**
* Compute a scope for every node in the IR. Synthetic input/output nodes get
* a sentinel scope so they never get grouped with anything else. Nodes that
* have no scope of their own (e.g. anonymous Constants) "borrow" the scope of
* their first scoped consumer to avoid polluting the root level.
*/
export function computeScopes(ir: IRGraph): Map<string, Scope> {
const scopes = new Map<string, Scope>();
// First pass: direct parsing.
for (const n of ir.nodes) {
if (n.id === SYNTHETIC_INPUT_ID) {
scopes.set(n.id, ["__io__", "input"]);
continue;
}
if (n.id === SYNTHETIC_OUTPUT_ID) {
scopes.set(n.id, ["__io__", "output"]);
continue;
}
scopes.set(n.id, parseScope(n.name));
}
// Second pass: orphan nodes (empty scope) borrow from their first scoped
// consumer. Build nodeId → consumer nodeIds index from the edges.
const consumersByNode = new Map<string, string[]>();
for (const e of ir.edges) {
const arr = consumersByNode.get(e.source) ?? [];
arr.push(e.target);
consumersByNode.set(e.source, arr);
}
for (const n of ir.nodes) {
if (scopes.get(n.id)?.length) continue;
if (n.id === SYNTHETIC_INPUT_ID || n.id === SYNTHETIC_OUTPUT_ID) continue;
// BFS over consumers, take first scoped one.
const seen = new Set<string>([n.id]);
const queue: string[] = consumersByNode.get(n.id) ?? [];
let found: Scope | null = null;
while (queue.length > 0) {
const next = queue.shift()!;
if (seen.has(next)) continue;
seen.add(next);
const s = scopes.get(next);
if (s && s.length > 0) {
found = s;
break;
}
for (const c of consumersByNode.get(next) ?? []) queue.push(c);
}
if (found) scopes.set(n.id, found);
}
return scopes;
}
/** Max depth across all real (non-IO) nodes. */
export function maxScopeDepth(scopes: Map<string, Scope>): number {
let max = 0;
for (const [id, s] of scopes) {
if (id === SYNTHETIC_INPUT_ID || id === SYNTHETIC_OUTPUT_ID) continue;
if (s.length > max) max = s.length;
}
return max;
}
/**
* Strip generated numeric suffixes from a name fragment so that "Constant_3"
* and "Constant_4" both map to "Constant". Used for cluster naming display
* only — never for cluster identity.
*/
export function stripDupSuffix(fragment: string): string {
return fragment.replace(TRAILING_DUP_SUFFIX, "");
}
/**
* Detect a friendly semantic label for a scope segment.
* E.g. `attention` -> "Attention", `layer.5` -> "TransformerLayer 5".
*/
export function semanticLabel(scope: Scope): {
opType: string;
label: string;
} {
if (scope.length === 0) return { opType: "Module", label: "model" };
const last = scope[scope.length - 1] ?? "";
const lower = last.toLowerCase();
// Layer index pattern: "layer.5", "layers.12", "h.3"
const layerMatch = last.match(/^(layer|layers|h|block|blocks)\.(\d+)$/i);
if (layerMatch) {
return {
opType: "TransformerLayer",
label: `Layer ${layerMatch[2]}`,
};
}
if (/^(self_attn|attention|attn|self_attention)$/.test(lower)) {
return { opType: "Attention", label: "Attention" };
}
if (/^(mlp|ffn|feed_forward|feedforward|intermediate)$/.test(lower)) {
return { opType: "MLP", label: "MLP" };
}
if (/^(embeddings?|word_embeddings?|token_embeddings?|position_embeddings?)$/.test(lower)) {
return { opType: "Embedding", label: last };
}
if (/(norm|layernorm|layer_norm|rmsnorm|rms_norm)$/.test(lower)) {
return { opType: "Norm", label: last };
}
if (/^(q_proj|k_proj|v_proj|out_proj|o_proj|q_lin|k_lin|v_lin|out_lin)$/.test(lower)) {
return { opType: "Linear", label: last };
}
if (/^(rotary|rope|rotary_emb)$/.test(lower)) {
return { opType: "RoPE", label: last };
}
if (lower.includes("router") || lower.includes("gate")) {
return { opType: "Router", label: last };
}
if (/expert/.test(lower)) {
return { opType: "Expert", label: last };
}
return { opType: "Module", label: last };
}