# plasticity.py import torch import torch.nn as nn from typing import List, Optional class UnifiedPlasticity(nn.Module): def __init__(self, modules: List[nn.Module], lr: float = 0.01, consolidation_rate: float = 0.01, forget_rate: float = 0.1): super().__init__() self.modules = modules self.lr = lr self.consolidation_rate = consolidation_rate self.forget_rate = forget_rate self.spike_traces = {} self.dopamine_trace = 0.0 self.dopamine_decay = 0.9 self.bcm_thresholds = {} self.step_count = 0 def step(self, layer_idx: int, pre_activation: torch.Tensor, post_activation: torch.Tensor, reward: float = 0.0, importance: float = 1.0): if layer_idx >= len(self.modules): return module = self.modules[layer_idx] self.dopamine_trace = reward + self.dopamine_decay * self.dopamine_trace modulation = 1.0 + self.dopamine_trace * importance pre_spikes = self._compute_spikes(pre_activation, f"pre_{layer_idx}") post_spikes = self._compute_spikes(post_activation, f"post_{layer_idx}") if hasattr(module, 'weight_fp32'): delta_hebb = self.lr * modulation * torch.einsum('bi,bj->ij', pre_spikes, post_spikes) if delta_hebb.shape == module.weight_fp32.shape[-2:]: module.weight_fp32.data += delta_hebb elif hasattr(module, 'n_head'): delta_hebb = delta_hebb.unsqueeze(0).expand(module.n_head, -1, -1) if delta_hebb.shape == module.weight_fp32.shape: module.weight_fp32.data += delta_hebb self.step_count += 1 if self.step_count % 100 == 0: self._consolidate_all() def _compute_spikes(self, activation: torch.Tensor, trace_id: str) -> torch.Tensor: if trace_id not in self.spike_traces: self.spike_traces[trace_id] = torch.zeros(activation.shape[-1], device=activation.device) trace = self.spike_traces[trace_id] trace += activation.mean(dim=0) spikes = (trace >= 0.5).float() trace -= spikes * 0.5 self.spike_traces[trace_id] = trace return spikes def _update_ternary(self): for module in self.modules: if hasattr(module, 'update_ternary_weights'): module.update_ternary_weights() def _consolidate_all(self): for module in self.modules: if hasattr(module, 'long_term_weight'): module.long_term_weight += self.consolidation_rate * module.weight_fp32.data module.weight_fp32.data *= (1 - self.forget_rate) module.weight_fp32.data += self.consolidation_rate * module.long_term_weight self._update_ternary() class Plasticity(nn.Module): def __init__(self, n_neurons: int = 128): super().__init__() self.n_neurons = n_neurons self.w = torch.zeros(n_neurons, n_neurons) self.long_term_w = torch.zeros(n_neurons, n_neurons) self.lr = 0.01 self.consolidation_rate = 0.01 self.forget_rate = 0.1 self.acc_pre = torch.zeros(n_neurons) self.acc_post = torch.zeros(n_neurons) self.threshold = 0.5 self.bcm_theta = torch.zeros(n_neurons) self.lr_bcm = 0.001 self.target_activity = 0.5 self.step_count = 0 self.dopamine_trace = 0.0 self.dopamine_decay = 0.9