File size: 4,371 Bytes
3bb804c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""

Antti's Spiking Neuron - A Leaky Integrate-and-Fire (LIF) neuron

Transforms input signals into spikes. Can be chained.

Place this file in the 'nodes' folder

"""

import numpy as np
from PyQt6 import QtGui
import cv2

import sys
import os
# --- This is the new, correct block ---
import __main__
BaseNode = __main__.BaseNode
PA_INSTANCE = getattr(__main__, "PA_INSTANCE", None)
# ------------------------------------

class SpikingNeuronNode(BaseNode):
    NODE_CATEGORY = "Transform"
    NODE_COLOR = QtGui.QColor(220, 120, 40) # Neural orange
    
    def __init__(self, threshold=1.0, tau_m=0.1, resistance=5.0, refractory_ms=0.05):
        super().__init__()
        self.node_title = "Spiking Neuron (LIF)"
        
        self.inputs = {'signal_in': 'signal'}
        self.outputs = {'spike_out': 'signal'}
        
        # --- Neuron Parameters ---
        # These are configurable (see get_config_options)
        self.V_rest = 0.0
        self.V_threshold = float(threshold)
        self.V_reset = 0.0
        self.tau_m = float(tau_m)             # Membrane time constant (sec)
        self.R_m = float(resistance)          # Membrane resistance (scales input)
        self.refractory_period = float(refractory_ms) # Refractory period (sec)
        
        # --- Neuron State ---
        self.V_m = self.V_rest                # Current membrane potential
        self.refractory_timer = 0.0           # Countdown timer for refractory period
        self.output_signal = 0.0              # Output spike
        self.dt = 1.0 / 30.0                  # Assume ~30 FPS step rate

    def step(self):
        # 1. Reset output
        self.output_signal = 0.0
        
        # 2. Check refractory period
        if self.refractory_timer > 0:
            self.refractory_timer -= self.dt
            self.V_m = self.V_reset # Keep potential at reset
            return

        # 3. Get total input current (crucially, using 'sum' blend mode)
        # This allows multiple neurons to connect and sum their inputs
        I_in = self.get_blended_input('signal_in', 'sum') or 0.0
        
        # 4. Leaky Integrate-and-Fire (LIF) equation
        # tau_m * dV/dt = (V_rest - V) + R_m * I_in
        # dV = [ (V_rest - V_m) + (R_m * I_in) ] / tau_m * dt
        dV = (((self.V_rest - self.V_m) + self.R_m * I_in) / self.tau_m) * self.dt
        
        self.V_m += dV
        
        # 5. Check for spike
        if self.V_m >= self.V_threshold:
            self.output_signal = 1.0          # Fire!
            self.V_m = self.V_reset           # Reset potential
            self.refractory_timer = self.refractory_period # Start refractory timer

    def get_output(self, port_name):
        if port_name == 'spike_out':
            return self.output_signal
        return None
        
    def get_display_image(self):
        w, h = 64, 64
        img = np.zeros((h, w, 3), dtype=np.uint8)
        
        # Max voltage to display (to see threshold)
        max_viz_v = self.V_threshold * 1.2
        
        # Draw threshold line (Red)
        thresh_y = h - int(np.clip(self.V_threshold / max_viz_v, 0, 1) * h)
        cv2.line(img, (0, thresh_y), (w, thresh_y), (0, 0, 255), 1)

        # Draw resting line (Gray)
        rest_y = h - int(np.clip(self.V_rest / max_viz_v, 0, 1) * h)
        cv2.line(img, (0, rest_y), (w, rest_y), (100, 100, 100), 1)

        # Draw membrane potential bar
        vm_y = h - int(np.clip(self.V_m / max_viz_v, 0, 1) * h)
        
        if self.output_signal == 1.0:
            bar_color = (0, 255, 255) # Yellow
        elif self.refractory_timer > 0:
            bar_color = (255, 100, 0) # Blue
        else:
            bar_color = (0, 255, 0) # Green
            
        cv2.rectangle(img, (w//2 - 5, vm_y), (w//2 + 5, h), bar_color, -1)
        
        img = np.ascontiguousarray(img)
        return QtGui.QImage(img.data, w, h, 3*w, QtGui.QImage.Format.Format_BGR888)

    def get_config_options(self):
        return [
            ("Threshold", "V_threshold", self.V_threshold, None),
            ("Leak (tau_m)", "tau_m", self.tau_m, None),
            ("Input (R_m)", "R_m", self.R_m, None),
            ("Refractory (sec)", "refractory_period", self.refractory_period, None),
        ]