|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import math
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| from model.config import PyCraftConfig, get_config_120m, get_config_tiny
|
| from model.attention import RMSNorm
|
| from model.transformer import TransformerBlock
|
|
|
|
|
| class PyCraftModel(nn.Module):
|
| def __init__(self, config: PyCraftConfig):
|
| super().__init__()
|
| self.config = config
|
|
|
|
|
| self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
|
|
|
|
|
| self.blocks = nn.ModuleList([
|
| TransformerBlock(config) for _ in range(config.n_layers)
|
| ])
|
|
|
|
|
| self.norm_final = RMSNorm(config.d_model)
|
|
|
|
|
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
|
|
|
|
|
|
| if config.weight_tying:
|
| self.lm_head.weight = self.token_embedding.weight
|
|
|
|
|
| self._init_weights()
|
|
|
| def _init_weights(self):
|
| """
|
| GPT-2 style initialisation:
|
| - Embeddings: N(0, 0.02)
|
| - Linear layers: N(0, 0.02)
|
| - Residual projections scaled by 1/sqrt(2 * n_layers)
|
| to keep activations stable as depth increases.
|
| """
|
| std = 0.02
|
| residual_scale = std / math.sqrt(2 * self.config.n_layers)
|
|
|
| for name, module in self.named_modules():
|
| if isinstance(module, nn.Linear):
|
|
|
|
|
| if "wo" in name or "down_proj" in name:
|
| nn.init.normal_(module.weight, mean=0.0,
|
| std=residual_scale)
|
| else:
|
| nn.init.normal_(module.weight, mean=0.0, std=std)
|
| if module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
| elif isinstance(module, nn.Embedding):
|
| nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.Tensor,
|
|
|
| targets: torch.Tensor | None = None,
|
| ) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| """
|
| Args:
|
| input_ids: token indices, shape (batch, seq_len)
|
| targets: next-token targets for loss computation.
|
| If None, returns logits only (inference mode).
|
|
|
| Returns:
|
| (logits, loss)
|
| logits: (batch, seq_len, vocab_size)
|
| loss: scalar cross-entropy loss, or None if targets not given
|
| """
|
|
|
| x = self.token_embedding(input_ids)
|
|
|
|
|
| for block in self.blocks:
|
| x = block(x)
|
|
|
|
|
| x = self.norm_final(x)
|
|
|
|
|
| logits = self.lm_head(x)
|
|
|
|
|
| loss = None
|
| if targets is not None:
|
|
|
|
|
|
|
| loss = F.cross_entropy(
|
| logits.view(-1, self.config.vocab_size),
|
| targets.view(-1),
|
| ignore_index=-1,
|
| )
|
|
|
| return logits, loss
|
|
|
| @torch.no_grad()
|
| def generate(
|
| self,
|
| input_ids: torch.Tensor,
|
| max_new_tokens: int = 128,
|
| temperature: float = 0.8,
|
| top_k: int = 50,
|
| ) -> torch.Tensor:
|
| """
|
| Simple greedy / top-k generation for testing.
|
| Not for production — use a proper sampler later.
|
| """
|
| self.eval()
|
| for _ in range(max_new_tokens):
|
|
|
| context = input_ids[:, -self.config.max_seq_len:]
|
| logits, _ = self(context)
|
|
|
|
|
| next_logits = logits[:, -1, :] / temperature
|
|
|
|
|
| if top_k > 0:
|
| top_vals, _ = torch.topk(next_logits, top_k)
|
| threshold = top_vals[:, -1].unsqueeze(-1)
|
| next_logits = next_logits.masked_fill(
|
| next_logits < threshold, float('-inf'))
|
|
|
| probs = torch.softmax(next_logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
| input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
|
| return input_ids
|
|
|
| def param_count(self) -> dict:
|
| total = sum(p.numel() for p in self.parameters())
|
| trainable = sum(p.numel()
|
| for p in self.parameters() if p.requires_grad)
|
| return {"total": total, "trainable": trainable}
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| torch.manual_seed(42)
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| print("=" * 50)
|
| print("PyCraft-1 Full Model Test")
|
| print("=" * 50)
|
|
|
|
|
| print("\n[1] Testing PyCraft-tiny...")
|
| cfg_tiny = get_config_tiny()
|
| model_tiny = PyCraftModel(cfg_tiny).to(device)
|
| counts = model_tiny.param_count()
|
| print(f" Params: {counts['total'] / 1e6:.2f}M total, "
|
| f"{counts['trainable'] / 1e6:.2f}M trainable")
|
|
|
| batch, seq = 2, 128
|
| ids = torch.randint(0, cfg_tiny.vocab_size, (batch, seq), device=device)
|
| targets = torch.randint(0, cfg_tiny.vocab_size,
|
| (batch, seq), device=device)
|
|
|
| logits, loss = model_tiny(ids, targets)
|
| print(f" Logits shape: {tuple(logits.shape)}")
|
| print(
|
| f" Loss: {loss.item():.4f} (expect ~{math.log(cfg_tiny.vocab_size):.2f} for random init)")
|
|
|
| loss.backward()
|
| print(f" Backward pass: OK")
|
|
|
|
|
| print("\n[2] Testing PyCraft-1 (120M)...")
|
| cfg = get_config_120m()
|
| model = PyCraftModel(cfg).to(device)
|
| counts = model.param_count()
|
| print(f" Params: {counts['total'] / 1e6:.2f}M total")
|
|
|
|
|
| torch.cuda.empty_cache()
|
| mem_before = torch.cuda.memory_allocated() / 1e6
|
|
|
| ids_full = torch.randint(0, cfg.vocab_size, (1, 256), device=device)
|
| tgt_full = torch.randint(0, cfg.vocab_size, (1, 256), device=device)
|
| logits_full, loss_full = model(ids_full, tgt_full)
|
| loss_full.backward()
|
|
|
| mem_after = torch.cuda.memory_allocated() / 1e6
|
| print(f" GPU memory used: {mem_after:.1f} MB")
|
| print(f" Loss: {loss_full.item():.4f}")
|
| print(f" Logits shape: {tuple(logits_full.shape)}")
|
|
|
| print("\n[3] Testing generation...")
|
| model.eval()
|
| torch.cuda.empty_cache()
|
| prompt = torch.randint(0, cfg.vocab_size, (1, 10), device=device)
|
| generated = model.generate(
|
| prompt, max_new_tokens=20, temperature=1.0, top_k=50)
|
| print(
|
| f" Prompt len: {prompt.shape[1]}, Generated len: {generated.shape[1]}")
|
|
|
| print("\n" + "=" * 50)
|
| print("All tests PASSED. PyCraft-1 architecture is ready.")
|
| print("=" * 50)
|
|
|