File size: 31,733 Bytes
1c59946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
// Fused HTM megakernel β€” SP + TM, all T timesteps in a single launch.
//
// Design rationale:
//   - Global top-K column selection requires cross-block synchronization at
//     every timestep (grid.sync is unreliable on WSL2/sm_86 without rdc=true).
//   - Replace with per-column threshold activation using local lateral
//     inhibition: column c activates if overlap[c]*boost[c] > threshold[c].
//     Threshold is a per-column running-EMA learned scalar that steers the
//     column's long-run activation rate toward the global sparsity target.
//   - This is biologically grounded (GABAergic local inhibition) and supported
//     by HTM theory (duty-cycle boost already drives this loop; we just
//     change which lever the EMA pulls).
//
// Launch shape:
//   grid  = min(device SM count, 16)  // hard cap β€” see below
//   block = 1024 threads = 32 warps
//   Each warp of 32 owns a contiguous column slice (n_columns / total_warps).
//
// Cross-block coherence:
//   - Ping-pong buffers for cell_active/cell_winner: write _a at even t,
//     read _b; reversed at odd t.
//   - Preferred path: cooperative launch + hardware whole-grid sync.
//   - Fallback path: software 3-slot rotating grid barrier for devices/drivers
//     that cannot do cooperative launch.
//
// 2026-04-16: grid_dim reduced from 28 to 16 after deadlock RCA. The previous
// cap of 28 relied on all blocks being concurrently resident on a 30-SM RTX
// 3060 Laptop. Under thermal throttling effective residency dropped to ~20-24,
// leaving scheduled blocks spinning on the software grid barrier waiting for
// peer blocks that would never run. 16 blocks is below any realistic residency
// floor and preserves enough warp parallelism (16*32 = 512 warps) to saturate
// memory bandwidth on the spatial-pooler stage.
//
// Kernel signature uses struct-by-value for pointers and config to stay
// inside cudarc's launch-arg count limit.

#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>

namespace cg = cooperative_groups;

// Maximum columns owned per cluster-block in DSMEM.
// Supports n_columns up to COLS_PER_CLUSTER_BLOCK_MAX * cluster_size.
// At cluster_size=16: supports up to 256*16=4096 columns.
// Each array costs 256*4 = 1024 bytes; three arrays = 3072 bytes per SM β€”
// well under the 228 KB H200 shared-memory cap.
#define COLS_PER_CLUSTER_BLOCK_MAX 256u

// Maximum input_bits supported by the TMA-multicast staging tile.
// At 32 KB this covers the production SDR width (16384 bits) with 2Γ— headroom.
// Total shared per SM: 32768 (tile) + 3072 (DSMEM float arrays) = ~35 KB β€”
// well under the 228 KB H200 limit.
//
// Expected speedup from TMA multicast input staging (T9/T11):
//   - Without staging: 16 SMs Γ— T Γ— (input_bits GMEM reads per timestep)
//   - With staging:    1 TMA DMA per timestep, shared reads from L1 thereafter
//   - Theoretical DRAM bandwidth reduction: ~16Γ— on input reads
//   - Wall-clock reduction estimate: -20 to -40 ms from reduced input fetch latency
#define INPUT_BITS_MAX 32768u

