| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from gguf import GGUFWriter
|
| |
|
| | class ModelConfig:
|
| | def __init__(self):
|
| |
|
| | self.vocab_size = 32000
|
| | self.hidden_size = 768
|
| | self.num_hidden_layers = 4
|
| | self.num_attention_heads = 8
|
| | self.intermediate_size = 3072
|
| |
|
| |
|
| | self.num_experts = 4
|
| |
|
| |
|
| | self.chunk_size = 256
|
| | self.compression_ratio = 4
|
| |
|
| |
|
| | self.max_graph_nodes = 512
|
| | self.node_dim = self.hidden_size // 4
|
| |
|
| |
|
| | self.hidden_dropout_prob = 0.1
|
| | self.initializer_range = 0.02
|
| |
|
| | class CoATGraphManager(nn.Module):
|
| | def __init__(self, config):
|
| | super().__init__()
|
| | self.config = config
|
| |
|
| |
|
| | self.base_nodes = nn.Parameter(torch.randn(config.max_graph_nodes, config.node_dim))
|
| | self.projection = nn.Linear(config.hidden_size, config.node_dim)
|
| |
|
| |
|
| | self.update_gate = nn.Sequential(
|
| | nn.Linear(config.node_dim * 2, config.hidden_size),
|
| | nn.GELU(),
|
| | nn.Linear(config.hidden_size, 1),
|
| | nn.Sigmoid()
|
| | )
|
| |
|
| | def forward(self, hidden_states, current_nodes):
|
| | batch_size = hidden_states.size(0)
|
| |
|
| |
|
| | current_nodes = current_nodes.clone()
|
| |
|
| |
|
| | seq_aggregated = hidden_states.mean(dim=1)
|
| |
|
| |
|
| | projected = self.projection(seq_aggregated)
|
| |
|
| |
|
| | similarity = torch.matmul(projected.unsqueeze(1), current_nodes.transpose(1, 2))
|
| |
|
| |
|
| | _, topk_indices = torch.topk(similarity.squeeze(1), k=2, dim=-1)
|
| |
|
| |
|
| | selected_nodes = torch.gather(
|
| | current_nodes,
|
| | 1,
|
| | topk_indices.unsqueeze(-1).expand(-1, -1, self.config.node_dim)
|
| | )
|
| |
|
| |
|
| | combined = torch.cat([
|
| | selected_nodes,
|
| | self.base_nodes[topk_indices]
|
| | ], dim=-1)
|
| | update_weights = self.update_gate(combined)
|
| | updated_nodes = selected_nodes * update_weights + self.base_nodes[topk_indices] * (1 - update_weights)
|
| |
|
| |
|
| | current_nodes.scatter_(
|
| | 1,
|
| | topk_indices.unsqueeze(-1).expand(-1, -1, self.config.node_dim),
|
| | updated_nodes
|
| | )
|
| |
|
| | return current_nodes
|
| |
|
| | class ChunkKVAttention(nn.Module):
|
| | def __init__(self, config):
|
| | super().__init__()
|
| | self.config = config
|
| | self.head_dim = config.hidden_size // config.num_attention_heads
|
| |
|
| |
|
| | self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| | self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| | self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| |
|
| |
|
| | self.k_compress = nn.Linear(config.chunk_size, config.chunk_size//config.compression_ratio)
|
| | self.v_compress = nn.Linear(config.chunk_size, config.chunk_size//config.compression_ratio)
|
| |
|
| | def forward(self, hidden_states):
|
| | batch_size, seq_len, _ = hidden_states.size()
|
| |
|
| |
|
| | q = self.q_proj(hidden_states)
|
| |
|
| |
|
| | k = self._process_chunk(self.k_proj, self.k_compress, hidden_states)
|
| | v = self._process_chunk(self.v_proj, self.v_compress, hidden_states)
|
| |
|
| |
|
| | q = q.view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
|
| | k = k.view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
|
| | v = v.view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
|
| |
|
| |
|
| | attn = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))
|
| | attn = F.softmax(attn, dim=-1)
|
| | output = torch.matmul(attn, v)
|
| |
|
| | return output.transpose(1, 2).flatten(2)
|
| |
|
| | def _process_chunk(self, proj, compress, x):
|
| | chunks = []
|
| | for i in range(0, x.size(1), self.config.chunk_size):
|
| | chunk = proj(x[:, i:i+self.config.chunk_size])
|
| | compressed = compress(chunk.transpose(1, 2)).transpose(1, 2)
|
| | chunks.append(compressed)
|
| | return torch.cat(chunks, dim=1)
|
| |
|
| | class SelfMoA(nn.Module):
|
| | def __init__(self, config):
|
| | super().__init__()
|
| | self.experts = nn.ModuleList([
|
| | nn.Sequential(
|
| | nn.Linear(config.hidden_size, config.intermediate_size),
|
| | nn.GELU(),
|
| | nn.Linear(config.intermediate_size, config.hidden_size)
|
| | ) for _ in range(config.num_experts)
|
| | ])
|
| | self.gate = nn.Linear(config.hidden_size, config.num_experts)
|
| |
|
| | def forward(self, x):
|
| | gate = F.gumbel_softmax(self.gate(x), hard=True, dim=-1)
|
| | return sum(expert(x) * gate[..., i].unsqueeze(-1) for i, expert in enumerate(self.experts))
|
| |
|
| | class DeepSeekLiteBlock(nn.Module):
|
| | def __init__(self, config):
|
| | super().__init__()
|
| | self.attention = ChunkKVAttention(config)
|
| | self.moa = SelfMoA(config)
|
| | self.coat = CoATGraphManager(config)
|
| | self.norm = nn.LayerNorm(config.hidden_size)
|
| | self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| |
|
| | def forward(self, x, nodes):
|
| |
|
| | attn_out = self.attention(self.norm(x))
|
| | x = x + self.dropout(attn_out)
|
| |
|
| |
|
| | updated_nodes = self.coat(x, nodes)
|
| |
|
| |
|
| | moa_out = self.moa(self.norm(x))
|
| | x = x + self.dropout(moa_out)
|
| |
|
| | return x, updated_nodes
|
| |
|
| | class DeepSeekLite(nn.Module):
|
| | def __init__(self, config):
|
| | super().__init__()
|
| | self.config = config
|
| | self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
| | self.layers = nn.ModuleList([DeepSeekLiteBlock(config) for _ in range(config.num_hidden_layers)])
|
| | self.final_norm = nn.LayerNorm(config.hidden_size)
|
| |
|
| |
|
| | self.graph_nodes = nn.ParameterList([
|
| | nn.Parameter(torch.randn(config.max_graph_nodes, config.node_dim).clone().detach().requires_grad_(True))
|
| | for _ in range(config.num_hidden_layers)
|
| | ])
|
| |
|
| | def forward(self, input_ids):
|
| | x = self.embedding(input_ids)
|
| | batch_size = input_ids.size(0)
|
| |
|
| | for layer_idx, layer in enumerate(self.layers):
|
| |
|
| | nodes = self.graph_nodes[layer_idx].unsqueeze(0).expand(batch_size, -1, -1).clone()
|
| | x, _ = layer(x, nodes)
|
| |
|
| | return self.final_norm(x)
|
| |
|
| | def save_gguf(model, filename):
|
| | writer = GGUFWriter(filename, "deepseek-lite")
|
| |
|
| |
|
| | writer.add_uint32("vocab_size", model.config.vocab_size)
|
| | writer.add_uint32("hidden_size", model.config.hidden_size)
|
| | writer.add_uint32("num_hidden_layers", model.config.num_hidden_layers)
|
| | writer.add_uint32("num_attention_heads", model.config.num_attention_heads)
|
| | writer.add_uint32("num_experts", model.config.num_experts)
|
| | writer.add_uint32("max_graph_nodes", model.config.max_graph_nodes)
|
| |
|
| |
|
| | for name, param in model.named_parameters():
|
| | writer.add_tensor(name, param.detach().cpu().numpy())
|
| |
|
| | writer.write_header_to_file()
|
| | writer.write_kv_data_to_file()
|
| | writer.write_tensors_to_file()
|
| | writer.close()
|
| |
|
| | if __name__ == "__main__":
|
| | config = ModelConfig()
|
| | model = DeepSeekLite(config)
|
| |
|
| |
|
| | inputs = torch.randint(0, config.vocab_size, (2, 1024))
|
| | with torch.no_grad():
|
| | outputs = model(inputs)
|
| | print(f"Successful execution! Output shape: {outputs.shape}")
|
| | print(f"Parameter count: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")
|
| |
|
| |
|
| | save_gguf(model, "deepseek-lite.gguf")
|
| | print("Model saved in GGUF format") |