model-explorer / js /tests /tfliteParser.test.js
mr4's picture
Upload 71 files
9bd422a verified
/**
* Unit tests for TFLiteParser
* Validates: Requirements 2.1, 2.2, 2.3, 2.4, 2.5, 2.6
*/
import { describe, it, expect } from 'vitest';
// ─── Lookup tables (subset for testing) ─────────────────────────────────
const BUILTIN_OPERATORS = {
0: 'ADD', 1: 'AVERAGE_POOL_2D', 2: 'CONCATENATION', 3: 'CONV_2D',
4: 'DEPTHWISE_CONV_2D', 9: 'FULLY_CONNECTED', 25: 'SOFTMAX', 32: 'CUSTOM'
};
const TENSOR_TYPES = {
0: 'FLOAT32', 1: 'FLOAT16', 2: 'INT32', 3: 'UINT8',
4: 'INT64', 5: 'STRING', 6: 'BOOL', 7: 'INT16',
8: 'COMPLEX64', 9: 'INT8', 10: 'FLOAT64', 11: 'COMPLEX128',
12: 'UINT64', 13: 'UINT32', 14: 'UINT16', 15: 'INT4', 16: 'BFLOAT16'
};
const BYTES_PER_ELEMENT = {
'FLOAT32': 4, 'FLOAT16': 2, 'INT32': 4, 'UINT8': 1,
'INT64': 8, 'STRING': 1, 'BOOL': 1, 'INT16': 2,
'COMPLEX64': 8, 'INT8': 1, 'FLOAT64': 8, 'COMPLEX128': 16,
'UINT64': 8, 'UINT32': 4, 'UINT16': 2, 'INT4': 0.5, 'BFLOAT16': 2
};
// ─── FlatBuffer Binary Builder ──────────────────────────────────────────
// Builds valid FlatBuffer binaries using a deferred-offset approach.
// All uoffset_t values point forward (to higher addresses).
//
// Strategy: Write tables with placeholder uoffset fields, then write
// the referenced data AFTER the table, then patch the offsets.
class FB {
constructor() {
this.buf = new ArrayBuffer(65536);
this.view = new DataView(this.buf);
this.pos = 0;
}
align(n) { const r = this.pos % n; if (r) this.pos += n - r; }
u8(p, v) { this.view.setUint8(p, v); }
u16(p, v) { this.view.setUint16(p, v, true); }
u32(p, v) { this.view.setUint32(p, v, true); }
i32(p, v) { this.view.setInt32(p, v, true); }
// Patch a uoffset_t at patchPos to point to targetPos
patch(patchPos, targetPos) {
this.u32(patchPos, targetPos - patchPos);
}
writeStr(s) {
this.align(4);
const p = this.pos;
const enc = new TextEncoder().encode(s);
this.u32(p, enc.length);
new Uint8Array(this.buf, p + 4, enc.length).set(enc);
this.buf[p + 4 + enc.length] = 0;
this.pos = p + 4 + enc.length + 1;
this.align(4);
return p;
}
writeI32Vec(arr) {
this.align(4);
const p = this.pos;
this.u32(p, arr.length);
for (let i = 0; i < arr.length; i++) this.i32(p + 4 + i * 4, arr[i]);
this.pos = p + 4 + arr.length * 4;
return p;
}
// Write a vector of N uoffset_t placeholders, return { vecPos, elemPositions[] }
writeOffsetVecPlaceholder(count) {
this.align(4);
const p = this.pos;
this.u32(p, count);
this.pos += 4;
const elems = [];
for (let i = 0; i < count; i++) {
elems.push(this.pos);
this.u32(this.pos, 0); // placeholder
this.pos += 4;
}
return { vecPos: p, elemPositions: elems };
}
/**
* Write a vtable + table. Returns { tablePos, fieldPositions }.
* fieldDefs: array of { type: 'u8'|'u32'|'i32'|'uoffset', value? } or null
* For 'uoffset', value is ignored (placeholder 0 written), caller patches later.
*/
writeTable(fieldDefs) {
const layouts = [];
let dataOff = 4; // after soffset_t
for (const f of fieldDefs) {
if (f === null) { layouts.push(null); continue; }
const sz = f.type === 'u8' ? 1 : 4;
const r = dataOff % sz;
if (r) dataOff += sz - r;
layouts.push({ off: dataOff, sz });
dataOff += sz;
}
if (dataOff % 4) dataOff += 4 - (dataOff % 4);
const objSize = dataOff;
// Vtable
this.align(4);
const vtPos = this.pos;
const vtSize = 4 + fieldDefs.length * 2;
this.u16(vtPos, vtSize);
this.u16(vtPos + 2, objSize);
for (let i = 0; i < fieldDefs.length; i++) {
this.u16(vtPos + 4 + i * 2, layouts[i] ? layouts[i].off : 0);
}
this.pos = vtPos + vtSize;
this.align(4);
// Table
const tblPos = this.pos;
this.i32(tblPos, tblPos - vtPos); // soffset to vtable
this.pos = tblPos + objSize;
// Write inline values and record field positions
const fieldPositions = {};
for (let i = 0; i < fieldDefs.length; i++) {
const f = fieldDefs[i];
const l = layouts[i];
if (!f || !l) continue;
const fPos = tblPos + l.off;
fieldPositions[i] = fPos;
switch (f.type) {
case 'u8': this.u8(fPos, f.value || 0); break;
case 'u32': this.u32(fPos, f.value || 0); break;
case 'i32': this.i32(fPos, f.value || 0); break;
case 'uoffset': this.u32(fPos, 0); break; // placeholder
}
}
return { tablePos: tblPos, fieldPositions };
}
done() { this.align(4); return this.buf.slice(0, this.pos); }
}
/**
* Build a valid TFLite FlatBuffer.
* Layout order (low → high address):
* [root offset] [Model vtable+table] [data referenced by Model] ...
* All uoffset_t point forward.
*/
function buildTFLiteBuffer(opts = {}) {
const {
version = 3, description = '',
operatorCodes = [], tensors = [],
operators = [], inputIndices = [], outputIndices = []
} = opts;
const fb = new FB();
fb.pos = 4; // reserve root offset
// ── Write Model table (with placeholder offsets) ──
const model = fb.writeTable([
{ type: 'u32', value: version }, // 0: version
operatorCodes.length > 0 ? { type: 'uoffset' } : null, // 1: operator_codes
{ type: 'uoffset' }, // 2: subgraphs
description ? { type: 'uoffset' } : null // 3: description
]);
fb.u32(0, model.tablePos); // root offset
// ── Write description string ──
if (description && model.fieldPositions[3] !== undefined) {
const p = fb.writeStr(description);
fb.patch(model.fieldPositions[3], p);
}
// ── Write operator_codes ──
if (operatorCodes.length > 0 && model.fieldPositions[1] !== undefined) {
// Write operator_codes vector placeholder
const opcVec = fb.writeOffsetVecPlaceholder(operatorCodes.length);
fb.patch(model.fieldPositions[1], opcVec.vecPos);
// Write each OperatorCode table AFTER the vector
for (let i = 0; i < operatorCodes.length; i++) {
const oc = operatorCodes[i];
const ocTable = fb.writeTable([
{ type: 'u8', value: oc.builtinCode }, // 0: builtin_code
oc.customCode ? { type: 'uoffset' } : null, // 1: custom_code
{ type: 'i32', value: 1 } // 2: version
]);
fb.patch(opcVec.elemPositions[i], ocTable.tablePos);
// Write custom_code string if present
if (oc.customCode && ocTable.fieldPositions[1] !== undefined) {
const p = fb.writeStr(oc.customCode);
fb.patch(ocTable.fieldPositions[1], p);
}
}
}
// ── Write subgraphs ──
// Subgraphs vector placeholder (1 subgraph)
const sgVec = fb.writeOffsetVecPlaceholder(1);
fb.patch(model.fieldPositions[2], sgVec.vecPos);
// Write SubGraph table
const sg = fb.writeTable([
tensors.length > 0 ? { type: 'uoffset' } : null, // 0: tensors
inputIndices.length > 0 ? { type: 'uoffset' } : null, // 1: inputs
outputIndices.length > 0 ? { type: 'uoffset' } : null, // 2: outputs
operators.length > 0 ? { type: 'uoffset' } : null, // 3: operators
null // 4: name
]);
fb.patch(sgVec.elemPositions[0], sg.tablePos);
// Write inputs vector
if (inputIndices.length > 0 && sg.fieldPositions[1] !== undefined) {
const p = fb.writeI32Vec(inputIndices);
fb.patch(sg.fieldPositions[1], p);
}
// Write outputs vector
if (outputIndices.length > 0 && sg.fieldPositions[2] !== undefined) {
const p = fb.writeI32Vec(outputIndices);
fb.patch(sg.fieldPositions[2], p);
}
// Write tensors
if (tensors.length > 0 && sg.fieldPositions[0] !== undefined) {
const tVec = fb.writeOffsetVecPlaceholder(tensors.length);
fb.patch(sg.fieldPositions[0], tVec.vecPos);
for (let i = 0; i < tensors.length; i++) {
const t = tensors[i];
const tt = fb.writeTable([
t.shape && t.shape.length > 0 ? { type: 'uoffset' } : null, // 0: shape
{ type: 'u8', value: t.typeCode || 0 }, // 1: type
{ type: 'u32', value: 0 }, // 2: buffer
t.name ? { type: 'uoffset' } : null // 3: name
]);
fb.patch(tVec.elemPositions[i], tt.tablePos);
// Write shape vector
if (t.shape && t.shape.length > 0 && tt.fieldPositions[0] !== undefined) {
const p = fb.writeI32Vec(t.shape);
fb.patch(tt.fieldPositions[0], p);
}
// Write name string
if (t.name && tt.fieldPositions[3] !== undefined) {
const p = fb.writeStr(t.name);
fb.patch(tt.fieldPositions[3], p);
}
}
}
// Write operators
if (operators.length > 0 && sg.fieldPositions[3] !== undefined) {
const oVec = fb.writeOffsetVecPlaceholder(operators.length);
fb.patch(sg.fieldPositions[3], oVec.vecPos);
for (let i = 0; i < operators.length; i++) {
const op = operators[i];
const ot = fb.writeTable([
{ type: 'u32', value: op.opcodeIndex } // 0: opcode_index
]);
fb.patch(oVec.elemPositions[i], ot.tablePos);
}
}
return fb.done();
}
// ─── Re-implement TFLiteParser.parse() for testability ──────────────────
function _getFieldOffset(view, tablePos, fieldIndex) {
const vtableRelOffset = view.getInt32(tablePos, true);
const vtablePos = tablePos - vtableRelOffset;
const vtableSize = view.getUint16(vtablePos, true);
const fieldVtableOffset = 4 + fieldIndex * 2;
if (fieldVtableOffset >= vtableSize) return 0;
const fieldRelOffset = view.getUint16(vtablePos + fieldVtableOffset, true);
if (fieldRelOffset === 0) return 0;
return tablePos + fieldRelOffset;
}
function _ru32(view, tablePos, fi, def = 0) {
const o = _getFieldOffset(view, tablePos, fi);
return o === 0 ? def : view.getUint32(o, true);
}
function _ru8(view, tablePos, fi, def = 0) {
const o = _getFieldOffset(view, tablePos, fi);
return o === 0 ? def : view.getUint8(o);
}
function _rstr(view, tablePos, fi) {
const o = _getFieldOffset(view, tablePos, fi);
if (o === 0) return '';
const rel = view.getUint32(o, true);
const sp = o + rel;
const len = view.getUint32(sp, true);
return new TextDecoder('utf-8').decode(new Uint8Array(view.buffer, sp + 4, len));
}
function _rvec(view, tablePos, fi) {
const o = _getFieldOffset(view, tablePos, fi);
if (o === 0) return null;
const rel = view.getUint32(o, true);
const vp = o + rel;
return { pos: vp + 4, length: view.getUint32(vp, true) };
}
function _ri32vec(view, tablePos, fi) {
const v = _rvec(view, tablePos, fi);
if (!v) return [];
const r = [];
for (let i = 0; i < v.length; i++) r.push(view.getInt32(v.pos + i * 4, true));
return r;
}
function _deref(view, ep) { return ep + view.getUint32(ep, true); }
function parse(buffer) {
try {
if (!buffer || !(buffer instanceof ArrayBuffer) || buffer.byteLength === 0)
return { success: false, error: 'File không hợp lệ: buffer rỗng' };
if (buffer.byteLength < 8)
return { success: false, error: 'File không hợp lệ: không đủ dữ liệu' };
const view = new DataView(buffer);
const rootOff = view.getUint32(0, true);
if (rootOff >= buffer.byteLength || rootOff < 4)
return { success: false, error: 'File không hợp lệ: cấu trúc FlatBuffer lỗi' };
const mp = rootOff;
const version = _ru32(view, mp, 0, 0);
const description = _rstr(view, mp, 3);
const operatorCodes = [];
const opcVec = _rvec(view, mp, 1);
if (opcVec) {
for (let i = 0; i < opcVec.length; i++) {
const ocp = _deref(view, opcVec.pos + i * 4);
const bc = _ru8(view, ocp, 0, 0);
const cc = _rstr(view, ocp, 1) || null;
const name = (bc === 32 && cc) ? cc : (BUILTIN_OPERATORS[bc] || `UNKNOWN_OP_${bc}`);
operatorCodes.push({ builtinCode: bc, customCode: cc, opcodeName: name });
}
}
const sgVec = _rvec(view, mp, 2);
const sgCount = sgVec ? sgVec.length : 0;
let tensors = [], operators = [], inputIndices = [], outputIndices = [];
if (sgCount > 0) {
const sgp = _deref(view, sgVec.pos);
const tVec = _rvec(view, sgp, 0);
if (tVec) {
for (let i = 0; i < tVec.length; i++) {
const tp = _deref(view, tVec.pos + i * 4);
const shape = _ri32vec(view, tp, 0);
const tc = _ru8(view, tp, 1, 0);
const name = _rstr(view, tp, 3);
const dtype = TENSOR_TYPES[tc] || `UNKNOWN_TYPE_${tc}`;
const bpe = BYTES_PER_ELEMENT[dtype] || 1;
const ec = shape.length > 0 ? shape.reduce((a, d) => a * Math.abs(d), 1) : 0;
tensors.push({ name, shape, dtype, byteSize: Math.ceil(ec * bpe) });
}
}
inputIndices = _ri32vec(view, sgp, 1);
outputIndices = _ri32vec(view, sgp, 2);
const oVec = _rvec(view, sgp, 3);
if (oVec) {
for (let i = 0; i < oVec.length; i++) {
const op = _deref(view, oVec.pos + i * 4);
const oi = _ru32(view, op, 0, 0);
const on = oi < operatorCodes.length ? operatorCodes[oi].opcodeName : `UNKNOWN_OP_${oi}`;
operators.push({ opcodeName: on, opcodeIndex: oi });
}
}
}
return { success: true, data: { version, description, operators, operatorCodes,
tensors, subgraphs: sgCount, inputIndices, outputIndices } };
} catch (err) {
return { success: false, error: err.message || 'Lỗi không xác định' };
}
}
// ─── Tests ──────────────────────────────────────────────────────────────
describe('TFLiteParser - parse', () => {
describe('Error handling (Req 2.3, 2.4, 2.5)', () => {
it('should return error for null buffer', () => {
expect(parse(null)).toEqual({ success: false, error: 'File không hợp lệ: buffer rỗng' });
});
it('should return error for undefined buffer', () => {
expect(parse(undefined).success).toBe(false);
});
it('should return error for empty buffer', () => {
expect(parse(new ArrayBuffer(0)).error).toBe('File không hợp lệ: buffer rỗng');
});
it('should return error for buffer < 8 bytes', () => {
expect(parse(new ArrayBuffer(4)).error).toBe('File không hợp lệ: không đủ dữ liệu');
});
it('should return error for root offset out of range', () => {
const b = new ArrayBuffer(16);
new DataView(b).setUint32(0, 99999, true);
expect(parse(b).success).toBe(false);
});
it('should return error for root offset < 4', () => {
const b = new ArrayBuffer(16);
new DataView(b).setUint32(0, 2, true);
expect(parse(b).success).toBe(false);
});
it('should never throw for random data', () => {
const b = new ArrayBuffer(64);
const a = new Uint8Array(b);
for (let i = 0; i < a.length; i++) a[i] = Math.floor(Math.random() * 256);
const r = parse(b);
expect(typeof r.success).toBe('boolean');
});
});
describe('Successful parsing (Req 2.1, 2.2, 2.6)', () => {
it('should parse a minimal valid TFLite buffer', () => {
const buf = buildTFLiteBuffer({
version: 3, description: 'Test model',
operatorCodes: [{ builtinCode: 3 }],
tensors: [
{ name: 'input', shape: [1, 224, 224, 3], typeCode: 0 },
{ name: 'output', shape: [1, 1000], typeCode: 0 }
],
operators: [{ opcodeIndex: 0 }],
inputIndices: [0], outputIndices: [1]
});
const r = parse(buf);
expect(r.success).toBe(true);
expect(r.data.version).toBe(3);
expect(r.data.description).toBe('Test model');
expect(r.data.subgraphs).toBe(1);
});
it('should extract operator codes correctly', () => {
const buf = buildTFLiteBuffer({
operatorCodes: [{ builtinCode: 3 }, { builtinCode: 9 }, { builtinCode: 25 }],
tensors: [{ name: 't', shape: [1], typeCode: 0 }],
operators: [{ opcodeIndex: 0 }, { opcodeIndex: 1 }, { opcodeIndex: 2 }]
});
const r = parse(buf);
expect(r.success).toBe(true);
expect(r.data.operatorCodes[0].opcodeName).toBe('CONV_2D');
expect(r.data.operatorCodes[1].opcodeName).toBe('FULLY_CONNECTED');
expect(r.data.operatorCodes[2].opcodeName).toBe('SOFTMAX');
});
it('should extract tensors with correct dtype and byteSize', () => {
const buf = buildTFLiteBuffer({
operatorCodes: [{ builtinCode: 0 }],
tensors: [
{ name: 'float_tensor', shape: [2, 3], typeCode: 0 },
{ name: 'int8_tensor', shape: [10, 10], typeCode: 9 },
{ name: 'uint8_tensor', shape: [5], typeCode: 3 }
],
operators: [{ opcodeIndex: 0 }]
});
const r = parse(buf);
expect(r.success).toBe(true);
const ft = r.data.tensors.find(t => t.name === 'float_tensor');
expect(ft.dtype).toBe('FLOAT32');
expect(ft.shape).toEqual([2, 3]);
expect(ft.byteSize).toBe(24);
const i8 = r.data.tensors.find(t => t.name === 'int8_tensor');
expect(i8.dtype).toBe('INT8');
expect(i8.byteSize).toBe(100);
});
it('should extract input and output indices', () => {
const buf = buildTFLiteBuffer({
operatorCodes: [{ builtinCode: 0 }],
tensors: [{ name: 'in', shape: [1], typeCode: 0 }, { name: 'out', shape: [1], typeCode: 0 }],
operators: [{ opcodeIndex: 0 }],
inputIndices: [0], outputIndices: [1]
});
const r = parse(buf);
expect(r.success).toBe(true);
expect(r.data.inputIndices).toEqual([0]);
expect(r.data.outputIndices).toEqual([1]);
});
it('should extract operators with correct opcodeName', () => {
const buf = buildTFLiteBuffer({
operatorCodes: [{ builtinCode: 3 }, { builtinCode: 25 }],
tensors: [{ name: 't', shape: [1], typeCode: 0 }],
operators: [{ opcodeIndex: 0 }, { opcodeIndex: 0 }, { opcodeIndex: 1 }]
});
const r = parse(buf);
expect(r.success).toBe(true);
expect(r.data.operators).toHaveLength(3);
expect(r.data.operators[0].opcodeName).toBe('CONV_2D');
expect(r.data.operators[2].opcodeName).toBe('SOFTMAX');
});
});
describe('Result structure invariant (Req 2.5)', () => {
it('should have data with all required fields when success is true', () => {
const buf = buildTFLiteBuffer({
operatorCodes: [{ builtinCode: 0 }],
tensors: [{ name: 't', shape: [1], typeCode: 0 }],
operators: [{ opcodeIndex: 0 }]
});
const r = parse(buf);
expect(r.success).toBe(true);
expect(Array.isArray(r.data.operators)).toBe(true);
expect(Array.isArray(r.data.tensors)).toBe(true);
expect(typeof r.data.version).toBe('number');
expect(typeof r.data.subgraphs).toBe('number');
expect(Array.isArray(r.data.operatorCodes)).toBe(true);
expect(Array.isArray(r.data.inputIndices)).toBe(true);
expect(Array.isArray(r.data.outputIndices)).toBe(true);
});
it('should have non-empty error when success is false', () => {
expect(parse(null).error.length).toBeGreaterThan(0);
});
});
describe('Unknown types handling', () => {
it('should display UNKNOWN_OP for unknown operator codes', () => {
const buf = buildTFLiteBuffer({
operatorCodes: [{ builtinCode: 250 }],
tensors: [{ name: 't', shape: [1], typeCode: 0 }],
operators: [{ opcodeIndex: 0 }]
});
const r = parse(buf);
expect(r.success).toBe(true);
expect(r.data.operatorCodes[0].opcodeName).toBe('UNKNOWN_OP_250');
});
it('should display UNKNOWN_TYPE for unknown tensor types', () => {
const buf = buildTFLiteBuffer({
operatorCodes: [{ builtinCode: 0 }],
tensors: [{ name: 't', shape: [1], typeCode: 99 }],
operators: [{ opcodeIndex: 0 }]
});
const r = parse(buf);
expect(r.success).toBe(true);
expect(r.data.tensors[0].dtype).toBe('UNKNOWN_TYPE_99');
});
});
describe('Lookup tables', () => {
it('should have 17 tensor types', () => {
expect(Object.keys(TENSOR_TYPES)).toHaveLength(17);
});
it('should have bytes per element for all tensor types', () => {
for (const n of Object.values(TENSOR_TYPES)) {
expect(BYTES_PER_ELEMENT[n]).toBeGreaterThan(0);
}
});
});
});