Initial upload: torch-compatible CUDA kernel with pybind11 bindings and CPU tests
Browse files- README.md +58 -0
- build.toml +12 -0
- scripts/test.py +212 -0
- src/zaremba_density.cu +210 -0
- torch-ext/torch_binding.cpp +25 -0
- torch-ext/torch_binding.h +5 -0
README.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- kernels
|
| 5 |
+
- cuda
|
| 6 |
+
- number-theory
|
| 7 |
+
- continued-fractions
|
| 8 |
+
- zaremba-conjecture
|
| 9 |
+
datasets:
|
| 10 |
+
- cahlen/zaremba-density
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Zaremba Density CUDA Kernel
|
| 14 |
+
|
| 15 |
+
GPU-accelerated computation of Zaremba density: for a digit set A and bound N, counts how many denominators d <= N have a continued fraction representation with all partial quotients in A.
|
| 16 |
+
|
| 17 |
+
This kernel was used to produce the results in [cahlen/zaremba-density](https://huggingface.co/datasets/cahlen/zaremba-density), computing densities for all 1,023 subsets of {1,...,10} at various scales up to 10^12.
|
| 18 |
+
|
| 19 |
+
## Algorithm
|
| 20 |
+
|
| 21 |
+
1. **CPU prefix generation**: Enumerate CF prefixes to a fixed depth, sorted by denominator descending for load balancing
|
| 22 |
+
2. **GPU persistent threads**: Each thread atomically claims prefixes and recursively extends them, marking denominators in a bitset
|
| 23 |
+
3. **CPU shallow marking**: Mark denominators from short CFs that fall below the prefix depth
|
| 24 |
+
4. **GPU popcount**: Count set bits in the bitset
|
| 25 |
+
|
| 26 |
+
The persistent-thread work-stealing design ensures good GPU utilization even when prefix subtrees vary widely in size.
|
| 27 |
+
|
| 28 |
+
## API
|
| 29 |
+
|
| 30 |
+
```python
|
| 31 |
+
import torch
|
| 32 |
+
# After building with the kernels build system:
|
| 33 |
+
from zaremba_density import count_representable
|
| 34 |
+
|
| 35 |
+
digits = torch.tensor([1, 2, 3], dtype=torch.int32)
|
| 36 |
+
result = count_representable(100, digits)
|
| 37 |
+
print(f"Representable: {result.item()} / 100")
|
| 38 |
+
# Expected: 98 (only d=1 and d=97 are not representable with {1,2,3})
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### `count_representable(max_d: int, digits: Tensor[int32]) -> Tensor[int64]`
|
| 42 |
+
|
| 43 |
+
- **max_d**: Upper bound on denominators to check (inclusive)
|
| 44 |
+
- **digits**: 1-D tensor of allowed partial quotient digits
|
| 45 |
+
- **returns**: 1-element int64 tensor with the count of representable denominators
|
| 46 |
+
|
| 47 |
+
## Known Values
|
| 48 |
+
|
| 49 |
+
| Digit set | N | Count | Density |
|
| 50 |
+
|-----------|---|-------|---------|
|
| 51 |
+
| {1} | 10 | 5 | 50% |
|
| 52 |
+
| {1,2} | 20 | 16 | 80% |
|
| 53 |
+
| {1,2,3} | 100 | 98 | 98% |
|
| 54 |
+
| {1,2,3,4,5} | 10^6 | 999,987 | 99.9987% |
|
| 55 |
+
|
| 56 |
+
## Hardware
|
| 57 |
+
|
| 58 |
+
Developed and tested on RTX 5090 (32GB) and 8xB200 cluster (~1.4TB VRAM). For N > 10^9, a GPU with >= 16GB VRAM is recommended due to bitset size.
|
build.toml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "zaremba_density"
|
| 3 |
+
universal = false
|
| 4 |
+
|
| 5 |
+
[torch]
|
| 6 |
+
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
|
| 7 |
+
|
| 8 |
+
[kernel.zaremba_density]
|
| 9 |
+
backend = "cuda"
|
| 10 |
+
cuda-capabilities = ["8.0", "9.0", "10.0", "12.0"]
|
| 11 |
+
src = ["src/zaremba_density.cu"]
|
| 12 |
+
depends = ["torch"]
|
scripts/test.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CPU-only test for zaremba_density kernel logic.
|
| 4 |
+
|
| 5 |
+
Verifies the continued fraction denominator enumeration algorithm
|
| 6 |
+
against known values without requiring a GPU.
|
| 7 |
+
|
| 8 |
+
The algorithm enumerates all denominators d <= N such that there exists
|
| 9 |
+
a fraction a/d (gcd(a,d)=1) whose CF expansion [0; a_1, ..., a_k]
|
| 10 |
+
uses only partial quotients a_i from the digit set A.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def count_representable_cpu(max_d: int, digits: list[int]) -> tuple[int, set[int]]:
|
| 17 |
+
"""
|
| 18 |
+
CPU reference implementation of the Zaremba density algorithm.
|
| 19 |
+
|
| 20 |
+
Enumerates all CF denominators <= max_d with partial quotients in `digits`.
|
| 21 |
+
Uses iterative DFS on the CF tree, matching the GPU kernel logic.
|
| 22 |
+
|
| 23 |
+
The CF tree starts from [0; a] for each a in digits, giving convergent 1/a.
|
| 24 |
+
Each node (q_prev, q) extends to (q, a*q + q_prev) for each digit a.
|
| 25 |
+
|
| 26 |
+
Returns (count, set_of_representable_d).
|
| 27 |
+
"""
|
| 28 |
+
representable = set()
|
| 29 |
+
representable.add(1) # d=1 is always representable (empty CF)
|
| 30 |
+
|
| 31 |
+
# Stack entries: (q_prev, q_curr)
|
| 32 |
+
# Initial: CF [0; a] has q_prev=1, q=a
|
| 33 |
+
stack = []
|
| 34 |
+
for a in digits:
|
| 35 |
+
if a <= max_d:
|
| 36 |
+
stack.append((1, a))
|
| 37 |
+
|
| 38 |
+
while stack:
|
| 39 |
+
q_prev, q_curr = stack.pop()
|
| 40 |
+
representable.add(q_curr)
|
| 41 |
+
for a in digits:
|
| 42 |
+
q_new = a * q_curr + q_prev
|
| 43 |
+
if q_new <= max_d:
|
| 44 |
+
stack.append((q_curr, q_new))
|
| 45 |
+
|
| 46 |
+
result = {d for d in representable if d <= max_d}
|
| 47 |
+
return len(result), result
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_digits_1_n10():
|
| 51 |
+
"""A={1}, N=10: CF denominators are Fibonacci numbers.
|
| 52 |
+
|
| 53 |
+
[0;1]=1/1, [0;1,1]=1/2, [0;1,1,1]=2/3, [0;1,1,1,1]=3/5, [0;1,1,1,1,1]=5/8
|
| 54 |
+
Denominators: {1, 2, 3, 5, 8} = 5 values.
|
| 55 |
+
"""
|
| 56 |
+
count, reprs = count_representable_cpu(10, [1])
|
| 57 |
+
fib_le_10 = {1, 2, 3, 5, 8}
|
| 58 |
+
assert reprs == fib_le_10, f"A={{1}}, N=10: expected Fibonacci {{1,2,3,5,8}}, got {sorted(reprs)}"
|
| 59 |
+
assert count == 5, f"A={{1}}, N=10: expected 5, got {count}"
|
| 60 |
+
print(f"PASS: A={{1}}, N=10 -> {count} representable = {sorted(reprs)} (Fibonacci numbers)")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def test_digits_1_n100():
|
| 64 |
+
"""A={1}, N=100: should be exactly the Fibonacci numbers <= 100."""
|
| 65 |
+
count, reprs = count_representable_cpu(100, [1])
|
| 66 |
+
# Fibonacci: 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89
|
| 67 |
+
fib = set()
|
| 68 |
+
a, b = 1, 1
|
| 69 |
+
while a <= 100:
|
| 70 |
+
fib.add(a)
|
| 71 |
+
a, b = b, a + b
|
| 72 |
+
assert reprs == fib, f"A={{1}}, N=100: expected Fibonacci, got {sorted(reprs)}"
|
| 73 |
+
print(f"PASS: A={{1}}, N=100 -> {count} representable (Fibonacci: {sorted(reprs)})")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def test_digits_12_n20():
|
| 77 |
+
"""A={1,2}, N=20: verify exact set of representable denominators."""
|
| 78 |
+
count, reprs = count_representable_cpu(20, [1, 2])
|
| 79 |
+
exceptions = sorted(set(range(1, 21)) - reprs)
|
| 80 |
+
print(f" A={{1,2}}, N=20: {count}/20 representable")
|
| 81 |
+
print(f" Representable: {sorted(reprs)}")
|
| 82 |
+
print(f" Exceptions: {exceptions}")
|
| 83 |
+
# Verify basic properties
|
| 84 |
+
assert 1 in reprs, "d=1 should always be representable"
|
| 85 |
+
assert count > 10, f"A={{1,2}} should cover most of 1..20, got only {count}"
|
| 86 |
+
# Fibonacci numbers are a subset (A={1} subset of A={1,2})
|
| 87 |
+
for fib in [1, 2, 3, 5, 8, 13]:
|
| 88 |
+
assert fib in reprs, f"Fibonacci {fib} should be representable with A={{1,2}}"
|
| 89 |
+
print(f"PASS: A={{1,2}}, N=20 -> {count} representable, {len(exceptions)} exceptions = {exceptions}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def test_digits_123_n100():
|
| 93 |
+
"""A={1,2,3}, N=100: verify representable count."""
|
| 94 |
+
count, reprs = count_representable_cpu(100, [1, 2, 3])
|
| 95 |
+
exceptions = sorted(set(range(1, 101)) - reprs)
|
| 96 |
+
print(f" A={{1,2,3}}, N=100: {count}/100 representable")
|
| 97 |
+
print(f" Exceptions: {exceptions}")
|
| 98 |
+
# Should cover most values
|
| 99 |
+
assert count >= 90, f"A={{1,2,3}} should cover >= 90% at N=100, got {count}"
|
| 100 |
+
# All values representable by A={1,2} should also be representable
|
| 101 |
+
count_12, reprs_12 = count_representable_cpu(100, [1, 2])
|
| 102 |
+
assert reprs_12.issubset(reprs), "A={1,2} representable should be subset of A={1,2,3}"
|
| 103 |
+
print(f"PASS: A={{1,2,3}}, N=100 -> {count} representable, {len(exceptions)} exceptions")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def test_digits_1_n1():
|
| 107 |
+
"""Edge case: N=1, any digits containing 1."""
|
| 108 |
+
count, reprs = count_representable_cpu(1, [1])
|
| 109 |
+
assert count == 1, f"A={{1}}, N=1: expected 1, got {count}"
|
| 110 |
+
assert reprs == {1}, f"A={{1}}, N=1: should be {{1}}, got {reprs}"
|
| 111 |
+
print(f"PASS: A={{1}}, N=1 -> {count} representable")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def test_digits_2_n10():
|
| 115 |
+
"""A={2}, N=10: only CFs with all partial quotients = 2."""
|
| 116 |
+
count, reprs = count_representable_cpu(10, [2])
|
| 117 |
+
# [0;2]=1/2, [0;2,2]=2/5, so denominators are 1 (trivial), 2, 5
|
| 118 |
+
# Actually [0;2,2,2] has q=2*5+2=12 > 10
|
| 119 |
+
# So: {1, 2, 5}
|
| 120 |
+
assert 1 in reprs and 2 in reprs and 5 in reprs
|
| 121 |
+
print(f" A={{2}}, N=10: representable = {sorted(reprs)}")
|
| 122 |
+
assert count == 3, f"A={{2}}, N=10: expected 3, got {count}"
|
| 123 |
+
print(f"PASS: A={{2}}, N=10 -> {count} representable = {sorted(reprs)}")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def test_digits_12345_small():
|
| 127 |
+
"""A={1,2,3,4,5} at small N -- should cover almost everything."""
|
| 128 |
+
count, reprs = count_representable_cpu(50, [1, 2, 3, 4, 5])
|
| 129 |
+
print(f" A={{1,2,3,4,5}}, N=50 -> {count}/50 representable")
|
| 130 |
+
assert count >= 48, f"A={{1,2,3,4,5}}, N=50: expected >= 48, got {count}"
|
| 131 |
+
print(f"PASS: A={{1,2,3,4,5}}, N=50 -> {count} representable (>= 48)")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def test_monotonicity():
|
| 135 |
+
"""Adding digits can only increase or maintain the count."""
|
| 136 |
+
c1, _ = count_representable_cpu(100, [1])
|
| 137 |
+
c12, _ = count_representable_cpu(100, [1, 2])
|
| 138 |
+
c123, _ = count_representable_cpu(100, [1, 2, 3])
|
| 139 |
+
c1234, _ = count_representable_cpu(100, [1, 2, 3, 4])
|
| 140 |
+
c12345, _ = count_representable_cpu(100, [1, 2, 3, 4, 5])
|
| 141 |
+
assert c1 <= c12 <= c123 <= c1234 <= c12345, \
|
| 142 |
+
f"Monotonicity failed: {c1}, {c12}, {c123}, {c1234}, {c12345}"
|
| 143 |
+
print(f"PASS: Monotonicity: {c1} <= {c12} <= {c123} <= {c1234} <= {c12345}")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def test_cf_recurrence():
|
| 147 |
+
"""Verify the CF recurrence q_{n+1} = a_{n+1} * q_n + q_{n-1} directly."""
|
| 148 |
+
# [0; 3, 1, 4, 1, 5] should give denominators via the recurrence:
|
| 149 |
+
# q_prev=1, q=3 -> extend 1: q=1*3+1=4 -> extend 4: q=4*4+3=19
|
| 150 |
+
# -> extend 1: q=1*19+4=23 -> extend 5: q=5*23+19=134
|
| 151 |
+
q_prev, q_curr = 1, 3
|
| 152 |
+
expected_denoms = [3]
|
| 153 |
+
for a in [1, 4, 1, 5]:
|
| 154 |
+
q_new = a * q_curr + q_prev
|
| 155 |
+
expected_denoms.append(q_new)
|
| 156 |
+
q_prev, q_curr = q_curr, q_new
|
| 157 |
+
assert expected_denoms == [3, 4, 19, 23, 134], f"CF recurrence failed: {expected_denoms}"
|
| 158 |
+
print(f"PASS: CF recurrence [3,1,4,1,5] -> denominators {expected_denoms}")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def test_subset_inclusion():
|
| 162 |
+
"""If A is a subset of B, then representable(A) is a subset of representable(B)."""
|
| 163 |
+
_, r1 = count_representable_cpu(200, [1])
|
| 164 |
+
_, r12 = count_representable_cpu(200, [1, 2])
|
| 165 |
+
_, r123 = count_representable_cpu(200, [1, 2, 3])
|
| 166 |
+
_, r45 = count_representable_cpu(200, [4, 5])
|
| 167 |
+
_, r12345 = count_representable_cpu(200, [1, 2, 3, 4, 5])
|
| 168 |
+
|
| 169 |
+
assert r1.issubset(r12), "A={1} not subset of A={1,2}"
|
| 170 |
+
assert r12.issubset(r123), "A={1,2} not subset of A={1,2,3}"
|
| 171 |
+
assert r123.issubset(r12345), "A={1,2,3} not subset of A={1,2,3,4,5}"
|
| 172 |
+
assert r45.issubset(r12345), "A={4,5} not subset of A={1,2,3,4,5}"
|
| 173 |
+
print(f"PASS: Subset inclusion verified for nested digit sets")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
print("=" * 60)
|
| 178 |
+
print("Zaremba Density -- CPU Reference Tests")
|
| 179 |
+
print("=" * 60)
|
| 180 |
+
print()
|
| 181 |
+
|
| 182 |
+
tests = [
|
| 183 |
+
test_digits_1_n10,
|
| 184 |
+
test_digits_1_n100,
|
| 185 |
+
test_digits_12_n20,
|
| 186 |
+
test_digits_123_n100,
|
| 187 |
+
test_digits_1_n1,
|
| 188 |
+
test_digits_2_n10,
|
| 189 |
+
test_digits_12345_small,
|
| 190 |
+
test_monotonicity,
|
| 191 |
+
test_cf_recurrence,
|
| 192 |
+
test_subset_inclusion,
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
passed = 0
|
| 196 |
+
failed = 0
|
| 197 |
+
for t in tests:
|
| 198 |
+
try:
|
| 199 |
+
t()
|
| 200 |
+
passed += 1
|
| 201 |
+
except AssertionError as e:
|
| 202 |
+
print(f"FAIL: {t.__name__}: {e}")
|
| 203 |
+
failed += 1
|
| 204 |
+
except Exception as e:
|
| 205 |
+
print(f"ERROR: {t.__name__}: {e}")
|
| 206 |
+
failed += 1
|
| 207 |
+
print()
|
| 208 |
+
|
| 209 |
+
print("=" * 60)
|
| 210 |
+
print(f"Results: {passed} passed, {failed} failed")
|
| 211 |
+
print("=" * 60)
|
| 212 |
+
sys.exit(0 if failed == 0 else 1)
|
src/zaremba_density.cu
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Zaremba Density CUDA Kernel — Torch-compatible version
|
| 3 |
+
*
|
| 4 |
+
* Enumerates all continued fraction denominators <= N with partial quotients
|
| 5 |
+
* from a given digit set A. Uses a persistent-thread work-stealing design
|
| 6 |
+
* with a bitset to track representable denominators.
|
| 7 |
+
*
|
| 8 |
+
* Original: zaremba_density_gpu.cu (standalone CLI)
|
| 9 |
+
* This version: torch C++ extension wrapper around the same kernel logic.
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#include <cstdint>
|
| 13 |
+
#include <cstdio>
|
| 14 |
+
#include <cstring>
|
| 15 |
+
|
| 16 |
+
#include <cuda_runtime.h>
|
| 17 |
+
|
| 18 |
+
#define MAX_DIGITS 10
|
| 19 |
+
#define MAX_DEPTH 200
|
| 20 |
+
|
| 21 |
+
// Mark a denominator in the bitset
|
| 22 |
+
__device__ void mark(uint64_t d, uint8_t *bitset, uint64_t max_d) {
|
| 23 |
+
if (d < 1 || d > max_d) return;
|
| 24 |
+
uint64_t byte_idx = d >> 3;
|
| 25 |
+
uint8_t bit = 1 << (d & 7);
|
| 26 |
+
atomicOr((unsigned int*)&bitset[byte_idx & ~3],
|
| 27 |
+
(unsigned int)bit << (8 * (byte_idx & 3)));
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// Persistent-thread kernel: each thread self-schedules prefixes via atomic counter
|
| 31 |
+
__global__ void enumerate_persistent(
|
| 32 |
+
uint64_t *prefixes, int num_prefixes,
|
| 33 |
+
int *digits, int num_digits,
|
| 34 |
+
uint8_t *bitset, uint64_t max_d,
|
| 35 |
+
int *progress)
|
| 36 |
+
{
|
| 37 |
+
struct { uint64_t p_prev, p, q_prev, q; } stack[MAX_DEPTH];
|
| 38 |
+
|
| 39 |
+
while (true) {
|
| 40 |
+
int my_prefix = atomicAdd(progress, 1);
|
| 41 |
+
if (my_prefix >= num_prefixes) return;
|
| 42 |
+
|
| 43 |
+
uint64_t pp0 = prefixes[my_prefix * 4 + 0];
|
| 44 |
+
uint64_t p0 = prefixes[my_prefix * 4 + 1];
|
| 45 |
+
uint64_t qp0 = prefixes[my_prefix * 4 + 2];
|
| 46 |
+
uint64_t q0 = prefixes[my_prefix * 4 + 3];
|
| 47 |
+
|
| 48 |
+
mark(q0, bitset, max_d);
|
| 49 |
+
|
| 50 |
+
int sp = 0;
|
| 51 |
+
for (int i = num_digits - 1; i >= 0; i--) {
|
| 52 |
+
uint64_t a = digits[i];
|
| 53 |
+
uint64_t q_new = a * q0 + qp0;
|
| 54 |
+
if (q_new > max_d || sp >= MAX_DEPTH) continue;
|
| 55 |
+
stack[sp].p_prev = p0; stack[sp].p = a * p0 + pp0;
|
| 56 |
+
stack[sp].q_prev = q0; stack[sp].q = q_new;
|
| 57 |
+
sp++;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
while (sp > 0) {
|
| 61 |
+
sp--;
|
| 62 |
+
uint64_t pp = stack[sp].p_prev, p = stack[sp].p;
|
| 63 |
+
uint64_t qp = stack[sp].q_prev, q = stack[sp].q;
|
| 64 |
+
mark(q, bitset, max_d);
|
| 65 |
+
for (int i = num_digits - 1; i >= 0; i--) {
|
| 66 |
+
uint64_t a = digits[i];
|
| 67 |
+
uint64_t q_new = a * q + qp;
|
| 68 |
+
if (q_new > max_d || sp >= MAX_DEPTH) continue;
|
| 69 |
+
stack[sp].p_prev = p; stack[sp].p = a * p + pp;
|
| 70 |
+
stack[sp].q_prev = q; stack[sp].q = q_new;
|
| 71 |
+
sp++;
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
// Count set bits in the bitset
|
| 78 |
+
__global__ void count_marked(uint8_t *bitset, uint64_t max_d, uint64_t *count) {
|
| 79 |
+
uint64_t tid = blockIdx.x * (uint64_t)blockDim.x + threadIdx.x;
|
| 80 |
+
uint64_t max_byte = (max_d + 8) / 8;
|
| 81 |
+
if (tid >= max_byte) return;
|
| 82 |
+
uint8_t b = bitset[tid];
|
| 83 |
+
int bits = __popc((unsigned int)b);
|
| 84 |
+
if (tid == max_byte - 1) {
|
| 85 |
+
int valid_bits = (max_d % 8) + 1;
|
| 86 |
+
bits = __popc((unsigned int)(b & ((1 << valid_bits) - 1)));
|
| 87 |
+
}
|
| 88 |
+
if (bits > 0) atomicAdd((unsigned long long*)count, (unsigned long long)bits);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
// C++ host function called from torch binding
|
| 92 |
+
extern "C" int64_t zaremba_count_representable(int64_t max_d, int *h_digits, int num_digits) {
|
| 93 |
+
// Generate prefixes on CPU
|
| 94 |
+
int PREFIX_DEPTH = 8;
|
| 95 |
+
if (max_d >= 1000000000LL) PREFIX_DEPTH = 15;
|
| 96 |
+
|
| 97 |
+
int max_prefixes = 20000000;
|
| 98 |
+
uint64_t *h_prefixes = new uint64_t[max_prefixes * 4];
|
| 99 |
+
int np = 0;
|
| 100 |
+
|
| 101 |
+
struct PfxEntry { uint64_t pp, p, qp, q; int depth; };
|
| 102 |
+
PfxEntry *stk = new PfxEntry[max_prefixes];
|
| 103 |
+
int ssp = 0;
|
| 104 |
+
for (int i = 0; i < num_digits; i++) {
|
| 105 |
+
stk[ssp] = {0, 1, 1, (uint64_t)h_digits[i], 1};
|
| 106 |
+
ssp++;
|
| 107 |
+
}
|
| 108 |
+
while (ssp > 0) {
|
| 109 |
+
ssp--;
|
| 110 |
+
uint64_t pp = stk[ssp].pp, p = stk[ssp].p;
|
| 111 |
+
uint64_t qp = stk[ssp].qp, q = stk[ssp].q;
|
| 112 |
+
int dep = stk[ssp].depth;
|
| 113 |
+
if (q > (uint64_t)max_d) continue;
|
| 114 |
+
if (dep >= PREFIX_DEPTH) {
|
| 115 |
+
if (np < max_prefixes) {
|
| 116 |
+
h_prefixes[np*4+0] = pp; h_prefixes[np*4+1] = p;
|
| 117 |
+
h_prefixes[np*4+2] = qp; h_prefixes[np*4+3] = q;
|
| 118 |
+
np++;
|
| 119 |
+
}
|
| 120 |
+
} else {
|
| 121 |
+
for (int i = num_digits - 1; i >= 0; i--) {
|
| 122 |
+
uint64_t qn = (uint64_t)h_digits[i] * q + qp;
|
| 123 |
+
if (qn > (uint64_t)max_d || ssp >= max_prefixes - 1) continue;
|
| 124 |
+
stk[ssp] = {p, (uint64_t)h_digits[i] * p + pp, q, qn, dep + 1};
|
| 125 |
+
ssp++;
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
delete[] stk;
|
| 130 |
+
|
| 131 |
+
// GPU allocation
|
| 132 |
+
uint64_t bitset_bytes = ((uint64_t)max_d + 8) / 8;
|
| 133 |
+
uint8_t *d_bs;
|
| 134 |
+
cudaMalloc(&d_bs, bitset_bytes);
|
| 135 |
+
cudaMemset(d_bs, 0, bitset_bytes);
|
| 136 |
+
|
| 137 |
+
int *d_digits;
|
| 138 |
+
cudaMalloc(&d_digits, num_digits * sizeof(int));
|
| 139 |
+
cudaMemcpy(d_digits, h_digits, num_digits * sizeof(int), cudaMemcpyHostToDevice);
|
| 140 |
+
|
| 141 |
+
uint64_t *d_prefixes;
|
| 142 |
+
cudaMalloc(&d_prefixes, (uint64_t)np * 4 * sizeof(uint64_t));
|
| 143 |
+
cudaMemcpy(d_prefixes, h_prefixes, (uint64_t)np * 4 * sizeof(uint64_t), cudaMemcpyHostToDevice);
|
| 144 |
+
|
| 145 |
+
int *d_progress;
|
| 146 |
+
cudaMalloc(&d_progress, sizeof(int));
|
| 147 |
+
cudaMemset(d_progress, 0, sizeof(int));
|
| 148 |
+
|
| 149 |
+
// Launch
|
| 150 |
+
int block_size = 256;
|
| 151 |
+
int grid_size = (np + block_size - 1) / block_size;
|
| 152 |
+
if (grid_size > 65535) grid_size = 65535;
|
| 153 |
+
|
| 154 |
+
enumerate_persistent<<<grid_size, block_size>>>(
|
| 155 |
+
d_prefixes, np, d_digits, num_digits, d_bs, (uint64_t)max_d, d_progress);
|
| 156 |
+
cudaDeviceSynchronize();
|
| 157 |
+
|
| 158 |
+
// Mark shallow denominators on CPU
|
| 159 |
+
uint8_t *h_bs = new uint8_t[bitset_bytes];
|
| 160 |
+
cudaMemcpy(h_bs, d_bs, bitset_bytes, cudaMemcpyDeviceToHost);
|
| 161 |
+
h_bs[0] |= (1 << 1); // d=1
|
| 162 |
+
|
| 163 |
+
PfxEntry *cstk = new PfxEntry[500000];
|
| 164 |
+
int csp = 0;
|
| 165 |
+
for (int i = 0; i < num_digits; i++) {
|
| 166 |
+
cstk[csp] = {0, 1, 1, (uint64_t)h_digits[i], 1};
|
| 167 |
+
csp++;
|
| 168 |
+
}
|
| 169 |
+
while (csp > 0) {
|
| 170 |
+
csp--;
|
| 171 |
+
uint64_t q = cstk[csp].q;
|
| 172 |
+
int dep = cstk[csp].depth;
|
| 173 |
+
if (q > (uint64_t)max_d) continue;
|
| 174 |
+
h_bs[q >> 3] |= (1 << (q & 7));
|
| 175 |
+
if (dep >= PREFIX_DEPTH) continue;
|
| 176 |
+
uint64_t pp = cstk[csp].pp, p = cstk[csp].p, qp = cstk[csp].qp;
|
| 177 |
+
for (int i = 0; i < num_digits; i++) {
|
| 178 |
+
uint64_t qn = (uint64_t)h_digits[i] * q + qp;
|
| 179 |
+
if (qn > (uint64_t)max_d || csp >= 499999) continue;
|
| 180 |
+
cstk[csp] = {p, (uint64_t)h_digits[i] * p + pp, q, qn, dep + 1};
|
| 181 |
+
csp++;
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
delete[] cstk;
|
| 185 |
+
cudaMemcpy(d_bs, h_bs, bitset_bytes, cudaMemcpyHostToDevice);
|
| 186 |
+
|
| 187 |
+
// Count on GPU
|
| 188 |
+
uint64_t *d_count;
|
| 189 |
+
cudaMalloc(&d_count, sizeof(uint64_t));
|
| 190 |
+
cudaMemset(d_count, 0, sizeof(uint64_t));
|
| 191 |
+
{
|
| 192 |
+
uint64_t max_byte = ((uint64_t)max_d + 8) / 8;
|
| 193 |
+
int gd = (int)((max_byte + 255) / 256);
|
| 194 |
+
count_marked<<<gd, 256>>>(d_bs, (uint64_t)max_d, d_count);
|
| 195 |
+
cudaDeviceSynchronize();
|
| 196 |
+
}
|
| 197 |
+
int64_t covered = 0;
|
| 198 |
+
cudaMemcpy(&covered, d_count, sizeof(uint64_t), cudaMemcpyDeviceToHost);
|
| 199 |
+
|
| 200 |
+
// Cleanup
|
| 201 |
+
cudaFree(d_count);
|
| 202 |
+
cudaFree(d_bs);
|
| 203 |
+
cudaFree(d_digits);
|
| 204 |
+
cudaFree(d_prefixes);
|
| 205 |
+
cudaFree(d_progress);
|
| 206 |
+
delete[] h_prefixes;
|
| 207 |
+
delete[] h_bs;
|
| 208 |
+
|
| 209 |
+
return covered;
|
| 210 |
+
}
|
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "torch_binding.h"
|
| 3 |
+
|
| 4 |
+
// Defined in the CUDA kernel
|
| 5 |
+
extern "C" int64_t zaremba_count_representable(int64_t max_d, int *h_digits, int num_digits);
|
| 6 |
+
|
| 7 |
+
torch::Tensor count_representable(int64_t max_d, torch::Tensor digits) {
|
| 8 |
+
TORCH_CHECK(digits.dtype() == torch::kInt32, "digits must be int32");
|
| 9 |
+
TORCH_CHECK(digits.dim() == 1, "digits must be 1-D");
|
| 10 |
+
TORCH_CHECK(digits.is_cpu(), "digits must be on CPU");
|
| 11 |
+
TORCH_CHECK(max_d > 0, "max_d must be positive");
|
| 12 |
+
|
| 13 |
+
int num_digits = digits.size(0);
|
| 14 |
+
int *h_digits = digits.data_ptr<int>();
|
| 15 |
+
|
| 16 |
+
int64_t count = zaremba_count_representable(max_d, h_digits, num_digits);
|
| 17 |
+
|
| 18 |
+
auto result = torch::tensor({count}, torch::dtype(torch::kInt64));
|
| 19 |
+
return result;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 23 |
+
m.def("count_representable", &count_representable,
|
| 24 |
+
"Count denominators <= max_d representable with given CF digit set");
|
| 25 |
+
}
|
torch-ext/torch_binding.h
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/types.h>
|
| 4 |
+
|
| 5 |
+
torch::Tensor count_representable(int64_t max_d, torch::Tensor digits);
|