Spaces:
Running
Running
| /** | |
| * ONNXParser - Bộ Phân Tích ONNX | |
| * Parses ONNX model files using protobuf.js (window.protobuf) with an inline | |
| * ONNX proto schema. No onnx-js runtime required. | |
| * Requirements: 9.1, 9.2, 9.3, 9.4, 9.5 | |
| */ | |
| class ONNXParser { | |
| constructor() { | |
| /** @type {Object|null} Cached protobuf root */ | |
| this._root = null; | |
| } | |
| // ─── Proto Schema ───────────────────────────────────────────────────────── | |
| /** | |
| * Minimal ONNX proto schema (subset needed for model exploration). | |
| * Based on https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3 | |
| */ | |
| _getProtoSchema() { | |
| return ` | |
| syntax = "proto3"; | |
| package onnx; | |
| message ModelProto { | |
| int64 ir_version = 1; | |
| repeated OperatorSetIdProto opset_import = 8; | |
| string producer_name = 2; | |
| string producer_version = 3; | |
| string domain = 4; | |
| int64 model_version = 5; | |
| string doc_string = 6; | |
| GraphProto graph = 7; | |
| repeated StringStringEntryProto metadata_props = 14; | |
| } | |
| message GraphProto { | |
| repeated NodeProto node = 1; | |
| string name = 2; | |
| repeated TensorProto initializer = 5; | |
| string doc_string = 10; | |
| repeated ValueInfoProto input = 11; | |
| repeated ValueInfoProto output = 12; | |
| repeated ValueInfoProto value_info = 13; | |
| } | |
| message NodeProto { | |
| repeated string input = 1; | |
| repeated string output = 2; | |
| string name = 3; | |
| string op_type = 4; | |
| string domain = 7; | |
| repeated AttributeProto attribute = 5; | |
| string doc_string = 6; | |
| } | |
| message AttributeProto { | |
| string name = 1; | |
| string ref_attr_name = 21; | |
| string doc_string = 13; | |
| AttributeType type = 20; | |
| float f = 2; | |
| int64 i = 3; | |
| bytes s = 4; | |
| TensorProto t = 5; | |
| GraphProto g = 6; | |
| repeated float floats = 7; | |
| repeated int64 ints = 8; | |
| repeated bytes strings = 9; | |
| repeated TensorProto tensors = 10; | |
| repeated GraphProto graphs = 11; | |
| enum AttributeType { | |
| UNDEFINED = 0; | |
| FLOAT = 1; | |
| INT = 2; | |
| STRING = 3; | |
| TENSOR = 4; | |
| GRAPH = 5; | |
| FLOATS = 6; | |
| INTS = 7; | |
| STRINGS = 8; | |
| TENSORS = 9; | |
| GRAPHS = 10; | |
| } | |
| } | |
| message ValueInfoProto { | |
| string name = 1; | |
| TypeProto type = 2; | |
| string doc_string = 3; | |
| } | |
| message TypeProto { | |
| message Tensor { | |
| int32 elem_type = 1; | |
| TensorShapeProto shape = 2; | |
| } | |
| oneof value { | |
| Tensor tensor_type = 1; | |
| } | |
| } | |
| message TensorShapeProto { | |
| message Dimension { | |
| oneof value { | |
| int64 dim_value = 1; | |
| string dim_param = 2; | |
| } | |
| string denotation = 3; | |
| } | |
| repeated Dimension dim = 1; | |
| } | |
| message TensorProto { | |
| repeated int64 dims = 1 [packed = true]; | |
| int32 data_type = 2; | |
| string name = 8; | |
| string doc_string = 12; | |
| bytes raw_data = 9; | |
| repeated float float_data = 4 [packed = true]; | |
| repeated int32 int32_data = 5 [packed = true]; | |
| repeated int64 int64_data = 7 [packed = true]; | |
| repeated double double_data = 10 [packed = true]; | |
| repeated uint64 uint64_data = 11 [packed = true]; | |
| } | |
| message OperatorSetIdProto { | |
| string domain = 1; | |
| int64 version = 2; | |
| } | |
| message StringStringEntryProto { | |
| string key = 1; | |
| string value = 2; | |
| } | |
| `; | |
| } | |
| // ─── Protobuf Root ──────────────────────────────────────────────────────── | |
| /** | |
| * Build (and cache) the protobuf root from the inline schema. | |
| * @returns {Object} protobuf Root | |
| */ | |
| _getRoot() { | |
| if (this._root) return this._root; | |
| if (typeof protobuf === 'undefined') { | |
| throw new Error( | |
| 'protobuf.js is not available. ' + | |
| 'Please ensure the protobuf.js CDN script is loaded before using ONNXParser.' | |
| ); | |
| } | |
| this._root = protobuf.parse(this._getProtoSchema(), { keepCase: false }).root; | |
| return this._root; | |
| } | |
| // ─── Public API ─────────────────────────────────────────────────────────── | |
| /** | |
| * Parse an ONNX model from an ArrayBuffer. | |
| * @param {ArrayBuffer} modelBuffer | |
| * @param {Object} [options] | |
| * @param {string} [options.fileName] | |
| * @param {number} [options.fileSize] | |
| * @param {number} [options.loadedAt] | |
| * @returns {Promise<ParsedModel>} | |
| */ | |
| async parseModel(modelBuffer, options = {}) { | |
| if (!modelBuffer || !(modelBuffer instanceof ArrayBuffer)) { | |
| throw new Error('Invalid model buffer: expected an ArrayBuffer.'); | |
| } | |
| if (modelBuffer.byteLength === 0) { | |
| throw new Error('Model buffer is empty. The file may be corrupted or invalid.'); | |
| } | |
| try { | |
| const root = this._getRoot(); | |
| const ModelProto = root.lookupType('onnx.ModelProto'); | |
| const uint8 = new Uint8Array(modelBuffer); | |
| const model = ModelProto.decode(uint8); | |
| const metadata = this._extractMetadata(model, options); | |
| const inputs = this._extractInputs(model); | |
| const outputs = this._extractOutputs(model); | |
| const initializers = this._extractInitializers(model); | |
| const graph = this._extractGraph(model, initializers); | |
| return { metadata, graph, inputs, outputs, initializers }; | |
| } catch (err) { | |
| if (err.message && ( | |
| err.message.includes('protobuf.js') || | |
| err.message.includes('Invalid model') || | |
| err.message.includes('empty') | |
| )) { | |
| throw err; | |
| } | |
| throw new Error( | |
| `Failed to parse ONNX model: ${err.message || 'The file may be corrupted or not a valid ONNX file.'}` | |
| ); | |
| } | |
| } | |
| // ─── Private Helpers ────────────────────────────────────────────────────── | |
| /** Convert Long / BigInt / number to a JS number. */ | |
| _toLong(value) { | |
| if (value === null || value === undefined) return 0; | |
| if (typeof value === 'number') return value; | |
| if (typeof value === 'bigint') return Number(value); | |
| if (typeof value === 'object') { | |
| if (typeof value.toNumber === 'function') return value.toNumber(); | |
| if (typeof value.low === 'number') return value.low + value.high * 0x100000000; | |
| } | |
| return Number(value) || 0; | |
| } | |
| _getDataType(code) { | |
| const n = this._toLong(code); | |
| return (CONFIG && CONFIG.DATA_TYPES && CONFIG.DATA_TYPES[n]) || `UNKNOWN(${n})`; | |
| } | |
| _extractShape(typeProto) { | |
| try { | |
| const tt = typeProto && typeProto.tensorType; | |
| if (!tt || !tt.shape) return []; | |
| return (tt.shape.dim || []).map(d => { | |
| if (d.dimParam) return d.dimParam; | |
| const v = this._toLong(d.dimValue); | |
| return v; | |
| }); | |
| } catch { return []; } | |
| } | |
| _extractDataType(typeProto) { | |
| try { | |
| const tt = typeProto && typeProto.tensorType; | |
| if (!tt) return 'UNKNOWN'; | |
| return this._getDataType(tt.elemType); | |
| } catch { return 'UNKNOWN'; } | |
| } | |
| _extractMetadata(model, options) { | |
| let opsetVersion = 0; | |
| const opsetImport = []; | |
| if (model.opsetImport && model.opsetImport.length > 0) { | |
| const def = model.opsetImport.find(op => !op.domain || op.domain === '') || model.opsetImport[0]; | |
| opsetVersion = this._toLong(def.version); | |
| for (const op of model.opsetImport) { | |
| opsetImport.push({ | |
| domain: op.domain || '', | |
| version: this._toLong(op.version) | |
| }); | |
| } | |
| } | |
| const customAttributes = {}; | |
| if (Array.isArray(model.metadataProps)) { | |
| for (const p of model.metadataProps) { | |
| if (p.key) customAttributes[p.key] = p.value || ''; | |
| } | |
| } | |
| return { | |
| producerName: model.producerName || '', | |
| producerVersion: model.producerVersion || '', | |
| opsetVersion, | |
| opsetImport, | |
| irVersion: this._toLong(model.irVersion), | |
| customAttributes, | |
| fileName: options.fileName || '', | |
| fileSize: options.fileSize || 0, | |
| loadedAt: options.loadedAt || Date.now() | |
| }; | |
| } | |
| _extractInputs(model) { | |
| const graph = model.graph; | |
| if (!graph || !graph.input) return []; | |
| const initNames = new Set((graph.initializer || []).map(i => i.name)); | |
| return graph.input | |
| .filter(i => i.name && !initNames.has(i.name)) | |
| .map(i => ({ | |
| name: i.name, | |
| shape: this._extractShape(i.type), | |
| dataType: this._extractDataType(i.type), | |
| description: i.docString || undefined, | |
| isOptional: false | |
| })); | |
| } | |
| _extractOutputs(model) { | |
| const graph = model.graph; | |
| if (!graph || !graph.output) return []; | |
| return graph.output | |
| .filter(o => o.name) | |
| .map(o => ({ | |
| name: o.name, | |
| shape: this._extractShape(o.type), | |
| dataType: this._extractDataType(o.type), | |
| description: o.docString || undefined, | |
| isOptional: false | |
| })); | |
| } | |
| _extractInitializers(model) { | |
| const graph = model.graph; | |
| if (!graph || !graph.initializer) return []; | |
| return graph.initializer.map(init => { | |
| const shape = (init.dims || []).map(d => this._toLong(d)); | |
| const dataType = this._getDataType(init.dataType); | |
| const elementCount = shape.length > 0 ? shape.reduce((a, d) => a * (d || 1), 1) : 0; | |
| const bytesPerElement = this._getBytesPerElement(this._toLong(init.dataType)); | |
| return { | |
| name: init.name || '', | |
| shape, | |
| dataType, | |
| size: elementCount * bytesPerElement, | |
| elementCount | |
| }; | |
| }); | |
| } | |
| _getBytesPerElement(code) { | |
| const m = { 1:4, 2:1, 3:1, 4:2, 5:2, 6:4, 7:8, 8:0, 9:1, 10:2, 11:8, 12:4, 13:8, 14:8, 15:16, 16:2 }; | |
| return m[code] || 4; | |
| } | |
| _extractAttributeValue(attr) { | |
| const t = this._toLong(attr.type); | |
| switch (t) { | |
| case 1: return attr.f !== undefined ? attr.f : null; | |
| case 2: return attr.i !== undefined ? this._toLong(attr.i) : null; | |
| case 3: { | |
| if (attr.s instanceof Uint8Array) { | |
| try { return new TextDecoder().decode(attr.s); } catch { return ''; } | |
| } | |
| return attr.s || ''; | |
| } | |
| case 4: return '[Tensor]'; | |
| case 5: return '[Graph]'; | |
| case 6: return Array.isArray(attr.floats) ? attr.floats : []; | |
| case 7: return Array.isArray(attr.ints) ? attr.ints.map(i => this._toLong(i)) : []; | |
| case 8: { | |
| if (Array.isArray(attr.strings)) { | |
| return attr.strings.map(s => { | |
| if (s instanceof Uint8Array) { try { return new TextDecoder().decode(s); } catch { return ''; } } | |
| return s || ''; | |
| }); | |
| } | |
| return []; | |
| } | |
| default: return null; | |
| } | |
| } | |
| /** | |
| * Extract value_info from GraphProto into a Map-like object. | |
| * Each entry maps tensor name → { name, shape, dataType }. | |
| * Requirements: 21.1, 21.2 | |
| */ | |
| _extractValueInfo(graph) { | |
| const valueInfo = {}; | |
| if (!graph || !Array.isArray(graph.valueInfo)) return valueInfo; | |
| for (const vi of graph.valueInfo) { | |
| if (!vi.name) continue; | |
| valueInfo[vi.name] = { | |
| name: vi.name, | |
| shape: this._extractShape(vi.type), | |
| dataType: this._extractDataType(vi.type) | |
| }; | |
| } | |
| return valueInfo; | |
| } | |
| _extractGraph(model, initializers) { | |
| const graph = model.graph; | |
| if (!graph) return { name: '', nodes: [], edges: [], initializers, valueInfo: {} }; | |
| const outputToNodeId = new Map(); | |
| const nodes = (graph.node || []).map((node, idx) => { | |
| const nodeId = `node_${idx}`; | |
| const nodeName = node.name || `${node.opType || 'Unknown'}_${idx}`; | |
| (node.output || []).forEach(o => { if (o) outputToNodeId.set(o, nodeId); }); | |
| const attributes = {}; | |
| if (Array.isArray(node.attribute)) { | |
| for (const a of node.attribute) { | |
| if (a.name) attributes[a.name] = this._extractAttributeValue(a); | |
| } | |
| } | |
| return { | |
| id: nodeId, | |
| name: nodeName, | |
| opType: node.opType || '', | |
| attributes, | |
| inputs: (node.input || []).filter(Boolean), | |
| outputs: (node.output || []).filter(Boolean), | |
| domain: node.domain || '' | |
| }; | |
| }); | |
| const edges = []; | |
| const edgeSet = new Set(); | |
| for (const node of nodes) { | |
| for (const inp of node.inputs) { | |
| const src = outputToNodeId.get(inp); | |
| if (src && src !== node.id) { | |
| const key = `${src}→${node.id}→${inp}`; | |
| if (!edgeSet.has(key)) { | |
| edgeSet.add(key); | |
| edges.push({ source: src, target: node.id, label: inp }); | |
| } | |
| } | |
| } | |
| } | |
| const valueInfo = this._extractValueInfo(graph); | |
| return { name: graph.name || '', nodes, edges, initializers, valueInfo }; | |
| } | |
| } | |
| window.ONNXParser = ONNXParser; | |