/** * Unit tests for FlopsEstimator * Validates: Requirements 34.1, 34.2, 34.3, 34.4, 34.5, 34.6 */ import { describe, it, expect, beforeEach } from 'vitest'; // ─── Mirror pure logic from FlopsEstimator for testability ────────────── const FORMULAS = { 'Conv': function (attrs, inputShapes) { const input = inputShapes[0] || []; const weight = inputShapes[1] || []; const Cout = weight[0] || 0; const CinPerGroup = weight[1] || 0; const Kh = weight[2] || 0; const Kw = weight[3] || 0; const Hout = input[2] || 0; const Wout = input[3] || 0; return 2 * Cout * Hout * Wout * CinPerGroup * Kh * Kw; }, 'MatMul': function (attrs, inputShapes) { const A = inputShapes[0] || []; const B = inputShapes[1] || []; const M = A[A.length - 2] || 0; const K = A[A.length - 1] || 0; const N = B[B.length - 1] || 0; return 2 * M * N * K; }, 'Gemm': function (attrs, inputShapes) { const A = inputShapes[0] || []; const B = inputShapes[1] || []; const M = A[0] || 0; const K = A[1] || 0; const N = B[1] || 0; return 2 * M * N * K + M * N; }, 'BatchNormalization': function (attrs, inputShapes) { const input = inputShapes[0] || []; if (input.length === 0) return 0; var elements = 1; for (var i = 0; i < input.length; i++) { elements *= (typeof input[i] === 'number' ? input[i] : 0); } return 4 * elements; }, 'Relu': function (attrs, inputShapes) { const input = inputShapes[0] || []; if (input.length === 0) return 0; var elements = 1; for (var i = 0; i < input.length; i++) { elements *= (typeof input[i] === 'number' ? input[i] : 0); } return elements; }, 'Add': function (attrs, inputShapes) { const input = inputShapes[0] || []; if (input.length === 0) return 0; var elements = 1; for (var i = 0; i < input.length; i++) { elements *= (typeof input[i] === 'number' ? input[i] : 0); } return elements; }, 'Mul': function (attrs, inputShapes) { const input = inputShapes[0] || []; if (input.length === 0) return 0; var elements = 1; for (var i = 0; i < input.length; i++) { elements *= (typeof input[i] === 'number' ? input[i] : 0); } return elements; } }; function estimateNodeFlops(nodeInfo) { const opType = nodeInfo && nodeInfo.opType; const formula = FORMULAS[opType]; if (!formula) return -1; const attrs = (nodeInfo && nodeInfo.attributes) || {}; const inputShapes = (nodeInfo && nodeInfo.inputShapes) || []; const flops = formula(attrs, inputShapes); return (typeof flops === 'number' && isFinite(flops) && flops >= 0) ? flops : 0; } function formatFlops(flops) { if (flops == null || isNaN(flops) || flops < 0) return 'N/A'; if (flops >= 1e12) return (flops / 1e12).toFixed(2) + ' TFLOPs'; if (flops >= 1e9) return (flops / 1e9).toFixed(2) + ' GFLOPs'; if (flops >= 1e6) return (flops / 1e6).toFixed(2) + ' MFLOPs'; if (flops >= 1e3) return (flops / 1e3).toFixed(2) + ' KFLOPs'; return flops + ' FLOPs'; } function buildShapeMap(parsedModel) { const shapeMap = {}; const inputs = (parsedModel && parsedModel.inputs) || []; for (var i = 0; i < inputs.length; i++) { if (inputs[i].name && inputs[i].shape) shapeMap[inputs[i].name] = inputs[i].shape; } const outputs = (parsedModel && parsedModel.outputs) || []; for (var i = 0; i < outputs.length; i++) { if (outputs[i].name && outputs[i].shape) shapeMap[outputs[i].name] = outputs[i].shape; } const valueInfo = (parsedModel && parsedModel.graph && parsedModel.graph.valueInfo) || {}; const viKeys = Object.keys(valueInfo); for (var i = 0; i < viKeys.length; i++) { const vi = valueInfo[viKeys[i]]; if (vi && vi.name && vi.shape) shapeMap[vi.name] = vi.shape; } const initializers = (parsedModel && parsedModel.graph && parsedModel.graph.initializers) || []; for (var i = 0; i < initializers.length; i++) { if (initializers[i].name && initializers[i].shape) shapeMap[initializers[i].name] = initializers[i].shape; } return shapeMap; } function compute(parsedModel) { const nodes = (parsedModel && parsedModel.graph && parsedModel.graph.nodes) || []; const shapeMap = buildShapeMap(parsedModel); var totalFlops = 0; const perNode = []; const opTypeMap = {}; const unsupportedSet = {}; for (var i = 0; i < nodes.length; i++) { const node = nodes[i]; const opType = node.opType || ''; const inputShapes = []; const nodeInputs = node.inputs || []; for (var j = 0; j < nodeInputs.length; j++) { inputShapes.push(shapeMap[nodeInputs[j]] || []); } const flops = estimateNodeFlops({ opType, attributes: node.attributes || {}, inputShapes }); if (flops < 0) { unsupportedSet[opType] = true; perNode.push({ nodeId: node.id, nodeName: node.name, opType, flops: -1, formattedFlops: 'N/A' }); } else { totalFlops += flops; perNode.push({ nodeId: node.id, nodeName: node.name, opType, flops, formattedFlops: formatFlops(flops) }); if (!opTypeMap[opType]) opTypeMap[opType] = { totalFlops: 0, count: 0 }; opTypeMap[opType].totalFlops += flops; opTypeMap[opType].count += 1; } } const perOpType = []; const opTypes = Object.keys(opTypeMap); for (var i = 0; i < opTypes.length; i++) { const ot = opTypes[i]; const entry = opTypeMap[ot]; perOpType.push({ opType: ot, totalFlops: entry.totalFlops, formattedFlops: formatFlops(entry.totalFlops), count: entry.count, percentage: totalFlops > 0 ? (entry.totalFlops / totalFlops) * 100 : 0 }); } perOpType.sort((a, b) => b.totalFlops - a.totalFlops); return { totalFlops, formattedTotal: formatFlops(totalFlops), perNode, perOpType, unsupportedOps: Object.keys(unsupportedSet) }; } // ─── Tests ────────────────────────────────────────────────────────────────── describe('FlopsEstimator.estimateNodeFlops', () => { it('should compute Conv FLOPs: 2 * Cout * Hout * Wout * CinPerGroup * Kh * Kw', () => { // input: [1, 3, 224, 224], weight: [64, 3, 7, 7] const flops = estimateNodeFlops({ opType: 'Conv', attributes: {}, inputShapes: [[1, 3, 224, 224], [64, 3, 7, 7]] }); // 2 * 64 * 224 * 224 * 3 * 7 * 7 = 2 * 64 * 50176 * 147 = 946,012,160 expect(flops).toBe(2 * 64 * 224 * 224 * 3 * 7 * 7); }); it('should compute MatMul FLOPs: 2 * M * N * K', () => { // A: [32, 128], B: [128, 64] const flops = estimateNodeFlops({ opType: 'MatMul', attributes: {}, inputShapes: [[32, 128], [128, 64]] }); expect(flops).toBe(2 * 32 * 64 * 128); }); it('should compute Gemm FLOPs: 2 * M * N * K + M * N', () => { // A: [10, 20], B: [20, 30] const flops = estimateNodeFlops({ opType: 'Gemm', attributes: {}, inputShapes: [[10, 20], [20, 30]] }); expect(flops).toBe(2 * 10 * 30 * 20 + 10 * 30); }); it('should compute BatchNormalization FLOPs: 4 * elements', () => { const flops = estimateNodeFlops({ opType: 'BatchNormalization', attributes: {}, inputShapes: [[1, 64, 56, 56]] }); expect(flops).toBe(4 * 1 * 64 * 56 * 56); }); it('should compute Relu FLOPs: elements', () => { const flops = estimateNodeFlops({ opType: 'Relu', attributes: {}, inputShapes: [[1, 64, 56, 56]] }); expect(flops).toBe(1 * 64 * 56 * 56); }); it('should compute Add FLOPs: elements', () => { const flops = estimateNodeFlops({ opType: 'Add', attributes: {}, inputShapes: [[1, 256, 14, 14]] }); expect(flops).toBe(1 * 256 * 14 * 14); }); it('should compute Mul FLOPs: elements', () => { const flops = estimateNodeFlops({ opType: 'Mul', attributes: {}, inputShapes: [[2, 128, 28, 28]] }); expect(flops).toBe(2 * 128 * 28 * 28); }); it('should return -1 for unsupported opType', () => { expect(estimateNodeFlops({ opType: 'Softmax', inputShapes: [[1, 1000]] })).toBe(-1); expect(estimateNodeFlops({ opType: 'Sigmoid', inputShapes: [[1, 64]] })).toBe(-1); }); it('should return 0 for supported op with empty input shapes', () => { expect(estimateNodeFlops({ opType: 'Relu', inputShapes: [[]] })).toBe(0); expect(estimateNodeFlops({ opType: 'Conv', inputShapes: [[], []] })).toBe(0); }); }); describe('FlopsEstimator.formatFlops', () => { it('should format TFLOPs', () => { expect(formatFlops(1.5e12)).toBe('1.50 TFLOPs'); expect(formatFlops(1e12)).toBe('1.00 TFLOPs'); }); it('should format GFLOPs', () => { expect(formatFlops(2.5e9)).toBe('2.50 GFLOPs'); expect(formatFlops(1e9)).toBe('1.00 GFLOPs'); }); it('should format MFLOPs', () => { expect(formatFlops(500e6)).toBe('500.00 MFLOPs'); expect(formatFlops(1e6)).toBe('1.00 MFLOPs'); }); it('should format KFLOPs', () => { expect(formatFlops(5000)).toBe('5.00 KFLOPs'); expect(formatFlops(1000)).toBe('1.00 KFLOPs'); }); it('should format plain FLOPs for small values', () => { expect(formatFlops(500)).toBe('500 FLOPs'); expect(formatFlops(0)).toBe('0 FLOPs'); }); it('should return N/A for negative or invalid values', () => { expect(formatFlops(-1)).toBe('N/A'); expect(formatFlops(null)).toBe('N/A'); expect(formatFlops(NaN)).toBe('N/A'); }); }); describe('FlopsEstimator.compute', () => { it('should compute total FLOPs across all supported nodes', () => { const model = { inputs: [{ name: 'input', shape: [1, 3, 224, 224] }], outputs: [{ name: 'output', shape: [1, 1000] }], graph: { nodes: [ { id: 'n0', name: 'conv1', opType: 'Conv', inputs: ['input', 'conv1_w'], outputs: ['conv1_out'] }, { id: 'n1', name: 'relu1', opType: 'Relu', inputs: ['conv1_out'], outputs: ['relu1_out'] }, ], edges: [], initializers: [ { name: 'conv1_w', shape: [64, 3, 7, 7] } ], valueInfo: { 'conv1_out': { name: 'conv1_out', shape: [1, 64, 224, 224] } } } }; const result = compute(model); const convFlops = 2 * 64 * 224 * 224 * 3 * 7 * 7; const reluFlops = 1 * 64 * 224 * 224; expect(result.totalFlops).toBe(convFlops + reluFlops); expect(result.perNode).toHaveLength(2); expect(result.perOpType.length).toBeGreaterThanOrEqual(1); }); it('should mark unsupported ops with N/A', () => { const model = { inputs: [{ name: 'x', shape: [1, 10] }], outputs: [], graph: { nodes: [ { id: 'n0', name: 'softmax', opType: 'Softmax', inputs: ['x'], outputs: ['y'] } ], edges: [], initializers: [], valueInfo: {} } }; const result = compute(model); expect(result.totalFlops).toBe(0); expect(result.unsupportedOps).toContain('Softmax'); expect(result.perNode[0].formattedFlops).toBe('N/A'); expect(result.perNode[0].flops).toBe(-1); }); it('should produce perOpType breakdown sorted descending', () => { const model = { inputs: [{ name: 'x', shape: [1, 64, 56, 56] }], outputs: [], graph: { nodes: [ { id: 'n0', name: 'relu1', opType: 'Relu', inputs: ['x'], outputs: ['r1'] }, { id: 'n1', name: 'bn1', opType: 'BatchNormalization', inputs: ['r1'], outputs: ['b1'] }, ], edges: [], initializers: [], valueInfo: { 'r1': { name: 'r1', shape: [1, 64, 56, 56] } } } }; const result = compute(model); expect(result.perOpType.length).toBe(2); // BN = 4 * elements > Relu = 1 * elements, so BN should be first expect(result.perOpType[0].opType).toBe('BatchNormalization'); expect(result.perOpType[1].opType).toBe('Relu'); }); it('should handle null/undefined model gracefully', () => { const result = compute(null); expect(result.totalFlops).toBe(0); expect(result.perNode).toHaveLength(0); expect(result.perOpType).toHaveLength(0); expect(result.unsupportedOps).toHaveLength(0); }); it('should look up shapes from initializers', () => { const model = { inputs: [], outputs: [], graph: { nodes: [ { id: 'n0', name: 'matmul', opType: 'MatMul', inputs: ['a', 'b'], outputs: ['c'] } ], edges: [], initializers: [ { name: 'a', shape: [4, 8] }, { name: 'b', shape: [8, 16] } ], valueInfo: {} } }; const result = compute(model); expect(result.totalFlops).toBe(2 * 4 * 16 * 8); }); });