extern "C" {

struct FusedPtrs {
    unsigned long long syn_bit;
    unsigned long long syn_perm;
    unsigned long long boost;
    unsigned long long active_duty;
    unsigned long long inhibition_threshold;
    unsigned long long seg_cell_id;
    unsigned long long seg_syn_count;
    unsigned long long syn_presyn;
    unsigned long long tm_syn_perm;
    unsigned long long cell_seg_count;
    unsigned long long cell_active_a;
    unsigned long long cell_active_b;
    unsigned long long cell_winner_a;
    unsigned long long cell_winner_b;
    unsigned long long inputs;
    unsigned long long cols_out;
    unsigned long long anom_out;
    unsigned long long barrier_counters;
    unsigned long long step_scratch;
};

struct FusedConfig {
    // SP constants
    unsigned int input_bits;
    unsigned int n_columns;
    unsigned int synapses_per_col;
    float        conn_thr;
    float        sp_inc;
    float        sp_dec;
    float        sparsity_target;
    float        duty_alpha;
    float        thr_adapt_rate;
    // TM constants
    unsigned int cells_per_column;
    unsigned int n_cells;
    unsigned int bits_words;
    unsigned int max_segments_per_cell;
    unsigned int synapses_per_segment;
    unsigned int activation_threshold;
    unsigned int learning_threshold;
    unsigned int max_new_synapses;
    int          conn_thr_i16;
    int          perm_inc_i16;
    int          perm_dec_i16;
    int          predicted_seg_dec_i16;
    int          initial_perm_i16;
    // Loop constants
    unsigned int T;
    unsigned int learn;
    unsigned int iter_seed;
    unsigned int cooperative_grid_sync;
};

// Hardware cluster barrier using Hopper sm_90a cooperative_groups::this_cluster().sync().
// Replaces the former software Decoupled Look-Back (DLB) atomic-spin barrier.
//
// cluster::sync() is a single PTX instruction (barrier.cluster) that resolves
// in ~10-40 ns inside the cluster, with no device-level serialization.
// Multiple clusters (one per HTM region) run fully concurrently β€” bounded
// only by SM count (8 clusters Γ— 16 SMs = 128 ≀ 132 on H200).
//
// The flags / expected / phase / cooperative_grid_sync parameters are kept
// in the signature for call-site compatibility but are unused.
__device__ static inline void fused_grid_barrier(cg::grid_group grid,
                                                 unsigned int * /* flags β€” unused */,
                                                 unsigned int /* expected β€” unused */,
                                                 unsigned int /* phase β€” unused */,
                                                 unsigned int /* cooperative_grid_sync β€” unused */) {
#if __CUDA_ARCH__ >= 900
    // Hopper+ : hardware cluster barrier (~10-40 ns)
    auto cluster = cg::this_cluster();
    cluster.sync();
#else
    // Pre-Hopper (sm_80, sm_86, sm_89): grid-level cooperative sync.
    // Requires cooperative kernel launch. ~us-ms range, adequate for HTM
    // workload (kernel launch frequency is low).
    grid.sync();
#endif
}

__device__ static inline unsigned int warp_sum_u32(unsigned int v) {
    for (int off = 16; off > 0; off >>= 1) {
        v += __shfl_down_sync(0xffffffffu, v, off);
    }
    return v;
}

// Core kernel body β€” works for both single-region and batched launches.
// Single-region: caller passes the one FusedPtrs struct.
// Batched: each block reads its region's FusedPtrs via blockIdx.y before
// calling this. State is independent per region (each region owns its own
// GPU buffers); grid.sync() is the only cross-block primitive and it
// spans ALL blocks in the grid (harmless over-sync across regions).
__device__ static inline
void htm_fused_step_body(const FusedPtrs& P, const FusedConfig& cfg) {
    cg::grid_group grid = cg::this_grid();
    // Cast pointers.
    const unsigned int  * __restrict__ syn_bit               = (const unsigned int*)P.syn_bit;
    float               * __restrict__ syn_perm              = (float*)P.syn_perm;
    float               * __restrict__ boost                 = (float*)P.boost;
    float               * __restrict__ active_duty           = (float*)P.active_duty;
    float               * __restrict__ inhibition_threshold  = (float*)P.inhibition_threshold;
    unsigned int        * __restrict__ seg_cell_id           = (unsigned int*)P.seg_cell_id;
    unsigned int        * __restrict__ seg_syn_count         = (unsigned int*)P.seg_syn_count;
    unsigned int        * __restrict__ syn_presyn            = (unsigned int*)P.syn_presyn;
    short               * __restrict__ tm_syn_perm           = (short*)P.tm_syn_perm;
    unsigned int        * __restrict__ cell_seg_count        = (unsigned int*)P.cell_seg_count;
    unsigned int        * __restrict__ cell_active_a         = (unsigned int*)P.cell_active_a;
    unsigned int        * __restrict__ cell_active_b         = (unsigned int*)P.cell_active_b;
    unsigned int        * __restrict__ cell_winner_a         = (unsigned int*)P.cell_winner_a;
    unsigned int        * __restrict__ cell_winner_b         = (unsigned int*)P.cell_winner_b;
    const unsigned char * __restrict__ inputs                = (const unsigned char*)P.inputs;
    unsigned char       * __restrict__ cols_out              = (unsigned char*)P.cols_out;
    float               * __restrict__ anom_out              = (float*)P.anom_out;
    unsigned int        * __restrict__ barrier_counters      = (unsigned int*)P.barrier_counters;
    unsigned int        * __restrict__ step_scratch          = (unsigned int*)P.step_scratch;

    const unsigned int tid     = threadIdx.x;
    const unsigned int lane    = tid & 31u;
    const unsigned int warp    = tid >> 5;
    const unsigned int warps_per_block = blockDim.x >> 5;
    const unsigned int gwarp   = blockIdx.x * warps_per_block + warp;
    const unsigned int n_warps = gridDim.x * warps_per_block;

    const unsigned int n_cols  = cfg.n_columns;
    const unsigned int col_lo  = (gwarp * n_cols) / n_warps;
    const unsigned int col_hi  = ((gwarp + 1) * n_cols) / n_warps;

    unsigned int phase = 0u;

    // =========================================================
    // DSMEM: Cluster-distributed shared memory for hot per-column
    // state (inhibition_threshold, boost, active_duty).
    //
    // On Hopper (sm_90+): Each block in the cluster owns a contiguous
    // slice of columns in its own __shared__ arrays. Any block can
    // peer-read another block's slice via cluster.map_shared_rank().
    //
    // On Ampere (sm_86) and other pre-Hopper: No cluster support.
    // Read/write directly from/to global memory (inhibition_threshold,
    // boost, active_duty device pointers). Slightly higher latency but
    // functionally correct.
    // =========================================================

#if __CUDA_ARCH__ >= 900
    // Hopper+ cluster path
    auto cluster = cg::this_cluster();
    const unsigned int cluster_block_rank = cluster.block_rank();  // 0..cluster_size-1
    const unsigned int cluster_sz         = cluster.num_blocks();  // == gridDim.x (≀16)
#else
    // Pre-Hopper: no cluster, each block is independent.
    const unsigned int cluster_block_rank = blockIdx.x;
    const unsigned int cluster_sz         = gridDim.x;
#endif

    // Partition n_cols evenly across cluster blocks.
    // Each block owns cols_per_block columns starting at my_col_start.
    const unsigned int cols_per_block =
        (n_cols + cluster_sz - 1u) / cluster_sz;               // ceil div
    const unsigned int my_col_start =
        cluster_block_rank * cols_per_block;
    const unsigned int my_col_end =
        (my_col_start + cols_per_block < n_cols)
            ? (my_col_start + cols_per_block) : n_cols;        // clamp

#if __CUDA_ARCH__ >= 900
    // Cluster-distributed shared memory arrays.
    // Each block holds at most COLS_PER_CLUSTER_BLOCK_MAX floats per array.
    // Peer blocks address into each other's smem via map_shared_rank.
    __shared__ float s_inhib_thr [COLS_PER_CLUSTER_BLOCK_MAX];
    __shared__ float s_boost     [COLS_PER_CLUSTER_BLOCK_MAX];
    __shared__ float s_active_duty[COLS_PER_CLUSTER_BLOCK_MAX];
#endif

    // TMA multicast input staging tile (T9) β€” HOPPER ONLY.
    //
    // On Hopper: cg::memcpy_async with cluster scope multicasts input to all
    // 16 SMs, reducing DRAM traffic by ~16Γ—.
    // On Ampere: 32 KB smem allocation exceeds per-block budget when
    // cooperatively launched (48 KB total, registers eat the rest). Skip the
    // tile entirely β€” Stage A reads from GMEM directly (original path).
#if __CUDA_ARCH__ >= 900
    __shared__ __align__(16) unsigned char s_input_tile[INPUT_BITS_MAX];
#endif

#if __CUDA_ARCH__ >= 900
    // Initial GMEM β†’ smem load (reads state from previous forward call).
    // Each block loads only its own slice; tid strides across the slice.
    for (unsigned int c = my_col_start + tid; c < my_col_end; c += blockDim.x) {
        const unsigned int off = c - my_col_start;
        s_inhib_thr [off] = inhibition_threshold[c];
        s_boost     [off] = boost[c];
        s_active_duty[off] = active_duty[c];
    }

    // All blocks in the cluster must finish loading before any block
    // starts reading peer smem inside the T-loop.
    cluster.sync();
#else
    // Pre-Hopper: no smem caching needed β€” reads go directly to GMEM.
    // Grid sync ensures all blocks have completed Phase 0 init before T-loop.
    grid.sync();
#endif

    const unsigned int S   = cfg.synapses_per_col;
    const unsigned int cpc = cfg.cells_per_column;
    const unsigned int SPS = cfg.synapses_per_segment;
    const unsigned int MSC = cfg.max_segments_per_cell;

    // Main timestep loop.
    for (unsigned int t = 0u; t < cfg.T; t++) {
        const unsigned int inp_off      = t * cfg.input_bits;
        const unsigned int col_base_out = t * n_cols;

        unsigned int * curr_active = (t & 1u) ? cell_active_b : cell_active_a;
        unsigned int * prev_active = (t & 1u) ? cell_active_a : cell_active_b;
        unsigned int * curr_winner = (t & 1u) ? cell_winner_b : cell_winner_a;
        unsigned int * prev_winner = (t & 1u) ? cell_winner_a : cell_winner_b;

        // ---- Phase 0: clear curr bitsets for my cell range ----
        const unsigned int my_cell_lo = col_lo * cpc;
        const unsigned int my_cell_hi = col_hi * cpc;
        if (cpc == 32u) {
            // Fast path: one word per column.
            for (unsigned int c = col_lo + lane; c < col_hi; c += 32u) {
                curr_active[c] = 0u;
                curr_winner[c] = 0u;
            }
        } else {
            for (unsigned int cell = my_cell_lo + lane; cell < my_cell_hi; cell += 32u) {
                unsigned int w = cell >> 5;
                unsigned int m = 1u << (cell & 31u);
                atomicAnd(&curr_active[w], ~m);
                atomicAnd(&curr_winner[w], ~m);
            }
        }

        // Block 0, lane 0, warp 0 resets step-scratch counters.
        if (blockIdx.x == 0u && tid == 0u) {
            step_scratch[0] = 0u;
            step_scratch[1] = 0u;
        }

        // ---- BARRIER 1 ----
        // Fence: make the above clear-bitsets + scratch writes globally
        // visible before peer blocks observe "barrier arrived".
        __threadfence();
        fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);

        // =========================================================
        // T9: TMA MULTICAST INPUT STAGING
        //
        // Issue a single cluster-scope async DMA to broadcast this
        // timestep's input slice into s_input_tile across all 16 SMs
        // in the cluster simultaneously.  On Hopper sm_90a,
        // cg::memcpy_async with cluster scope maps to the TMA
        // hardware unit (cp.async.bulk.tensor multicast), reducing
        // DRAM input traffic by ~16Γ— vs each block fetching its own
        // copy from GMEM.
        //
        // The staging is gated on cfg.input_bits <= INPUT_BITS_MAX.
        // If the tile is too small (custom large input_bits), we fall
        // back to per-thread GMEM reads in Stage A (identical to the
        // original path; use_input_tile==false).
        //
        // Ordering: BARRIER 1 completes before we issue the DMA.
        // The DMA completes before Stage A reads s_input_tile.
        // =========================================================
#if __CUDA_ARCH__ >= 900
        const bool use_input_tile = (cfg.input_bits <= INPUT_BITS_MAX);
        if (use_input_tile) {
            auto tb = cg::this_thread_block();
            cg::memcpy_async(tb, s_input_tile,
                             inputs + inp_off,
                             cfg.input_bits);
            cg::wait(tb);
            cluster.sync();
        }
#else
        const bool use_input_tile = false;
#endif

        // =========================================================
        // STAGE A: Spatial Pooler
        //
        // Hot per-column state (boost, inhibition_threshold,
        // active_duty) is served from cluster DSMEM rather than
        // GMEM for each of the T timesteps.  GMEM is written on
        // update so state persists across forward calls.
        // =========================================================
        for (unsigned int c = col_lo; c < col_hi; c++) {
            unsigned int base = c * S;
            unsigned int local = 0u;
            for (unsigned int s = lane; s < S; s += 32u) {
                unsigned int b = syn_bit[base + s];
                float p = syn_perm[base + s];
                // T9: read from cluster-broadcast tile when available;
                // fall back to direct GMEM when input_bits > INPUT_BITS_MAX.
#if __CUDA_ARCH__ >= 900
                unsigned int inp_byte = use_input_tile
                    ? (unsigned int)s_input_tile[b]
                    : (unsigned int)inputs[inp_off + b];
#else
                unsigned int inp_byte = (unsigned int)inputs[inp_off + b];
#endif
                unsigned int hit = ((inp_byte != 0u) && (p >= cfg.conn_thr)) ? 1u : 0u;
                local += hit;
            }
            unsigned int overlap = warp_sum_u32(local);
            overlap = __shfl_sync(0xffffffffu, overlap, 0);

            // Read boost + threshold for column c.
#if __CUDA_ARCH__ >= 900
            // Hopper: read from cluster-distributed shared memory.
            const unsigned int owner_block  = c / cols_per_block;
            const unsigned int owner_offset = c - owner_block * cols_per_block;
            float boost_val = cluster.map_shared_rank(s_boost,      owner_block)[owner_offset];
            float thr       = cluster.map_shared_rank(s_inhib_thr,  owner_block)[owner_offset];
#else
            // Pre-Hopper: read directly from global memory.
            float boost_val = boost[c];
            float thr       = inhibition_threshold[c];
#endif

            float boosted = (float)overlap * boost_val;
            unsigned int is_active = (boosted > thr) ? 1u : 0u;

            if (lane == 0) {
                cols_out[col_base_out + c] = (unsigned char)is_active;
                if (is_active) {
                    atomicAdd(&step_scratch[0], 1u);
                }
            }

            // SP learn (Hebbian) on active columns.
            // T9: use tile for input reads here too.
            if (cfg.learn && is_active) {
                for (unsigned int s = lane; s < S; s += 32u) {
                    unsigned int b = syn_bit[base + s];
                    float p = syn_perm[base + s];
#if __CUDA_ARCH__ >= 900
                    unsigned int inp_byte = use_input_tile
                        ? (unsigned int)s_input_tile[b]
                        : (unsigned int)inputs[inp_off + b];
#else
                    unsigned int inp_byte = (unsigned int)inputs[inp_off + b];
#endif
                    if (inp_byte != 0u) {
                        p += cfg.sp_inc;
                        if (p > 1.0f) p = 1.0f;
                    } else {
                        p -= cfg.sp_dec;
                        if (p < 0.0f) p = 0.0f;
                    }
                    syn_perm[base + s] = p;
                }
            }

            // active_duty EMA + threshold adaptation.
            // Writes go to both DSMEM (hot path, Hopper only) and GMEM (persistence).
            if (lane == 0) {
#if __CUDA_ARCH__ >= 900
                float ad = cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset];
#else
                float ad = active_duty[c];
#endif
                float sample = is_active ? 1.0f : 0.0f;
                ad = (1.0f - cfg.duty_alpha) * ad + cfg.duty_alpha * sample;

#if __CUDA_ARCH__ >= 900
                // Writeback: peer smem (for next timestep read) + GMEM (persistence).
                cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset] = ad;
#endif
                active_duty[c] = ad;

                // Threshold steers toward target sparsity.
                float err = ad - cfg.sparsity_target;
                float new_thr = thr + cfg.thr_adapt_rate * err * 100.0f;
                if (new_thr < 0.1f) new_thr = 0.1f;
                if (new_thr > 1000.0f) new_thr = 1000.0f;

#if __CUDA_ARCH__ >= 900
                // Writeback: peer smem (for next timestep read) + GMEM (persistence).
                cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset] = new_thr;
#endif
                inhibition_threshold[c] = new_thr;
            }
        }

        // ---- DSMEM WRITEBACK SYNC: peer-smem writes must be visible cluster-wide ----
        //
        // On Hopper: cluster.sync() ensures all peer smem writes from this
        // timestep are visible to all blocks before Stage B / next t.
        // On pre-Hopper: no smem peer writes occur (all state in GMEM),
        // so no extra sync needed here β€” the grid barrier below suffices.
