/** * FlopsEstimator - Estimates FLOPs (Floating Point Operations) for ONNX models * Computes per-node and total FLOPs based on opType and input tensor shapes. * Requirements: 34.1, 34.2, 34.3, 34.4, 34.5, 34.6 */ class FlopsEstimator { /** * @param {string} containerId - ID of the container element for rendering */ constructor(containerId) { this._containerId = containerId; this._container = document.getElementById(containerId); this._result = null; if (!this._container) { console.warn(`[FlopsEstimator] Container #${containerId} not found`); } } // ─── FLOPs Formulas ─────────────────────────────────────────────────────── /** * Supported opType formulas. * Each function receives (attrs, inputShapes) and returns a FLOPs count. */ static get FORMULAS() { return { // Conv: 2 * Cout * Hout * Wout * CinPerGroup * Kh * Kw 'Conv': function (attrs, inputShapes) { const input = inputShapes[0] || []; const weight = inputShapes[1] || []; const Cout = weight[0] || 0; const CinPerGroup = weight[1] || 0; const Kh = weight[2] || 0; const Kw = weight[3] || 0; // Estimate output spatial dims (assume same padding) const Hout = input[2] || 0; const Wout = input[3] || 0; return 2 * Cout * Hout * Wout * CinPerGroup * Kh * Kw; }, // MatMul: 2 * M * N * K 'MatMul': function (attrs, inputShapes) { const A = inputShapes[0] || []; const B = inputShapes[1] || []; const M = A[A.length - 2] || 0; const K = A[A.length - 1] || 0; const N = B[B.length - 1] || 0; return 2 * M * N * K; }, // Gemm: 2 * M * N * K + M * N (bias) 'Gemm': function (attrs, inputShapes) { const A = inputShapes[0] || []; const B = inputShapes[1] || []; const M = A[0] || 0; const K = A[1] || 0; const N = B[1] || 0; return 2 * M * N * K + M * N; }, // BatchNormalization: 4 * elements (mean, var, normalize, scale+shift) 'BatchNormalization': function (attrs, inputShapes) { const input = inputShapes[0] || []; if (input.length === 0) return 0; var elements = 1; for (var i = 0; i < input.length; i++) { elements *= (typeof input[i] === 'number' ? input[i] : 0); } return 4 * elements; }, // Relu: 1 comparison per element 'Relu': function (attrs, inputShapes) { const input = inputShapes[0] || []; if (input.length === 0) return 0; var elements = 1; for (var i = 0; i < input.length; i++) { elements *= (typeof input[i] === 'number' ? input[i] : 0); } return elements; }, // Add: element-wise 'Add': function (attrs, inputShapes) { const input = inputShapes[0] || []; if (input.length === 0) return 0; var elements = 1; for (var i = 0; i < input.length; i++) { elements *= (typeof input[i] === 'number' ? input[i] : 0); } return elements; }, // Mul: element-wise 'Mul': function (attrs, inputShapes) { const input = inputShapes[0] || []; if (input.length === 0) return 0; var elements = 1; for (var i = 0; i < input.length; i++) { elements *= (typeof input[i] === 'number' ? input[i] : 0); } return elements; } }; } // ─── Shape Lookup ───────────────────────────────────────────────────────── /** * Build a map of tensor name → shape from all available sources. * Sources: parsedModel.inputs, parsedModel.outputs, parsedModel.graph.valueInfo, initializers * @param {Object} parsedModel * @returns {Object} tensorName → shape array */ _buildShapeMap(parsedModel) { const shapeMap = {}; // Model inputs const inputs = (parsedModel && parsedModel.inputs) || []; for (var i = 0; i < inputs.length; i++) { if (inputs[i].name && inputs[i].shape) { shapeMap[inputs[i].name] = inputs[i].shape; } } // Model outputs const outputs = (parsedModel && parsedModel.outputs) || []; for (var i = 0; i < outputs.length; i++) { if (outputs[i].name && outputs[i].shape) { shapeMap[outputs[i].name] = outputs[i].shape; } } // Intermediate value info const valueInfo = (parsedModel && parsedModel.graph && parsedModel.graph.valueInfo) || {}; const viKeys = Object.keys(valueInfo); for (var i = 0; i < viKeys.length; i++) { const vi = valueInfo[viKeys[i]]; if (vi && vi.name && vi.shape) { shapeMap[vi.name] = vi.shape; } } // Initializers const initializers = (parsedModel && parsedModel.graph && parsedModel.graph.initializers) || []; for (var i = 0; i < initializers.length; i++) { if (initializers[i].name && initializers[i].shape) { shapeMap[initializers[i].name] = initializers[i].shape; } } return shapeMap; } // ─── Public API ─────────────────────────────────────────────────────────── /** * Estimate FLOPs for a single node given its opType, attributes, and resolved input shapes. * @param {{ opType: string, attributes: Object, inputShapes: Array> }} nodeInfo * @returns {number} Estimated FLOPs, or -1 if opType is unsupported */ estimateNodeFlops(nodeInfo) { const opType = nodeInfo && nodeInfo.opType; const formula = FlopsEstimator.FORMULAS[opType]; if (!formula) return -1; const attrs = (nodeInfo && nodeInfo.attributes) || {}; const inputShapes = (nodeInfo && nodeInfo.inputShapes) || []; const flops = formula(attrs, inputShapes); return (typeof flops === 'number' && isFinite(flops) && flops >= 0) ? flops : 0; } /** * Compute FLOPs for the entire model. * @param {Object} parsedModel - The parsed ONNX model * @returns {FlopsResult} */ compute(parsedModel) { const nodes = (parsedModel && parsedModel.graph && parsedModel.graph.nodes) || []; const shapeMap = this._buildShapeMap(parsedModel); const formulas = FlopsEstimator.FORMULAS; var totalFlops = 0; const perNode = []; const opTypeMap = {}; // opType → { totalFlops, count } const unsupportedSet = {}; // opType → true for (var i = 0; i < nodes.length; i++) { const node = nodes[i]; const opType = node.opType || ''; // Resolve input shapes for this node const inputShapes = []; const nodeInputs = node.inputs || []; for (var j = 0; j < nodeInputs.length; j++) { inputShapes.push(shapeMap[nodeInputs[j]] || []); } const flops = this.estimateNodeFlops({ opType: opType, attributes: node.attributes || {}, inputShapes: inputShapes }); if (flops < 0) { // Unsupported op unsupportedSet[opType] = true; perNode.push({ nodeId: node.id || '', nodeName: node.name || '', opType: opType, flops: -1, formattedFlops: 'N/A' }); } else { totalFlops += flops; perNode.push({ nodeId: node.id || '', nodeName: node.name || '', opType: opType, flops: flops, formattedFlops: this.formatFlops(flops) }); if (!opTypeMap[opType]) { opTypeMap[opType] = { totalFlops: 0, count: 0 }; } opTypeMap[opType].totalFlops += flops; opTypeMap[opType].count += 1; } } // Build perOpType array sorted descending by totalFlops const perOpType = []; const opTypes = Object.keys(opTypeMap); for (var i = 0; i < opTypes.length; i++) { const ot = opTypes[i]; const entry = opTypeMap[ot]; perOpType.push({ opType: ot, totalFlops: entry.totalFlops, formattedFlops: this.formatFlops(entry.totalFlops), count: entry.count, percentage: totalFlops > 0 ? (entry.totalFlops / totalFlops) * 100 : 0 }); } perOpType.sort(function (a, b) { return b.totalFlops - a.totalFlops; }); const unsupportedOps = Object.keys(unsupportedSet); this._result = { totalFlops: totalFlops, formattedTotal: this.formatFlops(totalFlops), perNode: perNode, perOpType: perOpType, unsupportedOps: unsupportedOps }; return this._result; } /** * Format a FLOPs number into a human-readable string. * @param {number} flops * @returns {string} e.g. "1.20 TFLOPs", "500.00 MFLOPs", "0 FLOPs" */ formatFlops(flops) { if (flops == null || isNaN(flops) || flops < 0) return 'N/A'; if (flops >= 1e12) return (flops / 1e12).toFixed(2) + ' TFLOPs'; if (flops >= 1e9) return (flops / 1e9).toFixed(2) + ' GFLOPs'; if (flops >= 1e6) return (flops / 1e6).toFixed(2) + ' MFLOPs'; if (flops >= 1e3) return (flops / 1e3).toFixed(2) + ' KFLOPs'; return flops + ' FLOPs'; } /** * Render the FLOPs estimation result into the container. * @param {FlopsResult} [result] - If omitted, uses the last computed result */ render(result) { if (!this._container) return; const data = result || this._result; if (!data) { this._container.innerHTML = '

No FLOPs estimation available.

'; return; } var html = '
'; // Total FLOPs header html += '
'; html += ''; html += 'Total FLOPs: '; html += '' + this._escapeHtml(data.formattedTotal) + ''; html += '
'; // Breakdown by opType if (data.perOpType && data.perOpType.length > 0) { html += ''; html += ''; html += ''; for (var i = 0; i < data.perOpType.length; i++) { var entry = data.perOpType[i]; html += ''; html += ''; html += ''; html += ''; html += ''; html += ''; } html += '
Op TypeFLOPsCount%
' + this._escapeHtml(entry.opType) + '' + this._escapeHtml(entry.formattedFlops) + '' + entry.count + '' + entry.percentage.toFixed(1) + '%
'; } // Unsupported ops if (data.unsupportedOps && data.unsupportedOps.length > 0) { html += '
'; html += ''; html += 'Unsupported ops (N/A): ' + data.unsupportedOps.map(this._escapeHtml).join(', '); html += '
'; } html += '
'; this._container.innerHTML = html; } /** * Get the last computed result. * @returns {FlopsResult|null} */ getResult() { return this._result; } /** * Clear the display and reset internal state. */ clear() { this._result = null; if (this._container) { this._container.innerHTML = ''; } } /** * Destroy and clean up. */ destroy() { this.clear(); this._container = null; } // ─── Private Helpers ────────────────────────────────────────────────────── /** * Escape HTML special characters. * @param {string} str * @returns {string} */ _escapeHtml(str) { if (typeof str !== 'string') return String(str); return str .replace(/&/g, '&') .replace(//g, '>') .replace(/"/g, '"'); } } window.FlopsEstimator = FlopsEstimator;