File size: 5,608 Bytes
fc01079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
// Parse the hierarchical scope embedded in ONNX node names.
//
// Optimum/transformers ONNX exports preserve the PyTorch module path as
// slash-separated names, e.g.:
//   /distilbert/transformer/layer.0/attention/q_lin/MatMul
//   /model/layers.5/self_attn/q_proj/MatMul
//
// We turn that into a scope array (["distilbert","transformer","layer.0",...])
// that drives the granularity slider.

import type { IRGraph } from "../types";
import { SYNTHETIC_INPUT_ID, SYNTHETIC_OUTPUT_ID } from "./graph";

export type Scope = string[];

const TRAILING_DUP_SUFFIX = /_\d+$/;

/**
 * Extract the scope array from a node name.
 *
 * Rules:
 *  - Leading slash is stripped.
 *  - The last segment is the op identifier (e.g. "MatMul", "Constant_3") and
 *    is intentionally *excluded* from the scope so it does not pollute the
 *    cluster keys.
 *  - For dotted-only names (e.g. `distilbert.embeddings.word_embeddings.weight`),
 *    we fall back to splitting on dots.
 *  - Returns [] for unscoped names.
 */
export function parseScope(name: string | undefined): Scope {
  if (!name) return [];
  const trimmed = name.replace(/^\/+/, "");
  if (trimmed.includes("/")) {
    const parts = trimmed.split("/").filter(Boolean);
    // Drop the last segment (the op name itself).
    return parts.slice(0, Math.max(0, parts.length - 1));
  }
  if (trimmed.includes(".")) {
    const parts = trimmed.split(".").filter(Boolean);
    // For dotted weight-style names we also drop the last segment.
    return parts.slice(0, Math.max(0, parts.length - 1));
  }
  return [];
}

/**
 * Compute a scope for every node in the IR. Synthetic input/output nodes get
 * a sentinel scope so they never get grouped with anything else. Nodes that
 * have no scope of their own (e.g. anonymous Constants) "borrow" the scope of
 * their first scoped consumer to avoid polluting the root level.
 */
export function computeScopes(ir: IRGraph): Map<string, Scope> {
  const scopes = new Map<string, Scope>();

  // First pass: direct parsing.
  for (const n of ir.nodes) {
    if (n.id === SYNTHETIC_INPUT_ID) {
      scopes.set(n.id, ["__io__", "input"]);
      continue;
    }
    if (n.id === SYNTHETIC_OUTPUT_ID) {
      scopes.set(n.id, ["__io__", "output"]);
      continue;
    }
    scopes.set(n.id, parseScope(n.name));
  }

  // Second pass: orphan nodes (empty scope) borrow from their first scoped
  // consumer. Build nodeId → consumer nodeIds index from the edges.
  const consumersByNode = new Map<string, string[]>();
  for (const e of ir.edges) {
    const arr = consumersByNode.get(e.source) ?? [];
    arr.push(e.target);
    consumersByNode.set(e.source, arr);
  }

  for (const n of ir.nodes) {
    if (scopes.get(n.id)?.length) continue;
    if (n.id === SYNTHETIC_INPUT_ID || n.id === SYNTHETIC_OUTPUT_ID) continue;
    // BFS over consumers, take first scoped one.
    const seen = new Set<string>([n.id]);
    const queue: string[] = consumersByNode.get(n.id) ?? [];
    let found: Scope | null = null;
    while (queue.length > 0) {
      const next = queue.shift()!;
      if (seen.has(next)) continue;
      seen.add(next);
      const s = scopes.get(next);
      if (s && s.length > 0) {
        found = s;
        break;
      }
      for (const c of consumersByNode.get(next) ?? []) queue.push(c);
    }
    if (found) scopes.set(n.id, found);
  }

  return scopes;
}

/** Max depth across all real (non-IO) nodes. */
export function maxScopeDepth(scopes: Map<string, Scope>): number {
  let max = 0;
  for (const [id, s] of scopes) {
    if (id === SYNTHETIC_INPUT_ID || id === SYNTHETIC_OUTPUT_ID) continue;
    if (s.length > max) max = s.length;
  }
  return max;
}

/**
 * Strip generated numeric suffixes from a name fragment so that "Constant_3"
 * and "Constant_4" both map to "Constant". Used for cluster naming display
 * only — never for cluster identity.
 */
export function stripDupSuffix(fragment: string): string {
  return fragment.replace(TRAILING_DUP_SUFFIX, "");
}

/**
 * Detect a friendly semantic label for a scope segment.
 * E.g. `attention` -> "Attention", `layer.5` -> "TransformerLayer 5".
 */
export function semanticLabel(scope: Scope): {
  opType: string;
  label: string;
} {
  if (scope.length === 0) return { opType: "Module", label: "model" };
  const last = scope[scope.length - 1] ?? "";
  const lower = last.toLowerCase();

  // Layer index pattern: "layer.5", "layers.12", "h.3"
  const layerMatch = last.match(/^(layer|layers|h|block|blocks)\.(\d+)$/i);
  if (layerMatch) {
    return {
      opType: "TransformerLayer",
      label: `Layer ${layerMatch[2]}`,
    };
  }

  if (/^(self_attn|attention|attn|self_attention)$/.test(lower)) {
    return { opType: "Attention", label: "Attention" };
  }
  if (/^(mlp|ffn|feed_forward|feedforward|intermediate)$/.test(lower)) {
    return { opType: "MLP", label: "MLP" };
  }
  if (/^(embeddings?|word_embeddings?|token_embeddings?|position_embeddings?)$/.test(lower)) {
    return { opType: "Embedding", label: last };
  }
  if (/(norm|layernorm|layer_norm|rmsnorm|rms_norm)$/.test(lower)) {
    return { opType: "Norm", label: last };
  }
  if (/^(q_proj|k_proj|v_proj|out_proj|o_proj|q_lin|k_lin|v_lin|out_lin)$/.test(lower)) {
    return { opType: "Linear", label: last };
  }
  if (/^(rotary|rope|rotary_emb)$/.test(lower)) {
    return { opType: "RoPE", label: last };
  }
  if (lower.includes("router") || lower.includes("gate")) {
    return { opType: "Router", label: last };
  }
  if (/expert/.test(lower)) {
    return { opType: "Expert", label: last };
  }
  return { opType: "Module", label: last };
}