/** * Unit tests for TFLiteParser * Validates: Requirements 2.1, 2.2, 2.3, 2.4, 2.5, 2.6 */ import { describe, it, expect } from 'vitest'; // ─── Lookup tables (subset for testing) ───────────────────────────────── const BUILTIN_OPERATORS = { 0: 'ADD', 1: 'AVERAGE_POOL_2D', 2: 'CONCATENATION', 3: 'CONV_2D', 4: 'DEPTHWISE_CONV_2D', 9: 'FULLY_CONNECTED', 25: 'SOFTMAX', 32: 'CUSTOM' }; const TENSOR_TYPES = { 0: 'FLOAT32', 1: 'FLOAT16', 2: 'INT32', 3: 'UINT8', 4: 'INT64', 5: 'STRING', 6: 'BOOL', 7: 'INT16', 8: 'COMPLEX64', 9: 'INT8', 10: 'FLOAT64', 11: 'COMPLEX128', 12: 'UINT64', 13: 'UINT32', 14: 'UINT16', 15: 'INT4', 16: 'BFLOAT16' }; const BYTES_PER_ELEMENT = { 'FLOAT32': 4, 'FLOAT16': 2, 'INT32': 4, 'UINT8': 1, 'INT64': 8, 'STRING': 1, 'BOOL': 1, 'INT16': 2, 'COMPLEX64': 8, 'INT8': 1, 'FLOAT64': 8, 'COMPLEX128': 16, 'UINT64': 8, 'UINT32': 4, 'UINT16': 2, 'INT4': 0.5, 'BFLOAT16': 2 }; // ─── FlatBuffer Binary Builder ────────────────────────────────────────── // Builds valid FlatBuffer binaries using a deferred-offset approach. // All uoffset_t values point forward (to higher addresses). // // Strategy: Write tables with placeholder uoffset fields, then write // the referenced data AFTER the table, then patch the offsets. class FB { constructor() { this.buf = new ArrayBuffer(65536); this.view = new DataView(this.buf); this.pos = 0; } align(n) { const r = this.pos % n; if (r) this.pos += n - r; } u8(p, v) { this.view.setUint8(p, v); } u16(p, v) { this.view.setUint16(p, v, true); } u32(p, v) { this.view.setUint32(p, v, true); } i32(p, v) { this.view.setInt32(p, v, true); } // Patch a uoffset_t at patchPos to point to targetPos patch(patchPos, targetPos) { this.u32(patchPos, targetPos - patchPos); } writeStr(s) { this.align(4); const p = this.pos; const enc = new TextEncoder().encode(s); this.u32(p, enc.length); new Uint8Array(this.buf, p + 4, enc.length).set(enc); this.buf[p + 4 + enc.length] = 0; this.pos = p + 4 + enc.length + 1; this.align(4); return p; } writeI32Vec(arr) { this.align(4); const p = this.pos; this.u32(p, arr.length); for (let i = 0; i < arr.length; i++) this.i32(p + 4 + i * 4, arr[i]); this.pos = p + 4 + arr.length * 4; return p; } // Write a vector of N uoffset_t placeholders, return { vecPos, elemPositions[] } writeOffsetVecPlaceholder(count) { this.align(4); const p = this.pos; this.u32(p, count); this.pos += 4; const elems = []; for (let i = 0; i < count; i++) { elems.push(this.pos); this.u32(this.pos, 0); // placeholder this.pos += 4; } return { vecPos: p, elemPositions: elems }; } /** * Write a vtable + table. Returns { tablePos, fieldPositions }. * fieldDefs: array of { type: 'u8'|'u32'|'i32'|'uoffset', value? } or null * For 'uoffset', value is ignored (placeholder 0 written), caller patches later. */ writeTable(fieldDefs) { const layouts = []; let dataOff = 4; // after soffset_t for (const f of fieldDefs) { if (f === null) { layouts.push(null); continue; } const sz = f.type === 'u8' ? 1 : 4; const r = dataOff % sz; if (r) dataOff += sz - r; layouts.push({ off: dataOff, sz }); dataOff += sz; } if (dataOff % 4) dataOff += 4 - (dataOff % 4); const objSize = dataOff; // Vtable this.align(4); const vtPos = this.pos; const vtSize = 4 + fieldDefs.length * 2; this.u16(vtPos, vtSize); this.u16(vtPos + 2, objSize); for (let i = 0; i < fieldDefs.length; i++) { this.u16(vtPos + 4 + i * 2, layouts[i] ? layouts[i].off : 0); } this.pos = vtPos + vtSize; this.align(4); // Table const tblPos = this.pos; this.i32(tblPos, tblPos - vtPos); // soffset to vtable this.pos = tblPos + objSize; // Write inline values and record field positions const fieldPositions = {}; for (let i = 0; i < fieldDefs.length; i++) { const f = fieldDefs[i]; const l = layouts[i]; if (!f || !l) continue; const fPos = tblPos + l.off; fieldPositions[i] = fPos; switch (f.type) { case 'u8': this.u8(fPos, f.value || 0); break; case 'u32': this.u32(fPos, f.value || 0); break; case 'i32': this.i32(fPos, f.value || 0); break; case 'uoffset': this.u32(fPos, 0); break; // placeholder } } return { tablePos: tblPos, fieldPositions }; } done() { this.align(4); return this.buf.slice(0, this.pos); } } /** * Build a valid TFLite FlatBuffer. * Layout order (low → high address): * [root offset] [Model vtable+table] [data referenced by Model] ... * All uoffset_t point forward. */ function buildTFLiteBuffer(opts = {}) { const { version = 3, description = '', operatorCodes = [], tensors = [], operators = [], inputIndices = [], outputIndices = [] } = opts; const fb = new FB(); fb.pos = 4; // reserve root offset // ── Write Model table (with placeholder offsets) ── const model = fb.writeTable([ { type: 'u32', value: version }, // 0: version operatorCodes.length > 0 ? { type: 'uoffset' } : null, // 1: operator_codes { type: 'uoffset' }, // 2: subgraphs description ? { type: 'uoffset' } : null // 3: description ]); fb.u32(0, model.tablePos); // root offset // ── Write description string ── if (description && model.fieldPositions[3] !== undefined) { const p = fb.writeStr(description); fb.patch(model.fieldPositions[3], p); } // ── Write operator_codes ── if (operatorCodes.length > 0 && model.fieldPositions[1] !== undefined) { // Write operator_codes vector placeholder const opcVec = fb.writeOffsetVecPlaceholder(operatorCodes.length); fb.patch(model.fieldPositions[1], opcVec.vecPos); // Write each OperatorCode table AFTER the vector for (let i = 0; i < operatorCodes.length; i++) { const oc = operatorCodes[i]; const ocTable = fb.writeTable([ { type: 'u8', value: oc.builtinCode }, // 0: builtin_code oc.customCode ? { type: 'uoffset' } : null, // 1: custom_code { type: 'i32', value: 1 } // 2: version ]); fb.patch(opcVec.elemPositions[i], ocTable.tablePos); // Write custom_code string if present if (oc.customCode && ocTable.fieldPositions[1] !== undefined) { const p = fb.writeStr(oc.customCode); fb.patch(ocTable.fieldPositions[1], p); } } } // ── Write subgraphs ── // Subgraphs vector placeholder (1 subgraph) const sgVec = fb.writeOffsetVecPlaceholder(1); fb.patch(model.fieldPositions[2], sgVec.vecPos); // Write SubGraph table const sg = fb.writeTable([ tensors.length > 0 ? { type: 'uoffset' } : null, // 0: tensors inputIndices.length > 0 ? { type: 'uoffset' } : null, // 1: inputs outputIndices.length > 0 ? { type: 'uoffset' } : null, // 2: outputs operators.length > 0 ? { type: 'uoffset' } : null, // 3: operators null // 4: name ]); fb.patch(sgVec.elemPositions[0], sg.tablePos); // Write inputs vector if (inputIndices.length > 0 && sg.fieldPositions[1] !== undefined) { const p = fb.writeI32Vec(inputIndices); fb.patch(sg.fieldPositions[1], p); } // Write outputs vector if (outputIndices.length > 0 && sg.fieldPositions[2] !== undefined) { const p = fb.writeI32Vec(outputIndices); fb.patch(sg.fieldPositions[2], p); } // Write tensors if (tensors.length > 0 && sg.fieldPositions[0] !== undefined) { const tVec = fb.writeOffsetVecPlaceholder(tensors.length); fb.patch(sg.fieldPositions[0], tVec.vecPos); for (let i = 0; i < tensors.length; i++) { const t = tensors[i]; const tt = fb.writeTable([ t.shape && t.shape.length > 0 ? { type: 'uoffset' } : null, // 0: shape { type: 'u8', value: t.typeCode || 0 }, // 1: type { type: 'u32', value: 0 }, // 2: buffer t.name ? { type: 'uoffset' } : null // 3: name ]); fb.patch(tVec.elemPositions[i], tt.tablePos); // Write shape vector if (t.shape && t.shape.length > 0 && tt.fieldPositions[0] !== undefined) { const p = fb.writeI32Vec(t.shape); fb.patch(tt.fieldPositions[0], p); } // Write name string if (t.name && tt.fieldPositions[3] !== undefined) { const p = fb.writeStr(t.name); fb.patch(tt.fieldPositions[3], p); } } } // Write operators if (operators.length > 0 && sg.fieldPositions[3] !== undefined) { const oVec = fb.writeOffsetVecPlaceholder(operators.length); fb.patch(sg.fieldPositions[3], oVec.vecPos); for (let i = 0; i < operators.length; i++) { const op = operators[i]; const ot = fb.writeTable([ { type: 'u32', value: op.opcodeIndex } // 0: opcode_index ]); fb.patch(oVec.elemPositions[i], ot.tablePos); } } return fb.done(); } // ─── Re-implement TFLiteParser.parse() for testability ────────────────── function _getFieldOffset(view, tablePos, fieldIndex) { const vtableRelOffset = view.getInt32(tablePos, true); const vtablePos = tablePos - vtableRelOffset; const vtableSize = view.getUint16(vtablePos, true); const fieldVtableOffset = 4 + fieldIndex * 2; if (fieldVtableOffset >= vtableSize) return 0; const fieldRelOffset = view.getUint16(vtablePos + fieldVtableOffset, true); if (fieldRelOffset === 0) return 0; return tablePos + fieldRelOffset; } function _ru32(view, tablePos, fi, def = 0) { const o = _getFieldOffset(view, tablePos, fi); return o === 0 ? def : view.getUint32(o, true); } function _ru8(view, tablePos, fi, def = 0) { const o = _getFieldOffset(view, tablePos, fi); return o === 0 ? def : view.getUint8(o); } function _rstr(view, tablePos, fi) { const o = _getFieldOffset(view, tablePos, fi); if (o === 0) return ''; const rel = view.getUint32(o, true); const sp = o + rel; const len = view.getUint32(sp, true); return new TextDecoder('utf-8').decode(new Uint8Array(view.buffer, sp + 4, len)); } function _rvec(view, tablePos, fi) { const o = _getFieldOffset(view, tablePos, fi); if (o === 0) return null; const rel = view.getUint32(o, true); const vp = o + rel; return { pos: vp + 4, length: view.getUint32(vp, true) }; } function _ri32vec(view, tablePos, fi) { const v = _rvec(view, tablePos, fi); if (!v) return []; const r = []; for (let i = 0; i < v.length; i++) r.push(view.getInt32(v.pos + i * 4, true)); return r; } function _deref(view, ep) { return ep + view.getUint32(ep, true); } function parse(buffer) { try { if (!buffer || !(buffer instanceof ArrayBuffer) || buffer.byteLength === 0) return { success: false, error: 'File không hợp lệ: buffer rỗng' }; if (buffer.byteLength < 8) return { success: false, error: 'File không hợp lệ: không đủ dữ liệu' }; const view = new DataView(buffer); const rootOff = view.getUint32(0, true); if (rootOff >= buffer.byteLength || rootOff < 4) return { success: false, error: 'File không hợp lệ: cấu trúc FlatBuffer lỗi' }; const mp = rootOff; const version = _ru32(view, mp, 0, 0); const description = _rstr(view, mp, 3); const operatorCodes = []; const opcVec = _rvec(view, mp, 1); if (opcVec) { for (let i = 0; i < opcVec.length; i++) { const ocp = _deref(view, opcVec.pos + i * 4); const bc = _ru8(view, ocp, 0, 0); const cc = _rstr(view, ocp, 1) || null; const name = (bc === 32 && cc) ? cc : (BUILTIN_OPERATORS[bc] || `UNKNOWN_OP_${bc}`); operatorCodes.push({ builtinCode: bc, customCode: cc, opcodeName: name }); } } const sgVec = _rvec(view, mp, 2); const sgCount = sgVec ? sgVec.length : 0; let tensors = [], operators = [], inputIndices = [], outputIndices = []; if (sgCount > 0) { const sgp = _deref(view, sgVec.pos); const tVec = _rvec(view, sgp, 0); if (tVec) { for (let i = 0; i < tVec.length; i++) { const tp = _deref(view, tVec.pos + i * 4); const shape = _ri32vec(view, tp, 0); const tc = _ru8(view, tp, 1, 0); const name = _rstr(view, tp, 3); const dtype = TENSOR_TYPES[tc] || `UNKNOWN_TYPE_${tc}`; const bpe = BYTES_PER_ELEMENT[dtype] || 1; const ec = shape.length > 0 ? shape.reduce((a, d) => a * Math.abs(d), 1) : 0; tensors.push({ name, shape, dtype, byteSize: Math.ceil(ec * bpe) }); } } inputIndices = _ri32vec(view, sgp, 1); outputIndices = _ri32vec(view, sgp, 2); const oVec = _rvec(view, sgp, 3); if (oVec) { for (let i = 0; i < oVec.length; i++) { const op = _deref(view, oVec.pos + i * 4); const oi = _ru32(view, op, 0, 0); const on = oi < operatorCodes.length ? operatorCodes[oi].opcodeName : `UNKNOWN_OP_${oi}`; operators.push({ opcodeName: on, opcodeIndex: oi }); } } } return { success: true, data: { version, description, operators, operatorCodes, tensors, subgraphs: sgCount, inputIndices, outputIndices } }; } catch (err) { return { success: false, error: err.message || 'Lỗi không xác định' }; } } // ─── Tests ────────────────────────────────────────────────────────────── describe('TFLiteParser - parse', () => { describe('Error handling (Req 2.3, 2.4, 2.5)', () => { it('should return error for null buffer', () => { expect(parse(null)).toEqual({ success: false, error: 'File không hợp lệ: buffer rỗng' }); }); it('should return error for undefined buffer', () => { expect(parse(undefined).success).toBe(false); }); it('should return error for empty buffer', () => { expect(parse(new ArrayBuffer(0)).error).toBe('File không hợp lệ: buffer rỗng'); }); it('should return error for buffer < 8 bytes', () => { expect(parse(new ArrayBuffer(4)).error).toBe('File không hợp lệ: không đủ dữ liệu'); }); it('should return error for root offset out of range', () => { const b = new ArrayBuffer(16); new DataView(b).setUint32(0, 99999, true); expect(parse(b).success).toBe(false); }); it('should return error for root offset < 4', () => { const b = new ArrayBuffer(16); new DataView(b).setUint32(0, 2, true); expect(parse(b).success).toBe(false); }); it('should never throw for random data', () => { const b = new ArrayBuffer(64); const a = new Uint8Array(b); for (let i = 0; i < a.length; i++) a[i] = Math.floor(Math.random() * 256); const r = parse(b); expect(typeof r.success).toBe('boolean'); }); }); describe('Successful parsing (Req 2.1, 2.2, 2.6)', () => { it('should parse a minimal valid TFLite buffer', () => { const buf = buildTFLiteBuffer({ version: 3, description: 'Test model', operatorCodes: [{ builtinCode: 3 }], tensors: [ { name: 'input', shape: [1, 224, 224, 3], typeCode: 0 }, { name: 'output', shape: [1, 1000], typeCode: 0 } ], operators: [{ opcodeIndex: 0 }], inputIndices: [0], outputIndices: [1] }); const r = parse(buf); expect(r.success).toBe(true); expect(r.data.version).toBe(3); expect(r.data.description).toBe('Test model'); expect(r.data.subgraphs).toBe(1); }); it('should extract operator codes correctly', () => { const buf = buildTFLiteBuffer({ operatorCodes: [{ builtinCode: 3 }, { builtinCode: 9 }, { builtinCode: 25 }], tensors: [{ name: 't', shape: [1], typeCode: 0 }], operators: [{ opcodeIndex: 0 }, { opcodeIndex: 1 }, { opcodeIndex: 2 }] }); const r = parse(buf); expect(r.success).toBe(true); expect(r.data.operatorCodes[0].opcodeName).toBe('CONV_2D'); expect(r.data.operatorCodes[1].opcodeName).toBe('FULLY_CONNECTED'); expect(r.data.operatorCodes[2].opcodeName).toBe('SOFTMAX'); }); it('should extract tensors with correct dtype and byteSize', () => { const buf = buildTFLiteBuffer({ operatorCodes: [{ builtinCode: 0 }], tensors: [ { name: 'float_tensor', shape: [2, 3], typeCode: 0 }, { name: 'int8_tensor', shape: [10, 10], typeCode: 9 }, { name: 'uint8_tensor', shape: [5], typeCode: 3 } ], operators: [{ opcodeIndex: 0 }] }); const r = parse(buf); expect(r.success).toBe(true); const ft = r.data.tensors.find(t => t.name === 'float_tensor'); expect(ft.dtype).toBe('FLOAT32'); expect(ft.shape).toEqual([2, 3]); expect(ft.byteSize).toBe(24); const i8 = r.data.tensors.find(t => t.name === 'int8_tensor'); expect(i8.dtype).toBe('INT8'); expect(i8.byteSize).toBe(100); }); it('should extract input and output indices', () => { const buf = buildTFLiteBuffer({ operatorCodes: [{ builtinCode: 0 }], tensors: [{ name: 'in', shape: [1], typeCode: 0 }, { name: 'out', shape: [1], typeCode: 0 }], operators: [{ opcodeIndex: 0 }], inputIndices: [0], outputIndices: [1] }); const r = parse(buf); expect(r.success).toBe(true); expect(r.data.inputIndices).toEqual([0]); expect(r.data.outputIndices).toEqual([1]); }); it('should extract operators with correct opcodeName', () => { const buf = buildTFLiteBuffer({ operatorCodes: [{ builtinCode: 3 }, { builtinCode: 25 }], tensors: [{ name: 't', shape: [1], typeCode: 0 }], operators: [{ opcodeIndex: 0 }, { opcodeIndex: 0 }, { opcodeIndex: 1 }] }); const r = parse(buf); expect(r.success).toBe(true); expect(r.data.operators).toHaveLength(3); expect(r.data.operators[0].opcodeName).toBe('CONV_2D'); expect(r.data.operators[2].opcodeName).toBe('SOFTMAX'); }); }); describe('Result structure invariant (Req 2.5)', () => { it('should have data with all required fields when success is true', () => { const buf = buildTFLiteBuffer({ operatorCodes: [{ builtinCode: 0 }], tensors: [{ name: 't', shape: [1], typeCode: 0 }], operators: [{ opcodeIndex: 0 }] }); const r = parse(buf); expect(r.success).toBe(true); expect(Array.isArray(r.data.operators)).toBe(true); expect(Array.isArray(r.data.tensors)).toBe(true); expect(typeof r.data.version).toBe('number'); expect(typeof r.data.subgraphs).toBe('number'); expect(Array.isArray(r.data.operatorCodes)).toBe(true); expect(Array.isArray(r.data.inputIndices)).toBe(true); expect(Array.isArray(r.data.outputIndices)).toBe(true); }); it('should have non-empty error when success is false', () => { expect(parse(null).error.length).toBeGreaterThan(0); }); }); describe('Unknown types handling', () => { it('should display UNKNOWN_OP for unknown operator codes', () => { const buf = buildTFLiteBuffer({ operatorCodes: [{ builtinCode: 250 }], tensors: [{ name: 't', shape: [1], typeCode: 0 }], operators: [{ opcodeIndex: 0 }] }); const r = parse(buf); expect(r.success).toBe(true); expect(r.data.operatorCodes[0].opcodeName).toBe('UNKNOWN_OP_250'); }); it('should display UNKNOWN_TYPE for unknown tensor types', () => { const buf = buildTFLiteBuffer({ operatorCodes: [{ builtinCode: 0 }], tensors: [{ name: 't', shape: [1], typeCode: 99 }], operators: [{ opcodeIndex: 0 }] }); const r = parse(buf); expect(r.success).toBe(true); expect(r.data.tensors[0].dtype).toBe('UNKNOWN_TYPE_99'); }); }); describe('Lookup tables', () => { it('should have 17 tensor types', () => { expect(Object.keys(TENSOR_TYPES)).toHaveLength(17); }); it('should have bytes per element for all tensor types', () => { for (const n of Object.values(TENSOR_TYPES)) { expect(BYTES_PER_ELEMENT[n]).toBeGreaterThan(0); } }); }); });