#if __CUDA_ARCH__ >= 900
        cluster.sync();
#endif

        // ---- BARRIER 2: SP active_mask must be visible before TM reads ----
        // Fence: flush cols_out + active_duty + inhibition_threshold + step_scratch
        // writes to global memory before peers advance past this barrier.
        __threadfence();
        fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);

        // =========================================================
        // STAGE B: Temporal Memory
        // =========================================================
        for (unsigned int c = col_lo; c < col_hi; c++) {
            unsigned int col_active = cols_out[col_base_out + c];
            if (col_active == 0u) continue;

            unsigned int base_cell = c * cpc;
            unsigned int any_predicted = 0u;
            unsigned int best_seg_id_for_grow = 0xFFFFFFFFu;
            unsigned int best_pot_count = 0u;

            for (unsigned int k = 0u; k < cpc; k++) {
                unsigned int cell = base_cell + k;
                unsigned int n_segs_here = cell_seg_count[cell];
                if (n_segs_here > MSC) n_segs_here = MSC;
                if (n_segs_here == 0u) continue;

                unsigned int seg_base_id = cell * MSC;
                unsigned int cell_is_predictive = 0u;

                for (unsigned int ls = 0u; ls < n_segs_here; ls++) {
                    unsigned int seg = seg_base_id + ls;
                    unsigned int n_syn = seg_syn_count[seg];
                    if (n_syn == 0u) continue;
                    unsigned int syn_base = seg * SPS;

                    unsigned int l_conn = 0u;
                    unsigned int l_pot  = 0u;
                    for (unsigned int s = lane; s < n_syn; s += 32u) {
                        unsigned int presyn = syn_presyn[syn_base + s];
                        unsigned int w = prev_active[presyn >> 5];
                        unsigned int bit = (w >> (presyn & 31u)) & 1u;
                        if (bit) {
                            l_pot += 1u;
                            int p = (int)tm_syn_perm[syn_base + s];
                            if (p >= cfg.conn_thr_i16) l_conn += 1u;
                        }
                    }
                    unsigned int tot_conn = warp_sum_u32(l_conn);
                    unsigned int tot_pot  = warp_sum_u32(l_pot);
                    tot_conn = __shfl_sync(0xffffffffu, tot_conn, 0);
                    tot_pot  = __shfl_sync(0xffffffffu, tot_pot, 0);

                    if (tot_conn >= cfg.activation_threshold) cell_is_predictive = 1u;
                    if (tot_pot >= cfg.learning_threshold && tot_pot > best_pot_count) {
                        best_pot_count = tot_pot;
                        best_seg_id_for_grow = seg;
                    }

                    // Reinforce predicted-and-correct segment.
                    if (cfg.learn && tot_conn >= cfg.activation_threshold) {
                        for (unsigned int s = lane; s < n_syn; s += 32u) {
                            unsigned int presyn = syn_presyn[syn_base + s];
                            unsigned int w = prev_active[presyn >> 5];
                            unsigned int bit = (w >> (presyn & 31u)) & 1u;
                            int p = (int)tm_syn_perm[syn_base + s];
                            if (bit) {
                                int np = p + cfg.perm_inc_i16;
                                if (np > 32767) np = 32767;
                                tm_syn_perm[syn_base + s] = (short)np;
                            } else {
                                int np = p - cfg.perm_dec_i16;
                                if (np < 0) np = 0;
                                tm_syn_perm[syn_base + s] = (short)np;
                            }
                        }
                    }
                }

                if (cell_is_predictive) {
                    any_predicted = 1u;
                    if (lane == 0) {
                        unsigned int w = cell >> 5;
                        unsigned int m = 1u << (cell & 31u);
                        atomicOr(&curr_active[w], m);
                        atomicOr(&curr_winner[w], m);
                    }
                }
            }

            // BURST if no predicted.
            if (!any_predicted) {
                if (lane == 0) {
                    for (unsigned int k = 0u; k < cpc; k++) {
                        unsigned int cell = base_cell + k;
                        unsigned int w = cell >> 5;
                        unsigned int m = 1u << (cell & 31u);
                        atomicOr(&curr_active[w], m);
                    }
                    unsigned int win = base_cell;
                    unsigned int ww = win >> 5;
                    unsigned int wm = 1u << (win & 31u);
                    atomicOr(&curr_winner[ww], wm);
                    atomicAdd(&step_scratch[1], 1u);
                }

                if (cfg.learn) {
                    unsigned int target_seg;
                    unsigned int existing_syn;
                    if (best_seg_id_for_grow != 0xFFFFFFFFu) {
                        // Reuse best matching segment.
                        target_seg = best_seg_id_for_grow;
                        existing_syn = seg_syn_count[target_seg];
                        target_seg = __shfl_sync(0xffffffffu, target_seg, 0);
                        existing_syn = __shfl_sync(0xffffffffu, existing_syn, 0);

                        // Reinforce its existing synapses.
                        unsigned int syn_base = target_seg * SPS;
                        for (unsigned int s = lane; s < existing_syn; s += 32u) {
                            unsigned int presyn = syn_presyn[syn_base + s];
                            unsigned int w = prev_active[presyn >> 5];
                            unsigned int bit = (w >> (presyn & 31u)) & 1u;
                            int p = (int)tm_syn_perm[syn_base + s];
                            if (bit) {
                                int np = p + cfg.perm_inc_i16;
                                if (np > 32767) np = 32767;
                                tm_syn_perm[syn_base + s] = (short)np;
                            } else {
                                int np = p - cfg.perm_dec_i16;
                                if (np < 0) np = 0;
                                tm_syn_perm[syn_base + s] = (short)np;
                            }
                        }
                    } else {
                        // Allocate new segment on winner cell (cell 0 of col).
                        unsigned int new_seg = 0u;
                        if (lane == 0) {
                            unsigned int winner_cell = base_cell;
                            unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
                            if (slot >= MSC) slot = slot % MSC;
                            new_seg = winner_cell * MSC + slot;
                            seg_cell_id[new_seg] = winner_cell;
                            seg_syn_count[new_seg] = 0u;
                        }
                        target_seg = __shfl_sync(0xffffffffu, new_seg, 0);
                        existing_syn = 0u;
                    }

                    // Grow synapses to prev_winner cells β€” lane 0 serialized.
                    unsigned int room = (SPS > existing_syn) ? (SPS - existing_syn) : 0u;
                    unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
                    if (lane == 0 && max_grow > 0u) {
                        unsigned int syn_base = target_seg * SPS;
                        unsigned int grown = 0u;
                        unsigned int start_off = (c * 2654435761u + cfg.iter_seed + t) % cfg.bits_words;
                        for (unsigned int w_off = 0u;
                             w_off < cfg.bits_words && grown < max_grow;
                             w_off++) {
                            unsigned int widx = (start_off + w_off) % cfg.bits_words;
                            unsigned int word = prev_winner[widx];
                            while (word != 0u && grown < max_grow) {
                                unsigned int bit_pos = __ffs(word) - 1u;
                                word &= ~(1u << bit_pos);
                                unsigned int cell_id = widx * 32u + bit_pos;
                                if (cell_id >= cfg.n_cells) continue;
                                bool exists = false;
                                for (unsigned int es = 0u; es < existing_syn + grown; es++) {
                                    if (syn_presyn[syn_base + es] == cell_id) { exists = true; break; }
                                }
                                if (exists) continue;
                                unsigned int write_idx = existing_syn + grown;
                                if (write_idx >= SPS) break;
                                syn_presyn[syn_base + write_idx] = cell_id;
                                tm_syn_perm[syn_base + write_idx] = (short)cfg.initial_perm_i16;
                                grown++;
                            }
                        }
                        if (grown > 0u) {
                            seg_syn_count[target_seg] = existing_syn + grown;
                        }
                    }
                }
            }
        }

        // ---- BARRIER 3: TM writes complete before anomaly + next-step read ----
        // Fence: flush curr_active/curr_winner bitsets + tm_syn_perm +
        // seg_syn_count + syn_presyn before peers advance and consume them as
        // prev_active/prev_winner at t+1.
        __threadfence();
        fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);

        // Write anomaly for step t.
        if (blockIdx.x == 0u && tid == 0u) {
            unsigned int total = step_scratch[0];
            unsigned int bad   = step_scratch[1];
            float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
            anom_out[t] = anom;
        }
    }
}

// Single-region kernel (legacy call site).
__global__ __launch_bounds__(256, 2)
void htm_fused_step(FusedPtrs P, FusedConfig cfg) {
    htm_fused_step_body(P, cfg);
}

// Batched kernel: one cooperative launch for B regions. grid.y = B,
// grid.x = per-region block count. Each block reads its region's
// FusedPtrs from the device array via blockIdx.y.
__global__ __launch_bounds__(256, 2)
void htm_fused_step_batched(const FusedPtrs* __restrict__ P_arr, FusedConfig cfg) {
    const FusedPtrs P = P_arr[blockIdx.y];
    htm_fused_step_body(P, cfg);
}

} // extern "C"