Spaces:
Running
Running
| /** | |
| * Unit tests for FlopsEstimator | |
| * Validates: Requirements 34.1, 34.2, 34.3, 34.4, 34.5, 34.6 | |
| */ | |
| import { describe, it, expect, beforeEach } from 'vitest'; | |
| // ─── Mirror pure logic from FlopsEstimator for testability ────────────── | |
| const FORMULAS = { | |
| 'Conv': function (attrs, inputShapes) { | |
| const input = inputShapes[0] || []; | |
| const weight = inputShapes[1] || []; | |
| const Cout = weight[0] || 0; | |
| const CinPerGroup = weight[1] || 0; | |
| const Kh = weight[2] || 0; | |
| const Kw = weight[3] || 0; | |
| const Hout = input[2] || 0; | |
| const Wout = input[3] || 0; | |
| return 2 * Cout * Hout * Wout * CinPerGroup * Kh * Kw; | |
| }, | |
| 'MatMul': function (attrs, inputShapes) { | |
| const A = inputShapes[0] || []; | |
| const B = inputShapes[1] || []; | |
| const M = A[A.length - 2] || 0; | |
| const K = A[A.length - 1] || 0; | |
| const N = B[B.length - 1] || 0; | |
| return 2 * M * N * K; | |
| }, | |
| 'Gemm': function (attrs, inputShapes) { | |
| const A = inputShapes[0] || []; | |
| const B = inputShapes[1] || []; | |
| const M = A[0] || 0; | |
| const K = A[1] || 0; | |
| const N = B[1] || 0; | |
| return 2 * M * N * K + M * N; | |
| }, | |
| 'BatchNormalization': function (attrs, inputShapes) { | |
| const input = inputShapes[0] || []; | |
| if (input.length === 0) return 0; | |
| var elements = 1; | |
| for (var i = 0; i < input.length; i++) { | |
| elements *= (typeof input[i] === 'number' ? input[i] : 0); | |
| } | |
| return 4 * elements; | |
| }, | |
| 'Relu': function (attrs, inputShapes) { | |
| const input = inputShapes[0] || []; | |
| if (input.length === 0) return 0; | |
| var elements = 1; | |
| for (var i = 0; i < input.length; i++) { | |
| elements *= (typeof input[i] === 'number' ? input[i] : 0); | |
| } | |
| return elements; | |
| }, | |
| 'Add': function (attrs, inputShapes) { | |
| const input = inputShapes[0] || []; | |
| if (input.length === 0) return 0; | |
| var elements = 1; | |
| for (var i = 0; i < input.length; i++) { | |
| elements *= (typeof input[i] === 'number' ? input[i] : 0); | |
| } | |
| return elements; | |
| }, | |
| 'Mul': function (attrs, inputShapes) { | |
| const input = inputShapes[0] || []; | |
| if (input.length === 0) return 0; | |
| var elements = 1; | |
| for (var i = 0; i < input.length; i++) { | |
| elements *= (typeof input[i] === 'number' ? input[i] : 0); | |
| } | |
| return elements; | |
| } | |
| }; | |
| function estimateNodeFlops(nodeInfo) { | |
| const opType = nodeInfo && nodeInfo.opType; | |
| const formula = FORMULAS[opType]; | |
| if (!formula) return -1; | |
| const attrs = (nodeInfo && nodeInfo.attributes) || {}; | |
| const inputShapes = (nodeInfo && nodeInfo.inputShapes) || []; | |
| const flops = formula(attrs, inputShapes); | |
| return (typeof flops === 'number' && isFinite(flops) && flops >= 0) ? flops : 0; | |
| } | |
| function formatFlops(flops) { | |
| if (flops == null || isNaN(flops) || flops < 0) return 'N/A'; | |
| if (flops >= 1e12) return (flops / 1e12).toFixed(2) + ' TFLOPs'; | |
| if (flops >= 1e9) return (flops / 1e9).toFixed(2) + ' GFLOPs'; | |
| if (flops >= 1e6) return (flops / 1e6).toFixed(2) + ' MFLOPs'; | |
| if (flops >= 1e3) return (flops / 1e3).toFixed(2) + ' KFLOPs'; | |
| return flops + ' FLOPs'; | |
| } | |
| function buildShapeMap(parsedModel) { | |
| const shapeMap = {}; | |
| const inputs = (parsedModel && parsedModel.inputs) || []; | |
| for (var i = 0; i < inputs.length; i++) { | |
| if (inputs[i].name && inputs[i].shape) shapeMap[inputs[i].name] = inputs[i].shape; | |
| } | |
| const outputs = (parsedModel && parsedModel.outputs) || []; | |
| for (var i = 0; i < outputs.length; i++) { | |
| if (outputs[i].name && outputs[i].shape) shapeMap[outputs[i].name] = outputs[i].shape; | |
| } | |
| const valueInfo = (parsedModel && parsedModel.graph && parsedModel.graph.valueInfo) || {}; | |
| const viKeys = Object.keys(valueInfo); | |
| for (var i = 0; i < viKeys.length; i++) { | |
| const vi = valueInfo[viKeys[i]]; | |
| if (vi && vi.name && vi.shape) shapeMap[vi.name] = vi.shape; | |
| } | |
| const initializers = (parsedModel && parsedModel.graph && parsedModel.graph.initializers) || []; | |
| for (var i = 0; i < initializers.length; i++) { | |
| if (initializers[i].name && initializers[i].shape) shapeMap[initializers[i].name] = initializers[i].shape; | |
| } | |
| return shapeMap; | |
| } | |
| function compute(parsedModel) { | |
| const nodes = (parsedModel && parsedModel.graph && parsedModel.graph.nodes) || []; | |
| const shapeMap = buildShapeMap(parsedModel); | |
| var totalFlops = 0; | |
| const perNode = []; | |
| const opTypeMap = {}; | |
| const unsupportedSet = {}; | |
| for (var i = 0; i < nodes.length; i++) { | |
| const node = nodes[i]; | |
| const opType = node.opType || ''; | |
| const inputShapes = []; | |
| const nodeInputs = node.inputs || []; | |
| for (var j = 0; j < nodeInputs.length; j++) { | |
| inputShapes.push(shapeMap[nodeInputs[j]] || []); | |
| } | |
| const flops = estimateNodeFlops({ opType, attributes: node.attributes || {}, inputShapes }); | |
| if (flops < 0) { | |
| unsupportedSet[opType] = true; | |
| perNode.push({ nodeId: node.id, nodeName: node.name, opType, flops: -1, formattedFlops: 'N/A' }); | |
| } else { | |
| totalFlops += flops; | |
| perNode.push({ nodeId: node.id, nodeName: node.name, opType, flops, formattedFlops: formatFlops(flops) }); | |
| if (!opTypeMap[opType]) opTypeMap[opType] = { totalFlops: 0, count: 0 }; | |
| opTypeMap[opType].totalFlops += flops; | |
| opTypeMap[opType].count += 1; | |
| } | |
| } | |
| const perOpType = []; | |
| const opTypes = Object.keys(opTypeMap); | |
| for (var i = 0; i < opTypes.length; i++) { | |
| const ot = opTypes[i]; | |
| const entry = opTypeMap[ot]; | |
| perOpType.push({ | |
| opType: ot, | |
| totalFlops: entry.totalFlops, | |
| formattedFlops: formatFlops(entry.totalFlops), | |
| count: entry.count, | |
| percentage: totalFlops > 0 ? (entry.totalFlops / totalFlops) * 100 : 0 | |
| }); | |
| } | |
| perOpType.sort((a, b) => b.totalFlops - a.totalFlops); | |
| return { | |
| totalFlops, | |
| formattedTotal: formatFlops(totalFlops), | |
| perNode, | |
| perOpType, | |
| unsupportedOps: Object.keys(unsupportedSet) | |
| }; | |
| } | |
| // ─── Tests ────────────────────────────────────────────────────────────────── | |
| describe('FlopsEstimator.estimateNodeFlops', () => { | |
| it('should compute Conv FLOPs: 2 * Cout * Hout * Wout * CinPerGroup * Kh * Kw', () => { | |
| // input: [1, 3, 224, 224], weight: [64, 3, 7, 7] | |
| const flops = estimateNodeFlops({ | |
| opType: 'Conv', | |
| attributes: {}, | |
| inputShapes: [[1, 3, 224, 224], [64, 3, 7, 7]] | |
| }); | |
| // 2 * 64 * 224 * 224 * 3 * 7 * 7 = 2 * 64 * 50176 * 147 = 946,012,160 | |
| expect(flops).toBe(2 * 64 * 224 * 224 * 3 * 7 * 7); | |
| }); | |
| it('should compute MatMul FLOPs: 2 * M * N * K', () => { | |
| // A: [32, 128], B: [128, 64] | |
| const flops = estimateNodeFlops({ | |
| opType: 'MatMul', | |
| attributes: {}, | |
| inputShapes: [[32, 128], [128, 64]] | |
| }); | |
| expect(flops).toBe(2 * 32 * 64 * 128); | |
| }); | |
| it('should compute Gemm FLOPs: 2 * M * N * K + M * N', () => { | |
| // A: [10, 20], B: [20, 30] | |
| const flops = estimateNodeFlops({ | |
| opType: 'Gemm', | |
| attributes: {}, | |
| inputShapes: [[10, 20], [20, 30]] | |
| }); | |
| expect(flops).toBe(2 * 10 * 30 * 20 + 10 * 30); | |
| }); | |
| it('should compute BatchNormalization FLOPs: 4 * elements', () => { | |
| const flops = estimateNodeFlops({ | |
| opType: 'BatchNormalization', | |
| attributes: {}, | |
| inputShapes: [[1, 64, 56, 56]] | |
| }); | |
| expect(flops).toBe(4 * 1 * 64 * 56 * 56); | |
| }); | |
| it('should compute Relu FLOPs: elements', () => { | |
| const flops = estimateNodeFlops({ | |
| opType: 'Relu', | |
| attributes: {}, | |
| inputShapes: [[1, 64, 56, 56]] | |
| }); | |
| expect(flops).toBe(1 * 64 * 56 * 56); | |
| }); | |
| it('should compute Add FLOPs: elements', () => { | |
| const flops = estimateNodeFlops({ | |
| opType: 'Add', | |
| attributes: {}, | |
| inputShapes: [[1, 256, 14, 14]] | |
| }); | |
| expect(flops).toBe(1 * 256 * 14 * 14); | |
| }); | |
| it('should compute Mul FLOPs: elements', () => { | |
| const flops = estimateNodeFlops({ | |
| opType: 'Mul', | |
| attributes: {}, | |
| inputShapes: [[2, 128, 28, 28]] | |
| }); | |
| expect(flops).toBe(2 * 128 * 28 * 28); | |
| }); | |
| it('should return -1 for unsupported opType', () => { | |
| expect(estimateNodeFlops({ opType: 'Softmax', inputShapes: [[1, 1000]] })).toBe(-1); | |
| expect(estimateNodeFlops({ opType: 'Sigmoid', inputShapes: [[1, 64]] })).toBe(-1); | |
| }); | |
| it('should return 0 for supported op with empty input shapes', () => { | |
| expect(estimateNodeFlops({ opType: 'Relu', inputShapes: [[]] })).toBe(0); | |
| expect(estimateNodeFlops({ opType: 'Conv', inputShapes: [[], []] })).toBe(0); | |
| }); | |
| }); | |
| describe('FlopsEstimator.formatFlops', () => { | |
| it('should format TFLOPs', () => { | |
| expect(formatFlops(1.5e12)).toBe('1.50 TFLOPs'); | |
| expect(formatFlops(1e12)).toBe('1.00 TFLOPs'); | |
| }); | |
| it('should format GFLOPs', () => { | |
| expect(formatFlops(2.5e9)).toBe('2.50 GFLOPs'); | |
| expect(formatFlops(1e9)).toBe('1.00 GFLOPs'); | |
| }); | |
| it('should format MFLOPs', () => { | |
| expect(formatFlops(500e6)).toBe('500.00 MFLOPs'); | |
| expect(formatFlops(1e6)).toBe('1.00 MFLOPs'); | |
| }); | |
| it('should format KFLOPs', () => { | |
| expect(formatFlops(5000)).toBe('5.00 KFLOPs'); | |
| expect(formatFlops(1000)).toBe('1.00 KFLOPs'); | |
| }); | |
| it('should format plain FLOPs for small values', () => { | |
| expect(formatFlops(500)).toBe('500 FLOPs'); | |
| expect(formatFlops(0)).toBe('0 FLOPs'); | |
| }); | |
| it('should return N/A for negative or invalid values', () => { | |
| expect(formatFlops(-1)).toBe('N/A'); | |
| expect(formatFlops(null)).toBe('N/A'); | |
| expect(formatFlops(NaN)).toBe('N/A'); | |
| }); | |
| }); | |
| describe('FlopsEstimator.compute', () => { | |
| it('should compute total FLOPs across all supported nodes', () => { | |
| const model = { | |
| inputs: [{ name: 'input', shape: [1, 3, 224, 224] }], | |
| outputs: [{ name: 'output', shape: [1, 1000] }], | |
| graph: { | |
| nodes: [ | |
| { id: 'n0', name: 'conv1', opType: 'Conv', inputs: ['input', 'conv1_w'], outputs: ['conv1_out'] }, | |
| { id: 'n1', name: 'relu1', opType: 'Relu', inputs: ['conv1_out'], outputs: ['relu1_out'] }, | |
| ], | |
| edges: [], | |
| initializers: [ | |
| { name: 'conv1_w', shape: [64, 3, 7, 7] } | |
| ], | |
| valueInfo: { | |
| 'conv1_out': { name: 'conv1_out', shape: [1, 64, 224, 224] } | |
| } | |
| } | |
| }; | |
| const result = compute(model); | |
| const convFlops = 2 * 64 * 224 * 224 * 3 * 7 * 7; | |
| const reluFlops = 1 * 64 * 224 * 224; | |
| expect(result.totalFlops).toBe(convFlops + reluFlops); | |
| expect(result.perNode).toHaveLength(2); | |
| expect(result.perOpType.length).toBeGreaterThanOrEqual(1); | |
| }); | |
| it('should mark unsupported ops with N/A', () => { | |
| const model = { | |
| inputs: [{ name: 'x', shape: [1, 10] }], | |
| outputs: [], | |
| graph: { | |
| nodes: [ | |
| { id: 'n0', name: 'softmax', opType: 'Softmax', inputs: ['x'], outputs: ['y'] } | |
| ], | |
| edges: [], | |
| initializers: [], | |
| valueInfo: {} | |
| } | |
| }; | |
| const result = compute(model); | |
| expect(result.totalFlops).toBe(0); | |
| expect(result.unsupportedOps).toContain('Softmax'); | |
| expect(result.perNode[0].formattedFlops).toBe('N/A'); | |
| expect(result.perNode[0].flops).toBe(-1); | |
| }); | |
| it('should produce perOpType breakdown sorted descending', () => { | |
| const model = { | |
| inputs: [{ name: 'x', shape: [1, 64, 56, 56] }], | |
| outputs: [], | |
| graph: { | |
| nodes: [ | |
| { id: 'n0', name: 'relu1', opType: 'Relu', inputs: ['x'], outputs: ['r1'] }, | |
| { id: 'n1', name: 'bn1', opType: 'BatchNormalization', inputs: ['r1'], outputs: ['b1'] }, | |
| ], | |
| edges: [], | |
| initializers: [], | |
| valueInfo: { | |
| 'r1': { name: 'r1', shape: [1, 64, 56, 56] } | |
| } | |
| } | |
| }; | |
| const result = compute(model); | |
| expect(result.perOpType.length).toBe(2); | |
| // BN = 4 * elements > Relu = 1 * elements, so BN should be first | |
| expect(result.perOpType[0].opType).toBe('BatchNormalization'); | |
| expect(result.perOpType[1].opType).toBe('Relu'); | |
| }); | |
| it('should handle null/undefined model gracefully', () => { | |
| const result = compute(null); | |
| expect(result.totalFlops).toBe(0); | |
| expect(result.perNode).toHaveLength(0); | |
| expect(result.perOpType).toHaveLength(0); | |
| expect(result.unsupportedOps).toHaveLength(0); | |
| }); | |
| it('should look up shapes from initializers', () => { | |
| const model = { | |
| inputs: [], | |
| outputs: [], | |
| graph: { | |
| nodes: [ | |
| { id: 'n0', name: 'matmul', opType: 'MatMul', inputs: ['a', 'b'], outputs: ['c'] } | |
| ], | |
| edges: [], | |
| initializers: [ | |
| { name: 'a', shape: [4, 8] }, | |
| { name: 'b', shape: [8, 16] } | |
| ], | |
| valueInfo: {} | |
| } | |
| }; | |
| const result = compute(model); | |
| expect(result.totalFlops).toBe(2 * 4 * 16 * 8); | |
| }); | |
| }); | |