Spaces:
Running
Running
| """ | |
| Antti's Spiking Neuron - A Leaky Integrate-and-Fire (LIF) neuron | |
| Transforms input signals into spikes. Can be chained. | |
| Place this file in the 'nodes' folder | |
| """ | |
| import numpy as np | |
| from PyQt6 import QtGui | |
| import cv2 | |
| import sys | |
| import os | |
| # --- This is the new, correct block --- | |
| import __main__ | |
| BaseNode = __main__.BaseNode | |
| PA_INSTANCE = getattr(__main__, "PA_INSTANCE", None) | |
| # ------------------------------------ | |
| class SpikingNeuronNode(BaseNode): | |
| NODE_CATEGORY = "Transform" | |
| NODE_COLOR = QtGui.QColor(220, 120, 40) # Neural orange | |
| def __init__(self, threshold=1.0, tau_m=0.1, resistance=5.0, refractory_ms=0.05): | |
| super().__init__() | |
| self.node_title = "Spiking Neuron (LIF)" | |
| self.inputs = {'signal_in': 'signal'} | |
| self.outputs = {'spike_out': 'signal'} | |
| # --- Neuron Parameters --- | |
| # These are configurable (see get_config_options) | |
| self.V_rest = 0.0 | |
| self.V_threshold = float(threshold) | |
| self.V_reset = 0.0 | |
| self.tau_m = float(tau_m) # Membrane time constant (sec) | |
| self.R_m = float(resistance) # Membrane resistance (scales input) | |
| self.refractory_period = float(refractory_ms) # Refractory period (sec) | |
| # --- Neuron State --- | |
| self.V_m = self.V_rest # Current membrane potential | |
| self.refractory_timer = 0.0 # Countdown timer for refractory period | |
| self.output_signal = 0.0 # Output spike | |
| self.dt = 1.0 / 30.0 # Assume ~30 FPS step rate | |
| def step(self): | |
| # 1. Reset output | |
| self.output_signal = 0.0 | |
| # 2. Check refractory period | |
| if self.refractory_timer > 0: | |
| self.refractory_timer -= self.dt | |
| self.V_m = self.V_reset # Keep potential at reset | |
| return | |
| # 3. Get total input current (crucially, using 'sum' blend mode) | |
| # This allows multiple neurons to connect and sum their inputs | |
| I_in = self.get_blended_input('signal_in', 'sum') or 0.0 | |
| # 4. Leaky Integrate-and-Fire (LIF) equation | |
| # tau_m * dV/dt = (V_rest - V) + R_m * I_in | |
| # dV = [ (V_rest - V_m) + (R_m * I_in) ] / tau_m * dt | |
| dV = (((self.V_rest - self.V_m) + self.R_m * I_in) / self.tau_m) * self.dt | |
| self.V_m += dV | |
| # 5. Check for spike | |
| if self.V_m >= self.V_threshold: | |
| self.output_signal = 1.0 # Fire! | |
| self.V_m = self.V_reset # Reset potential | |
| self.refractory_timer = self.refractory_period # Start refractory timer | |
| def get_output(self, port_name): | |
| if port_name == 'spike_out': | |
| return self.output_signal | |
| return None | |
| def get_display_image(self): | |
| w, h = 64, 64 | |
| img = np.zeros((h, w, 3), dtype=np.uint8) | |
| # Max voltage to display (to see threshold) | |
| max_viz_v = self.V_threshold * 1.2 | |
| # Draw threshold line (Red) | |
| thresh_y = h - int(np.clip(self.V_threshold / max_viz_v, 0, 1) * h) | |
| cv2.line(img, (0, thresh_y), (w, thresh_y), (0, 0, 255), 1) | |
| # Draw resting line (Gray) | |
| rest_y = h - int(np.clip(self.V_rest / max_viz_v, 0, 1) * h) | |
| cv2.line(img, (0, rest_y), (w, rest_y), (100, 100, 100), 1) | |
| # Draw membrane potential bar | |
| vm_y = h - int(np.clip(self.V_m / max_viz_v, 0, 1) * h) | |
| if self.output_signal == 1.0: | |
| bar_color = (0, 255, 255) # Yellow | |
| elif self.refractory_timer > 0: | |
| bar_color = (255, 100, 0) # Blue | |
| else: | |
| bar_color = (0, 255, 0) # Green | |
| cv2.rectangle(img, (w//2 - 5, vm_y), (w//2 + 5, h), bar_color, -1) | |
| img = np.ascontiguousarray(img) | |
| return QtGui.QImage(img.data, w, h, 3*w, QtGui.QImage.Format.Format_BGR888) | |
| def get_config_options(self): | |
| return [ | |
| ("Threshold", "V_threshold", self.V_threshold, None), | |
| ("Leak (tau_m)", "tau_m", self.tau_m, None), | |
| ("Input (R_m)", "R_m", self.R_m, None), | |
| ("Refractory (sec)", "refractory_period", self.refractory_period, None), | |
| ] |