Test-3 / generator.py
BICORP's picture
Rename main.py to generator.py
5e1243f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from gguf import GGUFWriter
class ModelConfig:
def __init__(self):
# Core parameters
self.vocab_size = 32000
self.hidden_size = 768
self.num_hidden_layers = 4
self.num_attention_heads = 8
self.intermediate_size = 3072
# Expert parameters
self.num_experts = 4
# Efficiency parameters
self.chunk_size = 256
self.compression_ratio = 4
# Reasoning parameters
self.max_graph_nodes = 512
self.node_dim = self.hidden_size // 4 # 192
# Regularization
self.hidden_dropout_prob = 0.1
self.initializer_range = 0.02
class CoATGraphManager(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Node initialization
self.base_nodes = nn.Parameter(torch.randn(config.max_graph_nodes, config.node_dim))
self.projection = nn.Linear(config.hidden_size, config.node_dim)
# Update mechanism
self.update_gate = nn.Sequential(
nn.Linear(config.node_dim * 2, config.hidden_size),
nn.GELU(),
nn.Linear(config.hidden_size, 1),
nn.Sigmoid()
)
def forward(self, hidden_states, current_nodes):
batch_size = hidden_states.size(0)
# Clone to prevent in-place errors
current_nodes = current_nodes.clone()
# Aggregate sequence information
seq_aggregated = hidden_states.mean(dim=1) # [batch, hidden_size]
# Project to node space
projected = self.projection(seq_aggregated) # [batch, node_dim]
# Calculate similarity scores
similarity = torch.matmul(projected.unsqueeze(1), current_nodes.transpose(1, 2)) # [batch, 1, max_nodes]
# Get top-2 nodes
_, topk_indices = torch.topk(similarity.squeeze(1), k=2, dim=-1) # [batch, 2]
# Gather relevant nodes
selected_nodes = torch.gather(
current_nodes,
1,
topk_indices.unsqueeze(-1).expand(-1, -1, self.config.node_dim)
) # [batch, 2, node_dim]
# Calculate updates
combined = torch.cat([
selected_nodes,
self.base_nodes[topk_indices]
], dim=-1)
update_weights = self.update_gate(combined)
updated_nodes = selected_nodes * update_weights + self.base_nodes[topk_indices] * (1 - update_weights)
# Safe scatter update
current_nodes.scatter_(
1,
topk_indices.unsqueeze(-1).expand(-1, -1, self.config.node_dim),
updated_nodes
)
return current_nodes
class ChunkKVAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.head_dim = config.hidden_size // config.num_attention_heads
# Projections
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
# Compression
self.k_compress = nn.Linear(config.chunk_size, config.chunk_size//config.compression_ratio)
self.v_compress = nn.Linear(config.chunk_size, config.chunk_size//config.compression_ratio)
def forward(self, hidden_states):
batch_size, seq_len, _ = hidden_states.size()
# Process queries
q = self.q_proj(hidden_states)
# Process keys/values in chunks
k = self._process_chunk(self.k_proj, self.k_compress, hidden_states)
v = self._process_chunk(self.v_proj, self.v_compress, hidden_states)
# Reshape for attention
q = q.view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
# Attention calculation
attn = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))
attn = F.softmax(attn, dim=-1)
output = torch.matmul(attn, v)
return output.transpose(1, 2).flatten(2)
def _process_chunk(self, proj, compress, x):
chunks = []
for i in range(0, x.size(1), self.config.chunk_size):
chunk = proj(x[:, i:i+self.config.chunk_size])
compressed = compress(chunk.transpose(1, 2)).transpose(1, 2)
chunks.append(compressed)
return torch.cat(chunks, dim=1)
class SelfMoA(nn.Module):
def __init__(self, config):
super().__init__()
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.GELU(),
nn.Linear(config.intermediate_size, config.hidden_size)
) for _ in range(config.num_experts)
])
self.gate = nn.Linear(config.hidden_size, config.num_experts)
def forward(self, x):
gate = F.gumbel_softmax(self.gate(x), hard=True, dim=-1)
return sum(expert(x) * gate[..., i].unsqueeze(-1) for i, expert in enumerate(self.experts))
class DeepSeekLiteBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = ChunkKVAttention(config)
self.moa = SelfMoA(config)
self.coat = CoATGraphManager(config)
self.norm = nn.LayerNorm(config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, x, nodes):
# Attention path
attn_out = self.attention(self.norm(x))
x = x + self.dropout(attn_out)
# Update graph nodes
updated_nodes = self.coat(x, nodes)
# MOA path
moa_out = self.moa(self.norm(x))
x = x + self.dropout(moa_out)
return x, updated_nodes
class DeepSeekLite(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([DeepSeekLiteBlock(config) for _ in range(config.num_hidden_layers)])
self.final_norm = nn.LayerNorm(config.hidden_size)
# Initialize graph nodes with cloning
self.graph_nodes = nn.ParameterList([
nn.Parameter(torch.randn(config.max_graph_nodes, config.node_dim).clone().detach().requires_grad_(True))
for _ in range(config.num_hidden_layers)
])
def forward(self, input_ids):
x = self.embedding(input_ids)
batch_size = input_ids.size(0)
for layer_idx, layer in enumerate(self.layers):
# Clone and expand nodes for each layer
nodes = self.graph_nodes[layer_idx].unsqueeze(0).expand(batch_size, -1, -1).clone()
x, _ = layer(x, nodes)
return self.final_norm(x)
def save_gguf(model, filename):
writer = GGUFWriter(filename, "deepseek-lite")
# Add model configuration
writer.add_uint32("vocab_size", model.config.vocab_size)
writer.add_uint32("hidden_size", model.config.hidden_size)
writer.add_uint32("num_hidden_layers", model.config.num_hidden_layers)
writer.add_uint32("num_attention_heads", model.config.num_attention_heads)
writer.add_uint32("num_experts", model.config.num_experts)
writer.add_uint32("max_graph_nodes", model.config.max_graph_nodes)
# Add all parameters
for name, param in model.named_parameters():
writer.add_tensor(name, param.detach().cpu().numpy())
writer.write_header_to_file()
writer.write_kv_data_to_file()
writer.write_tensors_to_file()
writer.close()
if __name__ == "__main__":
config = ModelConfig()
model = DeepSeekLite(config)
# Test forward pass
inputs = torch.randint(0, config.vocab_size, (2, 1024))
with torch.no_grad():
outputs = model(inputs)
print(f"Successful execution! Output shape: {outputs.shape}")
print(f"Parameter count: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")
# Save model
save_gguf(model, "deepseek-lite.gguf")
print("Model saved in GGUF format")