cahlen commited on
Commit
fe9f881
·
verified ·
1 Parent(s): 461c4da

Initial upload: torch-compatible CUDA kernel with pybind11 bindings and CPU tests

Browse files
README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - kernels
5
+ - cuda
6
+ - number-theory
7
+ - continued-fractions
8
+ - zaremba-conjecture
9
+ datasets:
10
+ - cahlen/zaremba-density
11
+ ---
12
+
13
+ # Zaremba Density CUDA Kernel
14
+
15
+ GPU-accelerated computation of Zaremba density: for a digit set A and bound N, counts how many denominators d <= N have a continued fraction representation with all partial quotients in A.
16
+
17
+ This kernel was used to produce the results in [cahlen/zaremba-density](https://huggingface.co/datasets/cahlen/zaremba-density), computing densities for all 1,023 subsets of {1,...,10} at various scales up to 10^12.
18
+
19
+ ## Algorithm
20
+
21
+ 1. **CPU prefix generation**: Enumerate CF prefixes to a fixed depth, sorted by denominator descending for load balancing
22
+ 2. **GPU persistent threads**: Each thread atomically claims prefixes and recursively extends them, marking denominators in a bitset
23
+ 3. **CPU shallow marking**: Mark denominators from short CFs that fall below the prefix depth
24
+ 4. **GPU popcount**: Count set bits in the bitset
25
+
26
+ The persistent-thread work-stealing design ensures good GPU utilization even when prefix subtrees vary widely in size.
27
+
28
+ ## API
29
+
30
+ ```python
31
+ import torch
32
+ # After building with the kernels build system:
33
+ from zaremba_density import count_representable
34
+
35
+ digits = torch.tensor([1, 2, 3], dtype=torch.int32)
36
+ result = count_representable(100, digits)
37
+ print(f"Representable: {result.item()} / 100")
38
+ # Expected: 98 (only d=1 and d=97 are not representable with {1,2,3})
39
+ ```
40
+
41
+ ### `count_representable(max_d: int, digits: Tensor[int32]) -> Tensor[int64]`
42
+
43
+ - **max_d**: Upper bound on denominators to check (inclusive)
44
+ - **digits**: 1-D tensor of allowed partial quotient digits
45
+ - **returns**: 1-element int64 tensor with the count of representable denominators
46
+
47
+ ## Known Values
48
+
49
+ | Digit set | N | Count | Density |
50
+ |-----------|---|-------|---------|
51
+ | {1} | 10 | 5 | 50% |
52
+ | {1,2} | 20 | 16 | 80% |
53
+ | {1,2,3} | 100 | 98 | 98% |
54
+ | {1,2,3,4,5} | 10^6 | 999,987 | 99.9987% |
55
+
56
+ ## Hardware
57
+
58
+ Developed and tested on RTX 5090 (32GB) and 8xB200 cluster (~1.4TB VRAM). For N > 10^9, a GPU with >= 16GB VRAM is recommended due to bitset size.
build.toml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "zaremba_density"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
7
+
8
+ [kernel.zaremba_density]
9
+ backend = "cuda"
10
+ cuda-capabilities = ["8.0", "9.0", "10.0", "12.0"]
11
+ src = ["src/zaremba_density.cu"]
12
+ depends = ["torch"]
scripts/test.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CPU-only test for zaremba_density kernel logic.
4
+
5
+ Verifies the continued fraction denominator enumeration algorithm
6
+ against known values without requiring a GPU.
7
+
8
+ The algorithm enumerates all denominators d <= N such that there exists
9
+ a fraction a/d (gcd(a,d)=1) whose CF expansion [0; a_1, ..., a_k]
10
+ uses only partial quotients a_i from the digit set A.
11
+ """
12
+
13
+ import sys
14
+
15
+
16
+ def count_representable_cpu(max_d: int, digits: list[int]) -> tuple[int, set[int]]:
17
+ """
18
+ CPU reference implementation of the Zaremba density algorithm.
19
+
20
+ Enumerates all CF denominators <= max_d with partial quotients in `digits`.
21
+ Uses iterative DFS on the CF tree, matching the GPU kernel logic.
22
+
23
+ The CF tree starts from [0; a] for each a in digits, giving convergent 1/a.
24
+ Each node (q_prev, q) extends to (q, a*q + q_prev) for each digit a.
25
+
26
+ Returns (count, set_of_representable_d).
27
+ """
28
+ representable = set()
29
+ representable.add(1) # d=1 is always representable (empty CF)
30
+
31
+ # Stack entries: (q_prev, q_curr)
32
+ # Initial: CF [0; a] has q_prev=1, q=a
33
+ stack = []
34
+ for a in digits:
35
+ if a <= max_d:
36
+ stack.append((1, a))
37
+
38
+ while stack:
39
+ q_prev, q_curr = stack.pop()
40
+ representable.add(q_curr)
41
+ for a in digits:
42
+ q_new = a * q_curr + q_prev
43
+ if q_new <= max_d:
44
+ stack.append((q_curr, q_new))
45
+
46
+ result = {d for d in representable if d <= max_d}
47
+ return len(result), result
48
+
49
+
50
+ def test_digits_1_n10():
51
+ """A={1}, N=10: CF denominators are Fibonacci numbers.
52
+
53
+ [0;1]=1/1, [0;1,1]=1/2, [0;1,1,1]=2/3, [0;1,1,1,1]=3/5, [0;1,1,1,1,1]=5/8
54
+ Denominators: {1, 2, 3, 5, 8} = 5 values.
55
+ """
56
+ count, reprs = count_representable_cpu(10, [1])
57
+ fib_le_10 = {1, 2, 3, 5, 8}
58
+ assert reprs == fib_le_10, f"A={{1}}, N=10: expected Fibonacci {{1,2,3,5,8}}, got {sorted(reprs)}"
59
+ assert count == 5, f"A={{1}}, N=10: expected 5, got {count}"
60
+ print(f"PASS: A={{1}}, N=10 -> {count} representable = {sorted(reprs)} (Fibonacci numbers)")
61
+
62
+
63
+ def test_digits_1_n100():
64
+ """A={1}, N=100: should be exactly the Fibonacci numbers <= 100."""
65
+ count, reprs = count_representable_cpu(100, [1])
66
+ # Fibonacci: 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89
67
+ fib = set()
68
+ a, b = 1, 1
69
+ while a <= 100:
70
+ fib.add(a)
71
+ a, b = b, a + b
72
+ assert reprs == fib, f"A={{1}}, N=100: expected Fibonacci, got {sorted(reprs)}"
73
+ print(f"PASS: A={{1}}, N=100 -> {count} representable (Fibonacci: {sorted(reprs)})")
74
+
75
+
76
+ def test_digits_12_n20():
77
+ """A={1,2}, N=20: verify exact set of representable denominators."""
78
+ count, reprs = count_representable_cpu(20, [1, 2])
79
+ exceptions = sorted(set(range(1, 21)) - reprs)
80
+ print(f" A={{1,2}}, N=20: {count}/20 representable")
81
+ print(f" Representable: {sorted(reprs)}")
82
+ print(f" Exceptions: {exceptions}")
83
+ # Verify basic properties
84
+ assert 1 in reprs, "d=1 should always be representable"
85
+ assert count > 10, f"A={{1,2}} should cover most of 1..20, got only {count}"
86
+ # Fibonacci numbers are a subset (A={1} subset of A={1,2})
87
+ for fib in [1, 2, 3, 5, 8, 13]:
88
+ assert fib in reprs, f"Fibonacci {fib} should be representable with A={{1,2}}"
89
+ print(f"PASS: A={{1,2}}, N=20 -> {count} representable, {len(exceptions)} exceptions = {exceptions}")
90
+
91
+
92
+ def test_digits_123_n100():
93
+ """A={1,2,3}, N=100: verify representable count."""
94
+ count, reprs = count_representable_cpu(100, [1, 2, 3])
95
+ exceptions = sorted(set(range(1, 101)) - reprs)
96
+ print(f" A={{1,2,3}}, N=100: {count}/100 representable")
97
+ print(f" Exceptions: {exceptions}")
98
+ # Should cover most values
99
+ assert count >= 90, f"A={{1,2,3}} should cover >= 90% at N=100, got {count}"
100
+ # All values representable by A={1,2} should also be representable
101
+ count_12, reprs_12 = count_representable_cpu(100, [1, 2])
102
+ assert reprs_12.issubset(reprs), "A={1,2} representable should be subset of A={1,2,3}"
103
+ print(f"PASS: A={{1,2,3}}, N=100 -> {count} representable, {len(exceptions)} exceptions")
104
+
105
+
106
+ def test_digits_1_n1():
107
+ """Edge case: N=1, any digits containing 1."""
108
+ count, reprs = count_representable_cpu(1, [1])
109
+ assert count == 1, f"A={{1}}, N=1: expected 1, got {count}"
110
+ assert reprs == {1}, f"A={{1}}, N=1: should be {{1}}, got {reprs}"
111
+ print(f"PASS: A={{1}}, N=1 -> {count} representable")
112
+
113
+
114
+ def test_digits_2_n10():
115
+ """A={2}, N=10: only CFs with all partial quotients = 2."""
116
+ count, reprs = count_representable_cpu(10, [2])
117
+ # [0;2]=1/2, [0;2,2]=2/5, so denominators are 1 (trivial), 2, 5
118
+ # Actually [0;2,2,2] has q=2*5+2=12 > 10
119
+ # So: {1, 2, 5}
120
+ assert 1 in reprs and 2 in reprs and 5 in reprs
121
+ print(f" A={{2}}, N=10: representable = {sorted(reprs)}")
122
+ assert count == 3, f"A={{2}}, N=10: expected 3, got {count}"
123
+ print(f"PASS: A={{2}}, N=10 -> {count} representable = {sorted(reprs)}")
124
+
125
+
126
+ def test_digits_12345_small():
127
+ """A={1,2,3,4,5} at small N -- should cover almost everything."""
128
+ count, reprs = count_representable_cpu(50, [1, 2, 3, 4, 5])
129
+ print(f" A={{1,2,3,4,5}}, N=50 -> {count}/50 representable")
130
+ assert count >= 48, f"A={{1,2,3,4,5}}, N=50: expected >= 48, got {count}"
131
+ print(f"PASS: A={{1,2,3,4,5}}, N=50 -> {count} representable (>= 48)")
132
+
133
+
134
+ def test_monotonicity():
135
+ """Adding digits can only increase or maintain the count."""
136
+ c1, _ = count_representable_cpu(100, [1])
137
+ c12, _ = count_representable_cpu(100, [1, 2])
138
+ c123, _ = count_representable_cpu(100, [1, 2, 3])
139
+ c1234, _ = count_representable_cpu(100, [1, 2, 3, 4])
140
+ c12345, _ = count_representable_cpu(100, [1, 2, 3, 4, 5])
141
+ assert c1 <= c12 <= c123 <= c1234 <= c12345, \
142
+ f"Monotonicity failed: {c1}, {c12}, {c123}, {c1234}, {c12345}"
143
+ print(f"PASS: Monotonicity: {c1} <= {c12} <= {c123} <= {c1234} <= {c12345}")
144
+
145
+
146
+ def test_cf_recurrence():
147
+ """Verify the CF recurrence q_{n+1} = a_{n+1} * q_n + q_{n-1} directly."""
148
+ # [0; 3, 1, 4, 1, 5] should give denominators via the recurrence:
149
+ # q_prev=1, q=3 -> extend 1: q=1*3+1=4 -> extend 4: q=4*4+3=19
150
+ # -> extend 1: q=1*19+4=23 -> extend 5: q=5*23+19=134
151
+ q_prev, q_curr = 1, 3
152
+ expected_denoms = [3]
153
+ for a in [1, 4, 1, 5]:
154
+ q_new = a * q_curr + q_prev
155
+ expected_denoms.append(q_new)
156
+ q_prev, q_curr = q_curr, q_new
157
+ assert expected_denoms == [3, 4, 19, 23, 134], f"CF recurrence failed: {expected_denoms}"
158
+ print(f"PASS: CF recurrence [3,1,4,1,5] -> denominators {expected_denoms}")
159
+
160
+
161
+ def test_subset_inclusion():
162
+ """If A is a subset of B, then representable(A) is a subset of representable(B)."""
163
+ _, r1 = count_representable_cpu(200, [1])
164
+ _, r12 = count_representable_cpu(200, [1, 2])
165
+ _, r123 = count_representable_cpu(200, [1, 2, 3])
166
+ _, r45 = count_representable_cpu(200, [4, 5])
167
+ _, r12345 = count_representable_cpu(200, [1, 2, 3, 4, 5])
168
+
169
+ assert r1.issubset(r12), "A={1} not subset of A={1,2}"
170
+ assert r12.issubset(r123), "A={1,2} not subset of A={1,2,3}"
171
+ assert r123.issubset(r12345), "A={1,2,3} not subset of A={1,2,3,4,5}"
172
+ assert r45.issubset(r12345), "A={4,5} not subset of A={1,2,3,4,5}"
173
+ print(f"PASS: Subset inclusion verified for nested digit sets")
174
+
175
+
176
+ if __name__ == "__main__":
177
+ print("=" * 60)
178
+ print("Zaremba Density -- CPU Reference Tests")
179
+ print("=" * 60)
180
+ print()
181
+
182
+ tests = [
183
+ test_digits_1_n10,
184
+ test_digits_1_n100,
185
+ test_digits_12_n20,
186
+ test_digits_123_n100,
187
+ test_digits_1_n1,
188
+ test_digits_2_n10,
189
+ test_digits_12345_small,
190
+ test_monotonicity,
191
+ test_cf_recurrence,
192
+ test_subset_inclusion,
193
+ ]
194
+
195
+ passed = 0
196
+ failed = 0
197
+ for t in tests:
198
+ try:
199
+ t()
200
+ passed += 1
201
+ except AssertionError as e:
202
+ print(f"FAIL: {t.__name__}: {e}")
203
+ failed += 1
204
+ except Exception as e:
205
+ print(f"ERROR: {t.__name__}: {e}")
206
+ failed += 1
207
+ print()
208
+
209
+ print("=" * 60)
210
+ print(f"Results: {passed} passed, {failed} failed")
211
+ print("=" * 60)
212
+ sys.exit(0 if failed == 0 else 1)
src/zaremba_density.cu ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Zaremba Density CUDA Kernel — Torch-compatible version
3
+ *
4
+ * Enumerates all continued fraction denominators <= N with partial quotients
5
+ * from a given digit set A. Uses a persistent-thread work-stealing design
6
+ * with a bitset to track representable denominators.
7
+ *
8
+ * Original: zaremba_density_gpu.cu (standalone CLI)
9
+ * This version: torch C++ extension wrapper around the same kernel logic.
10
+ */
11
+
12
+ #include <cstdint>
13
+ #include <cstdio>
14
+ #include <cstring>
15
+
16
+ #include <cuda_runtime.h>
17
+
18
+ #define MAX_DIGITS 10
19
+ #define MAX_DEPTH 200
20
+
21
+ // Mark a denominator in the bitset
22
+ __device__ void mark(uint64_t d, uint8_t *bitset, uint64_t max_d) {
23
+ if (d < 1 || d > max_d) return;
24
+ uint64_t byte_idx = d >> 3;
25
+ uint8_t bit = 1 << (d & 7);
26
+ atomicOr((unsigned int*)&bitset[byte_idx & ~3],
27
+ (unsigned int)bit << (8 * (byte_idx & 3)));
28
+ }
29
+
30
+ // Persistent-thread kernel: each thread self-schedules prefixes via atomic counter
31
+ __global__ void enumerate_persistent(
32
+ uint64_t *prefixes, int num_prefixes,
33
+ int *digits, int num_digits,
34
+ uint8_t *bitset, uint64_t max_d,
35
+ int *progress)
36
+ {
37
+ struct { uint64_t p_prev, p, q_prev, q; } stack[MAX_DEPTH];
38
+
39
+ while (true) {
40
+ int my_prefix = atomicAdd(progress, 1);
41
+ if (my_prefix >= num_prefixes) return;
42
+
43
+ uint64_t pp0 = prefixes[my_prefix * 4 + 0];
44
+ uint64_t p0 = prefixes[my_prefix * 4 + 1];
45
+ uint64_t qp0 = prefixes[my_prefix * 4 + 2];
46
+ uint64_t q0 = prefixes[my_prefix * 4 + 3];
47
+
48
+ mark(q0, bitset, max_d);
49
+
50
+ int sp = 0;
51
+ for (int i = num_digits - 1; i >= 0; i--) {
52
+ uint64_t a = digits[i];
53
+ uint64_t q_new = a * q0 + qp0;
54
+ if (q_new > max_d || sp >= MAX_DEPTH) continue;
55
+ stack[sp].p_prev = p0; stack[sp].p = a * p0 + pp0;
56
+ stack[sp].q_prev = q0; stack[sp].q = q_new;
57
+ sp++;
58
+ }
59
+
60
+ while (sp > 0) {
61
+ sp--;
62
+ uint64_t pp = stack[sp].p_prev, p = stack[sp].p;
63
+ uint64_t qp = stack[sp].q_prev, q = stack[sp].q;
64
+ mark(q, bitset, max_d);
65
+ for (int i = num_digits - 1; i >= 0; i--) {
66
+ uint64_t a = digits[i];
67
+ uint64_t q_new = a * q + qp;
68
+ if (q_new > max_d || sp >= MAX_DEPTH) continue;
69
+ stack[sp].p_prev = p; stack[sp].p = a * p + pp;
70
+ stack[sp].q_prev = q; stack[sp].q = q_new;
71
+ sp++;
72
+ }
73
+ }
74
+ }
75
+ }
76
+
77
+ // Count set bits in the bitset
78
+ __global__ void count_marked(uint8_t *bitset, uint64_t max_d, uint64_t *count) {
79
+ uint64_t tid = blockIdx.x * (uint64_t)blockDim.x + threadIdx.x;
80
+ uint64_t max_byte = (max_d + 8) / 8;
81
+ if (tid >= max_byte) return;
82
+ uint8_t b = bitset[tid];
83
+ int bits = __popc((unsigned int)b);
84
+ if (tid == max_byte - 1) {
85
+ int valid_bits = (max_d % 8) + 1;
86
+ bits = __popc((unsigned int)(b & ((1 << valid_bits) - 1)));
87
+ }
88
+ if (bits > 0) atomicAdd((unsigned long long*)count, (unsigned long long)bits);
89
+ }
90
+
91
+ // C++ host function called from torch binding
92
+ extern "C" int64_t zaremba_count_representable(int64_t max_d, int *h_digits, int num_digits) {
93
+ // Generate prefixes on CPU
94
+ int PREFIX_DEPTH = 8;
95
+ if (max_d >= 1000000000LL) PREFIX_DEPTH = 15;
96
+
97
+ int max_prefixes = 20000000;
98
+ uint64_t *h_prefixes = new uint64_t[max_prefixes * 4];
99
+ int np = 0;
100
+
101
+ struct PfxEntry { uint64_t pp, p, qp, q; int depth; };
102
+ PfxEntry *stk = new PfxEntry[max_prefixes];
103
+ int ssp = 0;
104
+ for (int i = 0; i < num_digits; i++) {
105
+ stk[ssp] = {0, 1, 1, (uint64_t)h_digits[i], 1};
106
+ ssp++;
107
+ }
108
+ while (ssp > 0) {
109
+ ssp--;
110
+ uint64_t pp = stk[ssp].pp, p = stk[ssp].p;
111
+ uint64_t qp = stk[ssp].qp, q = stk[ssp].q;
112
+ int dep = stk[ssp].depth;
113
+ if (q > (uint64_t)max_d) continue;
114
+ if (dep >= PREFIX_DEPTH) {
115
+ if (np < max_prefixes) {
116
+ h_prefixes[np*4+0] = pp; h_prefixes[np*4+1] = p;
117
+ h_prefixes[np*4+2] = qp; h_prefixes[np*4+3] = q;
118
+ np++;
119
+ }
120
+ } else {
121
+ for (int i = num_digits - 1; i >= 0; i--) {
122
+ uint64_t qn = (uint64_t)h_digits[i] * q + qp;
123
+ if (qn > (uint64_t)max_d || ssp >= max_prefixes - 1) continue;
124
+ stk[ssp] = {p, (uint64_t)h_digits[i] * p + pp, q, qn, dep + 1};
125
+ ssp++;
126
+ }
127
+ }
128
+ }
129
+ delete[] stk;
130
+
131
+ // GPU allocation
132
+ uint64_t bitset_bytes = ((uint64_t)max_d + 8) / 8;
133
+ uint8_t *d_bs;
134
+ cudaMalloc(&d_bs, bitset_bytes);
135
+ cudaMemset(d_bs, 0, bitset_bytes);
136
+
137
+ int *d_digits;
138
+ cudaMalloc(&d_digits, num_digits * sizeof(int));
139
+ cudaMemcpy(d_digits, h_digits, num_digits * sizeof(int), cudaMemcpyHostToDevice);
140
+
141
+ uint64_t *d_prefixes;
142
+ cudaMalloc(&d_prefixes, (uint64_t)np * 4 * sizeof(uint64_t));
143
+ cudaMemcpy(d_prefixes, h_prefixes, (uint64_t)np * 4 * sizeof(uint64_t), cudaMemcpyHostToDevice);
144
+
145
+ int *d_progress;
146
+ cudaMalloc(&d_progress, sizeof(int));
147
+ cudaMemset(d_progress, 0, sizeof(int));
148
+
149
+ // Launch
150
+ int block_size = 256;
151
+ int grid_size = (np + block_size - 1) / block_size;
152
+ if (grid_size > 65535) grid_size = 65535;
153
+
154
+ enumerate_persistent<<<grid_size, block_size>>>(
155
+ d_prefixes, np, d_digits, num_digits, d_bs, (uint64_t)max_d, d_progress);
156
+ cudaDeviceSynchronize();
157
+
158
+ // Mark shallow denominators on CPU
159
+ uint8_t *h_bs = new uint8_t[bitset_bytes];
160
+ cudaMemcpy(h_bs, d_bs, bitset_bytes, cudaMemcpyDeviceToHost);
161
+ h_bs[0] |= (1 << 1); // d=1
162
+
163
+ PfxEntry *cstk = new PfxEntry[500000];
164
+ int csp = 0;
165
+ for (int i = 0; i < num_digits; i++) {
166
+ cstk[csp] = {0, 1, 1, (uint64_t)h_digits[i], 1};
167
+ csp++;
168
+ }
169
+ while (csp > 0) {
170
+ csp--;
171
+ uint64_t q = cstk[csp].q;
172
+ int dep = cstk[csp].depth;
173
+ if (q > (uint64_t)max_d) continue;
174
+ h_bs[q >> 3] |= (1 << (q & 7));
175
+ if (dep >= PREFIX_DEPTH) continue;
176
+ uint64_t pp = cstk[csp].pp, p = cstk[csp].p, qp = cstk[csp].qp;
177
+ for (int i = 0; i < num_digits; i++) {
178
+ uint64_t qn = (uint64_t)h_digits[i] * q + qp;
179
+ if (qn > (uint64_t)max_d || csp >= 499999) continue;
180
+ cstk[csp] = {p, (uint64_t)h_digits[i] * p + pp, q, qn, dep + 1};
181
+ csp++;
182
+ }
183
+ }
184
+ delete[] cstk;
185
+ cudaMemcpy(d_bs, h_bs, bitset_bytes, cudaMemcpyHostToDevice);
186
+
187
+ // Count on GPU
188
+ uint64_t *d_count;
189
+ cudaMalloc(&d_count, sizeof(uint64_t));
190
+ cudaMemset(d_count, 0, sizeof(uint64_t));
191
+ {
192
+ uint64_t max_byte = ((uint64_t)max_d + 8) / 8;
193
+ int gd = (int)((max_byte + 255) / 256);
194
+ count_marked<<<gd, 256>>>(d_bs, (uint64_t)max_d, d_count);
195
+ cudaDeviceSynchronize();
196
+ }
197
+ int64_t covered = 0;
198
+ cudaMemcpy(&covered, d_count, sizeof(uint64_t), cudaMemcpyDeviceToHost);
199
+
200
+ // Cleanup
201
+ cudaFree(d_count);
202
+ cudaFree(d_bs);
203
+ cudaFree(d_digits);
204
+ cudaFree(d_prefixes);
205
+ cudaFree(d_progress);
206
+ delete[] h_prefixes;
207
+ delete[] h_bs;
208
+
209
+ return covered;
210
+ }
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "torch_binding.h"
3
+
4
+ // Defined in the CUDA kernel
5
+ extern "C" int64_t zaremba_count_representable(int64_t max_d, int *h_digits, int num_digits);
6
+
7
+ torch::Tensor count_representable(int64_t max_d, torch::Tensor digits) {
8
+ TORCH_CHECK(digits.dtype() == torch::kInt32, "digits must be int32");
9
+ TORCH_CHECK(digits.dim() == 1, "digits must be 1-D");
10
+ TORCH_CHECK(digits.is_cpu(), "digits must be on CPU");
11
+ TORCH_CHECK(max_d > 0, "max_d must be positive");
12
+
13
+ int num_digits = digits.size(0);
14
+ int *h_digits = digits.data_ptr<int>();
15
+
16
+ int64_t count = zaremba_count_representable(max_d, h_digits, num_digits);
17
+
18
+ auto result = torch::tensor({count}, torch::dtype(torch::kInt64));
19
+ return result;
20
+ }
21
+
22
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
23
+ m.def("count_representable", &count_representable,
24
+ "Count denominators <= max_d representable with given CF digit set");
25
+ }
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/types.h>
4
+
5
+ torch::Tensor count_representable(int64_t max_d, torch::Tensor digits);