Spaces:
Sleeping
Sleeping
| // Convert a decoded ONNX `ModelProto` into our internal IR. | |
| // | |
| // The strategy: | |
| // - Each `NodeProto` becomes one `IRNode`. | |
| // - We build edges by mapping every tensor name produced by a node to its | |
| // consumers. Tensors produced as graph inputs become edges from a synthetic | |
| // "input" pseudo-node, and outputs link to a synthetic "output" pseudo-node. | |
| // - `initializer[]` are weight tensors (= the model's parameters). We DO NOT | |
| // create nodes for them; instead we attach them as metadata to the node | |
| // that consumes them. This avoids the "wall of constants" that plagues | |
| // naive ONNX viewers. | |
| // - Attributes are converted to plain JS values for inspection. | |
| import type { IRGraph, IRNode, IREdge, IRWeight, IRTensorInfo } from "../types"; | |
| import { ONNX_DTYPE } from "./onnx"; | |
| // Minimal structural types matching the bits of onnx.proto we read. We avoid | |
| // pulling in a heavy generated namespace. | |
| interface OnnxAttr { | |
| name: string; | |
| type?: number; | |
| f?: number; | |
| i?: number | { toNumber(): number }; | |
| s?: Uint8Array; | |
| floats?: number[]; | |
| ints?: Array<number | { toNumber(): number }>; | |
| strings?: Uint8Array[]; | |
| } | |
| interface OnnxNodeProto { | |
| name?: string; | |
| op_type?: string; | |
| opType?: string; | |
| input?: string[]; | |
| output?: string[]; | |
| attribute?: OnnxAttr[]; | |
| } | |
| interface OnnxTensorProto { | |
| name?: string; | |
| dims?: Array<number | { toNumber(): number }>; | |
| data_type?: number; | |
| dataType?: number; | |
| } | |
| interface OnnxDim { | |
| dim_value?: number | { toNumber(): number }; | |
| dimValue?: number | { toNumber(): number }; | |
| dim_param?: string; | |
| dimParam?: string; | |
| } | |
| interface OnnxValueInfoProto { | |
| name?: string; | |
| type?: { | |
| tensor_type?: { | |
| elem_type?: number; | |
| shape?: { dim?: OnnxDim[] }; | |
| }; | |
| tensorType?: { | |
| elemType?: number; | |
| shape?: { dim?: OnnxDim[] }; | |
| }; | |
| }; | |
| } | |
| interface OnnxGraphProto { | |
| node?: OnnxNodeProto[]; | |
| input?: OnnxValueInfoProto[]; | |
| output?: OnnxValueInfoProto[]; | |
| initializer?: OnnxTensorProto[]; | |
| value_info?: OnnxValueInfoProto[]; | |
| valueInfo?: OnnxValueInfoProto[]; | |
| name?: string; | |
| } | |
| interface OnnxOpsetImport { | |
| version?: number | { toNumber(): number }; | |
| domain?: string; | |
| } | |
| interface OnnxModelProto { | |
| graph?: OnnxGraphProto; | |
| producer_name?: string; | |
| producerName?: string; | |
| ir_version?: number | { toNumber(): number }; | |
| irVersion?: number | { toNumber(): number }; | |
| opset_import?: OnnxOpsetImport[]; | |
| opsetImport?: OnnxOpsetImport[]; | |
| } | |
| function toNum(v: number | { toNumber(): number } | undefined): number { | |
| if (v === undefined || v === null) return 0; | |
| if (typeof v === "number") return v; | |
| return v.toNumber(); | |
| } | |
| function decodeUtf8(bytes: Uint8Array | undefined): string { | |
| if (!bytes) return ""; | |
| return new TextDecoder("utf-8").decode(bytes); | |
| } | |
| function decodeAttr(a: OnnxAttr): unknown { | |
| // AttributeType: 1=FLOAT 2=INT 3=STRING 4=TENSOR 6=FLOATS 7=INTS 8=STRINGS | |
| switch (a.type) { | |
| case 1: | |
| return a.f; | |
| case 2: | |
| return toNum(a.i); | |
| case 3: | |
| return decodeUtf8(a.s); | |
| case 6: | |
| return a.floats ?? []; | |
| case 7: | |
| return (a.ints ?? []).map(toNum); | |
| case 8: | |
| return (a.strings ?? []).map(decodeUtf8); | |
| default: | |
| return null; | |
| } | |
| } | |
| function decodeShape(dims: OnnxDim[] | undefined): (number | string)[] { | |
| if (!dims) return []; | |
| return dims.map((d) => { | |
| const v = d.dim_value ?? d.dimValue; | |
| if (v !== undefined) return toNum(v); | |
| return d.dim_param ?? d.dimParam ?? "?"; | |
| }); | |
| } | |
| function decodeValueInfo(vi: OnnxValueInfoProto): IRTensorInfo { | |
| const t = vi.type?.tensor_type ?? vi.type?.tensorType ?? {}; | |
| const elem = (t as { elem_type?: number; elemType?: number }).elem_type | |
| ?? (t as { elemType?: number }).elemType | |
| ?? 0; | |
| return { | |
| name: vi.name ?? "", | |
| shape: decodeShape((t as { shape?: { dim?: OnnxDim[] } }).shape?.dim), | |
| dtype: ONNX_DTYPE[elem] ?? `TYPE_${elem}`, | |
| }; | |
| } | |
| function makeWeight(t: OnnxTensorProto): IRWeight { | |
| const shape = (t.dims ?? []).map(toNum); | |
| const numParams = shape.reduce((a, b) => a * b, 1); | |
| const dt = t.data_type ?? t.dataType ?? 0; | |
| return { | |
| name: t.name ?? "", | |
| shape, | |
| dtype: ONNX_DTYPE[dt] ?? `TYPE_${dt}`, | |
| numParams, | |
| }; | |
| } | |
| const SYNTHETIC_INPUT_ID = "__graph_input__"; | |
| const SYNTHETIC_OUTPUT_ID = "__graph_output__"; | |
| export function modelProtoToIR(modelProto: unknown): IRGraph { | |
| const model = modelProto as OnnxModelProto; | |
| const graph = model.graph ?? {}; | |
| const nodes = graph.node ?? []; | |
| const initializers = graph.initializer ?? []; | |
| const inputs = (graph.input ?? []).map(decodeValueInfo); | |
| const outputs = (graph.output ?? []).map(decodeValueInfo); | |
| const valueInfos = graph.value_info ?? graph.valueInfo ?? []; | |
| const tensorTypeIndex = new Map<string, IRTensorInfo>(); | |
| for (const vi of [...(graph.input ?? []), ...(graph.output ?? []), ...valueInfos]) { | |
| const info = decodeValueInfo(vi); | |
| if (info.name) tensorTypeIndex.set(info.name, info); | |
| } | |
| // Initializers (= weights) indexed by tensor name so we can attach them to | |
| // their consumer node. | |
| const initIndex = new Map<string, IRWeight>(); | |
| for (const t of initializers) { | |
| const w = makeWeight(t); | |
| if (w.name) initIndex.set(w.name, w); | |
| } | |
| const initNames = new Set(initIndex.keys()); | |
| // Build IR nodes. | |
| const irNodes: IRNode[] = []; | |
| const idByName = new Map<string, string>(); | |
| nodes.forEach((n, idx) => { | |
| const opType = n.op_type ?? n.opType ?? "Unknown"; | |
| const name = n.name && n.name.length > 0 ? n.name : `${opType}_${idx}`; | |
| const id = `n${idx}_${opType}`; | |
| idByName.set(name, id); | |
| const inputArr = n.input ?? []; | |
| const weights: IRWeight[] = []; | |
| const realInputs: string[] = []; | |
| for (const t of inputArr) { | |
| if (initNames.has(t)) { | |
| const w = initIndex.get(t); | |
| if (w) weights.push(w); | |
| } else { | |
| realInputs.push(t); | |
| } | |
| } | |
| const attrs: Record<string, unknown> = {}; | |
| for (const a of n.attribute ?? []) { | |
| attrs[a.name] = decodeAttr(a); | |
| } | |
| irNodes.push({ | |
| id, | |
| opType, | |
| name, | |
| inputs: realInputs, | |
| outputs: n.output ?? [], | |
| attrs, | |
| weights, | |
| }); | |
| }); | |
| // Map tensor name -> producer node id. | |
| const producerOf = new Map<string, string>(); | |
| for (const n of irNodes) { | |
| for (const out of n.outputs) producerOf.set(out, n.id); | |
| } | |
| // Graph-level inputs are produced by a synthetic source node. | |
| for (const i of inputs) producerOf.set(i.name, SYNTHETIC_INPUT_ID); | |
| // Build edges. | |
| const irEdges: IREdge[] = []; | |
| const pushEdge = (source: string, target: string, tensor: string) => { | |
| const info = tensorTypeIndex.get(tensor); | |
| irEdges.push({ | |
| id: `${source}__${target}__${tensor}`, | |
| source, | |
| target, | |
| tensor, | |
| shape: info?.shape, | |
| dtype: info?.dtype, | |
| }); | |
| }; | |
| for (const n of irNodes) { | |
| for (const inp of n.inputs) { | |
| const src = producerOf.get(inp); | |
| if (src) pushEdge(src, n.id, inp); | |
| } | |
| } | |
| // Outputs -> synthetic sink. | |
| for (const out of outputs) { | |
| const src = producerOf.get(out.name); | |
| if (src) pushEdge(src, SYNTHETIC_OUTPUT_ID, out.name); | |
| } | |
| // Inject synthetic source/sink nodes. | |
| if (inputs.length > 0) { | |
| irNodes.unshift({ | |
| id: SYNTHETIC_INPUT_ID, | |
| opType: "Input", | |
| name: "inputs", | |
| inputs: [], | |
| outputs: inputs.map((i) => i.name), | |
| attrs: {}, | |
| weights: [], | |
| }); | |
| } | |
| if (outputs.length > 0) { | |
| irNodes.push({ | |
| id: SYNTHETIC_OUTPUT_ID, | |
| opType: "Output", | |
| name: "outputs", | |
| inputs: outputs.map((o) => o.name), | |
| outputs: [], | |
| attrs: {}, | |
| weights: [], | |
| }); | |
| } | |
| const totalParams = Array.from(initIndex.values()).reduce( | |
| (acc, w) => acc + w.numParams, | |
| 0, | |
| ); | |
| const opsetVersion = (() => { | |
| const arr = model.opset_import ?? model.opsetImport ?? []; | |
| const main = arr.find((o) => !o.domain || o.domain === "" || o.domain === "ai.onnx") ?? arr[0]; | |
| return main ? toNum(main.version) : 0; | |
| })(); | |
| return { | |
| nodes: irNodes, | |
| edges: irEdges, | |
| meta: { | |
| modelName: graph.name ?? "model", | |
| producer: model.producer_name ?? model.producerName ?? "unknown", | |
| irVersion: toNum(model.ir_version ?? model.irVersion), | |
| opsetVersion, | |
| totalParams, | |
| nodeCount: irNodes.length, | |
| inputs, | |
| outputs, | |
| }, | |
| }; | |
| } | |
| export { SYNTHETIC_INPUT_ID, SYNTHETIC_OUTPUT_ID }; | |