Spaces:
Runtime error
Runtime error
| // TM predict kernel — cell-grouped launch. | |
| // | |
| // Grid: n_cells blocks (one per cell). | |
| // Block: 32 threads (one warp). | |
| // | |
| // Each block iterates the segments owned by its cell (count in cell_seg_count[cell]). | |
| // For each live segment, counts active connected/potential synapses against | |
| // prev_active_bits. Updates per-segment counters, cell_predictive bit, and | |
| // col_predicted flag. | |
| struct TmConfig { | |
| unsigned int activation_threshold; | |
| unsigned int learning_threshold; | |
| unsigned int cells_per_column; | |
| unsigned int synapses_per_segment; | |
| unsigned int n_segments; | |
| unsigned int n_cells; | |
| unsigned int max_segments_per_cell; | |
| unsigned int max_new_synapses; | |
| int conn_thr_i16; | |
| int perm_inc_i16; | |
| int perm_dec_i16; | |
| int predicted_seg_dec_i16; | |
| int initial_perm_i16; | |
| unsigned int iter_seed; | |
| unsigned int n_cols; | |
| unsigned int bits_words; | |
| }; | |
| extern "C" __global__ | |
| void tm_predict( | |
| const unsigned int * __restrict__ seg_cell_id, | |
| const unsigned int * __restrict__ seg_syn_count, | |
| const unsigned int * __restrict__ syn_presyn, | |
| const short * __restrict__ syn_perm, | |
| const unsigned int * __restrict__ cell_active_bits, | |
| unsigned int * __restrict__ cell_predictive_bits, | |
| unsigned char * __restrict__ col_predicted, | |
| unsigned int * __restrict__ seg_num_active_connected, | |
| unsigned int * __restrict__ seg_num_active_potential, | |
| unsigned int * __restrict__ col_best_match, | |
| const unsigned int * __restrict__ cell_seg_count, | |
| TmConfig cfg | |
| ) { | |
| const unsigned int cell = blockIdx.x; | |
| if (cell >= cfg.n_cells) return; | |
| const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell); | |
| if (n_segs_here == 0) return; | |
| const unsigned int tid = threadIdx.x; | |
| const unsigned int col = cell / cfg.cells_per_column; | |
| const unsigned int seg_base_id = cell * cfg.max_segments_per_cell; | |
| for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) { | |
| const unsigned int seg = seg_base_id + local_seg; | |
| const unsigned int n_syn = seg_syn_count[seg]; | |
| if (n_syn == 0) { | |
| if (tid == 0) { | |
| seg_num_active_connected[seg] = 0; | |
| seg_num_active_potential[seg] = 0; | |
| } | |
| continue; | |
| } | |
| const unsigned int syn_base = seg * cfg.synapses_per_segment; | |
| unsigned int local_conn = 0; | |
| unsigned int local_pot = 0; | |
| for (unsigned int s = tid; s < n_syn; s += 32u) { | |
| unsigned int presyn = syn_presyn[syn_base + s]; | |
| unsigned int word = cell_active_bits[presyn >> 5]; | |
| unsigned int bit = (word >> (presyn & 31u)) & 1u; | |
| if (bit) { | |
| local_pot += 1u; | |
| int p = (int)syn_perm[syn_base + s]; | |
| if (p >= cfg.conn_thr_i16) { | |
| local_conn += 1u; | |
| } | |
| } | |
| } | |
| for (int off = 16; off > 0; off >>= 1) { | |
| local_conn += __shfl_down_sync(0xffffffffu, local_conn, off); | |
| local_pot += __shfl_down_sync(0xffffffffu, local_pot, off); | |
| } | |
| if (tid == 0) { | |
| seg_num_active_connected[seg] = local_conn; | |
| seg_num_active_potential[seg] = local_pot; | |
| if (local_conn >= cfg.activation_threshold) { | |
| unsigned int word_idx = cell >> 5; | |
| unsigned int bit_mask = 1u << (cell & 31u); | |
| atomicOr(&cell_predictive_bits[word_idx], bit_mask); | |
| col_predicted[col] = 1; | |
| } | |
| if (local_pot >= cfg.learning_threshold) { | |
| unsigned int pot_c = local_pot > 2047u ? 2047u : local_pot; | |
| unsigned int key = (pot_c << 21) | (seg & 0x1FFFFFu); | |
| atomicMax(&col_best_match[col], key); | |
| } | |
| } | |
| } | |
| } | |