aur / plasticity.py
Andrewstivan's picture
Create plasticity.py
4e0e690 verified
# plasticity.py
import torch
import torch.nn as nn
from typing import List, Optional
class UnifiedPlasticity(nn.Module):
def __init__(self, modules: List[nn.Module], lr: float = 0.01,
consolidation_rate: float = 0.01, forget_rate: float = 0.1):
super().__init__()
self.modules = modules
self.lr = lr
self.consolidation_rate = consolidation_rate
self.forget_rate = forget_rate
self.spike_traces = {}
self.dopamine_trace = 0.0
self.dopamine_decay = 0.9
self.bcm_thresholds = {}
self.step_count = 0
def step(self, layer_idx: int, pre_activation: torch.Tensor,
post_activation: torch.Tensor, reward: float = 0.0,
importance: float = 1.0):
if layer_idx >= len(self.modules):
return
module = self.modules[layer_idx]
self.dopamine_trace = reward + self.dopamine_decay * self.dopamine_trace
modulation = 1.0 + self.dopamine_trace * importance
pre_spikes = self._compute_spikes(pre_activation, f"pre_{layer_idx}")
post_spikes = self._compute_spikes(post_activation, f"post_{layer_idx}")
if hasattr(module, 'weight_fp32'):
delta_hebb = self.lr * modulation * torch.einsum('bi,bj->ij', pre_spikes, post_spikes)
if delta_hebb.shape == module.weight_fp32.shape[-2:]:
module.weight_fp32.data += delta_hebb
elif hasattr(module, 'n_head'):
delta_hebb = delta_hebb.unsqueeze(0).expand(module.n_head, -1, -1)
if delta_hebb.shape == module.weight_fp32.shape:
module.weight_fp32.data += delta_hebb
self.step_count += 1
if self.step_count % 100 == 0:
self._consolidate_all()
def _compute_spikes(self, activation: torch.Tensor, trace_id: str) -> torch.Tensor:
if trace_id not in self.spike_traces:
self.spike_traces[trace_id] = torch.zeros(activation.shape[-1], device=activation.device)
trace = self.spike_traces[trace_id]
trace += activation.mean(dim=0)
spikes = (trace >= 0.5).float()
trace -= spikes * 0.5
self.spike_traces[trace_id] = trace
return spikes
def _update_ternary(self):
for module in self.modules:
if hasattr(module, 'update_ternary_weights'):
module.update_ternary_weights()
def _consolidate_all(self):
for module in self.modules:
if hasattr(module, 'long_term_weight'):
module.long_term_weight += self.consolidation_rate * module.weight_fp32.data
module.weight_fp32.data *= (1 - self.forget_rate)
module.weight_fp32.data += self.consolidation_rate * module.long_term_weight
self._update_ternary()
class Plasticity(nn.Module):
def __init__(self, n_neurons: int = 128):
super().__init__()
self.n_neurons = n_neurons
self.w = torch.zeros(n_neurons, n_neurons)
self.long_term_w = torch.zeros(n_neurons, n_neurons)
self.lr = 0.01
self.consolidation_rate = 0.01
self.forget_rate = 0.1
self.acc_pre = torch.zeros(n_neurons)
self.acc_post = torch.zeros(n_neurons)
self.threshold = 0.5
self.bcm_theta = torch.zeros(n_neurons)
self.lr_bcm = 0.001
self.target_activity = 0.5
self.step_count = 0
self.dopamine_trace = 0.0
self.dopamine_decay = 0.9