| |
| |
| |
| """Test the Erdos-Straus kernel against CPU-verified values.""" |
| import torch |
| from kernels import get_kernel |
|
|
| erdos_straus = get_kernel("cahlen/erdos-straus-cuda") |
|
|
| |
| KNOWN = { |
| 2: 1, 3: 3, 5: 2, 7: 7, 11: 9, 13: 4, 17: 4, 19: 11, |
| 23: 21, 29: 7, 31: 19, 37: 9, 41: 7, 43: 14, 97: 8, |
| 101: 16, 1009: 19, |
| } |
|
|
| primes = torch.tensor(list(KNOWN.keys()), dtype=torch.int64, device="cuda") |
| counts = erdos_straus.count(primes) |
| results = dict(zip(KNOWN.keys(), counts.tolist())) |
|
|
| passed = failed = 0 |
| for p, expected in KNOWN.items(): |
| got = results[p] |
| if got != expected: |
| print(f" FAIL: f({p}) = {got}, expected {expected}") |
| failed += 1 |
| else: |
| passed += 1 |
|
|
| print(f"{passed}/{passed+failed} tests passed") |
| assert failed == 0, f"{failed} tests failed!" |
| assert all(c >= 1 for c in counts.tolist()), "COUNTEREXAMPLE FOUND!" |
| print("Conjecture holds for all test primes") |
|
|