File size: 1,006 Bytes
e0969aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# /// script
# dependencies = ["torch", "kernels"]
# ///
"""Test the Erdos-Straus kernel against CPU-verified values."""
import torch
from kernels import get_kernel

erdos_straus = get_kernel("cahlen/erdos-straus-cuda")

# CPU-verified f(p) values (exact algorithm match)
KNOWN = {
    2: 1, 3: 3, 5: 2, 7: 7, 11: 9, 13: 4, 17: 4, 19: 11,
    23: 21, 29: 7, 31: 19, 37: 9, 41: 7, 43: 14, 97: 8,
    101: 16, 1009: 19,
}

primes = torch.tensor(list(KNOWN.keys()), dtype=torch.int64, device="cuda")
counts = erdos_straus.count(primes)
results = dict(zip(KNOWN.keys(), counts.tolist()))

passed = failed = 0
for p, expected in KNOWN.items():
    got = results[p]
    if got != expected:
        print(f"  FAIL: f({p}) = {got}, expected {expected}")
        failed += 1
    else:
        passed += 1

print(f"{passed}/{passed+failed} tests passed")
assert failed == 0, f"{failed} tests failed!"
assert all(c >= 1 for c in counts.tolist()), "COUNTEREXAMPLE FOUND!"
print("Conjecture holds for all test primes")