File size: 2,160 Bytes
5fcb448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import tensorflow as tf
import struct, os, sys

with open("models/valid.tflite", "rb") as f:
    valid_data = bytearray(f.read())

os.makedirs("fuzz", exist_ok=True)

# Test each mutation in a separate subprocess to survive crashes
import subprocess

mutations = [
    ("root_offset_oob", 0, struct.pack("<I", 0xFFFFFFFF)),
    ("root_offset_zero", 0, struct.pack("<I", 0)),
    ("root_offset_past_end", 0, struct.pack("<I", len(valid_data) + 100)),
    ("ident_corrupt", 4, b"XXXX"),
    ("vtable_size_huge", 8, struct.pack("<H", 0xFFFF)),
    ("vtable_size_zero", 8, struct.pack("<H", 0)),
    ("negative_offset", 28, struct.pack("<i", -1)),
    ("truncated_half", 0, b""),  # special: truncate to half
    ("all_zeros", 0, b""),  # special: all zeros
]

for name, pos, val in mutations:
    if name == "truncated_half":
        corrupted = valid_data[:len(valid_data)//2]
    elif name == "all_zeros":
        corrupted = bytearray(len(valid_data))
    else:
        corrupted = valid_data[:]
        corrupted[pos:pos+len(val)] = val
    
    fpath = f"fuzz/{name}.tflite"
    with open(fpath, "wb") as f:
        f.write(corrupted)
    
    # Run in subprocess
    test_code = f'''
import tensorflow as tf
try:
    interp = tf.lite.Interpreter(model_path="{{fpath}}")
    try:
        interp.allocate_tensors()
        try:
            interp.invoke()
            print("EXECUTED {{name}}")
        except Exception as e:
            print(f"invoke_fail: {{type(e).__name__}}")
    except Exception as e:
        print(f"alloc_fail: {{type(e).__name__}}")
except Exception as e:
    print(f"load_fail: {{type(e).__name__}}")
'''
    result = subprocess.run(
        [sys.executable, "-c", test_code],
        capture_output=True, text=True, timeout=30
    )
    output = result.stdout.strip()
    if result.returncode < 0:
        # Negative = signal (e.g. -11 = SEGFAULT)
        import signal
        try:
            sig_name = signal.Signals(-result.returncode).name
        except:
            sig_name = f"signal {-result.returncode}"
        print(f"  {name}: CRASH ({sig_name}) <<-- POTENTIAL VULN")
    else:
        print(f"  {name}: {output}")