| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | class OLoRALinear(nn.Module): |
| | """ |
| | Phase 1: The Orthogonal LoRA Wrapper. |
| | Creates an isolated parallel highway for the memory gate. |
| | """ |
| | def __init__(self, base_layer: nn.Linear, rank: int = 8, alpha: float = 16.0): |
| | super().__init__() |
| | |
| | |
| | self.base_layer = base_layer |
| | self.base_layer.weight.requires_grad = False |
| | |
| | in_features = base_layer.in_features |
| | out_features = base_layer.out_features |
| | |
| | self.rank = rank |
| | self.scaling = alpha / rank |
| | |
| | |
| | self.lora_A = nn.Linear(in_features, rank, bias=False) |
| | self.lora_B = nn.Linear(rank, out_features, bias=False) |
| | |
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | |
| | nn.init.orthogonal_(self.lora_A.weight) |
| | |
| | |
| | nn.init.zeros_(self.lora_B.weight) |
| |
|
| | def get_orthogonal_penalty(self): |
| | """ |
| | The OTITANS Shield. |
| | Calculates how much the new memory weights overlap with the frozen base weights. |
| | We will add this to our loss later to force the memory into empty dimensions. |
| | """ |
| | |
| | delta_W = self.lora_B.weight @ self.lora_A.weight |
| | |
| | |
| | |
| | base_flat = self.base_layer.weight.view(-1) |
| | delta_flat = delta_W.view(-1) |
| | |
| | |
| | penalty = torch.abs(F.cosine_similarity(base_flat, delta_flat, dim=0)) |
| | return penalty |
| |
|
| | def forward(self, x: torch.Tensor): |
| | |
| | base_output = self.base_layer(x) |
| | |
| | |
| | lora_output = self.lora_B(self.lora_A(x)) * self.scaling |
| | |
| | |
| | return base_output + lora_output |
| |
|