BICORP commited on
Commit
a6e2346
·
verified ·
1 Parent(s): 75af758

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +231 -0
main.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from gguf import GGUFWriter
5
+
6
+ class ModelConfig:
7
+ def __init__(self):
8
+ # Core parameters
9
+ self.vocab_size = 32000
10
+ self.hidden_size = 768
11
+ self.num_hidden_layers = 4
12
+ self.num_attention_heads = 8
13
+ self.intermediate_size = 3072
14
+
15
+ # Expert parameters
16
+ self.num_experts = 4
17
+
18
+ # Efficiency parameters
19
+ self.chunk_size = 256
20
+ self.compression_ratio = 4
21
+
22
+ # Reasoning parameters
23
+ self.max_graph_nodes = 512
24
+ self.node_dim = self.hidden_size // 4 # 192
25
+
26
+ # Regularization
27
+ self.hidden_dropout_prob = 0.1
28
+ self.initializer_range = 0.02
29
+
30
+ class CoATGraphManager(nn.Module):
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ self.config = config
34
+
35
+ # Node initialization
36
+ self.base_nodes = nn.Parameter(torch.randn(config.max_graph_nodes, config.node_dim))
37
+ self.projection = nn.Linear(config.hidden_size, config.node_dim)
38
+
39
+ # Update mechanism
40
+ self.update_gate = nn.Sequential(
41
+ nn.Linear(config.node_dim * 2, config.hidden_size),
42
+ nn.GELU(),
43
+ nn.Linear(config.hidden_size, 1),
44
+ nn.Sigmoid()
45
+ )
46
+
47
+ def forward(self, hidden_states, current_nodes):
48
+ batch_size = hidden_states.size(0)
49
+
50
+ # Clone to prevent in-place errors
51
+ current_nodes = current_nodes.clone()
52
+
53
+ # Aggregate sequence information
54
+ seq_aggregated = hidden_states.mean(dim=1) # [batch, hidden_size]
55
+
56
+ # Project to node space
57
+ projected = self.projection(seq_aggregated) # [batch, node_dim]
58
+
59
+ # Calculate similarity scores
60
+ similarity = torch.matmul(projected.unsqueeze(1), current_nodes.transpose(1, 2)) # [batch, 1, max_nodes]
61
+
62
+ # Get top-2 nodes
63
+ _, topk_indices = torch.topk(similarity.squeeze(1), k=2, dim=-1) # [batch, 2]
64
+
65
+ # Gather relevant nodes
66
+ selected_nodes = torch.gather(
67
+ current_nodes,
68
+ 1,
69
+ topk_indices.unsqueeze(-1).expand(-1, -1, self.config.node_dim)
70
+ ) # [batch, 2, node_dim]
71
+
72
+ # Calculate updates
73
+ combined = torch.cat([
74
+ selected_nodes,
75
+ self.base_nodes[topk_indices]
76
+ ], dim=-1)
77
+ update_weights = self.update_gate(combined)
78
+ updated_nodes = selected_nodes * update_weights + self.base_nodes[topk_indices] * (1 - update_weights)
79
+
80
+ # Safe scatter update
81
+ current_nodes.scatter_(
82
+ 1,
83
+ topk_indices.unsqueeze(-1).expand(-1, -1, self.config.node_dim),
84
+ updated_nodes
85
+ )
86
+
87
+ return current_nodes
88
+
89
+ class ChunkKVAttention(nn.Module):
90
+ def __init__(self, config):
91
+ super().__init__()
92
+ self.config = config
93
+ self.head_dim = config.hidden_size // config.num_attention_heads
94
+
95
+ # Projections
96
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
97
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
98
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
99
+
100
+ # Compression
101
+ self.k_compress = nn.Linear(config.chunk_size, config.chunk_size//config.compression_ratio)
102
+ self.v_compress = nn.Linear(config.chunk_size, config.chunk_size//config.compression_ratio)
103
+
104
+ def forward(self, hidden_states):
105
+ batch_size, seq_len, _ = hidden_states.size()
106
+
107
+ # Process queries
108
+ q = self.q_proj(hidden_states)
109
+
110
+ # Process keys/values in chunks
111
+ k = self._process_chunk(self.k_proj, self.k_compress, hidden_states)
112
+ v = self._process_chunk(self.v_proj, self.v_compress, hidden_states)
113
+
114
+ # Reshape for attention
115
+ q = q.view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
116
+ k = k.view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
117
+ v = v.view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
118
+
119
+ # Attention calculation
120
+ attn = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))
121
+ attn = F.softmax(attn, dim=-1)
122
+ output = torch.matmul(attn, v)
123
+
124
+ return output.transpose(1, 2).flatten(2)
125
+
126
+ def _process_chunk(self, proj, compress, x):
127
+ chunks = []
128
+ for i in range(0, x.size(1), self.config.chunk_size):
129
+ chunk = proj(x[:, i:i+self.config.chunk_size])
130
+ compressed = compress(chunk.transpose(1, 2)).transpose(1, 2)
131
+ chunks.append(compressed)
132
+ return torch.cat(chunks, dim=1)
133
+
134
+ class SelfMoA(nn.Module):
135
+ def __init__(self, config):
136
+ super().__init__()
137
+ self.experts = nn.ModuleList([
138
+ nn.Sequential(
139
+ nn.Linear(config.hidden_size, config.intermediate_size),
140
+ nn.GELU(),
141
+ nn.Linear(config.intermediate_size, config.hidden_size)
142
+ ) for _ in range(config.num_experts)
143
+ ])
144
+ self.gate = nn.Linear(config.hidden_size, config.num_experts)
145
+
146
+ def forward(self, x):
147
+ gate = F.gumbel_softmax(self.gate(x), hard=True, dim=-1)
148
+ return sum(expert(x) * gate[..., i].unsqueeze(-1) for i, expert in enumerate(self.experts))
149
+
150
+ class DeepSeekLiteBlock(nn.Module):
151
+ def __init__(self, config):
152
+ super().__init__()
153
+ self.attention = ChunkKVAttention(config)
154
+ self.moa = SelfMoA(config)
155
+ self.coat = CoATGraphManager(config)
156
+ self.norm = nn.LayerNorm(config.hidden_size)
157
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
158
+
159
+ def forward(self, x, nodes):
160
+ # Attention path
161
+ attn_out = self.attention(self.norm(x))
162
+ x = x + self.dropout(attn_out)
163
+
164
+ # Update graph nodes
165
+ updated_nodes = self.coat(x, nodes)
166
+
167
+ # MOA path
168
+ moa_out = self.moa(self.norm(x))
169
+ x = x + self.dropout(moa_out)
170
+
171
+ return x, updated_nodes
172
+
173
+ class DeepSeekLite(nn.Module):
174
+ def __init__(self, config):
175
+ super().__init__()
176
+ self.config = config
177
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
178
+ self.layers = nn.ModuleList([DeepSeekLiteBlock(config) for _ in range(config.num_hidden_layers)])
179
+ self.final_norm = nn.LayerNorm(config.hidden_size)
180
+
181
+ # Initialize graph nodes with cloning
182
+ self.graph_nodes = nn.ParameterList([
183
+ nn.Parameter(torch.randn(config.max_graph_nodes, config.node_dim).clone().detach().requires_grad_(True))
184
+ for _ in range(config.num_hidden_layers)
185
+ ])
186
+
187
+ def forward(self, input_ids):
188
+ x = self.embedding(input_ids)
189
+ batch_size = input_ids.size(0)
190
+
191
+ for layer_idx, layer in enumerate(self.layers):
192
+ # Clone and expand nodes for each layer
193
+ nodes = self.graph_nodes[layer_idx].unsqueeze(0).expand(batch_size, -1, -1).clone()
194
+ x, _ = layer(x, nodes)
195
+
196
+ return self.final_norm(x)
197
+
198
+ def save_gguf(model, filename):
199
+ writer = GGUFWriter(filename, "deepseek-lite")
200
+
201
+ # Add model configuration
202
+ writer.add_uint32("vocab_size", model.config.vocab_size)
203
+ writer.add_uint32("hidden_size", model.config.hidden_size)
204
+ writer.add_uint32("num_hidden_layers", model.config.num_hidden_layers)
205
+ writer.add_uint32("num_attention_heads", model.config.num_attention_heads)
206
+ writer.add_uint32("num_experts", model.config.num_experts)
207
+ writer.add_uint32("max_graph_nodes", model.config.max_graph_nodes)
208
+
209
+ # Add all parameters
210
+ for name, param in model.named_parameters():
211
+ writer.add_tensor(name, param.detach().cpu().numpy())
212
+
213
+ writer.write_header_to_file()
214
+ writer.write_kv_data_to_file()
215
+ writer.write_tensors_to_file()
216
+ writer.close()
217
+
218
+ if __name__ == "__main__":
219
+ config = ModelConfig()
220
+ model = DeepSeekLite(config)
221
+
222
+ # Test forward pass
223
+ inputs = torch.randint(0, config.vocab_size, (2, 1024))
224
+ with torch.no_grad():
225
+ outputs = model(inputs)
226
+ print(f"Successful execution! Output shape: {outputs.shape}")
227
+ print(f"Parameter count: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")
228
+
229
+ # Save model
230
+ save_gguf(model, "deepseek-lite.gguf")
231
+ print("Model saved in GGUF format")