File size: 3,889 Bytes
1c59946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
// TM predict kernel — cell-grouped launch.
//
// Grid: n_cells blocks (one per cell).
// Block: 32 threads (one warp).
//
// Each block iterates the segments owned by its cell (count in cell_seg_count[cell]).
// For each live segment, counts active connected/potential synapses against
// prev_active_bits. Updates per-segment counters, cell_predictive bit, and
// col_predicted flag.

struct TmConfig {
    unsigned int activation_threshold;
    unsigned int learning_threshold;
    unsigned int cells_per_column;
    unsigned int synapses_per_segment;
    unsigned int n_segments;
    unsigned int n_cells;
    unsigned int max_segments_per_cell;
    unsigned int max_new_synapses;
    int conn_thr_i16;
    int perm_inc_i16;
    int perm_dec_i16;
    int predicted_seg_dec_i16;
    int initial_perm_i16;
    unsigned int iter_seed;
    unsigned int n_cols;
    unsigned int bits_words;
};

extern "C" __global__
void tm_predict(
    const unsigned int * __restrict__ seg_cell_id,
    const unsigned int * __restrict__ seg_syn_count,
    const unsigned int * __restrict__ syn_presyn,
    const short        * __restrict__ syn_perm,
    const unsigned int * __restrict__ cell_active_bits,
    unsigned int       * __restrict__ cell_predictive_bits,
    unsigned char      * __restrict__ col_predicted,
    unsigned int       * __restrict__ seg_num_active_connected,
    unsigned int       * __restrict__ seg_num_active_potential,
    unsigned int       * __restrict__ col_best_match,
    const unsigned int * __restrict__ cell_seg_count,
    TmConfig             cfg
) {
    const unsigned int cell = blockIdx.x;
    if (cell >= cfg.n_cells) return;

    const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
    if (n_segs_here == 0) return;

    const unsigned int tid = threadIdx.x;
    const unsigned int col = cell / cfg.cells_per_column;
    const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;

    for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
        const unsigned int seg = seg_base_id + local_seg;
        const unsigned int n_syn = seg_syn_count[seg];
        if (n_syn == 0) {
            if (tid == 0) {
                seg_num_active_connected[seg] = 0;
                seg_num_active_potential[seg] = 0;
            }
            continue;
        }
        const unsigned int syn_base = seg * cfg.synapses_per_segment;

        unsigned int local_conn = 0;
        unsigned int local_pot = 0;
        for (unsigned int s = tid; s < n_syn; s += 32u) {
            unsigned int presyn = syn_presyn[syn_base + s];
            unsigned int word = cell_active_bits[presyn >> 5];
            unsigned int bit  = (word >> (presyn & 31u)) & 1u;
            if (bit) {
                local_pot += 1u;
                int p = (int)syn_perm[syn_base + s];
                if (p >= cfg.conn_thr_i16) {
                    local_conn += 1u;
                }
            }
        }
        for (int off = 16; off > 0; off >>= 1) {
            local_conn += __shfl_down_sync(0xffffffffu, local_conn, off);
            local_pot  += __shfl_down_sync(0xffffffffu, local_pot,  off);
        }

        if (tid == 0) {
            seg_num_active_connected[seg] = local_conn;
            seg_num_active_potential[seg] = local_pot;
            if (local_conn >= cfg.activation_threshold) {
                unsigned int word_idx = cell >> 5;
                unsigned int bit_mask = 1u << (cell & 31u);
                atomicOr(&cell_predictive_bits[word_idx], bit_mask);
                col_predicted[col] = 1;
            }
            if (local_pot >= cfg.learning_threshold) {
                unsigned int pot_c = local_pot > 2047u ? 2047u : local_pot;
                unsigned int key = (pot_c << 21) | (seg & 0x1FFFFFu);
                atomicMax(&col_best_match[col], key);
            }
        }
    }
}