Spaces:
Sleeping
Sleeping
File size: 3,509 Bytes
4e0e690 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | # plasticity.py
import torch
import torch.nn as nn
from typing import List, Optional
class UnifiedPlasticity(nn.Module):
def __init__(self, modules: List[nn.Module], lr: float = 0.01,
consolidation_rate: float = 0.01, forget_rate: float = 0.1):
super().__init__()
self.modules = modules
self.lr = lr
self.consolidation_rate = consolidation_rate
self.forget_rate = forget_rate
self.spike_traces = {}
self.dopamine_trace = 0.0
self.dopamine_decay = 0.9
self.bcm_thresholds = {}
self.step_count = 0
def step(self, layer_idx: int, pre_activation: torch.Tensor,
post_activation: torch.Tensor, reward: float = 0.0,
importance: float = 1.0):
if layer_idx >= len(self.modules):
return
module = self.modules[layer_idx]
self.dopamine_trace = reward + self.dopamine_decay * self.dopamine_trace
modulation = 1.0 + self.dopamine_trace * importance
pre_spikes = self._compute_spikes(pre_activation, f"pre_{layer_idx}")
post_spikes = self._compute_spikes(post_activation, f"post_{layer_idx}")
if hasattr(module, 'weight_fp32'):
delta_hebb = self.lr * modulation * torch.einsum('bi,bj->ij', pre_spikes, post_spikes)
if delta_hebb.shape == module.weight_fp32.shape[-2:]:
module.weight_fp32.data += delta_hebb
elif hasattr(module, 'n_head'):
delta_hebb = delta_hebb.unsqueeze(0).expand(module.n_head, -1, -1)
if delta_hebb.shape == module.weight_fp32.shape:
module.weight_fp32.data += delta_hebb
self.step_count += 1
if self.step_count % 100 == 0:
self._consolidate_all()
def _compute_spikes(self, activation: torch.Tensor, trace_id: str) -> torch.Tensor:
if trace_id not in self.spike_traces:
self.spike_traces[trace_id] = torch.zeros(activation.shape[-1], device=activation.device)
trace = self.spike_traces[trace_id]
trace += activation.mean(dim=0)
spikes = (trace >= 0.5).float()
trace -= spikes * 0.5
self.spike_traces[trace_id] = trace
return spikes
def _update_ternary(self):
for module in self.modules:
if hasattr(module, 'update_ternary_weights'):
module.update_ternary_weights()
def _consolidate_all(self):
for module in self.modules:
if hasattr(module, 'long_term_weight'):
module.long_term_weight += self.consolidation_rate * module.weight_fp32.data
module.weight_fp32.data *= (1 - self.forget_rate)
module.weight_fp32.data += self.consolidation_rate * module.long_term_weight
self._update_ternary()
class Plasticity(nn.Module):
def __init__(self, n_neurons: int = 128):
super().__init__()
self.n_neurons = n_neurons
self.w = torch.zeros(n_neurons, n_neurons)
self.long_term_w = torch.zeros(n_neurons, n_neurons)
self.lr = 0.01
self.consolidation_rate = 0.01
self.forget_rate = 0.1
self.acc_pre = torch.zeros(n_neurons)
self.acc_post = torch.zeros(n_neurons)
self.threshold = 0.5
self.bcm_theta = torch.zeros(n_neurons)
self.lr_bcm = 0.001
self.target_activity = 0.5
self.step_count = 0
self.dopamine_trace = 0.0
self.dopamine_decay = 0.9 |