| |
| |
| |
| |
| |
| |
|
|
| struct TmConfig { |
| unsigned int activation_threshold; |
| unsigned int learning_threshold; |
| unsigned int cells_per_column; |
| unsigned int synapses_per_segment; |
| unsigned int n_segments; |
| unsigned int n_cells; |
| unsigned int max_segments_per_cell; |
| unsigned int max_new_synapses; |
| int conn_thr_i16; |
| int perm_inc_i16; |
| int perm_dec_i16; |
| int predicted_seg_dec_i16; |
| int initial_perm_i16; |
| unsigned int iter_seed; |
| unsigned int n_cols; |
| unsigned int bits_words; |
| }; |
|
|
| extern "C" __global__ |
| void tm_learn_reinforce( |
| const unsigned int * __restrict__ seg_cell_id, |
| const unsigned int * __restrict__ seg_syn_count, |
| const unsigned int * __restrict__ syn_presyn, |
| short * __restrict__ syn_perm, |
| const unsigned int * __restrict__ seg_num_active_connected, |
| const unsigned int * __restrict__ prev_active_bits, |
| const unsigned char * __restrict__ sp_active_mask, |
| const unsigned char * __restrict__ col_predicted, |
| const unsigned int * __restrict__ cell_seg_count, |
| TmConfig cfg |
| ) { |
| const unsigned int cell = blockIdx.x; |
| if (cell >= cfg.n_cells) return; |
| const unsigned int col = cell / cfg.cells_per_column; |
| if (sp_active_mask[col] == 0) return; |
| if (col_predicted[col] == 0) return; |
|
|
| const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell); |
| if (n_segs_here == 0) return; |
|
|
| const unsigned int tid = threadIdx.x; |
| const unsigned int seg_base_id = cell * cfg.max_segments_per_cell; |
|
|
| for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) { |
| const unsigned int seg = seg_base_id + local_seg; |
| if (seg_num_active_connected[seg] < cfg.activation_threshold) continue; |
| const unsigned int n_syn = seg_syn_count[seg]; |
| if (n_syn == 0) continue; |
| const unsigned int syn_base = seg * cfg.synapses_per_segment; |
|
|
| for (unsigned int s = tid; s < n_syn; s += 32u) { |
| unsigned int presyn = syn_presyn[syn_base + s]; |
| unsigned int word = prev_active_bits[presyn >> 5]; |
| unsigned int bit = (word >> (presyn & 31u)) & 1u; |
| int p = (int)syn_perm[syn_base + s]; |
| if (bit) { |
| int np = p + cfg.perm_inc_i16; |
| if (np > 32767) np = 32767; |
| syn_perm[syn_base + s] = (short)np; |
| } else { |
| int np = p - cfg.perm_dec_i16; |
| if (np < 0) np = 0; |
| syn_perm[syn_base + s] = (short)np; |
| } |
| } |
| } |
| } |
|
|