File size: 3,509 Bytes
4e0e690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# plasticity.py
import torch
import torch.nn as nn
from typing import List, Optional

class UnifiedPlasticity(nn.Module):
    def __init__(self, modules: List[nn.Module], lr: float = 0.01,
                 consolidation_rate: float = 0.01, forget_rate: float = 0.1):
        super().__init__()
        self.modules = modules
        self.lr = lr
        self.consolidation_rate = consolidation_rate
        self.forget_rate = forget_rate
        self.spike_traces = {}
        self.dopamine_trace = 0.0
        self.dopamine_decay = 0.9
        self.bcm_thresholds = {}
        self.step_count = 0
    
    def step(self, layer_idx: int, pre_activation: torch.Tensor,
             post_activation: torch.Tensor, reward: float = 0.0,
             importance: float = 1.0):
        if layer_idx >= len(self.modules):
            return
        module = self.modules[layer_idx]
        self.dopamine_trace = reward + self.dopamine_decay * self.dopamine_trace
        modulation = 1.0 + self.dopamine_trace * importance
        pre_spikes = self._compute_spikes(pre_activation, f"pre_{layer_idx}")
        post_spikes = self._compute_spikes(post_activation, f"post_{layer_idx}")
        if hasattr(module, 'weight_fp32'):
            delta_hebb = self.lr * modulation * torch.einsum('bi,bj->ij', pre_spikes, post_spikes)
            if delta_hebb.shape == module.weight_fp32.shape[-2:]:
                module.weight_fp32.data += delta_hebb
            elif hasattr(module, 'n_head'):
                delta_hebb = delta_hebb.unsqueeze(0).expand(module.n_head, -1, -1)
                if delta_hebb.shape == module.weight_fp32.shape:
                    module.weight_fp32.data += delta_hebb
        self.step_count += 1
        if self.step_count % 100 == 0:
            self._consolidate_all()
    
    def _compute_spikes(self, activation: torch.Tensor, trace_id: str) -> torch.Tensor:
        if trace_id not in self.spike_traces:
            self.spike_traces[trace_id] = torch.zeros(activation.shape[-1], device=activation.device)
        trace = self.spike_traces[trace_id]
        trace += activation.mean(dim=0)
        spikes = (trace >= 0.5).float()
        trace -= spikes * 0.5
        self.spike_traces[trace_id] = trace
        return spikes
    
    def _update_ternary(self):
        for module in self.modules:
            if hasattr(module, 'update_ternary_weights'):
                module.update_ternary_weights()
    
    def _consolidate_all(self):
        for module in self.modules:
            if hasattr(module, 'long_term_weight'):
                module.long_term_weight += self.consolidation_rate * module.weight_fp32.data
                module.weight_fp32.data *= (1 - self.forget_rate)
                module.weight_fp32.data += self.consolidation_rate * module.long_term_weight
        self._update_ternary()

class Plasticity(nn.Module):
    def __init__(self, n_neurons: int = 128):
        super().__init__()
        self.n_neurons = n_neurons
        self.w = torch.zeros(n_neurons, n_neurons)
        self.long_term_w = torch.zeros(n_neurons, n_neurons)
        self.lr = 0.01
        self.consolidation_rate = 0.01
        self.forget_rate = 0.1
        self.acc_pre = torch.zeros(n_neurons)
        self.acc_post = torch.zeros(n_neurons)
        self.threshold = 0.5
        self.bcm_theta = torch.zeros(n_neurons)
        self.lr_bcm = 0.001
        self.target_activity = 0.5
        self.step_count = 0
        self.dopamine_trace = 0.0
        self.dopamine_decay = 0.9