File size: 6,837 Bytes
649d0bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# --- START OF FILE architecture.py ---

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Phi3Config, Phi3ForCausalLM
from typing import Optional, Dict

# --- BUILDING BLOCK 1: VectorMemoryHead (No changes needed here, it inherits dtype correctly) ---
class VectorMemoryHead(nn.Module):
    def __init__(self, hidden_dim: int, num_memory_slots: int, num_heads: int, ff_dim: int, device=None, dtype=None):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_memory_slots = num_memory_slots
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=0.1, batch_first=True,
            device=device, dtype=dtype
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.memory_queries = nn.Parameter(torch.randn(1, num_memory_slots, hidden_dim, device=device, dtype=dtype))
        self.memory_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True,
            device=device, dtype=dtype
        )
        self.memory_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
        self.decoder_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True,
            device=device, dtype=dtype
        )
        self.decoder_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
        self.decoder_ffn = nn.Sequential(
            nn.Linear(hidden_dim, ff_dim, device=device, dtype=dtype),
            nn.ReLU(),
            nn.Linear(ff_dim, hidden_dim, device=device, dtype=dtype)
        )

    def forward(self, memory_input_sequence: torch.Tensor):
        batch_size = memory_input_sequence.shape[0]
        encoded_vectors = self.encoder(memory_input_sequence)
        queries = self.memory_queries.expand(batch_size, -1, -1)
        compressed_memory, _ = self.memory_attention(query=queries, key=encoded_vectors, value=encoded_vectors)
        compressed_memory = self.memory_layernorm(compressed_memory + queries)
        reconstructed, _ = self.decoder_attention(query=encoded_vectors, key=compressed_memory, value=compressed_memory)
        reconstructed_vectors = self.decoder_layernorm(reconstructed + encoded_vectors)
        reconstructed_vectors = self.decoder_ffn(reconstructed_vectors)
        return compressed_memory, reconstructed_vectors

# --- BUILDING BLOCK 2: The Custom Layer (Corrected for dtype consistency) ---
class GCVectorMemoryLayer(nn.Module):
    def __init__(self, original_layer: nn.Linear, global_input_dim: int,
                 memory_dim: int, num_memory_slots: int, memory_num_heads: int,
                 global_state_storage: Dict):
        super().__init__()
        self.input_dim = original_layer.in_features
        self.output_dim = original_layer.out_features
        self.memory_dim = memory_dim
        self.global_state_storage = global_state_storage
        self.linear = original_layer

        device, dtype = self.linear.weight.device, self.linear.weight.dtype

        # This part is correct: initialize with the correct dtype
        self.local_state_proj = nn.Linear(self.input_dim, memory_dim, device=device, dtype=dtype)
        self.global_state_proj = nn.Linear(global_input_dim, memory_dim, device=device, dtype=dtype)
        self.memory_head = VectorMemoryHead(
            hidden_dim=memory_dim, num_memory_slots=num_memory_slots,
            num_heads=memory_num_heads, ff_dim=memory_dim * 2, device=device, dtype=dtype
        )
        self.correction_head = nn.Linear(memory_dim, 2 * self.output_dim, device=device, dtype=dtype)

        self.last_corrected_activation: Optional[torch.Tensor] = None
        self.last_additive_correction: Optional[torch.Tensor] = None
        self.last_memory_input: Optional[torch.Tensor] = None
        self.last_reconstructed_from_memory: Optional[torch.Tensor] = None

    def forward(self, x: torch.Tensor):
        base_output = self.linear(x)
        if 'embeds' not in self.global_state_storage: return base_output
        global_embeds = self.global_state_storage['embeds']
        if global_embeds.shape[1] != x.shape[1]: global_embeds = global_embeds[:, -x.shape[1]:, :]
        B, S, _ = x.shape

        # THE DEFINITIVE FIX: REMOVE ALL HARD-CODED DTYPE CASTING
        # Let the layers operate on the native dtype (bfloat16) they receive.
        with torch.enable_grad():
            proj_local = self.local_state_proj(x)
            proj_global = self.global_state_proj(global_embeds)

            memory_input = torch.stack([proj_global, proj_local], dim=2)
            memory_input_flat = memory_input.view(B * S, 2, self.memory_dim)
            compressed_mem_flat, recon_flat = self.memory_head(memory_input_flat)
            aggregated_thought_flat = compressed_mem_flat.mean(dim=1)
            aggregated_thought = aggregated_thought_flat.view(B, S, self.memory_dim)
            raw_correction = self.correction_head(aggregated_thought)
            gate, value = torch.chunk(raw_correction, 2, dim=-1)
            corrected_activation = base_output * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype)

        if self.training:
            self.last_corrected_activation, self.last_additive_correction = corrected_activation, value
            self.last_memory_input, self.last_reconstructed_from_memory = memory_input_flat, recon_flat
        return corrected_activation

# --- BUILDING BLOCK 3: The Full Custom Model ---
class Phi3WithVectorMemoryForCausalLM(Phi3ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.global_state_storage = {}
        self.target_layer_path = "model.layers.15.mlp.gate_up_proj"

        self.model.embed_tokens.register_forward_hook(
            lambda module, input, output: self.global_state_storage.update({'embeds': output.detach()})
        )

        try:
            original_layer = self.get_submodule(self.target_layer_path)
            custom_layer = GCVectorMemoryLayer(
                original_layer=original_layer, global_input_dim=config.hidden_size,
                memory_dim=64, num_memory_slots=8, memory_num_heads=4,
                global_state_storage=self.global_state_storage
            )
            parent_path = ".".join(self.target_layer_path.split('.')[:-1])
            child_name = self.target_layer_path.split('.')[-1]
            setattr(self.get_submodule(parent_path), child_name, custom_layer)
            print(f"Successfully replaced '{self.target_layer_path}' with GCVectorMemoryLayer.")
        except AttributeError:
            print(f"Could not find target layer '{self.target_layer_path}'. Model remains unmodified.")
# --- END OF FILE architecture.py ---