Spaces:
Sleeping
Sleeping
| # plasticity.py | |
| import torch | |
| import torch.nn as nn | |
| from typing import List, Optional | |
| class UnifiedPlasticity(nn.Module): | |
| def __init__(self, modules: List[nn.Module], lr: float = 0.01, | |
| consolidation_rate: float = 0.01, forget_rate: float = 0.1): | |
| super().__init__() | |
| self.modules = modules | |
| self.lr = lr | |
| self.consolidation_rate = consolidation_rate | |
| self.forget_rate = forget_rate | |
| self.spike_traces = {} | |
| self.dopamine_trace = 0.0 | |
| self.dopamine_decay = 0.9 | |
| self.bcm_thresholds = {} | |
| self.step_count = 0 | |
| def step(self, layer_idx: int, pre_activation: torch.Tensor, | |
| post_activation: torch.Tensor, reward: float = 0.0, | |
| importance: float = 1.0): | |
| if layer_idx >= len(self.modules): | |
| return | |
| module = self.modules[layer_idx] | |
| self.dopamine_trace = reward + self.dopamine_decay * self.dopamine_trace | |
| modulation = 1.0 + self.dopamine_trace * importance | |
| pre_spikes = self._compute_spikes(pre_activation, f"pre_{layer_idx}") | |
| post_spikes = self._compute_spikes(post_activation, f"post_{layer_idx}") | |
| if hasattr(module, 'weight_fp32'): | |
| delta_hebb = self.lr * modulation * torch.einsum('bi,bj->ij', pre_spikes, post_spikes) | |
| if delta_hebb.shape == module.weight_fp32.shape[-2:]: | |
| module.weight_fp32.data += delta_hebb | |
| elif hasattr(module, 'n_head'): | |
| delta_hebb = delta_hebb.unsqueeze(0).expand(module.n_head, -1, -1) | |
| if delta_hebb.shape == module.weight_fp32.shape: | |
| module.weight_fp32.data += delta_hebb | |
| self.step_count += 1 | |
| if self.step_count % 100 == 0: | |
| self._consolidate_all() | |
| def _compute_spikes(self, activation: torch.Tensor, trace_id: str) -> torch.Tensor: | |
| if trace_id not in self.spike_traces: | |
| self.spike_traces[trace_id] = torch.zeros(activation.shape[-1], device=activation.device) | |
| trace = self.spike_traces[trace_id] | |
| trace += activation.mean(dim=0) | |
| spikes = (trace >= 0.5).float() | |
| trace -= spikes * 0.5 | |
| self.spike_traces[trace_id] = trace | |
| return spikes | |
| def _update_ternary(self): | |
| for module in self.modules: | |
| if hasattr(module, 'update_ternary_weights'): | |
| module.update_ternary_weights() | |
| def _consolidate_all(self): | |
| for module in self.modules: | |
| if hasattr(module, 'long_term_weight'): | |
| module.long_term_weight += self.consolidation_rate * module.weight_fp32.data | |
| module.weight_fp32.data *= (1 - self.forget_rate) | |
| module.weight_fp32.data += self.consolidation_rate * module.long_term_weight | |
| self._update_ternary() | |
| class Plasticity(nn.Module): | |
| def __init__(self, n_neurons: int = 128): | |
| super().__init__() | |
| self.n_neurons = n_neurons | |
| self.w = torch.zeros(n_neurons, n_neurons) | |
| self.long_term_w = torch.zeros(n_neurons, n_neurons) | |
| self.lr = 0.01 | |
| self.consolidation_rate = 0.01 | |
| self.forget_rate = 0.1 | |
| self.acc_pre = torch.zeros(n_neurons) | |
| self.acc_post = torch.zeros(n_neurons) | |
| self.threshold = 0.5 | |
| self.bcm_theta = torch.zeros(n_neurons) | |
| self.lr_bcm = 0.001 | |
| self.target_activity = 0.5 | |
| self.step_count = 0 | |
| self.dopamine_trace = 0.0 | |
| self.dopamine_decay = 0.9 |