// TM grow+reinforce kernel. // // For each bursting column: // If col_best_match[col] is non-zero (i.e. at least one matching segment // with num_active_potential >= learning_threshold exists on cells in this col): // Target = that matching segment. // Reinforce its existing synapses: +inc if presyn in prev_active, -dec otherwise. // Grow up to (max_new - current_syn_count) additional synapses to prev_winners. // Else: // Allocate a fresh segment slot on winner cell (cell 0 of col). // Grow up to max_new synapses to prev_winners (no reinforce needed — new seg). // // This mirrors the CPU TM burst logic. struct TmConfig { unsigned int activation_threshold; unsigned int learning_threshold; unsigned int cells_per_column; unsigned int synapses_per_segment; unsigned int n_segments; unsigned int n_cells; unsigned int max_segments_per_cell; unsigned int max_new_synapses; int conn_thr_i16; int perm_inc_i16; int perm_dec_i16; int predicted_seg_dec_i16; int initial_perm_i16; unsigned int iter_seed; unsigned int n_cols; unsigned int bits_words; }; extern "C" __global__ void tm_grow( unsigned int * __restrict__ seg_cell_id, unsigned int * __restrict__ seg_syn_count, unsigned int * __restrict__ syn_presyn, short * __restrict__ syn_perm, unsigned int * __restrict__ cell_seg_count, const unsigned int * __restrict__ burst_cols_flat, const unsigned int * __restrict__ burst_cols_count, const unsigned int * __restrict__ prev_winner_bits, const unsigned int * __restrict__ prev_active_bits, const unsigned int * __restrict__ col_best_match, TmConfig cfg ) { const unsigned int b = blockIdx.x; const unsigned int n_burst_cols = burst_cols_count[0]; if (b >= n_burst_cols) return; const unsigned int tid = threadIdx.x; const unsigned int col = burst_cols_flat[b]; __shared__ unsigned int shared_seg_id; __shared__ unsigned int shared_existing_syn_count; __shared__ unsigned int shared_grown; __shared__ unsigned int shared_is_new; __shared__ unsigned int shared_start_offset; if (tid == 0) { unsigned int match_key = col_best_match[col]; if (match_key != 0u) { // Reuse matching segment. unsigned int seg_id = match_key & 0x1FFFFFu; shared_seg_id = seg_id; shared_existing_syn_count = seg_syn_count[seg_id]; shared_is_new = 0u; } else { // Allocate new segment on winner cell (cell 0 of col). unsigned int winner_cell = col * cfg.cells_per_column; unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u); if (slot >= cfg.max_segments_per_cell) { slot = slot % cfg.max_segments_per_cell; } unsigned int seg_id = winner_cell * cfg.max_segments_per_cell + slot; seg_cell_id[seg_id] = winner_cell; seg_syn_count[seg_id] = 0; shared_seg_id = seg_id; shared_existing_syn_count = 0u; shared_is_new = 1u; } shared_grown = 0u; shared_start_offset = (b * 2654435761u + cfg.iter_seed) % cfg.bits_words; } __syncthreads(); const unsigned int seg_id = shared_seg_id; const unsigned int seg_base = seg_id * cfg.synapses_per_segment; const unsigned int existing_syn = shared_existing_syn_count; const unsigned int is_new = shared_is_new; const unsigned int start = shared_start_offset; // PHASE 1: If reusing, reinforce existing synapses. if (!is_new) { for (unsigned int s = tid; s < existing_syn; s += 32u) { unsigned int presyn = syn_presyn[seg_base + s]; unsigned int word = prev_active_bits[presyn >> 5]; unsigned int bit = (word >> (presyn & 31u)) & 1u; int p = (int)syn_perm[seg_base + s]; if (bit) { int np = p + cfg.perm_inc_i16; if (np > 32767) np = 32767; syn_perm[seg_base + s] = (short)np; } else { int np = p - cfg.perm_dec_i16; if (np < 0) np = 0; syn_perm[seg_base + s] = (short)np; } } __syncthreads(); } // PHASE 2: Grow up to `max_new_synapses` (or room) synapses to prev_winners // that aren't already presynaptic to this segment. const unsigned int room = (cfg.synapses_per_segment > existing_syn) ? (cfg.synapses_per_segment - existing_syn) : 0u; const unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room; for (unsigned int w_off = 0; w_off < cfg.bits_words; w_off += 32u) { if (shared_grown >= max_grow) break; unsigned int widx = (start + w_off + tid) % cfg.bits_words; unsigned int word = prev_winner_bits[widx]; while (word != 0u) { if (shared_grown >= max_grow) break; unsigned int bit_pos = __ffs(word) - 1u; word &= ~(1u << bit_pos); unsigned int cell = widx * 32u + bit_pos; if (cell >= cfg.n_cells) continue; // Skip if already presynaptic (O(existing_syn) scan; usually small). bool exists = false; for (unsigned int s = 0; s < existing_syn; s++) { if (syn_presyn[seg_base + s] == cell) { exists = true; break; } } if (exists) continue; unsigned int slot = atomicAdd(&shared_grown, 1u); if (slot >= max_grow) break; unsigned int write_idx = existing_syn + slot; if (write_idx >= cfg.synapses_per_segment) break; syn_presyn[seg_base + write_idx] = cell; syn_perm[seg_base + write_idx] = (short)cfg.initial_perm_i16; } } __syncthreads(); if (tid == 0) { unsigned int grown = shared_grown; if (grown > max_grow) grown = max_grow; unsigned int new_count = existing_syn + grown; if (new_count > cfg.synapses_per_segment) new_count = cfg.synapses_per_segment; seg_syn_count[seg_id] = new_count; } }