catalyst-n1 / sdk /neurocore /gpu_simulator.py
mrwabbit's picture
Initial upload: Catalyst N1 open source neuromorphic processor RTL
e4cdd5f verified
"""GPU-accelerated LIF simulator using PyTorch sparse tensors.
Matches the cycle-accurate behavior of simulator.py but runs on CUDA GPU,
achieving 100-1000x speedup for large networks (4K-32K neurons).
All neuron state stored as dense int32 tensors on GPU.
Connectivity stored as sparse CSR float32 matrices: W @ spike_vec = current.
"""
import torch
import numpy as np
from collections import defaultdict
from .backend import Backend
from .compiler import Compiler, CompiledNetwork
from .network import Network, Population, PopulationSlice
from .constants import (
MAX_CORES, NEURONS_PER_CORE, GRADE_SHIFT,
TRACE_MAX, LEARN_SHIFT,
WEIGHT_MAX_STDP, WEIGHT_MIN_STDP,
REWARD_SHIFT, ELIG_DECAY_SHIFT, ELIG_MAX,
DEFAULT_THRESHOLD, DEFAULT_LEAK, DEFAULT_RESTING, DEFAULT_REFRAC,
DEFAULT_DEND_THRESHOLD, DEFAULT_NOISE_CONFIG, DEFAULT_TAU1, DEFAULT_TAU2,
NOISE_LFSR_SEED, NOISE_LFSR_TAPS,
DELAY_QUEUE_BUCKETS,
)
from .microcode import (
execute_program, R_TRACE1, R_TRACE2, R_WEIGHT, R_ELIG, R_CONST,
R_TEMP0, R_TEMP1, R_REWARD, LTD_START, LTD_END, LTP_START, LTP_END,
)
from .exceptions import NeurocoreError
class GpuSimulator(Backend):
"""GPU-accelerated LIF simulator using PyTorch CUDA tensors."""
def __init__(self, device=None):
if device is None:
if torch.cuda.is_available():
# Prefer GPU 1 (20GB 3080) if available, else GPU 0
device = torch.device("cuda:1" if torch.cuda.device_count() > 1 else "cuda:0")
else:
device = torch.device("cpu")
self.device = device
self._compiler = Compiler()
self._compiled = None
self._n = 0
self._timestep_count = 0
# Neuron state tensors (set by deploy)
self._potential = None
self._refrac = None
self._trace = None
self._trace2 = None
self._ext_current = None
# Per-neuron parameter tensors
self._threshold = None
self._leak = None
self._resting = None
self._refrac_period = None
self._dend_threshold = None
self._noise_config = None
self._tau1 = None
self._tau2 = None
self._lfsr = None
# Sparse weight matrices (CSR, float32, shape (N, N))
# Convention: W[target, source] so W @ spike_vec = accumulated current
self._W_soma = None # compartment 0, delay=0
self._W_dend = [None] * 3 # compartments 1-3, delay=0
# Delay structures
self._has_delays = False
self._delay_buf_soma = None # (64, N) ring buffer
self._delay_buf_dend = None # (3, 64, N) ring buffer
self._delay_src_ids = None # (num_delayed,) source neuron indices
self._delay_tgt_ids = None # (num_delayed,) target neuron indices
self._delay_weights = None # (num_delayed,) weight values
self._delay_comps = None # (num_delayed,) compartment IDs
self._delay_values = None # (num_delayed,) delay tick values
# Spike vectors
self._prev_spike_vec = None # (N,) float32 - payload from previous timestep
self._spike_mask = None # (N,) bool - who spiked this timestep
# Config flags
self._learn_enable = False
self._graded_enable = False
self._dendritic_enable = False
self._three_factor_enable = False
self._noise_enable = False
# Learning state
self._learning_rule = None
self._elig_crow = None # CSR row pointers for eligibility
self._elig_col = None # CSR column indices
self._elig_vals = None # eligibility values (same sparsity as W_soma)
self._reward_value = 0
self._reward_pending = False
# STDP mask: bool tensor over CSR values (True = learnable)
self._stdp_mask = None # None means all connections learnable
# CSR structure cache for STDP (avoids recomputing each timestep)
self._soma_crow = None
self._soma_col = None
self._soma_row_idx = None # expanded row indices (nnz,)
# CPU-side adjacency for microcode fallback and weight export
self._adjacency = None
def deploy(self, network_or_compiled):
"""Compile (if needed) and initialize GPU state."""
if isinstance(network_or_compiled, Network):
self._compiled = self._compiler.compile(network_or_compiled)
elif isinstance(network_or_compiled, CompiledNetwork):
self._compiled = network_or_compiled
else:
raise TypeError(f"Expected Network or CompiledNetwork, got {type(network_or_compiled)}")
n = self._compiled.placement.total_neurons
self._n = n
dev = self.device
# Initialize neuron state tensors
self._potential = torch.zeros(n, dtype=torch.int32, device=dev)
self._refrac = torch.zeros(n, dtype=torch.int32, device=dev)
self._trace = torch.zeros(n, dtype=torch.int32, device=dev)
self._trace2 = torch.zeros(n, dtype=torch.int32, device=dev)
self._ext_current = torch.zeros(n, dtype=torch.int32, device=dev)
# Per-neuron parameters
self._threshold = torch.full((n,), DEFAULT_THRESHOLD, dtype=torch.int32, device=dev)
self._leak = torch.full((n,), DEFAULT_LEAK, dtype=torch.int32, device=dev)
self._resting = torch.full((n,), DEFAULT_RESTING, dtype=torch.int32, device=dev)
self._refrac_period = torch.full((n,), DEFAULT_REFRAC, dtype=torch.int32, device=dev)
self._dend_threshold = torch.full((n,), DEFAULT_DEND_THRESHOLD, dtype=torch.int32, device=dev)
self._noise_config = torch.full((n,), DEFAULT_NOISE_CONFIG, dtype=torch.int32, device=dev)
self._tau1 = torch.full((n,), DEFAULT_TAU1, dtype=torch.int32, device=dev)
self._tau2 = torch.full((n,), DEFAULT_TAU2, dtype=torch.int32, device=dev)
# LFSR seeds: advance per-neuron for unique starting states
lfsr_seeds = np.zeros(n, dtype=np.int32)
lfsr = NOISE_LFSR_SEED
for gid in range(n):
lfsr_seeds[gid] = lfsr
bit = lfsr & 1
lfsr >>= 1
if bit:
lfsr ^= NOISE_LFSR_TAPS
self._lfsr = torch.from_numpy(lfsr_seeds).to(dev)
# Apply per-neuron parameter overrides
for gid, params in self._compiled.neuron_params.items():
if gid < n:
self._threshold[gid] = params.threshold
self._leak[gid] = params.leak
self._resting[gid] = params.resting
self._refrac_period[gid] = params.refrac
self._dend_threshold[gid] = params.dend_threshold
self._noise_config[gid] = params.noise_config
self._tau1[gid] = params.tau1
self._tau2[gid] = params.tau2
# Build sparse weight matrices from adjacency
self._adjacency = dict(self._compiled.adjacency)
self._build_weight_matrices(n)
# Apply learn config
cfg = self._compiled.learn_config
self._learn_enable = cfg.get("learn_enable", False)
self._graded_enable = cfg.get("graded_enable", False)
self._dendritic_enable = cfg.get("dendritic_enable", False)
self._noise_enable = cfg.get("noise_enable", False)
# P19 learning rule
self._learning_rule = self._compiled.learning_rule
# Spike vectors
self._prev_spike_vec = torch.zeros(n, dtype=torch.float32, device=dev)
# Learning state
self._reward_value = 0
self._reward_pending = False
# Initialize eligibility with same sparsity as W_soma
if self._W_soma is not None and self._W_soma._nnz() > 0:
self._elig_crow = self._soma_crow
self._elig_col = self._soma_col
self._elig_vals = torch.zeros(self._W_soma._nnz(), dtype=torch.float32, device=dev)
else:
self._elig_vals = None
self._timestep_count = 0
def _build_weight_matrices(self, n):
"""Build sparse CSR weight matrices from adjacency dict."""
dev = self.device
# Collect COO triplets per compartment, split by delay
rows_imm = [[] for _ in range(4)] # immediate (delay=0)
cols_imm = [[] for _ in range(4)]
vals_imm = [[] for _ in range(4)]
delay_srcs, delay_tgts, delay_wts, delay_comps, delay_vals = [], [], [], [], []
for src_gid, targets in self._adjacency.items():
for entry in targets:
tgt_gid, weight, comp = entry[0], entry[1], entry[2]
delay = entry[3] if len(entry) > 3 else 0
if tgt_gid >= n:
continue
if delay > 0:
delay_srcs.append(src_gid)
delay_tgts.append(tgt_gid)
delay_wts.append(float(weight))
delay_comps.append(comp)
delay_vals.append(delay)
else:
rows_imm[comp].append(tgt_gid)
cols_imm[comp].append(src_gid)
vals_imm[comp].append(float(weight))
# Build CSR for each compartment (immediate delivery)
def _build_csr(rows, cols, vals):
if not rows:
return torch.sparse_csr_tensor(
torch.zeros(n + 1, dtype=torch.int32),
torch.tensor([], dtype=torch.int32),
torch.tensor([], dtype=torch.float32),
size=(n, n),
).to(dev)
indices = torch.tensor([rows, cols], dtype=torch.int64)
values = torch.tensor(vals, dtype=torch.float32)
coo = torch.sparse_coo_tensor(indices, values, (n, n))
# Coalesce to sum duplicates (same src->tgt with different entries)
coo = coo.coalesce()
return coo.to_sparse_csr().to(dev)
self._W_soma = _build_csr(rows_imm[0], cols_imm[0], vals_imm[0])
for d in range(3):
self._W_dend[d] = _build_csr(rows_imm[d + 1], cols_imm[d + 1], vals_imm[d + 1])
# Cache CSR structure for STDP
self._soma_crow = self._W_soma.crow_indices()
self._soma_col = self._W_soma.col_indices()
if self._W_soma._nnz() > 0:
self._soma_row_idx = torch.repeat_interleave(
torch.arange(n, device=dev),
self._soma_crow[1:] - self._soma_crow[:-1]
)
else:
self._soma_row_idx = torch.tensor([], dtype=torch.int64, device=dev)
# Build delay structures
if delay_srcs:
self._has_delays = True
self._delay_src_ids = torch.tensor(delay_srcs, dtype=torch.int64, device=dev)
self._delay_tgt_ids = torch.tensor(delay_tgts, dtype=torch.int64, device=dev)
self._delay_weights = torch.tensor(delay_wts, dtype=torch.float32, device=dev)
self._delay_comps = torch.tensor(delay_comps, dtype=torch.int64, device=dev)
self._delay_values = torch.tensor(delay_vals, dtype=torch.int64, device=dev)
self._delay_buf_soma = torch.zeros(DELAY_QUEUE_BUCKETS, n, dtype=torch.float32, device=dev)
self._delay_buf_dend = torch.zeros(3, DELAY_QUEUE_BUCKETS, n, dtype=torch.float32, device=dev)
else:
self._has_delays = False
def inject(self, target, current):
"""Set external stimulus current for specified neurons."""
if self._compiled is None:
raise NeurocoreError("No network deployed. Call deploy() first.")
resolved = self._resolve_targets(target)
for core, neuron in resolved:
gid = core * NEURONS_PER_CORE + neuron
if gid < self._n:
self._ext_current[gid] = current
def reward(self, value):
"""Set reward signal for 3-factor learning."""
self._reward_value = int(value)
self._reward_pending = True
def run(self, timesteps):
"""Execute timesteps on GPU and return RunResult."""
from .result import RunResult
if self._compiled is None:
raise NeurocoreError("No network deployed. Call deploy() first.")
if getattr(self, '_async_enable', False):
raise NeurocoreError("Async mode not supported on GPU simulator. Use sync mode.")
return self._run_sync(timesteps)
@torch.no_grad()
def _run_sync(self, timesteps):
"""Synchronous GPU execution: all neurons updated every timestep."""
from .result import RunResult
n = self._n
dev = self.device
spike_trains = defaultdict(list)
total_spikes = 0
# Pre-allocate accumulators
acc_soma = torch.zeros(n, dtype=torch.float32, device=dev)
acc_dend = [torch.zeros(n, dtype=torch.float32, device=dev) for _ in range(3)]
zero_f = torch.zeros(n, dtype=torch.float32, device=dev)
for t in range(timesteps):
acc_soma.zero_()
for d in range(3):
acc_dend[d].zero_()
if self._has_delays:
bucket = self._timestep_count % DELAY_QUEUE_BUCKETS
acc_soma.add_(self._delay_buf_soma[bucket])
self._delay_buf_soma[bucket].zero_()
for d in range(3):
acc_dend[d].add_(self._delay_buf_dend[d, bucket])
self._delay_buf_dend[d, bucket].zero_()
if self._prev_spike_vec.any():
spike_col = self._prev_spike_vec.unsqueeze(1) # (N, 1)
if self._graded_enable:
# Graded: result = (W @ payload_vec) / 128
raw = torch.sparse.mm(self._W_soma, spike_col).squeeze(1)
acc_soma.add_(torch.div(raw, 128, rounding_mode='trunc'))
if self._dendritic_enable:
for d in range(3):
raw_d = torch.sparse.mm(self._W_dend[d], spike_col).squeeze(1)
acc_dend[d].add_(torch.div(raw_d, 128, rounding_mode='trunc'))
else:
# Binary: result = W @ spike_binary (spike_vec has value 128 for binary)
# But we stored actual weights in W, not weight*128.
# CPU sim uses: delivered = weight (when not graded)
# Our spike_vec has payload=128 for non-graded. We need:
# delivered = weight, so we need W @ binary_spike_vec
binary_vec = (self._prev_spike_vec > 0).float().unsqueeze(1)
acc_soma.add_(torch.sparse.mm(self._W_soma, binary_vec).squeeze(1))
if self._dendritic_enable:
for d in range(3):
acc_dend[d].add_(torch.sparse.mm(self._W_dend[d], binary_vec).squeeze(1))
# Delayed connections
if self._has_delays:
self._deliver_delayed()
# Add external current
acc_soma.add_(self._ext_current.float())
spike_vec, spike_mask = self._update_neurons_gpu(acc_soma, acc_dend)
# Record spikes (small GPU->CPU transfer)
if spike_mask.any():
spiking_ids = spike_mask.nonzero(as_tuple=True)[0].cpu().numpy()
total_spikes += len(spiking_ids)
for gid in spiking_ids:
spike_trains[int(gid)].append(t)
if self._learn_enable:
if self._three_factor_enable:
self._elig_update_gpu(spike_mask)
if self._reward_pending:
self._reward_apply_gpu()
self._reward_pending = False
self._elig_decay_gpu()
else:
self._stdp_update_gpu(spike_mask)
self._prev_spike_vec = spike_vec.clone()
self._ext_current.zero_()
self._timestep_count += 1
# Update adjacency from GPU weights (for weight export / subsequent runs)
if self._learn_enable:
self._sync_weights_to_adjacency()
return RunResult(
total_spikes=total_spikes,
timesteps=timesteps,
spike_trains=dict(spike_trains),
placement=self._compiled.placement,
backend="gpu_simulator",
)
@torch.no_grad()
def run_with_schedule(self, schedule, rest_steps=0, sync_weights=True):
"""Run timesteps with pre-computed per-timestep stimulus, returning spike counts.
This is much faster than calling inject()+run(1) in a Python loop because:
- No Python→GPU per-timestep injection overhead
- Spike counts accumulated on GPU (no per-timestep CPU transfer)
Args:
schedule: torch.Tensor of shape (T, N), int32, on self.device.
schedule[t, gid] = external current for neuron gid at timestep t.
rest_steps: additional timesteps to run after schedule with no stimulus.
sync_weights: if True (default), sync GPU weights back to adjacency dict
after run. Set False during training loops for performance, then
call _sync_weights_to_adjacency() manually when needed.
Returns:
(spike_counts, total_spikes) where spike_counts is a (N,) int32 numpy
array of per-neuron spike counts across all timesteps.
"""
if self._compiled is None:
raise NeurocoreError("No network deployed. Call deploy() first.")
n = self._n
dev = self.device
total_timesteps = schedule.shape[0] + rest_steps
# Accumulate spike counts on GPU — no per-timestep CPU transfer
spike_counts = torch.zeros(n, dtype=torch.int32, device=dev)
total_spikes = 0
# Pre-allocate accumulators
acc_soma = torch.zeros(n, dtype=torch.float32, device=dev)
acc_dend = [torch.zeros(n, dtype=torch.float32, device=dev) for _ in range(3)]
for t in range(total_timesteps):
acc_soma.zero_()
for d in range(3):
acc_dend[d].zero_()
if self._has_delays:
bucket = self._timestep_count % DELAY_QUEUE_BUCKETS
acc_soma.add_(self._delay_buf_soma[bucket])
self._delay_buf_soma[bucket].zero_()
for d in range(3):
acc_dend[d].add_(self._delay_buf_dend[d, bucket])
self._delay_buf_dend[d, bucket].zero_()
# Spike delivery
if self._prev_spike_vec.any():
spike_col = self._prev_spike_vec.unsqueeze(1)
if self._graded_enable:
raw = torch.sparse.mm(self._W_soma, spike_col).squeeze(1)
acc_soma.add_(torch.div(raw, 128, rounding_mode='trunc'))
if self._dendritic_enable:
for d in range(3):
raw_d = torch.sparse.mm(self._W_dend[d], spike_col).squeeze(1)
acc_dend[d].add_(torch.div(raw_d, 128, rounding_mode='trunc'))
else:
binary_vec = (self._prev_spike_vec > 0).float().unsqueeze(1)
acc_soma.add_(torch.sparse.mm(self._W_soma, binary_vec).squeeze(1))
if self._dendritic_enable:
for d in range(3):
acc_dend[d].add_(torch.sparse.mm(self._W_dend[d], binary_vec).squeeze(1))
if self._has_delays:
self._deliver_delayed()
# Add scheduled stimulus (or zero during rest)
if t < schedule.shape[0]:
acc_soma.add_(schedule[t].float())
# Neuron update
spike_vec, spike_mask = self._update_neurons_gpu(acc_soma, acc_dend)
# Accumulate counts on GPU (no CPU transfer!)
spike_counts.add_(spike_mask.int())
# STDP learning
if self._learn_enable:
if self._three_factor_enable:
self._elig_update_gpu(spike_mask)
if self._reward_pending:
self._reward_apply_gpu()
self._reward_pending = False
self._elig_decay_gpu()
else:
self._stdp_update_gpu(spike_mask)
self._prev_spike_vec = spike_vec.clone()
self._timestep_count += 1
# Sync weights after learning (can be deferred for performance)
if self._learn_enable and sync_weights:
self._sync_weights_to_adjacency()
counts_np = spike_counts.cpu().numpy()
return counts_np, int(spike_counts.sum().item())
def _deliver_delayed(self):
"""Scatter delayed spike currents into future ring buffer buckets."""
# Find which delayed synapses have spiking sources
if self._graded_enable:
src_payloads = self._prev_spike_vec[self._delay_src_ids]
else:
src_payloads = (self._prev_spike_vec > 0).float()
src_payloads = src_payloads[self._delay_src_ids]
active = src_payloads > 0
if not active.any():
return
tgts = self._delay_tgt_ids[active]
weights = self._delay_weights[active]
comps = self._delay_comps[active]
delays = self._delay_values[active]
if self._graded_enable:
payloads = src_payloads[active]
delivered = torch.div(weights * payloads, 128, rounding_mode='trunc')
else:
delivered = weights
buckets = (self._timestep_count + delays) % DELAY_QUEUE_BUCKETS
# Scatter by compartment
soma_mask = comps == 0
if soma_mask.any():
self._delay_buf_soma.index_put_(
(buckets[soma_mask], tgts[soma_mask]),
delivered[soma_mask], accumulate=True)
for d in range(3):
d_mask = comps == (d + 1)
if d_mask.any():
self._delay_buf_dend[d].index_put_(
(buckets[d_mask], tgts[d_mask]),
delivered[d_mask], accumulate=True)
def _update_neurons_gpu(self, acc_soma, acc_dend):
"""Vectorized LIF update for all neurons simultaneously.
Returns:
spike_vec: (N,) float32 - payload values for spiking neurons, 0 elsewhere
spike_mask: (N,) bool - which neurons spiked
"""
n = self._n
dev = self.device
# Dendritic compartment thresholding
total_input = acc_soma.int()
if self._dendritic_enable:
dthr = self._dend_threshold
for d in range(3):
dval = acc_dend[d].int()
excess = dval - dthr
total_input = total_input + torch.where(excess > 0, excess, torch.zeros_like(excess))
# P14 Noise: vectorized LFSR advance + threshold perturbation
threshold = self._threshold.clone()
if self._noise_enable:
threshold = self._apply_noise(threshold)
potential = self._potential
refrac = self._refrac
leak = self._leak
resting = self._resting
# Compute conditions for all neurons simultaneously
in_refrac = refrac > 0
v_plus_input = potential + total_input
v_minus_leak = v_plus_input - leak
above_thresh = (~in_refrac) & (v_minus_leak >= threshold)
above_leak = (~in_refrac) & (~above_thresh) & (v_plus_input > leak)
below_leak = (~in_refrac) & (~above_thresh) & (~above_leak)
# Branch 1: Refractory — reset potential, decrement counter, decay traces
self._potential = torch.where(in_refrac, resting, self._potential)
self._refrac = torch.where(in_refrac, refrac - 1, self._refrac)
# Spike: reset, enter refractory, set traces to max
excess = v_minus_leak - threshold
payload = torch.clamp(excess, min=1, max=255)
self._potential = torch.where(above_thresh, resting, self._potential)
self._refrac = torch.where(above_thresh, self._refrac_period, self._refrac)
trace_max_t = torch.full_like(self._trace, TRACE_MAX)
self._trace = torch.where(above_thresh, trace_max_t, self._trace)
self._trace2 = torch.where(above_thresh, trace_max_t, self._trace2)
# Branch 3: Integrate — accumulate input
self._potential = torch.where(above_leak, v_minus_leak, self._potential)
# Branch 4: Below leak — reset to resting
self._potential = torch.where(below_leak, resting, self._potential)
# Trace decay for non-spiking neurons (P15 dual traces)
non_spiking = ~above_thresh
self._trace = torch.where(non_spiking,
self._decay_trace_vec(self._trace, self._tau1),
self._trace)
self._trace2 = torch.where(non_spiking,
self._decay_trace_vec(self._trace2, self._tau2),
self._trace2)
# Build spike vector
if self._graded_enable:
spike_vec = torch.where(above_thresh, payload.float(),
torch.zeros(n, dtype=torch.float32, device=dev))
else:
spike_vec = torch.where(above_thresh,
torch.full((n,), 128.0, dtype=torch.float32, device=dev),
torch.zeros(n, dtype=torch.float32, device=dev))
return spike_vec, above_thresh
def _decay_trace_vec(self, trace, tau):
"""Vectorized P15 exponential trace decay with min-step-1 guarantee."""
positive = trace > 0
decay = torch.max(torch.ones_like(trace), trace >> tau)
new_trace = torch.clamp(trace - decay, min=0)
return torch.where(positive, new_trace, trace)
def _apply_noise(self, threshold):
"""Vectorized P14 LFSR advance and threshold perturbation."""
# Advance Galois LFSR: bit = lfsr & 1; lfsr >>= 1; if bit: lfsr ^= taps
lfsr = self._lfsr
bit = lfsr & 1
lfsr_shifted = lfsr >> 1
lfsr_xored = lfsr_shifted ^ NOISE_LFSR_TAPS
self._lfsr = torch.where(bit.bool(), lfsr_xored, lfsr_shifted)
mantissa = self._noise_config & 0x0F
exponent = (self._noise_config >> 4) & 0x0F
has_noise = mantissa > 0
noise_mask = mantissa << exponent
noise_val = (self._lfsr & noise_mask) - (noise_mask >> 1)
return torch.where(has_noise, threshold + noise_val, threshold)
def _stdp_update_gpu(self, spike_mask):
"""Vectorized 2-factor STDP using CSR structure."""
if self._learning_rule is not None:
self._microcode_learn_gpu(spike_mask, three_factor=False)
return
if not spike_mask.any() or self._W_soma._nnz() == 0:
return
spike_f = spike_mask.float()
crow = self._soma_crow
col = self._soma_col
row_idx = self._soma_row_idx
val = self._W_soma.values().clone()
trace_shifted = (self._trace >> LEARN_SHIFT).float()
zero = torch.zeros_like(val)
# LTD: source spiked → weight -= post_trace[target] >> 3
ltd_active = spike_f[col] > 0
ltd_delta = trace_shifted[row_idx]
delta_ltd = torch.where(ltd_active, ltd_delta, zero)
# LTP: target spiked → weight += pre_trace[source] >> 3
ltp_active = spike_f[row_idx] > 0
ltp_delta = trace_shifted[col]
delta_ltp = torch.where(ltp_active, ltp_delta, zero)
# Apply mask: only update learnable connections
if self._stdp_mask is not None:
delta_ltd = delta_ltd * self._stdp_mask.float()
delta_ltp = delta_ltp * self._stdp_mask.float()
val_new = val - delta_ltd + delta_ltp
# Clamp only learnable connections (preserve fixed inhibitory weights)
clamped = torch.clamp(val_new, min=WEIGHT_MIN_STDP, max=WEIGHT_MAX_STDP)
if self._stdp_mask is not None:
val_new = torch.where(self._stdp_mask, clamped, val)
else:
val_new = clamped
# Rebuild CSR (structure unchanged, only values updated)
self._W_soma = torch.sparse_csr_tensor(crow, col, val_new, (self._n, self._n))
def _elig_update_gpu(self, spike_mask):
"""3-factor: STDP correlation → eligibility accumulation."""
if self._learning_rule is not None:
self._microcode_learn_gpu(spike_mask, three_factor=True)
return
if not spike_mask.any() or self._elig_vals is None:
return
spike_f = spike_mask.float()
col = self._soma_col
row_idx = self._soma_row_idx
trace_shifted = (self._trace >> LEARN_SHIFT).float()
# LTD: source spiked → elig -= post_trace[target] >> 3
ltd_active = spike_f[col] > 0
ltd_delta = trace_shifted[row_idx]
self._elig_vals = self._elig_vals - torch.where(ltd_active, ltd_delta,
torch.zeros_like(self._elig_vals))
# LTP: target spiked → elig += pre_trace[source] >> 3
ltp_active = spike_f[row_idx] > 0
ltp_delta = trace_shifted[col]
self._elig_vals = self._elig_vals + torch.where(ltp_active, ltp_delta,
torch.zeros_like(self._elig_vals))
# Clamp eligibility
self._elig_vals = torch.clamp(self._elig_vals, min=-ELIG_MAX, max=ELIG_MAX)
def _reward_apply_gpu(self):
"""Apply reward to weights via eligibility: W += (elig * reward) >> REWARD_SHIFT."""
if self._reward_value == 0 or self._elig_vals is None:
return
delta = torch.div(self._elig_vals * self._reward_value, 1 << REWARD_SHIFT,
rounding_mode='trunc')
val = self._W_soma.values() + delta
val = torch.clamp(val, min=WEIGHT_MIN_STDP, max=WEIGHT_MAX_STDP)
self._W_soma = torch.sparse_csr_tensor(
self._soma_crow, self._soma_col, val, (self._n, self._n))
self._reward_value = 0
def _elig_decay_gpu(self):
"""Exponential decay of eligibility: elig -= sign(elig) * max(1, |elig| >> 3)."""
if self._elig_vals is None:
return
abs_vals = self._elig_vals.abs()
nonzero = abs_vals > 0
decay = torch.max(torch.ones_like(self._elig_vals),
torch.div(abs_vals, 1 << ELIG_DECAY_SHIFT, rounding_mode='trunc'))
sign = self._elig_vals.sign()
new_vals = self._elig_vals - sign * decay
# Zero out values that crossed zero
crossed_zero = (self._elig_vals * new_vals) < 0
new_vals = torch.where(crossed_zero, torch.zeros_like(new_vals), new_vals)
# Also zero out values where decay >= |val|
new_vals = torch.where(nonzero, new_vals, self._elig_vals)
self._elig_vals = new_vals
def _microcode_learn_gpu(self, spike_mask, three_factor=False):
"""P19 microcode learning: CPU fallback for custom rules.
Transfers spiking neuron data to CPU, runs interpreter, transfers back.
"""
if not spike_mask.any() or self._W_soma._nnz() == 0:
return
program = self._learning_rule.get_program()
spiking_ids = spike_mask.nonzero(as_tuple=True)[0].cpu().numpy()
trace_cpu = self._trace.cpu().numpy()
trace2_cpu = self._trace2.cpu().numpy()
# Pull weight values to CPU
crow_cpu = self._soma_crow.cpu().numpy()
col_cpu = self._soma_col.cpu().numpy()
val_cpu = self._W_soma.values().cpu().numpy().copy()
# Pull eligibility if 3-factor
elig_cpu = self._elig_vals.cpu().numpy().copy() if self._elig_vals is not None else None
for spike_gid in spiking_ids:
row_start = crow_cpu[spike_gid]
row_end = crow_cpu[spike_gid + 1]
for idx in range(row_start, row_end):
pass
# Full adjacency iteration for microcode learning
adj = self._adjacency
weights_dict = {}
# Build mutable weight dict from adjacency
for src, targets in adj.items():
weights_dict[src] = list(targets)
for spike_gid in spiking_ids:
spike_gid = int(spike_gid)
# LTD: pre spiked
if spike_gid in weights_dict:
updated = []
for entry in weights_dict[spike_gid]:
tgt, w, c = entry[0], entry[1], entry[2]
rest = entry[3:]
if tgt < self._n:
post_t1 = int(trace_cpu[tgt])
post_t2 = int(trace2_cpu[tgt])
elig_key = self._get_elig_index(spike_gid, tgt)
elig = int(elig_cpu[elig_key]) if elig_cpu is not None and elig_key is not None else 0
regs = [post_t1, post_t2, w, elig, 0, 0, 0, self._reward_value]
result = execute_program(program, LTD_START, LTD_END + 1, regs)
if three_factor:
if result["elig_written"] and elig_key is not None:
elig_cpu[elig_key] = max(-ELIG_MAX, min(ELIG_MAX, result["elig"]))
else:
if result["weight_written"]:
w = max(WEIGHT_MIN_STDP, min(WEIGHT_MAX_STDP, result["weight"]))
updated.append((tgt, w, c, *rest))
weights_dict[spike_gid] = updated
# LTP: post spiked
for src, targets in weights_dict.items():
if src == spike_gid:
continue
updated = []
for entry in targets:
tgt, w, c = entry[0], entry[1], entry[2]
rest = entry[3:]
if tgt == spike_gid:
pre_t1 = int(trace_cpu[src])
pre_t2 = int(trace2_cpu[src])
elig_key = self._get_elig_index(src, tgt)
elig = int(elig_cpu[elig_key]) if elig_cpu is not None and elig_key is not None else 0
regs = [pre_t1, pre_t2, w, elig, 0, 0, 0, self._reward_value]
result = execute_program(program, LTP_START, LTP_END + 1, regs)
if three_factor:
if result["elig_written"] and elig_key is not None:
elig_cpu[elig_key] = max(-ELIG_MAX, min(ELIG_MAX, result["elig"]))
else:
if result["weight_written"]:
w = max(WEIGHT_MIN_STDP, min(WEIGHT_MAX_STDP, result["weight"]))
updated.append((tgt, w, c, *rest))
weights_dict[src] = updated
# Sync back to GPU
self._adjacency = weights_dict
self._rebuild_weight_matrices_from_adjacency()
if elig_cpu is not None and self._elig_vals is not None:
self._elig_vals = torch.from_numpy(elig_cpu).to(self.device)
def _get_elig_index(self, src_gid, tgt_gid):
"""Find the CSR value index for synapse (src_gid, tgt_gid) in W_soma.
W_soma is (target, source) CSR, so row=tgt_gid, and we search
for col=src_gid within that row.
"""
if self._soma_crow is None:
return None
crow_cpu = self._soma_crow.cpu()
col_cpu = self._soma_col.cpu()
row_start = int(crow_cpu[tgt_gid])
row_end = int(crow_cpu[tgt_gid + 1])
for idx in range(row_start, row_end):
if int(col_cpu[idx]) == src_gid:
return idx
return None
def _rebuild_weight_matrices_from_adjacency(self):
"""Rebuild GPU weight matrices from CPU adjacency (after microcode update)."""
self._build_weight_matrices(self._n)
def _sync_weights_to_adjacency(self):
"""Sync GPU weight matrix values back to CPU adjacency dict.
Only updates weights for compartment-0 immediate connections (the learnable ones).
"""
if self._W_soma is None or self._W_soma._nnz() == 0:
return
val_cpu = self._W_soma.values().cpu().numpy()
crow_cpu = self._soma_crow.cpu().numpy()
col_cpu = self._soma_col.cpu().numpy()
# Build a lookup: (tgt, src) -> new_weight
weight_updates = {}
for tgt in range(self._n):
start = int(crow_cpu[tgt])
end = int(crow_cpu[tgt + 1])
for idx in range(start, end):
src = int(col_cpu[idx])
weight_updates[(src, tgt)] = int(round(val_cpu[idx]))
# Update adjacency
for src, targets in self._adjacency.items():
updated = []
for entry in targets:
tgt, w, c = entry[0], entry[1], entry[2]
rest = entry[3:]
delay = rest[0] if rest else 0
if delay == 0 and c == 0:
key = (src, tgt)
if key in weight_updates:
w = weight_updates[key]
updated.append((tgt, w, c, *rest))
self._adjacency[src] = updated
def set_learning(self, learn=False, graded=False, dendritic=False,
async_mode=False, three_factor=False, noise=False):
"""Configure feature flags."""
self._learn_enable = learn
self._graded_enable = graded
self._dendritic_enable = dendritic
self._three_factor_enable = three_factor
self._noise_enable = noise
if async_mode:
raise NeurocoreError("Async mode not supported on GPU simulator.")
if three_factor and not learn:
self._learn_enable = True
def set_stdp_mask(self, learnable_source_gids):
"""Mark which connections are STDP-learnable by source neuron ID.
Only connections FROM neurons in learnable_source_gids will be updated
by STDP. All other connections remain fixed. This is essential for
networks where only some connections should learn (e.g., input→excitatory
in Diehl & Cook architecture).
Args:
learnable_source_gids: set or list of global neuron IDs whose
outgoing connections should be STDP-learnable.
"""
if self._W_soma is None or self._W_soma._nnz() == 0:
return
src_set = set(learnable_source_gids)
col = self._soma_col.cpu().numpy()
mask = torch.tensor([int(c) in src_set for c in col],
dtype=torch.bool, device=self.device)
self._stdp_mask = mask
def reset_state(self):
"""Reset all neuron state to initial values. Call between training images."""
self._potential.zero_()
self._refrac.zero_()
self._trace.zero_()
self._trace2.zero_()
self._ext_current.zero_()
self._prev_spike_vec.zero_()
if self._has_delays and self._delay_buf_soma is not None:
self._delay_buf_soma.zero_()
self._delay_buf_dend.zero_()
@torch.no_grad()
def randomize_learnable_weights(self, low=10.0, high=400.0, seed=42):
"""Randomize STDP-masked connection weights on GPU.
Useful for breaking symmetry before competitive learning.
Only modifies entries where self._stdp_mask is True.
"""
if self._stdp_mask is None or self._W_soma._nnz() == 0:
return
nnz = int(self._W_soma._nnz())
rng = np.random.RandomState(seed)
rand_vals = torch.from_numpy(
rng.uniform(low, high, size=nnz).astype(np.float32)
).to(self.device)
val = self._W_soma.values().clone()
val_new = torch.where(self._stdp_mask, rand_vals, val)
self._W_soma = torch.sparse_csr_tensor(
self._soma_crow, self._soma_col, val_new, (self._n, self._n))
@torch.no_grad()
def competitive_update(self, winner_gids, pixel_intensity, pixel_gids,
eta_ltp=0.05, eta_ltd=0.01, w_max=2000.0):
"""GPU-native competitive weight update on W_soma CSR values.
Uses scale-invariant EMA: the target is scaled to match each winner
neuron's current weight magnitude, so eta truly represents the
fractional movement toward the input pattern.
Winner: w += eta_ltp * (x_pre * scale_i - w)
where scale_i = sum(w_i) / sum(x_pre_i) for neuron i.
Loser: w -= eta_ltd * w * x_pre
Anti-Hebbian for active pixels.
Args:
winner_gids: (K,) int64 tensor of winner GIDs on GPU
pixel_intensity: (n_input,) float32 tensor of pixel values [0,1] on GPU
pixel_gids: (n_input,) int64 tensor of input neuron GIDs on GPU
eta_ltp: learning rate for winners (default: 0.05)
eta_ltd: learning rate for losers (default: 0.01)
w_max: clamp ceiling for final weights
"""
if self._stdp_mask is None or self._W_soma._nnz() == 0:
return
dev = self.device
val = self._W_soma.values()
col = self._soma_col
row_idx = self._soma_row_idx.long()
learnable = self._stdp_mask
# Pixel intensity lookup: only input neuron GIDs have nonzero values
pixel_lookup = torch.zeros(self._n, dtype=torch.float32, device=dev)
pixel_lookup[pixel_gids] = pixel_intensity
x_pre = pixel_lookup[col] # (nnz,) pixel intensity per source
# Winner lookup
winner_full = torch.zeros(self._n, dtype=torch.bool, device=dev)
winner_full[winner_gids] = True
is_winner = winner_full[row_idx] # (nnz,)
winner_mask = learnable & is_winner
# Compute per-neuron adaptive scale so target has same magnitude as
# current weights (scale = w_sum / x_sum per winner neuron)
w_per_tgt = torch.zeros(self._n, dtype=torch.float32, device=dev)
w_per_tgt.scatter_add_(0, row_idx,
torch.where(winner_mask, val.clamp(min=0), torch.zeros_like(val)))
x_per_tgt = torch.zeros(self._n, dtype=torch.float32, device=dev)
x_per_tgt.scatter_add_(0, row_idx,
torch.where(winner_mask, x_pre, torch.zeros_like(x_pre)))
scale = torch.where(x_per_tgt > 1e-6, w_per_tgt / x_per_tgt,
torch.ones(self._n, dtype=torch.float32, device=dev))
entry_scale = scale[row_idx] # (nnz,) per-entry scale
# Winner: scale-invariant EMA toward input pattern
target = x_pre * entry_scale
dw_winner = eta_ltp * (target - val)
# Loser: anti-Hebbian for active pixels
active = x_pre > 0.01
loser_mask = learnable & (~is_winner) & active
dw_loser = eta_ltd * val * x_pre
val_new = val.clone()
val_new = torch.where(winner_mask, val + dw_winner, val_new)
val_new = torch.where(loser_mask, val - dw_loser, val_new)
# Clamp learnable only, preserve fixed weights
val_clamped = torch.clamp(val_new, min=0.0, max=w_max)
val_final = torch.where(learnable, val_clamped, val)
self._W_soma = torch.sparse_csr_tensor(
self._soma_crow, self._soma_col, val_final, (self._n, self._n))
@torch.no_grad()
def normalize_learnable_weights(self, target_sum, target_gids=None):
"""GPU-native per-target weight normalization for learnable connections.
Scales learnable incoming weights for each target neuron so their sum
equals target_sum. Non-learnable weights are preserved.
Args:
target_sum: desired sum of learnable weights per target neuron
target_gids: (M,) int64 tensor of target GIDs on GPU, or None for all
"""
if self._stdp_mask is None or self._W_soma._nnz() == 0:
return
dev = self.device
val = self._W_soma.values().clone()
row_idx = self._soma_row_idx.long()
learnable = self._stdp_mask
# Entry mask: learnable connections to specified targets
if target_gids is not None:
tgt_mask = torch.zeros(self._n, dtype=torch.bool, device=dev)
tgt_mask[target_gids] = True
entry_mask = tgt_mask[row_idx] & learnable
else:
entry_mask = learnable
# Sum positive weights per target (only masked entries)
masked_vals = torch.where(entry_mask, val.clamp(min=0), torch.zeros_like(val))
per_tgt_sum = torch.zeros(self._n, dtype=torch.float32, device=dev)
per_tgt_sum.scatter_add_(0, row_idx, masked_vals)
# Per-target scale factor
scale = torch.where(per_tgt_sum > 0,
float(target_sum) / per_tgt_sum,
torch.ones(self._n, dtype=torch.float32, device=dev))
entry_scale = scale[row_idx]
# Apply scale only to masked entries
val_scaled = torch.where(entry_mask, val * entry_scale, val)
val_final = torch.where(learnable,
val_scaled.clamp(min=0, max=float(WEIGHT_MAX_STDP)),
val)
self._W_soma = torch.sparse_csr_tensor(
self._soma_crow, self._soma_col, val_final, (self._n, self._n))
def status(self):
return {"state": 0, "timestep_count": self._timestep_count}
def close(self):
"""Release GPU memory."""
self._W_soma = None
self._W_dend = [None] * 3
self._potential = None
self._delay_buf_soma = None
self._delay_buf_dend = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _resolve_targets(self, target):
"""Convert Population/PopulationSlice to [(core, neuron)] pairs."""
if isinstance(target, list):
return target
placement = self._compiled.placement
if isinstance(target, PopulationSlice):
return [
placement.neuron_map[(target.population.id, i)]
for i in target.indices
]
if isinstance(target, Population):
return [
placement.neuron_map[(target.id, i)]
for i in range(target.size)
]
raise TypeError(f"Cannot resolve target of type {type(target)}")
def get_weights(self):
"""Export current weights as adjacency dict (CPU)."""
if self._learn_enable:
self._sync_weights_to_adjacency()
return dict(self._adjacency) if self._adjacency else {}