icarus112's picture
Upload folder using huggingface_hub
1c59946 verified
// TM grow+reinforce kernel.
//
// For each bursting column:
// If col_best_match[col] is non-zero (i.e. at least one matching segment
// with num_active_potential >= learning_threshold exists on cells in this col):
// Target = that matching segment.
// Reinforce its existing synapses: +inc if presyn in prev_active, -dec otherwise.
// Grow up to (max_new - current_syn_count) additional synapses to prev_winners.
// Else:
// Allocate a fresh segment slot on winner cell (cell 0 of col).
// Grow up to max_new synapses to prev_winners (no reinforce needed — new seg).
//
// This mirrors the CPU TM burst logic.
struct TmConfig {
unsigned int activation_threshold;
unsigned int learning_threshold;
unsigned int cells_per_column;
unsigned int synapses_per_segment;
unsigned int n_segments;
unsigned int n_cells;
unsigned int max_segments_per_cell;
unsigned int max_new_synapses;
int conn_thr_i16;
int perm_inc_i16;
int perm_dec_i16;
int predicted_seg_dec_i16;
int initial_perm_i16;
unsigned int iter_seed;
unsigned int n_cols;
unsigned int bits_words;
};
extern "C" __global__
void tm_grow(
unsigned int * __restrict__ seg_cell_id,
unsigned int * __restrict__ seg_syn_count,
unsigned int * __restrict__ syn_presyn,
short * __restrict__ syn_perm,
unsigned int * __restrict__ cell_seg_count,
const unsigned int * __restrict__ burst_cols_flat,
const unsigned int * __restrict__ burst_cols_count,
const unsigned int * __restrict__ prev_winner_bits,
const unsigned int * __restrict__ prev_active_bits,
const unsigned int * __restrict__ col_best_match,
TmConfig cfg
) {
const unsigned int b = blockIdx.x;
const unsigned int n_burst_cols = burst_cols_count[0];
if (b >= n_burst_cols) return;
const unsigned int tid = threadIdx.x;
const unsigned int col = burst_cols_flat[b];
__shared__ unsigned int shared_seg_id;
__shared__ unsigned int shared_existing_syn_count;
__shared__ unsigned int shared_grown;
__shared__ unsigned int shared_is_new;
__shared__ unsigned int shared_start_offset;
if (tid == 0) {
unsigned int match_key = col_best_match[col];
if (match_key != 0u) {
// Reuse matching segment.
unsigned int seg_id = match_key & 0x1FFFFFu;
shared_seg_id = seg_id;
shared_existing_syn_count = seg_syn_count[seg_id];
shared_is_new = 0u;
} else {
// Allocate new segment on winner cell (cell 0 of col).
unsigned int winner_cell = col * cfg.cells_per_column;
unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
if (slot >= cfg.max_segments_per_cell) {
slot = slot % cfg.max_segments_per_cell;
}
unsigned int seg_id = winner_cell * cfg.max_segments_per_cell + slot;
seg_cell_id[seg_id] = winner_cell;
seg_syn_count[seg_id] = 0;
shared_seg_id = seg_id;
shared_existing_syn_count = 0u;
shared_is_new = 1u;
}
shared_grown = 0u;
shared_start_offset = (b * 2654435761u + cfg.iter_seed) % cfg.bits_words;
}
__syncthreads();
const unsigned int seg_id = shared_seg_id;
const unsigned int seg_base = seg_id * cfg.synapses_per_segment;
const unsigned int existing_syn = shared_existing_syn_count;
const unsigned int is_new = shared_is_new;
const unsigned int start = shared_start_offset;
// PHASE 1: If reusing, reinforce existing synapses.
if (!is_new) {
for (unsigned int s = tid; s < existing_syn; s += 32u) {
unsigned int presyn = syn_presyn[seg_base + s];
unsigned int word = prev_active_bits[presyn >> 5];
unsigned int bit = (word >> (presyn & 31u)) & 1u;
int p = (int)syn_perm[seg_base + s];
if (bit) {
int np = p + cfg.perm_inc_i16;
if (np > 32767) np = 32767;
syn_perm[seg_base + s] = (short)np;
} else {
int np = p - cfg.perm_dec_i16;
if (np < 0) np = 0;
syn_perm[seg_base + s] = (short)np;
}
}
__syncthreads();
}
// PHASE 2: Grow up to `max_new_synapses` (or room) synapses to prev_winners
// that aren't already presynaptic to this segment.
const unsigned int room = (cfg.synapses_per_segment > existing_syn)
? (cfg.synapses_per_segment - existing_syn) : 0u;
const unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
for (unsigned int w_off = 0; w_off < cfg.bits_words; w_off += 32u) {
if (shared_grown >= max_grow) break;
unsigned int widx = (start + w_off + tid) % cfg.bits_words;
unsigned int word = prev_winner_bits[widx];
while (word != 0u) {
if (shared_grown >= max_grow) break;
unsigned int bit_pos = __ffs(word) - 1u;
word &= ~(1u << bit_pos);
unsigned int cell = widx * 32u + bit_pos;
if (cell >= cfg.n_cells) continue;
// Skip if already presynaptic (O(existing_syn) scan; usually small).
bool exists = false;
for (unsigned int s = 0; s < existing_syn; s++) {
if (syn_presyn[seg_base + s] == cell) { exists = true; break; }
}
if (exists) continue;
unsigned int slot = atomicAdd(&shared_grown, 1u);
if (slot >= max_grow) break;
unsigned int write_idx = existing_syn + slot;
if (write_idx >= cfg.synapses_per_segment) break;
syn_presyn[seg_base + write_idx] = cell;
syn_perm[seg_base + write_idx] = (short)cfg.initial_perm_i16;
}
}
__syncthreads();
if (tid == 0) {
unsigned int grown = shared_grown;
if (grown > max_grow) grown = max_grow;
unsigned int new_count = existing_syn + grown;
if (new_count > cfg.synapses_per_segment) new_count = cfg.synapses_per_segment;
seg_syn_count[seg_id] = new_count;
}
}