Spaces:
Running
Running
| /** | |
| * FlopsEstimator - Estimates FLOPs (Floating Point Operations) for ONNX models | |
| * Computes per-node and total FLOPs based on opType and input tensor shapes. | |
| * Requirements: 34.1, 34.2, 34.3, 34.4, 34.5, 34.6 | |
| */ | |
| class FlopsEstimator { | |
| /** | |
| * @param {string} containerId - ID of the container element for rendering | |
| */ | |
| constructor(containerId) { | |
| this._containerId = containerId; | |
| this._container = document.getElementById(containerId); | |
| this._result = null; | |
| if (!this._container) { | |
| console.warn(`[FlopsEstimator] Container #${containerId} not found`); | |
| } | |
| } | |
| // βββ FLOPs Formulas βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| /** | |
| * Supported opType formulas. | |
| * Each function receives (attrs, inputShapes) and returns a FLOPs count. | |
| */ | |
| static get FORMULAS() { | |
| return { | |
| // Conv: 2 * Cout * Hout * Wout * CinPerGroup * Kh * Kw | |
| 'Conv': function (attrs, inputShapes) { | |
| const input = inputShapes[0] || []; | |
| const weight = inputShapes[1] || []; | |
| const Cout = weight[0] || 0; | |
| const CinPerGroup = weight[1] || 0; | |
| const Kh = weight[2] || 0; | |
| const Kw = weight[3] || 0; | |
| // Estimate output spatial dims (assume same padding) | |
| const Hout = input[2] || 0; | |
| const Wout = input[3] || 0; | |
| return 2 * Cout * Hout * Wout * CinPerGroup * Kh * Kw; | |
| }, | |
| // MatMul: 2 * M * N * K | |
| 'MatMul': function (attrs, inputShapes) { | |
| const A = inputShapes[0] || []; | |
| const B = inputShapes[1] || []; | |
| const M = A[A.length - 2] || 0; | |
| const K = A[A.length - 1] || 0; | |
| const N = B[B.length - 1] || 0; | |
| return 2 * M * N * K; | |
| }, | |
| // Gemm: 2 * M * N * K + M * N (bias) | |
| 'Gemm': function (attrs, inputShapes) { | |
| const A = inputShapes[0] || []; | |
| const B = inputShapes[1] || []; | |
| const M = A[0] || 0; | |
| const K = A[1] || 0; | |
| const N = B[1] || 0; | |
| return 2 * M * N * K + M * N; | |
| }, | |
| // BatchNormalization: 4 * elements (mean, var, normalize, scale+shift) | |
| 'BatchNormalization': function (attrs, inputShapes) { | |
| const input = inputShapes[0] || []; | |
| if (input.length === 0) return 0; | |
| var elements = 1; | |
| for (var i = 0; i < input.length; i++) { | |
| elements *= (typeof input[i] === 'number' ? input[i] : 0); | |
| } | |
| return 4 * elements; | |
| }, | |
| // Relu: 1 comparison per element | |
| 'Relu': function (attrs, inputShapes) { | |
| const input = inputShapes[0] || []; | |
| if (input.length === 0) return 0; | |
| var elements = 1; | |
| for (var i = 0; i < input.length; i++) { | |
| elements *= (typeof input[i] === 'number' ? input[i] : 0); | |
| } | |
| return elements; | |
| }, | |
| // Add: element-wise | |
| 'Add': function (attrs, inputShapes) { | |
| const input = inputShapes[0] || []; | |
| if (input.length === 0) return 0; | |
| var elements = 1; | |
| for (var i = 0; i < input.length; i++) { | |
| elements *= (typeof input[i] === 'number' ? input[i] : 0); | |
| } | |
| return elements; | |
| }, | |
| // Mul: element-wise | |
| 'Mul': function (attrs, inputShapes) { | |
| const input = inputShapes[0] || []; | |
| if (input.length === 0) return 0; | |
| var elements = 1; | |
| for (var i = 0; i < input.length; i++) { | |
| elements *= (typeof input[i] === 'number' ? input[i] : 0); | |
| } | |
| return elements; | |
| } | |
| }; | |
| } | |
| // βββ Shape Lookup βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| /** | |
| * Build a map of tensor name β shape from all available sources. | |
| * Sources: parsedModel.inputs, parsedModel.outputs, parsedModel.graph.valueInfo, initializers | |
| * @param {Object} parsedModel | |
| * @returns {Object} tensorName β shape array | |
| */ | |
| _buildShapeMap(parsedModel) { | |
| const shapeMap = {}; | |
| // Model inputs | |
| const inputs = (parsedModel && parsedModel.inputs) || []; | |
| for (var i = 0; i < inputs.length; i++) { | |
| if (inputs[i].name && inputs[i].shape) { | |
| shapeMap[inputs[i].name] = inputs[i].shape; | |
| } | |
| } | |
| // Model outputs | |
| const outputs = (parsedModel && parsedModel.outputs) || []; | |
| for (var i = 0; i < outputs.length; i++) { | |
| if (outputs[i].name && outputs[i].shape) { | |
| shapeMap[outputs[i].name] = outputs[i].shape; | |
| } | |
| } | |
| // Intermediate value info | |
| const valueInfo = (parsedModel && parsedModel.graph && parsedModel.graph.valueInfo) || {}; | |
| const viKeys = Object.keys(valueInfo); | |
| for (var i = 0; i < viKeys.length; i++) { | |
| const vi = valueInfo[viKeys[i]]; | |
| if (vi && vi.name && vi.shape) { | |
| shapeMap[vi.name] = vi.shape; | |
| } | |
| } | |
| // Initializers | |
| const initializers = (parsedModel && parsedModel.graph && parsedModel.graph.initializers) || []; | |
| for (var i = 0; i < initializers.length; i++) { | |
| if (initializers[i].name && initializers[i].shape) { | |
| shapeMap[initializers[i].name] = initializers[i].shape; | |
| } | |
| } | |
| return shapeMap; | |
| } | |
| // βββ Public API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| /** | |
| * Estimate FLOPs for a single node given its opType, attributes, and resolved input shapes. | |
| * @param {{ opType: string, attributes: Object, inputShapes: Array<Array<number>> }} nodeInfo | |
| * @returns {number} Estimated FLOPs, or -1 if opType is unsupported | |
| */ | |
| estimateNodeFlops(nodeInfo) { | |
| const opType = nodeInfo && nodeInfo.opType; | |
| const formula = FlopsEstimator.FORMULAS[opType]; | |
| if (!formula) return -1; | |
| const attrs = (nodeInfo && nodeInfo.attributes) || {}; | |
| const inputShapes = (nodeInfo && nodeInfo.inputShapes) || []; | |
| const flops = formula(attrs, inputShapes); | |
| return (typeof flops === 'number' && isFinite(flops) && flops >= 0) ? flops : 0; | |
| } | |
| /** | |
| * Compute FLOPs for the entire model. | |
| * @param {Object} parsedModel - The parsed ONNX model | |
| * @returns {FlopsResult} | |
| */ | |
| compute(parsedModel) { | |
| const nodes = (parsedModel && parsedModel.graph && parsedModel.graph.nodes) || []; | |
| const shapeMap = this._buildShapeMap(parsedModel); | |
| const formulas = FlopsEstimator.FORMULAS; | |
| var totalFlops = 0; | |
| const perNode = []; | |
| const opTypeMap = {}; // opType β { totalFlops, count } | |
| const unsupportedSet = {}; // opType β true | |
| for (var i = 0; i < nodes.length; i++) { | |
| const node = nodes[i]; | |
| const opType = node.opType || ''; | |
| // Resolve input shapes for this node | |
| const inputShapes = []; | |
| const nodeInputs = node.inputs || []; | |
| for (var j = 0; j < nodeInputs.length; j++) { | |
| inputShapes.push(shapeMap[nodeInputs[j]] || []); | |
| } | |
| const flops = this.estimateNodeFlops({ | |
| opType: opType, | |
| attributes: node.attributes || {}, | |
| inputShapes: inputShapes | |
| }); | |
| if (flops < 0) { | |
| // Unsupported op | |
| unsupportedSet[opType] = true; | |
| perNode.push({ | |
| nodeId: node.id || '', | |
| nodeName: node.name || '', | |
| opType: opType, | |
| flops: -1, | |
| formattedFlops: 'N/A' | |
| }); | |
| } else { | |
| totalFlops += flops; | |
| perNode.push({ | |
| nodeId: node.id || '', | |
| nodeName: node.name || '', | |
| opType: opType, | |
| flops: flops, | |
| formattedFlops: this.formatFlops(flops) | |
| }); | |
| if (!opTypeMap[opType]) { | |
| opTypeMap[opType] = { totalFlops: 0, count: 0 }; | |
| } | |
| opTypeMap[opType].totalFlops += flops; | |
| opTypeMap[opType].count += 1; | |
| } | |
| } | |
| // Build perOpType array sorted descending by totalFlops | |
| const perOpType = []; | |
| const opTypes = Object.keys(opTypeMap); | |
| for (var i = 0; i < opTypes.length; i++) { | |
| const ot = opTypes[i]; | |
| const entry = opTypeMap[ot]; | |
| perOpType.push({ | |
| opType: ot, | |
| totalFlops: entry.totalFlops, | |
| formattedFlops: this.formatFlops(entry.totalFlops), | |
| count: entry.count, | |
| percentage: totalFlops > 0 ? (entry.totalFlops / totalFlops) * 100 : 0 | |
| }); | |
| } | |
| perOpType.sort(function (a, b) { return b.totalFlops - a.totalFlops; }); | |
| const unsupportedOps = Object.keys(unsupportedSet); | |
| this._result = { | |
| totalFlops: totalFlops, | |
| formattedTotal: this.formatFlops(totalFlops), | |
| perNode: perNode, | |
| perOpType: perOpType, | |
| unsupportedOps: unsupportedOps | |
| }; | |
| return this._result; | |
| } | |
| /** | |
| * Format a FLOPs number into a human-readable string. | |
| * @param {number} flops | |
| * @returns {string} e.g. "1.20 TFLOPs", "500.00 MFLOPs", "0 FLOPs" | |
| */ | |
| formatFlops(flops) { | |
| if (flops == null || isNaN(flops) || flops < 0) return 'N/A'; | |
| if (flops >= 1e12) return (flops / 1e12).toFixed(2) + ' TFLOPs'; | |
| if (flops >= 1e9) return (flops / 1e9).toFixed(2) + ' GFLOPs'; | |
| if (flops >= 1e6) return (flops / 1e6).toFixed(2) + ' MFLOPs'; | |
| if (flops >= 1e3) return (flops / 1e3).toFixed(2) + ' KFLOPs'; | |
| return flops + ' FLOPs'; | |
| } | |
| /** | |
| * Render the FLOPs estimation result into the container. | |
| * @param {FlopsResult} [result] - If omitted, uses the last computed result | |
| */ | |
| render(result) { | |
| if (!this._container) return; | |
| const data = result || this._result; | |
| if (!data) { | |
| this._container.innerHTML = '<p class="text-muted">No FLOPs estimation available.</p>'; | |
| return; | |
| } | |
| var html = '<div class="flops-card">'; | |
| // Total FLOPs header | |
| html += '<div class="d-flex align-items-center mb-2">'; | |
| html += '<i class="fas fa-calculator me-2"></i>'; | |
| html += '<strong>Total FLOPs:</strong> '; | |
| html += '<span class="flops-total">' + this._escapeHtml(data.formattedTotal) + '</span>'; | |
| html += '</div>'; | |
| // Breakdown by opType | |
| if (data.perOpType && data.perOpType.length > 0) { | |
| html += '<table class="flops-table table table-sm table-bordered mb-2">'; | |
| html += '<thead><tr><th>Op Type</th><th>FLOPs</th><th>Count</th><th>%</th></tr></thead>'; | |
| html += '<tbody>'; | |
| for (var i = 0; i < data.perOpType.length; i++) { | |
| var entry = data.perOpType[i]; | |
| html += '<tr>'; | |
| html += '<td>' + this._escapeHtml(entry.opType) + '</td>'; | |
| html += '<td>' + this._escapeHtml(entry.formattedFlops) + '</td>'; | |
| html += '<td>' + entry.count + '</td>'; | |
| html += '<td>' + entry.percentage.toFixed(1) + '%</td>'; | |
| html += '</tr>'; | |
| } | |
| html += '</tbody></table>'; | |
| } | |
| // Unsupported ops | |
| if (data.unsupportedOps && data.unsupportedOps.length > 0) { | |
| html += '<div class="small text-muted">'; | |
| html += '<i class="fas fa-info-circle me-1"></i>'; | |
| html += 'Unsupported ops (N/A): ' + data.unsupportedOps.map(this._escapeHtml).join(', '); | |
| html += '</div>'; | |
| } | |
| html += '</div>'; | |
| this._container.innerHTML = html; | |
| } | |
| /** | |
| * Get the last computed result. | |
| * @returns {FlopsResult|null} | |
| */ | |
| getResult() { | |
| return this._result; | |
| } | |
| /** | |
| * Clear the display and reset internal state. | |
| */ | |
| clear() { | |
| this._result = null; | |
| if (this._container) { | |
| this._container.innerHTML = ''; | |
| } | |
| } | |
| /** | |
| * Destroy and clean up. | |
| */ | |
| destroy() { | |
| this.clear(); | |
| this._container = null; | |
| } | |
| // βββ Private Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| /** | |
| * Escape HTML special characters. | |
| * @param {string} str | |
| * @returns {string} | |
| */ | |
| _escapeHtml(str) { | |
| if (typeof str !== 'string') return String(str); | |
| return str | |
| .replace(/&/g, '&') | |
| .replace(/</g, '<') | |
| .replace(/>/g, '>') | |
| .replace(/"/g, '"'); | |
| } | |
| } | |
| window.FlopsEstimator = FlopsEstimator; | |