File size: 6,228 Bytes
1c59946 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | // TM grow+reinforce kernel.
//
// For each bursting column:
// If col_best_match[col] is non-zero (i.e. at least one matching segment
// with num_active_potential >= learning_threshold exists on cells in this col):
// Target = that matching segment.
// Reinforce its existing synapses: +inc if presyn in prev_active, -dec otherwise.
// Grow up to (max_new - current_syn_count) additional synapses to prev_winners.
// Else:
// Allocate a fresh segment slot on winner cell (cell 0 of col).
// Grow up to max_new synapses to prev_winners (no reinforce needed — new seg).
//
// This mirrors the CPU TM burst logic.
struct TmConfig {
unsigned int activation_threshold;
unsigned int learning_threshold;
unsigned int cells_per_column;
unsigned int synapses_per_segment;
unsigned int n_segments;
unsigned int n_cells;
unsigned int max_segments_per_cell;
unsigned int max_new_synapses;
int conn_thr_i16;
int perm_inc_i16;
int perm_dec_i16;
int predicted_seg_dec_i16;
int initial_perm_i16;
unsigned int iter_seed;
unsigned int n_cols;
unsigned int bits_words;
};
extern "C" __global__
void tm_grow(
unsigned int * __restrict__ seg_cell_id,
unsigned int * __restrict__ seg_syn_count,
unsigned int * __restrict__ syn_presyn,
short * __restrict__ syn_perm,
unsigned int * __restrict__ cell_seg_count,
const unsigned int * __restrict__ burst_cols_flat,
const unsigned int * __restrict__ burst_cols_count,
const unsigned int * __restrict__ prev_winner_bits,
const unsigned int * __restrict__ prev_active_bits,
const unsigned int * __restrict__ col_best_match,
TmConfig cfg
) {
const unsigned int b = blockIdx.x;
const unsigned int n_burst_cols = burst_cols_count[0];
if (b >= n_burst_cols) return;
const unsigned int tid = threadIdx.x;
const unsigned int col = burst_cols_flat[b];
__shared__ unsigned int shared_seg_id;
__shared__ unsigned int shared_existing_syn_count;
__shared__ unsigned int shared_grown;
__shared__ unsigned int shared_is_new;
__shared__ unsigned int shared_start_offset;
if (tid == 0) {
unsigned int match_key = col_best_match[col];
if (match_key != 0u) {
// Reuse matching segment.
unsigned int seg_id = match_key & 0x1FFFFFu;
shared_seg_id = seg_id;
shared_existing_syn_count = seg_syn_count[seg_id];
shared_is_new = 0u;
} else {
// Allocate new segment on winner cell (cell 0 of col).
unsigned int winner_cell = col * cfg.cells_per_column;
unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
if (slot >= cfg.max_segments_per_cell) {
slot = slot % cfg.max_segments_per_cell;
}
unsigned int seg_id = winner_cell * cfg.max_segments_per_cell + slot;
seg_cell_id[seg_id] = winner_cell;
seg_syn_count[seg_id] = 0;
shared_seg_id = seg_id;
shared_existing_syn_count = 0u;
shared_is_new = 1u;
}
shared_grown = 0u;
shared_start_offset = (b * 2654435761u + cfg.iter_seed) % cfg.bits_words;
}
__syncthreads();
const unsigned int seg_id = shared_seg_id;
const unsigned int seg_base = seg_id * cfg.synapses_per_segment;
const unsigned int existing_syn = shared_existing_syn_count;
const unsigned int is_new = shared_is_new;
const unsigned int start = shared_start_offset;
// PHASE 1: If reusing, reinforce existing synapses.
if (!is_new) {
for (unsigned int s = tid; s < existing_syn; s += 32u) {
unsigned int presyn = syn_presyn[seg_base + s];
unsigned int word = prev_active_bits[presyn >> 5];
unsigned int bit = (word >> (presyn & 31u)) & 1u;
int p = (int)syn_perm[seg_base + s];
if (bit) {
int np = p + cfg.perm_inc_i16;
if (np > 32767) np = 32767;
syn_perm[seg_base + s] = (short)np;
} else {
int np = p - cfg.perm_dec_i16;
if (np < 0) np = 0;
syn_perm[seg_base + s] = (short)np;
}
}
__syncthreads();
}
// PHASE 2: Grow up to `max_new_synapses` (or room) synapses to prev_winners
// that aren't already presynaptic to this segment.
const unsigned int room = (cfg.synapses_per_segment > existing_syn)
? (cfg.synapses_per_segment - existing_syn) : 0u;
const unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
for (unsigned int w_off = 0; w_off < cfg.bits_words; w_off += 32u) {
if (shared_grown >= max_grow) break;
unsigned int widx = (start + w_off + tid) % cfg.bits_words;
unsigned int word = prev_winner_bits[widx];
while (word != 0u) {
if (shared_grown >= max_grow) break;
unsigned int bit_pos = __ffs(word) - 1u;
word &= ~(1u << bit_pos);
unsigned int cell = widx * 32u + bit_pos;
if (cell >= cfg.n_cells) continue;
// Skip if already presynaptic (O(existing_syn) scan; usually small).
bool exists = false;
for (unsigned int s = 0; s < existing_syn; s++) {
if (syn_presyn[seg_base + s] == cell) { exists = true; break; }
}
if (exists) continue;
unsigned int slot = atomicAdd(&shared_grown, 1u);
if (slot >= max_grow) break;
unsigned int write_idx = existing_syn + slot;
if (write_idx >= cfg.synapses_per_segment) break;
syn_presyn[seg_base + write_idx] = cell;
syn_perm[seg_base + write_idx] = (short)cfg.initial_perm_i16;
}
}
__syncthreads();
if (tid == 0) {
unsigned int grown = shared_grown;
if (grown > max_grow) grown = max_grow;
unsigned int new_count = existing_syn + grown;
if (new_count > cfg.synapses_per_segment) new_count = cfg.synapses_per_segment;
seg_syn_count[seg_id] = new_count;
}
}
|