hf-model-viewer / src /lib /graph.ts
tfrere's picture
tfrere HF Staff
Deploy hf-model-viewer 2026-05-22T16:59:58Z
fc01079 verified
// Convert a decoded ONNX `ModelProto` into our internal IR.
//
// The strategy:
// - Each `NodeProto` becomes one `IRNode`.
// - We build edges by mapping every tensor name produced by a node to its
// consumers. Tensors produced as graph inputs become edges from a synthetic
// "input" pseudo-node, and outputs link to a synthetic "output" pseudo-node.
// - `initializer[]` are weight tensors (= the model's parameters). We DO NOT
// create nodes for them; instead we attach them as metadata to the node
// that consumes them. This avoids the "wall of constants" that plagues
// naive ONNX viewers.
// - Attributes are converted to plain JS values for inspection.
import type { IRGraph, IRNode, IREdge, IRWeight, IRTensorInfo } from "../types";
import { ONNX_DTYPE } from "./onnx";
// Minimal structural types matching the bits of onnx.proto we read. We avoid
// pulling in a heavy generated namespace.
interface OnnxAttr {
name: string;
type?: number;
f?: number;
i?: number | { toNumber(): number };
s?: Uint8Array;
floats?: number[];
ints?: Array<number | { toNumber(): number }>;
strings?: Uint8Array[];
}
interface OnnxNodeProto {
name?: string;
op_type?: string;
opType?: string;
input?: string[];
output?: string[];
attribute?: OnnxAttr[];
}
interface OnnxTensorProto {
name?: string;
dims?: Array<number | { toNumber(): number }>;
data_type?: number;
dataType?: number;
}
interface OnnxDim {
dim_value?: number | { toNumber(): number };
dimValue?: number | { toNumber(): number };
dim_param?: string;
dimParam?: string;
}
interface OnnxValueInfoProto {
name?: string;
type?: {
tensor_type?: {
elem_type?: number;
shape?: { dim?: OnnxDim[] };
};
tensorType?: {
elemType?: number;
shape?: { dim?: OnnxDim[] };
};
};
}
interface OnnxGraphProto {
node?: OnnxNodeProto[];
input?: OnnxValueInfoProto[];
output?: OnnxValueInfoProto[];
initializer?: OnnxTensorProto[];
value_info?: OnnxValueInfoProto[];
valueInfo?: OnnxValueInfoProto[];
name?: string;
}
interface OnnxOpsetImport {
version?: number | { toNumber(): number };
domain?: string;
}
interface OnnxModelProto {
graph?: OnnxGraphProto;
producer_name?: string;
producerName?: string;
ir_version?: number | { toNumber(): number };
irVersion?: number | { toNumber(): number };
opset_import?: OnnxOpsetImport[];
opsetImport?: OnnxOpsetImport[];
}
function toNum(v: number | { toNumber(): number } | undefined): number {
if (v === undefined || v === null) return 0;
if (typeof v === "number") return v;
return v.toNumber();
}
function decodeUtf8(bytes: Uint8Array | undefined): string {
if (!bytes) return "";
return new TextDecoder("utf-8").decode(bytes);
}
function decodeAttr(a: OnnxAttr): unknown {
// AttributeType: 1=FLOAT 2=INT 3=STRING 4=TENSOR 6=FLOATS 7=INTS 8=STRINGS
switch (a.type) {
case 1:
return a.f;
case 2:
return toNum(a.i);
case 3:
return decodeUtf8(a.s);
case 6:
return a.floats ?? [];
case 7:
return (a.ints ?? []).map(toNum);
case 8:
return (a.strings ?? []).map(decodeUtf8);
default:
return null;
}
}
function decodeShape(dims: OnnxDim[] | undefined): (number | string)[] {
if (!dims) return [];
return dims.map((d) => {
const v = d.dim_value ?? d.dimValue;
if (v !== undefined) return toNum(v);
return d.dim_param ?? d.dimParam ?? "?";
});
}
function decodeValueInfo(vi: OnnxValueInfoProto): IRTensorInfo {
const t = vi.type?.tensor_type ?? vi.type?.tensorType ?? {};
const elem = (t as { elem_type?: number; elemType?: number }).elem_type
?? (t as { elemType?: number }).elemType
?? 0;
return {
name: vi.name ?? "",
shape: decodeShape((t as { shape?: { dim?: OnnxDim[] } }).shape?.dim),
dtype: ONNX_DTYPE[elem] ?? `TYPE_${elem}`,
};
}
function makeWeight(t: OnnxTensorProto): IRWeight {
const shape = (t.dims ?? []).map(toNum);
const numParams = shape.reduce((a, b) => a * b, 1);
const dt = t.data_type ?? t.dataType ?? 0;
return {
name: t.name ?? "",
shape,
dtype: ONNX_DTYPE[dt] ?? `TYPE_${dt}`,
numParams,
};
}
const SYNTHETIC_INPUT_ID = "__graph_input__";
const SYNTHETIC_OUTPUT_ID = "__graph_output__";
export function modelProtoToIR(modelProto: unknown): IRGraph {
const model = modelProto as OnnxModelProto;
const graph = model.graph ?? {};
const nodes = graph.node ?? [];
const initializers = graph.initializer ?? [];
const inputs = (graph.input ?? []).map(decodeValueInfo);
const outputs = (graph.output ?? []).map(decodeValueInfo);
const valueInfos = graph.value_info ?? graph.valueInfo ?? [];
const tensorTypeIndex = new Map<string, IRTensorInfo>();
for (const vi of [...(graph.input ?? []), ...(graph.output ?? []), ...valueInfos]) {
const info = decodeValueInfo(vi);
if (info.name) tensorTypeIndex.set(info.name, info);
}
// Initializers (= weights) indexed by tensor name so we can attach them to
// their consumer node.
const initIndex = new Map<string, IRWeight>();
for (const t of initializers) {
const w = makeWeight(t);
if (w.name) initIndex.set(w.name, w);
}
const initNames = new Set(initIndex.keys());
// Build IR nodes.
const irNodes: IRNode[] = [];
const idByName = new Map<string, string>();
nodes.forEach((n, idx) => {
const opType = n.op_type ?? n.opType ?? "Unknown";
const name = n.name && n.name.length > 0 ? n.name : `${opType}_${idx}`;
const id = `n${idx}_${opType}`;
idByName.set(name, id);
const inputArr = n.input ?? [];
const weights: IRWeight[] = [];
const realInputs: string[] = [];
for (const t of inputArr) {
if (initNames.has(t)) {
const w = initIndex.get(t);
if (w) weights.push(w);
} else {
realInputs.push(t);
}
}
const attrs: Record<string, unknown> = {};
for (const a of n.attribute ?? []) {
attrs[a.name] = decodeAttr(a);
}
irNodes.push({
id,
opType,
name,
inputs: realInputs,
outputs: n.output ?? [],
attrs,
weights,
});
});
// Map tensor name -> producer node id.
const producerOf = new Map<string, string>();
for (const n of irNodes) {
for (const out of n.outputs) producerOf.set(out, n.id);
}
// Graph-level inputs are produced by a synthetic source node.
for (const i of inputs) producerOf.set(i.name, SYNTHETIC_INPUT_ID);
// Build edges.
const irEdges: IREdge[] = [];
const pushEdge = (source: string, target: string, tensor: string) => {
const info = tensorTypeIndex.get(tensor);
irEdges.push({
id: `${source}__${target}__${tensor}`,
source,
target,
tensor,
shape: info?.shape,
dtype: info?.dtype,
});
};
for (const n of irNodes) {
for (const inp of n.inputs) {
const src = producerOf.get(inp);
if (src) pushEdge(src, n.id, inp);
}
}
// Outputs -> synthetic sink.
for (const out of outputs) {
const src = producerOf.get(out.name);
if (src) pushEdge(src, SYNTHETIC_OUTPUT_ID, out.name);
}
// Inject synthetic source/sink nodes.
if (inputs.length > 0) {
irNodes.unshift({
id: SYNTHETIC_INPUT_ID,
opType: "Input",
name: "inputs",
inputs: [],
outputs: inputs.map((i) => i.name),
attrs: {},
weights: [],
});
}
if (outputs.length > 0) {
irNodes.push({
id: SYNTHETIC_OUTPUT_ID,
opType: "Output",
name: "outputs",
inputs: outputs.map((o) => o.name),
outputs: [],
attrs: {},
weights: [],
});
}
const totalParams = Array.from(initIndex.values()).reduce(
(acc, w) => acc + w.numParams,
0,
);
const opsetVersion = (() => {
const arr = model.opset_import ?? model.opsetImport ?? [];
const main = arr.find((o) => !o.domain || o.domain === "" || o.domain === "ai.onnx") ?? arr[0];
return main ? toNum(main.version) : 0;
})();
return {
nodes: irNodes,
edges: irEdges,
meta: {
modelName: graph.name ?? "model",
producer: model.producer_name ?? model.producerName ?? "unknown",
irVersion: toNum(model.ir_version ?? model.irVersion),
opsetVersion,
totalParams,
nodeCount: irNodes.length,
inputs,
outputs,
},
};
}
export { SYNTHETIC_INPUT_ID, SYNTHETIC_OUTPUT_ID };