model-explorer / js /core /onnxParser.js
mr4's picture
Upload 71 files
9bd422a verified
/**
* ONNXParser - Bộ Phân Tích ONNX
* Parses ONNX model files using protobuf.js (window.protobuf) with an inline
* ONNX proto schema. No onnx-js runtime required.
* Requirements: 9.1, 9.2, 9.3, 9.4, 9.5
*/
class ONNXParser {
constructor() {
/** @type {Object|null} Cached protobuf root */
this._root = null;
}
// ─── Proto Schema ─────────────────────────────────────────────────────────
/**
* Minimal ONNX proto schema (subset needed for model exploration).
* Based on https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
*/
_getProtoSchema() {
return `
syntax = "proto3";
package onnx;
message ModelProto {
int64 ir_version = 1;
repeated OperatorSetIdProto opset_import = 8;
string producer_name = 2;
string producer_version = 3;
string domain = 4;
int64 model_version = 5;
string doc_string = 6;
GraphProto graph = 7;
repeated StringStringEntryProto metadata_props = 14;
}
message GraphProto {
repeated NodeProto node = 1;
string name = 2;
repeated TensorProto initializer = 5;
string doc_string = 10;
repeated ValueInfoProto input = 11;
repeated ValueInfoProto output = 12;
repeated ValueInfoProto value_info = 13;
}
message NodeProto {
repeated string input = 1;
repeated string output = 2;
string name = 3;
string op_type = 4;
string domain = 7;
repeated AttributeProto attribute = 5;
string doc_string = 6;
}
message AttributeProto {
string name = 1;
string ref_attr_name = 21;
string doc_string = 13;
AttributeType type = 20;
float f = 2;
int64 i = 3;
bytes s = 4;
TensorProto t = 5;
GraphProto g = 6;
repeated float floats = 7;
repeated int64 ints = 8;
repeated bytes strings = 9;
repeated TensorProto tensors = 10;
repeated GraphProto graphs = 11;
enum AttributeType {
UNDEFINED = 0;
FLOAT = 1;
INT = 2;
STRING = 3;
TENSOR = 4;
GRAPH = 5;
FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
}
}
message ValueInfoProto {
string name = 1;
TypeProto type = 2;
string doc_string = 3;
}
message TypeProto {
message Tensor {
int32 elem_type = 1;
TensorShapeProto shape = 2;
}
oneof value {
Tensor tensor_type = 1;
}
}
message TensorShapeProto {
message Dimension {
oneof value {
int64 dim_value = 1;
string dim_param = 2;
}
string denotation = 3;
}
repeated Dimension dim = 1;
}
message TensorProto {
repeated int64 dims = 1 [packed = true];
int32 data_type = 2;
string name = 8;
string doc_string = 12;
bytes raw_data = 9;
repeated float float_data = 4 [packed = true];
repeated int32 int32_data = 5 [packed = true];
repeated int64 int64_data = 7 [packed = true];
repeated double double_data = 10 [packed = true];
repeated uint64 uint64_data = 11 [packed = true];
}
message OperatorSetIdProto {
string domain = 1;
int64 version = 2;
}
message StringStringEntryProto {
string key = 1;
string value = 2;
}
`;
}
// ─── Protobuf Root ────────────────────────────────────────────────────────
/**
* Build (and cache) the protobuf root from the inline schema.
* @returns {Object} protobuf Root
*/
_getRoot() {
if (this._root) return this._root;
if (typeof protobuf === 'undefined') {
throw new Error(
'protobuf.js is not available. ' +
'Please ensure the protobuf.js CDN script is loaded before using ONNXParser.'
);
}
this._root = protobuf.parse(this._getProtoSchema(), { keepCase: false }).root;
return this._root;
}
// ─── Public API ───────────────────────────────────────────────────────────
/**
* Parse an ONNX model from an ArrayBuffer.
* @param {ArrayBuffer} modelBuffer
* @param {Object} [options]
* @param {string} [options.fileName]
* @param {number} [options.fileSize]
* @param {number} [options.loadedAt]
* @returns {Promise<ParsedModel>}
*/
async parseModel(modelBuffer, options = {}) {
if (!modelBuffer || !(modelBuffer instanceof ArrayBuffer)) {
throw new Error('Invalid model buffer: expected an ArrayBuffer.');
}
if (modelBuffer.byteLength === 0) {
throw new Error('Model buffer is empty. The file may be corrupted or invalid.');
}
try {
const root = this._getRoot();
const ModelProto = root.lookupType('onnx.ModelProto');
const uint8 = new Uint8Array(modelBuffer);
const model = ModelProto.decode(uint8);
const metadata = this._extractMetadata(model, options);
const inputs = this._extractInputs(model);
const outputs = this._extractOutputs(model);
const initializers = this._extractInitializers(model);
const graph = this._extractGraph(model, initializers);
return { metadata, graph, inputs, outputs, initializers };
} catch (err) {
if (err.message && (
err.message.includes('protobuf.js') ||
err.message.includes('Invalid model') ||
err.message.includes('empty')
)) {
throw err;
}
throw new Error(
`Failed to parse ONNX model: ${err.message || 'The file may be corrupted or not a valid ONNX file.'}`
);
}
}
// ─── Private Helpers ──────────────────────────────────────────────────────
/** Convert Long / BigInt / number to a JS number. */
_toLong(value) {
if (value === null || value === undefined) return 0;
if (typeof value === 'number') return value;
if (typeof value === 'bigint') return Number(value);
if (typeof value === 'object') {
if (typeof value.toNumber === 'function') return value.toNumber();
if (typeof value.low === 'number') return value.low + value.high * 0x100000000;
}
return Number(value) || 0;
}
_getDataType(code) {
const n = this._toLong(code);
return (CONFIG && CONFIG.DATA_TYPES && CONFIG.DATA_TYPES[n]) || `UNKNOWN(${n})`;
}
_extractShape(typeProto) {
try {
const tt = typeProto && typeProto.tensorType;
if (!tt || !tt.shape) return [];
return (tt.shape.dim || []).map(d => {
if (d.dimParam) return d.dimParam;
const v = this._toLong(d.dimValue);
return v;
});
} catch { return []; }
}
_extractDataType(typeProto) {
try {
const tt = typeProto && typeProto.tensorType;
if (!tt) return 'UNKNOWN';
return this._getDataType(tt.elemType);
} catch { return 'UNKNOWN'; }
}
_extractMetadata(model, options) {
let opsetVersion = 0;
const opsetImport = [];
if (model.opsetImport && model.opsetImport.length > 0) {
const def = model.opsetImport.find(op => !op.domain || op.domain === '') || model.opsetImport[0];
opsetVersion = this._toLong(def.version);
for (const op of model.opsetImport) {
opsetImport.push({
domain: op.domain || '',
version: this._toLong(op.version)
});
}
}
const customAttributes = {};
if (Array.isArray(model.metadataProps)) {
for (const p of model.metadataProps) {
if (p.key) customAttributes[p.key] = p.value || '';
}
}
return {
producerName: model.producerName || '',
producerVersion: model.producerVersion || '',
opsetVersion,
opsetImport,
irVersion: this._toLong(model.irVersion),
customAttributes,
fileName: options.fileName || '',
fileSize: options.fileSize || 0,
loadedAt: options.loadedAt || Date.now()
};
}
_extractInputs(model) {
const graph = model.graph;
if (!graph || !graph.input) return [];
const initNames = new Set((graph.initializer || []).map(i => i.name));
return graph.input
.filter(i => i.name && !initNames.has(i.name))
.map(i => ({
name: i.name,
shape: this._extractShape(i.type),
dataType: this._extractDataType(i.type),
description: i.docString || undefined,
isOptional: false
}));
}
_extractOutputs(model) {
const graph = model.graph;
if (!graph || !graph.output) return [];
return graph.output
.filter(o => o.name)
.map(o => ({
name: o.name,
shape: this._extractShape(o.type),
dataType: this._extractDataType(o.type),
description: o.docString || undefined,
isOptional: false
}));
}
_extractInitializers(model) {
const graph = model.graph;
if (!graph || !graph.initializer) return [];
return graph.initializer.map(init => {
const shape = (init.dims || []).map(d => this._toLong(d));
const dataType = this._getDataType(init.dataType);
const elementCount = shape.length > 0 ? shape.reduce((a, d) => a * (d || 1), 1) : 0;
const bytesPerElement = this._getBytesPerElement(this._toLong(init.dataType));
return {
name: init.name || '',
shape,
dataType,
size: elementCount * bytesPerElement,
elementCount
};
});
}
_getBytesPerElement(code) {
const m = { 1:4, 2:1, 3:1, 4:2, 5:2, 6:4, 7:8, 8:0, 9:1, 10:2, 11:8, 12:4, 13:8, 14:8, 15:16, 16:2 };
return m[code] || 4;
}
_extractAttributeValue(attr) {
const t = this._toLong(attr.type);
switch (t) {
case 1: return attr.f !== undefined ? attr.f : null;
case 2: return attr.i !== undefined ? this._toLong(attr.i) : null;
case 3: {
if (attr.s instanceof Uint8Array) {
try { return new TextDecoder().decode(attr.s); } catch { return ''; }
}
return attr.s || '';
}
case 4: return '[Tensor]';
case 5: return '[Graph]';
case 6: return Array.isArray(attr.floats) ? attr.floats : [];
case 7: return Array.isArray(attr.ints) ? attr.ints.map(i => this._toLong(i)) : [];
case 8: {
if (Array.isArray(attr.strings)) {
return attr.strings.map(s => {
if (s instanceof Uint8Array) { try { return new TextDecoder().decode(s); } catch { return ''; } }
return s || '';
});
}
return [];
}
default: return null;
}
}
/**
* Extract value_info from GraphProto into a Map-like object.
* Each entry maps tensor name → { name, shape, dataType }.
* Requirements: 21.1, 21.2
*/
_extractValueInfo(graph) {
const valueInfo = {};
if (!graph || !Array.isArray(graph.valueInfo)) return valueInfo;
for (const vi of graph.valueInfo) {
if (!vi.name) continue;
valueInfo[vi.name] = {
name: vi.name,
shape: this._extractShape(vi.type),
dataType: this._extractDataType(vi.type)
};
}
return valueInfo;
}
_extractGraph(model, initializers) {
const graph = model.graph;
if (!graph) return { name: '', nodes: [], edges: [], initializers, valueInfo: {} };
const outputToNodeId = new Map();
const nodes = (graph.node || []).map((node, idx) => {
const nodeId = `node_${idx}`;
const nodeName = node.name || `${node.opType || 'Unknown'}_${idx}`;
(node.output || []).forEach(o => { if (o) outputToNodeId.set(o, nodeId); });
const attributes = {};
if (Array.isArray(node.attribute)) {
for (const a of node.attribute) {
if (a.name) attributes[a.name] = this._extractAttributeValue(a);
}
}
return {
id: nodeId,
name: nodeName,
opType: node.opType || '',
attributes,
inputs: (node.input || []).filter(Boolean),
outputs: (node.output || []).filter(Boolean),
domain: node.domain || ''
};
});
const edges = [];
const edgeSet = new Set();
for (const node of nodes) {
for (const inp of node.inputs) {
const src = outputToNodeId.get(inp);
if (src && src !== node.id) {
const key = `${src}${node.id}${inp}`;
if (!edgeSet.has(key)) {
edgeSet.add(key);
edges.push({ source: src, target: node.id, label: inp });
}
}
}
}
const valueInfo = this._extractValueInfo(graph);
return { name: graph.name || '', nodes, edges, initializers, valueInfo };
}
}
window.ONNXParser = ONNXParser;