icarus112's picture
Upload folder using huggingface_hub
1c59946 verified
// TM learn (reinforce correctly predicted segments) — cell-grouped launch.
//
// Grid: n_cells.
// For each cell in a predicted, SP-active column: iterate its segments.
// For each segment with num_active_connected >= activation_threshold,
// reinforce its synapses against prev_active_bits.
struct TmConfig {
unsigned int activation_threshold;
unsigned int learning_threshold;
unsigned int cells_per_column;
unsigned int synapses_per_segment;
unsigned int n_segments;
unsigned int n_cells;
unsigned int max_segments_per_cell;
unsigned int max_new_synapses;
int conn_thr_i16;
int perm_inc_i16;
int perm_dec_i16;
int predicted_seg_dec_i16;
int initial_perm_i16;
unsigned int iter_seed;
unsigned int n_cols;
unsigned int bits_words;
};
extern "C" __global__
void tm_learn_reinforce(
const unsigned int * __restrict__ seg_cell_id,
const unsigned int * __restrict__ seg_syn_count,
const unsigned int * __restrict__ syn_presyn,
short * __restrict__ syn_perm,
const unsigned int * __restrict__ seg_num_active_connected,
const unsigned int * __restrict__ prev_active_bits,
const unsigned char * __restrict__ sp_active_mask,
const unsigned char * __restrict__ col_predicted,
const unsigned int * __restrict__ cell_seg_count,
TmConfig cfg
) {
const unsigned int cell = blockIdx.x;
if (cell >= cfg.n_cells) return;
const unsigned int col = cell / cfg.cells_per_column;
if (sp_active_mask[col] == 0) return;
if (col_predicted[col] == 0) return;
const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
if (n_segs_here == 0) return;
const unsigned int tid = threadIdx.x;
const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
const unsigned int seg = seg_base_id + local_seg;
if (seg_num_active_connected[seg] < cfg.activation_threshold) continue;
const unsigned int n_syn = seg_syn_count[seg];
if (n_syn == 0) continue;
const unsigned int syn_base = seg * cfg.synapses_per_segment;
for (unsigned int s = tid; s < n_syn; s += 32u) {
unsigned int presyn = syn_presyn[syn_base + s];
unsigned int word = prev_active_bits[presyn >> 5];
unsigned int bit = (word >> (presyn & 31u)) & 1u;
int p = (int)syn_perm[syn_base + s];
if (bit) {
int np = p + cfg.perm_inc_i16;
if (np > 32767) np = 32767;
syn_perm[syn_base + s] = (short)np;
} else {
int np = p - cfg.perm_dec_i16;
if (np < 0) np = 0;
syn_perm[syn_base + s] = (short)np;
}
}
}
}