icarus112's picture
Upload folder using huggingface_hub
1c59946 verified
// TM activate kernel. See tm_predict.cu for TmConfig.
struct TmConfig {
unsigned int activation_threshold;
unsigned int learning_threshold;
unsigned int cells_per_column;
unsigned int synapses_per_segment;
unsigned int n_segments;
unsigned int n_cells;
unsigned int max_segments_per_cell;
unsigned int max_new_synapses;
int conn_thr_i16;
int perm_inc_i16;
int perm_dec_i16;
int predicted_seg_dec_i16;
int initial_perm_i16;
unsigned int iter_seed;
unsigned int n_cols;
unsigned int bits_words;
};
extern "C" __global__
void tm_activate(
const unsigned char * __restrict__ sp_active_mask,
const unsigned char * __restrict__ col_predicted,
const unsigned int * __restrict__ cell_predictive_bits,
unsigned int * __restrict__ cell_active_bits,
unsigned int * __restrict__ cell_winner_bits,
unsigned int * __restrict__ unpredicted_count,
unsigned int * __restrict__ burst_cols_flat,
unsigned int * __restrict__ burst_cols_count,
TmConfig cfg
) {
unsigned int col = blockIdx.x * blockDim.x + threadIdx.x;
if (col >= cfg.n_cols) return;
if (sp_active_mask[col] == 0) return;
unsigned int base_cell = col * cfg.cells_per_column;
if (col_predicted[col]) {
for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
unsigned int cell = base_cell + k;
unsigned int word_idx = cell >> 5;
unsigned int bit_mask = 1u << (cell & 31u);
unsigned int pred_word = cell_predictive_bits[word_idx];
if (pred_word & bit_mask) {
atomicOr(&cell_active_bits[word_idx], bit_mask);
atomicOr(&cell_winner_bits[word_idx], bit_mask);
}
}
} else {
atomicAdd(unpredicted_count, 1u);
for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
unsigned int cell = base_cell + k;
unsigned int word_idx = cell >> 5;
unsigned int bit_mask = 1u << (cell & 31u);
atomicOr(&cell_active_bits[word_idx], bit_mask);
}
unsigned int winner = base_cell;
unsigned int word_idx = winner >> 5;
unsigned int bit_mask = 1u << (winner & 31u);
atomicOr(&cell_winner_bits[word_idx], bit_mask);
unsigned int slot = atomicAdd(burst_cols_count, 1u);
burst_cols_flat[slot] = col;
}
}