erdos-straus-cuda / scripts /test_erdos_straus.py
cahlen's picture
Erdos-Straus CUDA kernel: torch extension with build.toml, tests, README
e0969aa verified
# /// script
# dependencies = ["torch", "kernels"]
# ///
"""Test the Erdos-Straus kernel against CPU-verified values."""
import torch
from kernels import get_kernel
erdos_straus = get_kernel("cahlen/erdos-straus-cuda")
# CPU-verified f(p) values (exact algorithm match)
KNOWN = {
2: 1, 3: 3, 5: 2, 7: 7, 11: 9, 13: 4, 17: 4, 19: 11,
23: 21, 29: 7, 31: 19, 37: 9, 41: 7, 43: 14, 97: 8,
101: 16, 1009: 19,
}
primes = torch.tensor(list(KNOWN.keys()), dtype=torch.int64, device="cuda")
counts = erdos_straus.count(primes)
results = dict(zip(KNOWN.keys(), counts.tolist()))
passed = failed = 0
for p, expected in KNOWN.items():
got = results[p]
if got != expected:
print(f" FAIL: f({p}) = {got}, expected {expected}")
failed += 1
else:
passed += 1
print(f"{passed}/{passed+failed} tests passed")
assert failed == 0, f"{failed} tests failed!"
assert all(c >= 1 for c in counts.tolist()), "COUNTEREXAMPLE FOUND!"
print("Conjecture holds for all test primes")