Spaces:
Sleeping
Sleeping
File size: 5,608 Bytes
fc01079 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | // Parse the hierarchical scope embedded in ONNX node names.
//
// Optimum/transformers ONNX exports preserve the PyTorch module path as
// slash-separated names, e.g.:
// /distilbert/transformer/layer.0/attention/q_lin/MatMul
// /model/layers.5/self_attn/q_proj/MatMul
//
// We turn that into a scope array (["distilbert","transformer","layer.0",...])
// that drives the granularity slider.
import type { IRGraph } from "../types";
import { SYNTHETIC_INPUT_ID, SYNTHETIC_OUTPUT_ID } from "./graph";
export type Scope = string[];
const TRAILING_DUP_SUFFIX = /_\d+$/;
/**
* Extract the scope array from a node name.
*
* Rules:
* - Leading slash is stripped.
* - The last segment is the op identifier (e.g. "MatMul", "Constant_3") and
* is intentionally *excluded* from the scope so it does not pollute the
* cluster keys.
* - For dotted-only names (e.g. `distilbert.embeddings.word_embeddings.weight`),
* we fall back to splitting on dots.
* - Returns [] for unscoped names.
*/
export function parseScope(name: string | undefined): Scope {
if (!name) return [];
const trimmed = name.replace(/^\/+/, "");
if (trimmed.includes("/")) {
const parts = trimmed.split("/").filter(Boolean);
// Drop the last segment (the op name itself).
return parts.slice(0, Math.max(0, parts.length - 1));
}
if (trimmed.includes(".")) {
const parts = trimmed.split(".").filter(Boolean);
// For dotted weight-style names we also drop the last segment.
return parts.slice(0, Math.max(0, parts.length - 1));
}
return [];
}
/**
* Compute a scope for every node in the IR. Synthetic input/output nodes get
* a sentinel scope so they never get grouped with anything else. Nodes that
* have no scope of their own (e.g. anonymous Constants) "borrow" the scope of
* their first scoped consumer to avoid polluting the root level.
*/
export function computeScopes(ir: IRGraph): Map<string, Scope> {
const scopes = new Map<string, Scope>();
// First pass: direct parsing.
for (const n of ir.nodes) {
if (n.id === SYNTHETIC_INPUT_ID) {
scopes.set(n.id, ["__io__", "input"]);
continue;
}
if (n.id === SYNTHETIC_OUTPUT_ID) {
scopes.set(n.id, ["__io__", "output"]);
continue;
}
scopes.set(n.id, parseScope(n.name));
}
// Second pass: orphan nodes (empty scope) borrow from their first scoped
// consumer. Build nodeId → consumer nodeIds index from the edges.
const consumersByNode = new Map<string, string[]>();
for (const e of ir.edges) {
const arr = consumersByNode.get(e.source) ?? [];
arr.push(e.target);
consumersByNode.set(e.source, arr);
}
for (const n of ir.nodes) {
if (scopes.get(n.id)?.length) continue;
if (n.id === SYNTHETIC_INPUT_ID || n.id === SYNTHETIC_OUTPUT_ID) continue;
// BFS over consumers, take first scoped one.
const seen = new Set<string>([n.id]);
const queue: string[] = consumersByNode.get(n.id) ?? [];
let found: Scope | null = null;
while (queue.length > 0) {
const next = queue.shift()!;
if (seen.has(next)) continue;
seen.add(next);
const s = scopes.get(next);
if (s && s.length > 0) {
found = s;
break;
}
for (const c of consumersByNode.get(next) ?? []) queue.push(c);
}
if (found) scopes.set(n.id, found);
}
return scopes;
}
/** Max depth across all real (non-IO) nodes. */
export function maxScopeDepth(scopes: Map<string, Scope>): number {
let max = 0;
for (const [id, s] of scopes) {
if (id === SYNTHETIC_INPUT_ID || id === SYNTHETIC_OUTPUT_ID) continue;
if (s.length > max) max = s.length;
}
return max;
}
/**
* Strip generated numeric suffixes from a name fragment so that "Constant_3"
* and "Constant_4" both map to "Constant". Used for cluster naming display
* only — never for cluster identity.
*/
export function stripDupSuffix(fragment: string): string {
return fragment.replace(TRAILING_DUP_SUFFIX, "");
}
/**
* Detect a friendly semantic label for a scope segment.
* E.g. `attention` -> "Attention", `layer.5` -> "TransformerLayer 5".
*/
export function semanticLabel(scope: Scope): {
opType: string;
label: string;
} {
if (scope.length === 0) return { opType: "Module", label: "model" };
const last = scope[scope.length - 1] ?? "";
const lower = last.toLowerCase();
// Layer index pattern: "layer.5", "layers.12", "h.3"
const layerMatch = last.match(/^(layer|layers|h|block|blocks)\.(\d+)$/i);
if (layerMatch) {
return {
opType: "TransformerLayer",
label: `Layer ${layerMatch[2]}`,
};
}
if (/^(self_attn|attention|attn|self_attention)$/.test(lower)) {
return { opType: "Attention", label: "Attention" };
}
if (/^(mlp|ffn|feed_forward|feedforward|intermediate)$/.test(lower)) {
return { opType: "MLP", label: "MLP" };
}
if (/^(embeddings?|word_embeddings?|token_embeddings?|position_embeddings?)$/.test(lower)) {
return { opType: "Embedding", label: last };
}
if (/(norm|layernorm|layer_norm|rmsnorm|rms_norm)$/.test(lower)) {
return { opType: "Norm", label: last };
}
if (/^(q_proj|k_proj|v_proj|out_proj|o_proj|q_lin|k_lin|v_lin|out_lin)$/.test(lower)) {
return { opType: "Linear", label: last };
}
if (/^(rotary|rope|rotary_emb)$/.test(lower)) {
return { opType: "RoPE", label: last };
}
if (lower.includes("router") || lower.includes("gate")) {
return { opType: "Router", label: last };
}
if (/expert/.test(lower)) {
return { opType: "Expert", label: last };
}
return { opType: "Module", label: last };
}
|