File size: 2,346 Bytes
9659b2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
import torch.nn.functional as F

class OLoRALinear(nn.Module):
    """
    Phase 1: The Orthogonal LoRA Wrapper.
    Creates an isolated parallel highway for the memory gate.
    """
    def __init__(self, base_layer: nn.Linear, rank: int = 8, alpha: float = 16.0):
        super().__init__()
        
        # 1. The Base Freeze
        self.base_layer = base_layer
        self.base_layer.weight.requires_grad = False
        
        in_features = base_layer.in_features
        out_features = base_layer.out_features
        
        self.rank = rank
        self.scaling = alpha / rank
        
        # 2. The Isolated Memory Matrices
        self.lora_A = nn.Linear(in_features, rank, bias=False)
        self.lora_B = nn.Linear(rank, out_features, bias=False)
        
        self.reset_parameters()

    def reset_parameters(self):
        # Initialize A with strict mathematical orthogonality
        nn.init.orthogonal_(self.lora_A.weight)
        
        # Initialize B as zero so the module starts completely invisible to the network
        nn.init.zeros_(self.lora_B.weight)

    def get_orthogonal_penalty(self):
        """
        The OTITANS Shield. 
        Calculates how much the new memory weights overlap with the frozen base weights.
        We will add this to our loss later to force the memory into empty dimensions.
        """
        # Calculate the full Delta W matrix (B * A)
        delta_W = self.lora_B.weight @ self.lora_A.weight
        
        # Calculate cosine similarity between the base weights and the new memory weights
        # We flatten them to 1D to compare their overall directional vectors
        base_flat = self.base_layer.weight.view(-1)
        delta_flat = delta_W.view(-1)
        
        # The penalty is the absolute cosine similarity (0 = perfectly orthogonal, 1 = total overlap)
        penalty = torch.abs(F.cosine_similarity(base_flat, delta_flat, dim=0))
        return penalty

    def forward(self, x: torch.Tensor):
        # Pass 1: The frozen Gemma English syntax
        base_output = self.base_layer(x)
        
        # Pass 2: The parallel OTITANS memory logic
        lora_output = self.lora_B(self.lora_A(x)) * self.scaling
        
        # Seamlessly merge the two isolated highways
        return base_output + lora_output