// Convert a decoded ONNX `ModelProto` into our internal IR. // // The strategy: // - Each `NodeProto` becomes one `IRNode`. // - We build edges by mapping every tensor name produced by a node to its // consumers. Tensors produced as graph inputs become edges from a synthetic // "input" pseudo-node, and outputs link to a synthetic "output" pseudo-node. // - `initializer[]` are weight tensors (= the model's parameters). We DO NOT // create nodes for them; instead we attach them as metadata to the node // that consumes them. This avoids the "wall of constants" that plagues // naive ONNX viewers. // - Attributes are converted to plain JS values for inspection. import type { IRGraph, IRNode, IREdge, IRWeight, IRTensorInfo } from "../types"; import { ONNX_DTYPE } from "./onnx"; // Minimal structural types matching the bits of onnx.proto we read. We avoid // pulling in a heavy generated namespace. interface OnnxAttr { name: string; type?: number; f?: number; i?: number | { toNumber(): number }; s?: Uint8Array; floats?: number[]; ints?: Array; strings?: Uint8Array[]; } interface OnnxNodeProto { name?: string; op_type?: string; opType?: string; input?: string[]; output?: string[]; attribute?: OnnxAttr[]; } interface OnnxTensorProto { name?: string; dims?: Array; data_type?: number; dataType?: number; } interface OnnxDim { dim_value?: number | { toNumber(): number }; dimValue?: number | { toNumber(): number }; dim_param?: string; dimParam?: string; } interface OnnxValueInfoProto { name?: string; type?: { tensor_type?: { elem_type?: number; shape?: { dim?: OnnxDim[] }; }; tensorType?: { elemType?: number; shape?: { dim?: OnnxDim[] }; }; }; } interface OnnxGraphProto { node?: OnnxNodeProto[]; input?: OnnxValueInfoProto[]; output?: OnnxValueInfoProto[]; initializer?: OnnxTensorProto[]; value_info?: OnnxValueInfoProto[]; valueInfo?: OnnxValueInfoProto[]; name?: string; } interface OnnxOpsetImport { version?: number | { toNumber(): number }; domain?: string; } interface OnnxModelProto { graph?: OnnxGraphProto; producer_name?: string; producerName?: string; ir_version?: number | { toNumber(): number }; irVersion?: number | { toNumber(): number }; opset_import?: OnnxOpsetImport[]; opsetImport?: OnnxOpsetImport[]; } function toNum(v: number | { toNumber(): number } | undefined): number { if (v === undefined || v === null) return 0; if (typeof v === "number") return v; return v.toNumber(); } function decodeUtf8(bytes: Uint8Array | undefined): string { if (!bytes) return ""; return new TextDecoder("utf-8").decode(bytes); } function decodeAttr(a: OnnxAttr): unknown { // AttributeType: 1=FLOAT 2=INT 3=STRING 4=TENSOR 6=FLOATS 7=INTS 8=STRINGS switch (a.type) { case 1: return a.f; case 2: return toNum(a.i); case 3: return decodeUtf8(a.s); case 6: return a.floats ?? []; case 7: return (a.ints ?? []).map(toNum); case 8: return (a.strings ?? []).map(decodeUtf8); default: return null; } } function decodeShape(dims: OnnxDim[] | undefined): (number | string)[] { if (!dims) return []; return dims.map((d) => { const v = d.dim_value ?? d.dimValue; if (v !== undefined) return toNum(v); return d.dim_param ?? d.dimParam ?? "?"; }); } function decodeValueInfo(vi: OnnxValueInfoProto): IRTensorInfo { const t = vi.type?.tensor_type ?? vi.type?.tensorType ?? {}; const elem = (t as { elem_type?: number; elemType?: number }).elem_type ?? (t as { elemType?: number }).elemType ?? 0; return { name: vi.name ?? "", shape: decodeShape((t as { shape?: { dim?: OnnxDim[] } }).shape?.dim), dtype: ONNX_DTYPE[elem] ?? `TYPE_${elem}`, }; } function makeWeight(t: OnnxTensorProto): IRWeight { const shape = (t.dims ?? []).map(toNum); const numParams = shape.reduce((a, b) => a * b, 1); const dt = t.data_type ?? t.dataType ?? 0; return { name: t.name ?? "", shape, dtype: ONNX_DTYPE[dt] ?? `TYPE_${dt}`, numParams, }; } const SYNTHETIC_INPUT_ID = "__graph_input__"; const SYNTHETIC_OUTPUT_ID = "__graph_output__"; export function modelProtoToIR(modelProto: unknown): IRGraph { const model = modelProto as OnnxModelProto; const graph = model.graph ?? {}; const nodes = graph.node ?? []; const initializers = graph.initializer ?? []; const inputs = (graph.input ?? []).map(decodeValueInfo); const outputs = (graph.output ?? []).map(decodeValueInfo); const valueInfos = graph.value_info ?? graph.valueInfo ?? []; const tensorTypeIndex = new Map(); for (const vi of [...(graph.input ?? []), ...(graph.output ?? []), ...valueInfos]) { const info = decodeValueInfo(vi); if (info.name) tensorTypeIndex.set(info.name, info); } // Initializers (= weights) indexed by tensor name so we can attach them to // their consumer node. const initIndex = new Map(); for (const t of initializers) { const w = makeWeight(t); if (w.name) initIndex.set(w.name, w); } const initNames = new Set(initIndex.keys()); // Build IR nodes. const irNodes: IRNode[] = []; const idByName = new Map(); nodes.forEach((n, idx) => { const opType = n.op_type ?? n.opType ?? "Unknown"; const name = n.name && n.name.length > 0 ? n.name : `${opType}_${idx}`; const id = `n${idx}_${opType}`; idByName.set(name, id); const inputArr = n.input ?? []; const weights: IRWeight[] = []; const realInputs: string[] = []; for (const t of inputArr) { if (initNames.has(t)) { const w = initIndex.get(t); if (w) weights.push(w); } else { realInputs.push(t); } } const attrs: Record = {}; for (const a of n.attribute ?? []) { attrs[a.name] = decodeAttr(a); } irNodes.push({ id, opType, name, inputs: realInputs, outputs: n.output ?? [], attrs, weights, }); }); // Map tensor name -> producer node id. const producerOf = new Map(); for (const n of irNodes) { for (const out of n.outputs) producerOf.set(out, n.id); } // Graph-level inputs are produced by a synthetic source node. for (const i of inputs) producerOf.set(i.name, SYNTHETIC_INPUT_ID); // Build edges. const irEdges: IREdge[] = []; const pushEdge = (source: string, target: string, tensor: string) => { const info = tensorTypeIndex.get(tensor); irEdges.push({ id: `${source}__${target}__${tensor}`, source, target, tensor, shape: info?.shape, dtype: info?.dtype, }); }; for (const n of irNodes) { for (const inp of n.inputs) { const src = producerOf.get(inp); if (src) pushEdge(src, n.id, inp); } } // Outputs -> synthetic sink. for (const out of outputs) { const src = producerOf.get(out.name); if (src) pushEdge(src, SYNTHETIC_OUTPUT_ID, out.name); } // Inject synthetic source/sink nodes. if (inputs.length > 0) { irNodes.unshift({ id: SYNTHETIC_INPUT_ID, opType: "Input", name: "inputs", inputs: [], outputs: inputs.map((i) => i.name), attrs: {}, weights: [], }); } if (outputs.length > 0) { irNodes.push({ id: SYNTHETIC_OUTPUT_ID, opType: "Output", name: "outputs", inputs: outputs.map((o) => o.name), outputs: [], attrs: {}, weights: [], }); } const totalParams = Array.from(initIndex.values()).reduce( (acc, w) => acc + w.numParams, 0, ); const opsetVersion = (() => { const arr = model.opset_import ?? model.opsetImport ?? []; const main = arr.find((o) => !o.domain || o.domain === "" || o.domain === "ai.onnx") ?? arr[0]; return main ? toNum(main.version) : 0; })(); return { nodes: irNodes, edges: irEdges, meta: { modelName: graph.name ?? "model", producer: model.producer_name ?? model.producerName ?? "unknown", irVersion: toNum(model.ir_version ?? model.irVersion), opsetVersion, totalParams, nodeCount: irNodes.length, inputs, outputs, }, }; } export { SYNTHETIC_INPUT_ID, SYNTHETIC_OUTPUT_ID };