model-explorer / js /tests /flopsEstimator.test.js
mr4's picture
Upload 71 files
9bd422a verified
/**
* Unit tests for FlopsEstimator
* Validates: Requirements 34.1, 34.2, 34.3, 34.4, 34.5, 34.6
*/
import { describe, it, expect, beforeEach } from 'vitest';
// ─── Mirror pure logic from FlopsEstimator for testability ──────────────
const FORMULAS = {
'Conv': function (attrs, inputShapes) {
const input = inputShapes[0] || [];
const weight = inputShapes[1] || [];
const Cout = weight[0] || 0;
const CinPerGroup = weight[1] || 0;
const Kh = weight[2] || 0;
const Kw = weight[3] || 0;
const Hout = input[2] || 0;
const Wout = input[3] || 0;
return 2 * Cout * Hout * Wout * CinPerGroup * Kh * Kw;
},
'MatMul': function (attrs, inputShapes) {
const A = inputShapes[0] || [];
const B = inputShapes[1] || [];
const M = A[A.length - 2] || 0;
const K = A[A.length - 1] || 0;
const N = B[B.length - 1] || 0;
return 2 * M * N * K;
},
'Gemm': function (attrs, inputShapes) {
const A = inputShapes[0] || [];
const B = inputShapes[1] || [];
const M = A[0] || 0;
const K = A[1] || 0;
const N = B[1] || 0;
return 2 * M * N * K + M * N;
},
'BatchNormalization': function (attrs, inputShapes) {
const input = inputShapes[0] || [];
if (input.length === 0) return 0;
var elements = 1;
for (var i = 0; i < input.length; i++) {
elements *= (typeof input[i] === 'number' ? input[i] : 0);
}
return 4 * elements;
},
'Relu': function (attrs, inputShapes) {
const input = inputShapes[0] || [];
if (input.length === 0) return 0;
var elements = 1;
for (var i = 0; i < input.length; i++) {
elements *= (typeof input[i] === 'number' ? input[i] : 0);
}
return elements;
},
'Add': function (attrs, inputShapes) {
const input = inputShapes[0] || [];
if (input.length === 0) return 0;
var elements = 1;
for (var i = 0; i < input.length; i++) {
elements *= (typeof input[i] === 'number' ? input[i] : 0);
}
return elements;
},
'Mul': function (attrs, inputShapes) {
const input = inputShapes[0] || [];
if (input.length === 0) return 0;
var elements = 1;
for (var i = 0; i < input.length; i++) {
elements *= (typeof input[i] === 'number' ? input[i] : 0);
}
return elements;
}
};
function estimateNodeFlops(nodeInfo) {
const opType = nodeInfo && nodeInfo.opType;
const formula = FORMULAS[opType];
if (!formula) return -1;
const attrs = (nodeInfo && nodeInfo.attributes) || {};
const inputShapes = (nodeInfo && nodeInfo.inputShapes) || [];
const flops = formula(attrs, inputShapes);
return (typeof flops === 'number' && isFinite(flops) && flops >= 0) ? flops : 0;
}
function formatFlops(flops) {
if (flops == null || isNaN(flops) || flops < 0) return 'N/A';
if (flops >= 1e12) return (flops / 1e12).toFixed(2) + ' TFLOPs';
if (flops >= 1e9) return (flops / 1e9).toFixed(2) + ' GFLOPs';
if (flops >= 1e6) return (flops / 1e6).toFixed(2) + ' MFLOPs';
if (flops >= 1e3) return (flops / 1e3).toFixed(2) + ' KFLOPs';
return flops + ' FLOPs';
}
function buildShapeMap(parsedModel) {
const shapeMap = {};
const inputs = (parsedModel && parsedModel.inputs) || [];
for (var i = 0; i < inputs.length; i++) {
if (inputs[i].name && inputs[i].shape) shapeMap[inputs[i].name] = inputs[i].shape;
}
const outputs = (parsedModel && parsedModel.outputs) || [];
for (var i = 0; i < outputs.length; i++) {
if (outputs[i].name && outputs[i].shape) shapeMap[outputs[i].name] = outputs[i].shape;
}
const valueInfo = (parsedModel && parsedModel.graph && parsedModel.graph.valueInfo) || {};
const viKeys = Object.keys(valueInfo);
for (var i = 0; i < viKeys.length; i++) {
const vi = valueInfo[viKeys[i]];
if (vi && vi.name && vi.shape) shapeMap[vi.name] = vi.shape;
}
const initializers = (parsedModel && parsedModel.graph && parsedModel.graph.initializers) || [];
for (var i = 0; i < initializers.length; i++) {
if (initializers[i].name && initializers[i].shape) shapeMap[initializers[i].name] = initializers[i].shape;
}
return shapeMap;
}
function compute(parsedModel) {
const nodes = (parsedModel && parsedModel.graph && parsedModel.graph.nodes) || [];
const shapeMap = buildShapeMap(parsedModel);
var totalFlops = 0;
const perNode = [];
const opTypeMap = {};
const unsupportedSet = {};
for (var i = 0; i < nodes.length; i++) {
const node = nodes[i];
const opType = node.opType || '';
const inputShapes = [];
const nodeInputs = node.inputs || [];
for (var j = 0; j < nodeInputs.length; j++) {
inputShapes.push(shapeMap[nodeInputs[j]] || []);
}
const flops = estimateNodeFlops({ opType, attributes: node.attributes || {}, inputShapes });
if (flops < 0) {
unsupportedSet[opType] = true;
perNode.push({ nodeId: node.id, nodeName: node.name, opType, flops: -1, formattedFlops: 'N/A' });
} else {
totalFlops += flops;
perNode.push({ nodeId: node.id, nodeName: node.name, opType, flops, formattedFlops: formatFlops(flops) });
if (!opTypeMap[opType]) opTypeMap[opType] = { totalFlops: 0, count: 0 };
opTypeMap[opType].totalFlops += flops;
opTypeMap[opType].count += 1;
}
}
const perOpType = [];
const opTypes = Object.keys(opTypeMap);
for (var i = 0; i < opTypes.length; i++) {
const ot = opTypes[i];
const entry = opTypeMap[ot];
perOpType.push({
opType: ot,
totalFlops: entry.totalFlops,
formattedFlops: formatFlops(entry.totalFlops),
count: entry.count,
percentage: totalFlops > 0 ? (entry.totalFlops / totalFlops) * 100 : 0
});
}
perOpType.sort((a, b) => b.totalFlops - a.totalFlops);
return {
totalFlops,
formattedTotal: formatFlops(totalFlops),
perNode,
perOpType,
unsupportedOps: Object.keys(unsupportedSet)
};
}
// ─── Tests ──────────────────────────────────────────────────────────────────
describe('FlopsEstimator.estimateNodeFlops', () => {
it('should compute Conv FLOPs: 2 * Cout * Hout * Wout * CinPerGroup * Kh * Kw', () => {
// input: [1, 3, 224, 224], weight: [64, 3, 7, 7]
const flops = estimateNodeFlops({
opType: 'Conv',
attributes: {},
inputShapes: [[1, 3, 224, 224], [64, 3, 7, 7]]
});
// 2 * 64 * 224 * 224 * 3 * 7 * 7 = 2 * 64 * 50176 * 147 = 946,012,160
expect(flops).toBe(2 * 64 * 224 * 224 * 3 * 7 * 7);
});
it('should compute MatMul FLOPs: 2 * M * N * K', () => {
// A: [32, 128], B: [128, 64]
const flops = estimateNodeFlops({
opType: 'MatMul',
attributes: {},
inputShapes: [[32, 128], [128, 64]]
});
expect(flops).toBe(2 * 32 * 64 * 128);
});
it('should compute Gemm FLOPs: 2 * M * N * K + M * N', () => {
// A: [10, 20], B: [20, 30]
const flops = estimateNodeFlops({
opType: 'Gemm',
attributes: {},
inputShapes: [[10, 20], [20, 30]]
});
expect(flops).toBe(2 * 10 * 30 * 20 + 10 * 30);
});
it('should compute BatchNormalization FLOPs: 4 * elements', () => {
const flops = estimateNodeFlops({
opType: 'BatchNormalization',
attributes: {},
inputShapes: [[1, 64, 56, 56]]
});
expect(flops).toBe(4 * 1 * 64 * 56 * 56);
});
it('should compute Relu FLOPs: elements', () => {
const flops = estimateNodeFlops({
opType: 'Relu',
attributes: {},
inputShapes: [[1, 64, 56, 56]]
});
expect(flops).toBe(1 * 64 * 56 * 56);
});
it('should compute Add FLOPs: elements', () => {
const flops = estimateNodeFlops({
opType: 'Add',
attributes: {},
inputShapes: [[1, 256, 14, 14]]
});
expect(flops).toBe(1 * 256 * 14 * 14);
});
it('should compute Mul FLOPs: elements', () => {
const flops = estimateNodeFlops({
opType: 'Mul',
attributes: {},
inputShapes: [[2, 128, 28, 28]]
});
expect(flops).toBe(2 * 128 * 28 * 28);
});
it('should return -1 for unsupported opType', () => {
expect(estimateNodeFlops({ opType: 'Softmax', inputShapes: [[1, 1000]] })).toBe(-1);
expect(estimateNodeFlops({ opType: 'Sigmoid', inputShapes: [[1, 64]] })).toBe(-1);
});
it('should return 0 for supported op with empty input shapes', () => {
expect(estimateNodeFlops({ opType: 'Relu', inputShapes: [[]] })).toBe(0);
expect(estimateNodeFlops({ opType: 'Conv', inputShapes: [[], []] })).toBe(0);
});
});
describe('FlopsEstimator.formatFlops', () => {
it('should format TFLOPs', () => {
expect(formatFlops(1.5e12)).toBe('1.50 TFLOPs');
expect(formatFlops(1e12)).toBe('1.00 TFLOPs');
});
it('should format GFLOPs', () => {
expect(formatFlops(2.5e9)).toBe('2.50 GFLOPs');
expect(formatFlops(1e9)).toBe('1.00 GFLOPs');
});
it('should format MFLOPs', () => {
expect(formatFlops(500e6)).toBe('500.00 MFLOPs');
expect(formatFlops(1e6)).toBe('1.00 MFLOPs');
});
it('should format KFLOPs', () => {
expect(formatFlops(5000)).toBe('5.00 KFLOPs');
expect(formatFlops(1000)).toBe('1.00 KFLOPs');
});
it('should format plain FLOPs for small values', () => {
expect(formatFlops(500)).toBe('500 FLOPs');
expect(formatFlops(0)).toBe('0 FLOPs');
});
it('should return N/A for negative or invalid values', () => {
expect(formatFlops(-1)).toBe('N/A');
expect(formatFlops(null)).toBe('N/A');
expect(formatFlops(NaN)).toBe('N/A');
});
});
describe('FlopsEstimator.compute', () => {
it('should compute total FLOPs across all supported nodes', () => {
const model = {
inputs: [{ name: 'input', shape: [1, 3, 224, 224] }],
outputs: [{ name: 'output', shape: [1, 1000] }],
graph: {
nodes: [
{ id: 'n0', name: 'conv1', opType: 'Conv', inputs: ['input', 'conv1_w'], outputs: ['conv1_out'] },
{ id: 'n1', name: 'relu1', opType: 'Relu', inputs: ['conv1_out'], outputs: ['relu1_out'] },
],
edges: [],
initializers: [
{ name: 'conv1_w', shape: [64, 3, 7, 7] }
],
valueInfo: {
'conv1_out': { name: 'conv1_out', shape: [1, 64, 224, 224] }
}
}
};
const result = compute(model);
const convFlops = 2 * 64 * 224 * 224 * 3 * 7 * 7;
const reluFlops = 1 * 64 * 224 * 224;
expect(result.totalFlops).toBe(convFlops + reluFlops);
expect(result.perNode).toHaveLength(2);
expect(result.perOpType.length).toBeGreaterThanOrEqual(1);
});
it('should mark unsupported ops with N/A', () => {
const model = {
inputs: [{ name: 'x', shape: [1, 10] }],
outputs: [],
graph: {
nodes: [
{ id: 'n0', name: 'softmax', opType: 'Softmax', inputs: ['x'], outputs: ['y'] }
],
edges: [],
initializers: [],
valueInfo: {}
}
};
const result = compute(model);
expect(result.totalFlops).toBe(0);
expect(result.unsupportedOps).toContain('Softmax');
expect(result.perNode[0].formattedFlops).toBe('N/A');
expect(result.perNode[0].flops).toBe(-1);
});
it('should produce perOpType breakdown sorted descending', () => {
const model = {
inputs: [{ name: 'x', shape: [1, 64, 56, 56] }],
outputs: [],
graph: {
nodes: [
{ id: 'n0', name: 'relu1', opType: 'Relu', inputs: ['x'], outputs: ['r1'] },
{ id: 'n1', name: 'bn1', opType: 'BatchNormalization', inputs: ['r1'], outputs: ['b1'] },
],
edges: [],
initializers: [],
valueInfo: {
'r1': { name: 'r1', shape: [1, 64, 56, 56] }
}
}
};
const result = compute(model);
expect(result.perOpType.length).toBe(2);
// BN = 4 * elements > Relu = 1 * elements, so BN should be first
expect(result.perOpType[0].opType).toBe('BatchNormalization');
expect(result.perOpType[1].opType).toBe('Relu');
});
it('should handle null/undefined model gracefully', () => {
const result = compute(null);
expect(result.totalFlops).toBe(0);
expect(result.perNode).toHaveLength(0);
expect(result.perOpType).toHaveLength(0);
expect(result.unsupportedOps).toHaveLength(0);
});
it('should look up shapes from initializers', () => {
const model = {
inputs: [],
outputs: [],
graph: {
nodes: [
{ id: 'n0', name: 'matmul', opType: 'MatMul', inputs: ['a', 'b'], outputs: ['c'] }
],
edges: [],
initializers: [
{ name: 'a', shape: [4, 8] },
{ name: 'b', shape: [8, 16] }
],
valueInfo: {}
}
};
const result = compute(model);
expect(result.totalFlops).toBe(2 * 4 * 16 * 8);
});
});