File size: 3,247 Bytes
8125804
 
f37be5a
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f37be5a
 
 
 
 
 
 
 
 
 
 
 
 
8125804
 
 
 
f37be5a
 
 
 
8125804
f37be5a
8125804
 
 
 
f37be5a
8125804
 
 
 
 
 
 
f37be5a
 
 
 
 
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.config import ModelConfig


class PlasticLayer(nn.Module):
    """Short-term plastic adapter on top of stable backbone.
    - Gradient-based update with small LR
    - L2 regularization toward initial state
    - Exponential decay of adapter weights
    Analogy: hippocampal fast learning (complementary learning systems).
    """

    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.cfg = cfg
        self.adapter = nn.Sequential(
            nn.Linear(cfg.d_model, cfg.plastic_hidden),
            nn.GELU(),
            nn.Linear(cfg.plastic_hidden, cfg.d_model),
        )
        self.initial_state: dict[str, torch.Tensor] = {}
        self._save_initial_state()

    def _save_initial_state(self):
        self.initial_state = {
            n: p.data.clone() for n, p in self.adapter.named_parameters()
        }

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.adapter(x)

    def _corrupt(self, x: torch.Tensor) -> torch.Tensor:
        noisy = x.detach()
        if self.cfg.plastic_mask_ratio > 0.0:
            keep = (
                torch.rand(*x.shape[:-1], 1, device=x.device) > self.cfg.plastic_mask_ratio
            ).to(x.dtype)
            noisy = noisy * keep
        if self.cfg.plastic_noise_scale > 0.0:
            noisy = noisy + self.cfg.plastic_noise_scale * torch.randn_like(noisy)
        return noisy

    def adapt_step(self, x: torch.Tensor, lr: float | None = None) -> dict[str, float]:
        """One gradient step of denoising-style self-supervised adaptation."""
        if lr is None:
            lr = self.cfg.plastic_lr

        self.adapter.train()
        self.adapter.zero_grad(set_to_none=True)
        target = x.detach()
        corrupted = self._corrupt(target)
        reconstructed = corrupted + self.adapter(corrupted)

        denoise_loss = F.mse_loss(reconstructed, target)

        l2_loss = torch.tensor(0.0, device=x.device)
        for n, p in self.adapter.named_parameters():
            l2_loss = l2_loss + ((p - self.initial_state[n].to(p.device)) ** 2).mean()
        loss = denoise_loss + self.cfg.plastic_l2_weight * l2_loss

        loss.backward()
        with torch.no_grad():
            for p in self.adapter.parameters():
                if p.grad is not None:
                    p.add_(p.grad, alpha=-lr)
                    p.grad.zero_()
        return {
            "loss": float(loss.item()),
            "denoise_loss": float(denoise_loss.item()),
            "l2_loss": float(l2_loss.item()),
        }

    def apply_decay(self, decay_rate: float | None = None):
        """Exponential decay toward initial state."""
        if decay_rate is None:
            decay_rate = self.cfg.plastic_decay
        with torch.no_grad():
            for n, p in self.adapter.named_parameters():
                init = self.initial_state[n].to(p.device)
                p.data.mul_(decay_rate).add_(init, alpha=1.0 - decay_rate)

    def reset(self):
        """Reset adapter to initial state."""
        with torch.no_grad():
            for n, p in self.adapter.named_parameters():
                p.data.copy_(self.initial_state[n])