Andrewstivan commited on
Commit
4e0e690
·
verified ·
1 Parent(s): 627aea7

Create plasticity.py

Browse files
Files changed (1) hide show
  1. plasticity.py +82 -0
plasticity.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # plasticity.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import List, Optional
5
+
6
+ class UnifiedPlasticity(nn.Module):
7
+ def __init__(self, modules: List[nn.Module], lr: float = 0.01,
8
+ consolidation_rate: float = 0.01, forget_rate: float = 0.1):
9
+ super().__init__()
10
+ self.modules = modules
11
+ self.lr = lr
12
+ self.consolidation_rate = consolidation_rate
13
+ self.forget_rate = forget_rate
14
+ self.spike_traces = {}
15
+ self.dopamine_trace = 0.0
16
+ self.dopamine_decay = 0.9
17
+ self.bcm_thresholds = {}
18
+ self.step_count = 0
19
+
20
+ def step(self, layer_idx: int, pre_activation: torch.Tensor,
21
+ post_activation: torch.Tensor, reward: float = 0.0,
22
+ importance: float = 1.0):
23
+ if layer_idx >= len(self.modules):
24
+ return
25
+ module = self.modules[layer_idx]
26
+ self.dopamine_trace = reward + self.dopamine_decay * self.dopamine_trace
27
+ modulation = 1.0 + self.dopamine_trace * importance
28
+ pre_spikes = self._compute_spikes(pre_activation, f"pre_{layer_idx}")
29
+ post_spikes = self._compute_spikes(post_activation, f"post_{layer_idx}")
30
+ if hasattr(module, 'weight_fp32'):
31
+ delta_hebb = self.lr * modulation * torch.einsum('bi,bj->ij', pre_spikes, post_spikes)
32
+ if delta_hebb.shape == module.weight_fp32.shape[-2:]:
33
+ module.weight_fp32.data += delta_hebb
34
+ elif hasattr(module, 'n_head'):
35
+ delta_hebb = delta_hebb.unsqueeze(0).expand(module.n_head, -1, -1)
36
+ if delta_hebb.shape == module.weight_fp32.shape:
37
+ module.weight_fp32.data += delta_hebb
38
+ self.step_count += 1
39
+ if self.step_count % 100 == 0:
40
+ self._consolidate_all()
41
+
42
+ def _compute_spikes(self, activation: torch.Tensor, trace_id: str) -> torch.Tensor:
43
+ if trace_id not in self.spike_traces:
44
+ self.spike_traces[trace_id] = torch.zeros(activation.shape[-1], device=activation.device)
45
+ trace = self.spike_traces[trace_id]
46
+ trace += activation.mean(dim=0)
47
+ spikes = (trace >= 0.5).float()
48
+ trace -= spikes * 0.5
49
+ self.spike_traces[trace_id] = trace
50
+ return spikes
51
+
52
+ def _update_ternary(self):
53
+ for module in self.modules:
54
+ if hasattr(module, 'update_ternary_weights'):
55
+ module.update_ternary_weights()
56
+
57
+ def _consolidate_all(self):
58
+ for module in self.modules:
59
+ if hasattr(module, 'long_term_weight'):
60
+ module.long_term_weight += self.consolidation_rate * module.weight_fp32.data
61
+ module.weight_fp32.data *= (1 - self.forget_rate)
62
+ module.weight_fp32.data += self.consolidation_rate * module.long_term_weight
63
+ self._update_ternary()
64
+
65
+ class Plasticity(nn.Module):
66
+ def __init__(self, n_neurons: int = 128):
67
+ super().__init__()
68
+ self.n_neurons = n_neurons
69
+ self.w = torch.zeros(n_neurons, n_neurons)
70
+ self.long_term_w = torch.zeros(n_neurons, n_neurons)
71
+ self.lr = 0.01
72
+ self.consolidation_rate = 0.01
73
+ self.forget_rate = 0.1
74
+ self.acc_pre = torch.zeros(n_neurons)
75
+ self.acc_post = torch.zeros(n_neurons)
76
+ self.threshold = 0.5
77
+ self.bcm_theta = torch.zeros(n_neurons)
78
+ self.lr_bcm = 0.001
79
+ self.target_activity = 0.5
80
+ self.step_count = 0
81
+ self.dopamine_trace = 0.0
82
+ self.dopamine_decay = 0.9