File size: 6,228 Bytes
1c59946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// TM grow+reinforce kernel.
//
// For each bursting column:
//   If col_best_match[col] is non-zero (i.e. at least one matching segment
//     with num_active_potential >= learning_threshold exists on cells in this col):
//     Target = that matching segment.
//     Reinforce its existing synapses: +inc if presyn in prev_active, -dec otherwise.
//     Grow up to (max_new - current_syn_count) additional synapses to prev_winners.
//   Else:
//     Allocate a fresh segment slot on winner cell (cell 0 of col).
//     Grow up to max_new synapses to prev_winners (no reinforce needed — new seg).
//
// This mirrors the CPU TM burst logic.

struct TmConfig {
    unsigned int activation_threshold;
    unsigned int learning_threshold;
    unsigned int cells_per_column;
    unsigned int synapses_per_segment;
    unsigned int n_segments;
    unsigned int n_cells;
    unsigned int max_segments_per_cell;
    unsigned int max_new_synapses;
    int conn_thr_i16;
    int perm_inc_i16;
    int perm_dec_i16;
    int predicted_seg_dec_i16;
    int initial_perm_i16;
    unsigned int iter_seed;
    unsigned int n_cols;
    unsigned int bits_words;
};

extern "C" __global__
void tm_grow(
    unsigned int       * __restrict__ seg_cell_id,
    unsigned int       * __restrict__ seg_syn_count,
    unsigned int       * __restrict__ syn_presyn,
    short              * __restrict__ syn_perm,
    unsigned int       * __restrict__ cell_seg_count,
    const unsigned int * __restrict__ burst_cols_flat,
    const unsigned int * __restrict__ burst_cols_count,
    const unsigned int * __restrict__ prev_winner_bits,
    const unsigned int * __restrict__ prev_active_bits,
    const unsigned int * __restrict__ col_best_match,
    TmConfig             cfg
) {
    const unsigned int b = blockIdx.x;
    const unsigned int n_burst_cols = burst_cols_count[0];
    if (b >= n_burst_cols) return;
    const unsigned int tid = threadIdx.x;

    const unsigned int col = burst_cols_flat[b];

    __shared__ unsigned int shared_seg_id;
    __shared__ unsigned int shared_existing_syn_count;
    __shared__ unsigned int shared_grown;
    __shared__ unsigned int shared_is_new;
    __shared__ unsigned int shared_start_offset;

    if (tid == 0) {
        unsigned int match_key = col_best_match[col];
        if (match_key != 0u) {
            // Reuse matching segment.
            unsigned int seg_id = match_key & 0x1FFFFFu;
            shared_seg_id = seg_id;
            shared_existing_syn_count = seg_syn_count[seg_id];
            shared_is_new = 0u;
        } else {
            // Allocate new segment on winner cell (cell 0 of col).
            unsigned int winner_cell = col * cfg.cells_per_column;
            unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
            if (slot >= cfg.max_segments_per_cell) {
                slot = slot % cfg.max_segments_per_cell;
            }
            unsigned int seg_id = winner_cell * cfg.max_segments_per_cell + slot;
            seg_cell_id[seg_id] = winner_cell;
            seg_syn_count[seg_id] = 0;
            shared_seg_id = seg_id;
            shared_existing_syn_count = 0u;
            shared_is_new = 1u;
        }
        shared_grown = 0u;
        shared_start_offset = (b * 2654435761u + cfg.iter_seed) % cfg.bits_words;
    }
    __syncthreads();

    const unsigned int seg_id = shared_seg_id;
    const unsigned int seg_base = seg_id * cfg.synapses_per_segment;
    const unsigned int existing_syn = shared_existing_syn_count;
    const unsigned int is_new = shared_is_new;
    const unsigned int start = shared_start_offset;

    // PHASE 1: If reusing, reinforce existing synapses.
    if (!is_new) {
        for (unsigned int s = tid; s < existing_syn; s += 32u) {
            unsigned int presyn = syn_presyn[seg_base + s];
            unsigned int word = prev_active_bits[presyn >> 5];
            unsigned int bit = (word >> (presyn & 31u)) & 1u;
            int p = (int)syn_perm[seg_base + s];
            if (bit) {
                int np = p + cfg.perm_inc_i16;
                if (np > 32767) np = 32767;
                syn_perm[seg_base + s] = (short)np;
            } else {
                int np = p - cfg.perm_dec_i16;
                if (np < 0) np = 0;
                syn_perm[seg_base + s] = (short)np;
            }
        }
        __syncthreads();
    }

    // PHASE 2: Grow up to `max_new_synapses` (or room) synapses to prev_winners
    // that aren't already presynaptic to this segment.
    const unsigned int room = (cfg.synapses_per_segment > existing_syn)
        ? (cfg.synapses_per_segment - existing_syn) : 0u;
    const unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;

    for (unsigned int w_off = 0; w_off < cfg.bits_words; w_off += 32u) {
        if (shared_grown >= max_grow) break;
        unsigned int widx = (start + w_off + tid) % cfg.bits_words;
        unsigned int word = prev_winner_bits[widx];
        while (word != 0u) {
            if (shared_grown >= max_grow) break;
            unsigned int bit_pos = __ffs(word) - 1u;
            word &= ~(1u << bit_pos);
            unsigned int cell = widx * 32u + bit_pos;
            if (cell >= cfg.n_cells) continue;

            // Skip if already presynaptic (O(existing_syn) scan; usually small).
            bool exists = false;
            for (unsigned int s = 0; s < existing_syn; s++) {
                if (syn_presyn[seg_base + s] == cell) { exists = true; break; }
            }
            if (exists) continue;

            unsigned int slot = atomicAdd(&shared_grown, 1u);
            if (slot >= max_grow) break;
            unsigned int write_idx = existing_syn + slot;
            if (write_idx >= cfg.synapses_per_segment) break;
            syn_presyn[seg_base + write_idx] = cell;
            syn_perm[seg_base + write_idx] = (short)cfg.initial_perm_i16;
        }
    }
    __syncthreads();

    if (tid == 0) {
        unsigned int grown = shared_grown;
        if (grown > max_grow) grown = max_grow;
        unsigned int new_count = existing_syn + grown;
        if (new_count > cfg.synapses_per_segment) new_count = cfg.synapses_per_segment;
        seg_syn_count[seg_id] = new_count;
    }
}