Spaces:
Running
Running
File size: 4,371 Bytes
3bb804c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
"""
Antti's Spiking Neuron - A Leaky Integrate-and-Fire (LIF) neuron
Transforms input signals into spikes. Can be chained.
Place this file in the 'nodes' folder
"""
import numpy as np
from PyQt6 import QtGui
import cv2
import sys
import os
# --- This is the new, correct block ---
import __main__
BaseNode = __main__.BaseNode
PA_INSTANCE = getattr(__main__, "PA_INSTANCE", None)
# ------------------------------------
class SpikingNeuronNode(BaseNode):
NODE_CATEGORY = "Transform"
NODE_COLOR = QtGui.QColor(220, 120, 40) # Neural orange
def __init__(self, threshold=1.0, tau_m=0.1, resistance=5.0, refractory_ms=0.05):
super().__init__()
self.node_title = "Spiking Neuron (LIF)"
self.inputs = {'signal_in': 'signal'}
self.outputs = {'spike_out': 'signal'}
# --- Neuron Parameters ---
# These are configurable (see get_config_options)
self.V_rest = 0.0
self.V_threshold = float(threshold)
self.V_reset = 0.0
self.tau_m = float(tau_m) # Membrane time constant (sec)
self.R_m = float(resistance) # Membrane resistance (scales input)
self.refractory_period = float(refractory_ms) # Refractory period (sec)
# --- Neuron State ---
self.V_m = self.V_rest # Current membrane potential
self.refractory_timer = 0.0 # Countdown timer for refractory period
self.output_signal = 0.0 # Output spike
self.dt = 1.0 / 30.0 # Assume ~30 FPS step rate
def step(self):
# 1. Reset output
self.output_signal = 0.0
# 2. Check refractory period
if self.refractory_timer > 0:
self.refractory_timer -= self.dt
self.V_m = self.V_reset # Keep potential at reset
return
# 3. Get total input current (crucially, using 'sum' blend mode)
# This allows multiple neurons to connect and sum their inputs
I_in = self.get_blended_input('signal_in', 'sum') or 0.0
# 4. Leaky Integrate-and-Fire (LIF) equation
# tau_m * dV/dt = (V_rest - V) + R_m * I_in
# dV = [ (V_rest - V_m) + (R_m * I_in) ] / tau_m * dt
dV = (((self.V_rest - self.V_m) + self.R_m * I_in) / self.tau_m) * self.dt
self.V_m += dV
# 5. Check for spike
if self.V_m >= self.V_threshold:
self.output_signal = 1.0 # Fire!
self.V_m = self.V_reset # Reset potential
self.refractory_timer = self.refractory_period # Start refractory timer
def get_output(self, port_name):
if port_name == 'spike_out':
return self.output_signal
return None
def get_display_image(self):
w, h = 64, 64
img = np.zeros((h, w, 3), dtype=np.uint8)
# Max voltage to display (to see threshold)
max_viz_v = self.V_threshold * 1.2
# Draw threshold line (Red)
thresh_y = h - int(np.clip(self.V_threshold / max_viz_v, 0, 1) * h)
cv2.line(img, (0, thresh_y), (w, thresh_y), (0, 0, 255), 1)
# Draw resting line (Gray)
rest_y = h - int(np.clip(self.V_rest / max_viz_v, 0, 1) * h)
cv2.line(img, (0, rest_y), (w, rest_y), (100, 100, 100), 1)
# Draw membrane potential bar
vm_y = h - int(np.clip(self.V_m / max_viz_v, 0, 1) * h)
if self.output_signal == 1.0:
bar_color = (0, 255, 255) # Yellow
elif self.refractory_timer > 0:
bar_color = (255, 100, 0) # Blue
else:
bar_color = (0, 255, 0) # Green
cv2.rectangle(img, (w//2 - 5, vm_y), (w//2 + 5, h), bar_color, -1)
img = np.ascontiguousarray(img)
return QtGui.QImage(img.data, w, h, 3*w, QtGui.QImage.Format.Format_BGR888)
def get_config_options(self):
return [
("Threshold", "V_threshold", self.V_threshold, None),
("Leak (tau_m)", "tau_m", self.tau_m, None),
("Input (R_m)", "R_m", self.R_m, None),
("Refractory (sec)", "refractory_period", self.refractory_period, None),
] |