Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from src.model.config import ModelConfig | |
| class PlasticLayer(nn.Module): | |
| """Short-term plastic adapter on top of stable backbone. | |
| - Gradient-based update with small LR | |
| - L2 regularization toward initial state | |
| - Exponential decay of adapter weights | |
| Analogy: hippocampal fast learning (complementary learning systems). | |
| """ | |
| def __init__(self, cfg: ModelConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.adapter = nn.Sequential( | |
| nn.Linear(cfg.d_model, cfg.plastic_hidden), | |
| nn.GELU(), | |
| nn.Linear(cfg.plastic_hidden, cfg.d_model), | |
| ) | |
| self.initial_state: dict[str, torch.Tensor] = {} | |
| self._save_initial_state() | |
| def _save_initial_state(self): | |
| self.initial_state = { | |
| n: p.data.clone() for n, p in self.adapter.named_parameters() | |
| } | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x + self.adapter(x) | |
| def _corrupt(self, x: torch.Tensor) -> torch.Tensor: | |
| noisy = x.detach() | |
| if self.cfg.plastic_mask_ratio > 0.0: | |
| keep = ( | |
| torch.rand(*x.shape[:-1], 1, device=x.device) > self.cfg.plastic_mask_ratio | |
| ).to(x.dtype) | |
| noisy = noisy * keep | |
| if self.cfg.plastic_noise_scale > 0.0: | |
| noisy = noisy + self.cfg.plastic_noise_scale * torch.randn_like(noisy) | |
| return noisy | |
| def adapt_step(self, x: torch.Tensor, lr: float | None = None) -> dict[str, float]: | |
| """One gradient step of denoising-style self-supervised adaptation.""" | |
| if lr is None: | |
| lr = self.cfg.plastic_lr | |
| self.adapter.train() | |
| self.adapter.zero_grad(set_to_none=True) | |
| target = x.detach() | |
| corrupted = self._corrupt(target) | |
| reconstructed = corrupted + self.adapter(corrupted) | |
| denoise_loss = F.mse_loss(reconstructed, target) | |
| l2_loss = torch.tensor(0.0, device=x.device) | |
| for n, p in self.adapter.named_parameters(): | |
| l2_loss = l2_loss + ((p - self.initial_state[n].to(p.device)) ** 2).mean() | |
| loss = denoise_loss + self.cfg.plastic_l2_weight * l2_loss | |
| loss.backward() | |
| with torch.no_grad(): | |
| for p in self.adapter.parameters(): | |
| if p.grad is not None: | |
| p.add_(p.grad, alpha=-lr) | |
| p.grad.zero_() | |
| return { | |
| "loss": float(loss.item()), | |
| "denoise_loss": float(denoise_loss.item()), | |
| "l2_loss": float(l2_loss.item()), | |
| } | |
| def apply_decay(self, decay_rate: float | None = None): | |
| """Exponential decay toward initial state.""" | |
| if decay_rate is None: | |
| decay_rate = self.cfg.plastic_decay | |
| with torch.no_grad(): | |
| for n, p in self.adapter.named_parameters(): | |
| init = self.initial_state[n].to(p.device) | |
| p.data.mul_(decay_rate).add_(init, alpha=1.0 - decay_rate) | |
| def reset(self): | |
| """Reset adapter to initial state.""" | |
| with torch.no_grad(): | |
| for n, p in self.adapter.named_parameters(): | |
| p.data.copy_(self.initial_state[n]) | |