model-explorer / js /ui /flopsEstimator.js
mr4's picture
Upload 71 files
9bd422a verified
/**
* FlopsEstimator - Estimates FLOPs (Floating Point Operations) for ONNX models
* Computes per-node and total FLOPs based on opType and input tensor shapes.
* Requirements: 34.1, 34.2, 34.3, 34.4, 34.5, 34.6
*/
class FlopsEstimator {
/**
* @param {string} containerId - ID of the container element for rendering
*/
constructor(containerId) {
this._containerId = containerId;
this._container = document.getElementById(containerId);
this._result = null;
if (!this._container) {
console.warn(`[FlopsEstimator] Container #${containerId} not found`);
}
}
// ─── FLOPs Formulas ───────────────────────────────────────────────────────
/**
* Supported opType formulas.
* Each function receives (attrs, inputShapes) and returns a FLOPs count.
*/
static get FORMULAS() {
return {
// Conv: 2 * Cout * Hout * Wout * CinPerGroup * Kh * Kw
'Conv': function (attrs, inputShapes) {
const input = inputShapes[0] || [];
const weight = inputShapes[1] || [];
const Cout = weight[0] || 0;
const CinPerGroup = weight[1] || 0;
const Kh = weight[2] || 0;
const Kw = weight[3] || 0;
// Estimate output spatial dims (assume same padding)
const Hout = input[2] || 0;
const Wout = input[3] || 0;
return 2 * Cout * Hout * Wout * CinPerGroup * Kh * Kw;
},
// MatMul: 2 * M * N * K
'MatMul': function (attrs, inputShapes) {
const A = inputShapes[0] || [];
const B = inputShapes[1] || [];
const M = A[A.length - 2] || 0;
const K = A[A.length - 1] || 0;
const N = B[B.length - 1] || 0;
return 2 * M * N * K;
},
// Gemm: 2 * M * N * K + M * N (bias)
'Gemm': function (attrs, inputShapes) {
const A = inputShapes[0] || [];
const B = inputShapes[1] || [];
const M = A[0] || 0;
const K = A[1] || 0;
const N = B[1] || 0;
return 2 * M * N * K + M * N;
},
// BatchNormalization: 4 * elements (mean, var, normalize, scale+shift)
'BatchNormalization': function (attrs, inputShapes) {
const input = inputShapes[0] || [];
if (input.length === 0) return 0;
var elements = 1;
for (var i = 0; i < input.length; i++) {
elements *= (typeof input[i] === 'number' ? input[i] : 0);
}
return 4 * elements;
},
// Relu: 1 comparison per element
'Relu': function (attrs, inputShapes) {
const input = inputShapes[0] || [];
if (input.length === 0) return 0;
var elements = 1;
for (var i = 0; i < input.length; i++) {
elements *= (typeof input[i] === 'number' ? input[i] : 0);
}
return elements;
},
// Add: element-wise
'Add': function (attrs, inputShapes) {
const input = inputShapes[0] || [];
if (input.length === 0) return 0;
var elements = 1;
for (var i = 0; i < input.length; i++) {
elements *= (typeof input[i] === 'number' ? input[i] : 0);
}
return elements;
},
// Mul: element-wise
'Mul': function (attrs, inputShapes) {
const input = inputShapes[0] || [];
if (input.length === 0) return 0;
var elements = 1;
for (var i = 0; i < input.length; i++) {
elements *= (typeof input[i] === 'number' ? input[i] : 0);
}
return elements;
}
};
}
// ─── Shape Lookup ─────────────────────────────────────────────────────────
/**
* Build a map of tensor name β†’ shape from all available sources.
* Sources: parsedModel.inputs, parsedModel.outputs, parsedModel.graph.valueInfo, initializers
* @param {Object} parsedModel
* @returns {Object} tensorName β†’ shape array
*/
_buildShapeMap(parsedModel) {
const shapeMap = {};
// Model inputs
const inputs = (parsedModel && parsedModel.inputs) || [];
for (var i = 0; i < inputs.length; i++) {
if (inputs[i].name && inputs[i].shape) {
shapeMap[inputs[i].name] = inputs[i].shape;
}
}
// Model outputs
const outputs = (parsedModel && parsedModel.outputs) || [];
for (var i = 0; i < outputs.length; i++) {
if (outputs[i].name && outputs[i].shape) {
shapeMap[outputs[i].name] = outputs[i].shape;
}
}
// Intermediate value info
const valueInfo = (parsedModel && parsedModel.graph && parsedModel.graph.valueInfo) || {};
const viKeys = Object.keys(valueInfo);
for (var i = 0; i < viKeys.length; i++) {
const vi = valueInfo[viKeys[i]];
if (vi && vi.name && vi.shape) {
shapeMap[vi.name] = vi.shape;
}
}
// Initializers
const initializers = (parsedModel && parsedModel.graph && parsedModel.graph.initializers) || [];
for (var i = 0; i < initializers.length; i++) {
if (initializers[i].name && initializers[i].shape) {
shapeMap[initializers[i].name] = initializers[i].shape;
}
}
return shapeMap;
}
// ─── Public API ───────────────────────────────────────────────────────────
/**
* Estimate FLOPs for a single node given its opType, attributes, and resolved input shapes.
* @param {{ opType: string, attributes: Object, inputShapes: Array<Array<number>> }} nodeInfo
* @returns {number} Estimated FLOPs, or -1 if opType is unsupported
*/
estimateNodeFlops(nodeInfo) {
const opType = nodeInfo && nodeInfo.opType;
const formula = FlopsEstimator.FORMULAS[opType];
if (!formula) return -1;
const attrs = (nodeInfo && nodeInfo.attributes) || {};
const inputShapes = (nodeInfo && nodeInfo.inputShapes) || [];
const flops = formula(attrs, inputShapes);
return (typeof flops === 'number' && isFinite(flops) && flops >= 0) ? flops : 0;
}
/**
* Compute FLOPs for the entire model.
* @param {Object} parsedModel - The parsed ONNX model
* @returns {FlopsResult}
*/
compute(parsedModel) {
const nodes = (parsedModel && parsedModel.graph && parsedModel.graph.nodes) || [];
const shapeMap = this._buildShapeMap(parsedModel);
const formulas = FlopsEstimator.FORMULAS;
var totalFlops = 0;
const perNode = [];
const opTypeMap = {}; // opType β†’ { totalFlops, count }
const unsupportedSet = {}; // opType β†’ true
for (var i = 0; i < nodes.length; i++) {
const node = nodes[i];
const opType = node.opType || '';
// Resolve input shapes for this node
const inputShapes = [];
const nodeInputs = node.inputs || [];
for (var j = 0; j < nodeInputs.length; j++) {
inputShapes.push(shapeMap[nodeInputs[j]] || []);
}
const flops = this.estimateNodeFlops({
opType: opType,
attributes: node.attributes || {},
inputShapes: inputShapes
});
if (flops < 0) {
// Unsupported op
unsupportedSet[opType] = true;
perNode.push({
nodeId: node.id || '',
nodeName: node.name || '',
opType: opType,
flops: -1,
formattedFlops: 'N/A'
});
} else {
totalFlops += flops;
perNode.push({
nodeId: node.id || '',
nodeName: node.name || '',
opType: opType,
flops: flops,
formattedFlops: this.formatFlops(flops)
});
if (!opTypeMap[opType]) {
opTypeMap[opType] = { totalFlops: 0, count: 0 };
}
opTypeMap[opType].totalFlops += flops;
opTypeMap[opType].count += 1;
}
}
// Build perOpType array sorted descending by totalFlops
const perOpType = [];
const opTypes = Object.keys(opTypeMap);
for (var i = 0; i < opTypes.length; i++) {
const ot = opTypes[i];
const entry = opTypeMap[ot];
perOpType.push({
opType: ot,
totalFlops: entry.totalFlops,
formattedFlops: this.formatFlops(entry.totalFlops),
count: entry.count,
percentage: totalFlops > 0 ? (entry.totalFlops / totalFlops) * 100 : 0
});
}
perOpType.sort(function (a, b) { return b.totalFlops - a.totalFlops; });
const unsupportedOps = Object.keys(unsupportedSet);
this._result = {
totalFlops: totalFlops,
formattedTotal: this.formatFlops(totalFlops),
perNode: perNode,
perOpType: perOpType,
unsupportedOps: unsupportedOps
};
return this._result;
}
/**
* Format a FLOPs number into a human-readable string.
* @param {number} flops
* @returns {string} e.g. "1.20 TFLOPs", "500.00 MFLOPs", "0 FLOPs"
*/
formatFlops(flops) {
if (flops == null || isNaN(flops) || flops < 0) return 'N/A';
if (flops >= 1e12) return (flops / 1e12).toFixed(2) + ' TFLOPs';
if (flops >= 1e9) return (flops / 1e9).toFixed(2) + ' GFLOPs';
if (flops >= 1e6) return (flops / 1e6).toFixed(2) + ' MFLOPs';
if (flops >= 1e3) return (flops / 1e3).toFixed(2) + ' KFLOPs';
return flops + ' FLOPs';
}
/**
* Render the FLOPs estimation result into the container.
* @param {FlopsResult} [result] - If omitted, uses the last computed result
*/
render(result) {
if (!this._container) return;
const data = result || this._result;
if (!data) {
this._container.innerHTML = '<p class="text-muted">No FLOPs estimation available.</p>';
return;
}
var html = '<div class="flops-card">';
// Total FLOPs header
html += '<div class="d-flex align-items-center mb-2">';
html += '<i class="fas fa-calculator me-2"></i>';
html += '<strong>Total FLOPs:</strong>&nbsp;';
html += '<span class="flops-total">' + this._escapeHtml(data.formattedTotal) + '</span>';
html += '</div>';
// Breakdown by opType
if (data.perOpType && data.perOpType.length > 0) {
html += '<table class="flops-table table table-sm table-bordered mb-2">';
html += '<thead><tr><th>Op Type</th><th>FLOPs</th><th>Count</th><th>%</th></tr></thead>';
html += '<tbody>';
for (var i = 0; i < data.perOpType.length; i++) {
var entry = data.perOpType[i];
html += '<tr>';
html += '<td>' + this._escapeHtml(entry.opType) + '</td>';
html += '<td>' + this._escapeHtml(entry.formattedFlops) + '</td>';
html += '<td>' + entry.count + '</td>';
html += '<td>' + entry.percentage.toFixed(1) + '%</td>';
html += '</tr>';
}
html += '</tbody></table>';
}
// Unsupported ops
if (data.unsupportedOps && data.unsupportedOps.length > 0) {
html += '<div class="small text-muted">';
html += '<i class="fas fa-info-circle me-1"></i>';
html += 'Unsupported ops (N/A): ' + data.unsupportedOps.map(this._escapeHtml).join(', ');
html += '</div>';
}
html += '</div>';
this._container.innerHTML = html;
}
/**
* Get the last computed result.
* @returns {FlopsResult|null}
*/
getResult() {
return this._result;
}
/**
* Clear the display and reset internal state.
*/
clear() {
this._result = null;
if (this._container) {
this._container.innerHTML = '';
}
}
/**
* Destroy and clean up.
*/
destroy() {
this.clear();
this._container = null;
}
// ─── Private Helpers ──────────────────────────────────────────────────────
/**
* Escape HTML special characters.
* @param {string} str
* @returns {string}
*/
_escapeHtml(str) {
if (typeof str !== 'string') return String(str);
return str
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;');
}
}
window.FlopsEstimator = FlopsEstimator;