pycraft-1 / model /pycraft_model.py
imshadow0's picture
Upload model/pycraft_model.py with huggingface_hub
fdc8c89 verified
Raw
History Blame Contribute Delete
8.25 kB
# model/pycraft_model.py
# PyCraft-1: full autoregressive language model.
#
# Architecture summary:
# Token embedding
# → N × TransformerBlock (RMSNorm + GQA/QK-Norm/RoPE + SwiGLU)
# → Final RMSNorm
# → Linear output projection (vocab logits)
#
# Training objective: causal language modelling (next-token prediction)
# + Fill-in-the-Middle (FIM) on 50% of batches (handled in data pipeline).
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.config import PyCraftConfig, get_config_120m, get_config_tiny
from model.attention import RMSNorm
from model.transformer import TransformerBlock
class PyCraftModel(nn.Module):
def __init__(self, config: PyCraftConfig):
super().__init__()
self.config = config
# Token embedding table
self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
# Stack of transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(config) for _ in range(config.n_layers)
])
# Final layer norm before output projection
self.norm_final = RMSNorm(config.d_model)
# Output projection: d_model → vocab_size
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Weight tying (optional): share embedding and lm_head weights
# Saves ~16M params but slightly reduces flexibility.
if config.weight_tying:
self.lm_head.weight = self.token_embedding.weight
# Initialise weights
self._init_weights()
def _init_weights(self):
"""
GPT-2 style initialisation:
- Embeddings: N(0, 0.02)
- Linear layers: N(0, 0.02)
- Residual projections scaled by 1/sqrt(2 * n_layers)
to keep activations stable as depth increases.
"""
std = 0.02
residual_scale = std / math.sqrt(2 * self.config.n_layers)
for name, module in self.named_modules():
if isinstance(module, nn.Linear):
# Scale down output projections (wo and down_proj)
# which feed directly into residual connections
if "wo" in name or "down_proj" in name:
nn.init.normal_(module.weight, mean=0.0,
std=residual_scale)
else:
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
def forward(
self,
input_ids: torch.Tensor, # (batch, seq_len)
# (batch, seq_len) for training
targets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Args:
input_ids: token indices, shape (batch, seq_len)
targets: next-token targets for loss computation.
If None, returns logits only (inference mode).
Returns:
(logits, loss)
logits: (batch, seq_len, vocab_size)
loss: scalar cross-entropy loss, or None if targets not given
"""
# 1. Embed tokens
x = self.token_embedding(input_ids) # (batch, seq_len, d_model)
# 2. Pass through transformer blocks
for block in self.blocks:
x = block(x)
# 3. Final norm
x = self.norm_final(x)
# 4. Project to vocabulary logits
logits = self.lm_head(x) # (batch, seq_len, vocab_size)
# 5. Compute loss if targets provided
loss = None
if targets is not None:
# Flatten for cross-entropy:
# logits: (batch * seq_len, vocab_size)
# targets: (batch * seq_len,)
loss = F.cross_entropy(
logits.view(-1, self.config.vocab_size),
targets.view(-1),
ignore_index=-1, # -1 = padding token (masked from loss)
)
return logits, loss
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor, # (1, prompt_len) — single sequence only
max_new_tokens: int = 128,
temperature: float = 0.8,
top_k: int = 50,
) -> torch.Tensor:
"""
Simple greedy / top-k generation for testing.
Not for production — use a proper sampler later.
"""
self.eval()
for _ in range(max_new_tokens):
# Crop context to max_seq_len
context = input_ids[:, -self.config.max_seq_len:]
logits, _ = self(context)
# Take logits at last position
next_logits = logits[:, -1, :] / temperature # (1, vocab_size)
# Top-k filtering
if top_k > 0:
top_vals, _ = torch.topk(next_logits, top_k)
threshold = top_vals[:, -1].unsqueeze(-1)
next_logits = next_logits.masked_fill(
next_logits < threshold, float('-inf'))
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
def param_count(self) -> dict:
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel()
for p in self.parameters() if p.requires_grad)
return {"total": total, "trainable": trainable}
# ------------------------------------------------------------------ #
# Full model self-test
# ------------------------------------------------------------------ #
if __name__ == "__main__":
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("=" * 50)
print("PyCraft-1 Full Model Test")
print("=" * 50)
# Test with tiny config first (fast)
print("\n[1] Testing PyCraft-tiny...")
cfg_tiny = get_config_tiny()
model_tiny = PyCraftModel(cfg_tiny).to(device)
counts = model_tiny.param_count()
print(f" Params: {counts['total'] / 1e6:.2f}M total, "
f"{counts['trainable'] / 1e6:.2f}M trainable")
batch, seq = 2, 128
ids = torch.randint(0, cfg_tiny.vocab_size, (batch, seq), device=device)
targets = torch.randint(0, cfg_tiny.vocab_size,
(batch, seq), device=device)
logits, loss = model_tiny(ids, targets)
print(f" Logits shape: {tuple(logits.shape)}")
print(
f" Loss: {loss.item():.4f} (expect ~{math.log(cfg_tiny.vocab_size):.2f} for random init)")
loss.backward()
print(f" Backward pass: OK")
# Test with full 120M config
print("\n[2] Testing PyCraft-1 (120M)...")
cfg = get_config_120m()
model = PyCraftModel(cfg).to(device)
counts = model.param_count()
print(f" Params: {counts['total'] / 1e6:.2f}M total")
# Memory check
torch.cuda.empty_cache()
mem_before = torch.cuda.memory_allocated() / 1e6
ids_full = torch.randint(0, cfg.vocab_size, (1, 256), device=device)
tgt_full = torch.randint(0, cfg.vocab_size, (1, 256), device=device)
logits_full, loss_full = model(ids_full, tgt_full)
loss_full.backward()
mem_after = torch.cuda.memory_allocated() / 1e6
print(f" GPU memory used: {mem_after:.1f} MB")
print(f" Loss: {loss_full.item():.4f}")
print(f" Logits shape: {tuple(logits_full.shape)}")
print("\n[3] Testing generation...")
model.eval()
torch.cuda.empty_cache()
prompt = torch.randint(0, cfg.vocab_size, (1, 10), device=device)
generated = model.generate(
prompt, max_new_tokens=20, temperature=1.0, top_k=50)
print(
f" Prompt len: {prompt.shape[1]}, Generated len: {generated.shape[1]}")
print("\n" + "=" * 50)
print("All tests PASSED. PyCraft-1 architecture is ready.")
print("=" * 50)