| """ |
| Full GPT Model for FrawdLLM. |
| |
| This is the complete model that: |
| 1. Takes token IDs as input |
| 2. Converts to embeddings (token + position) |
| 3. Passes through N transformer blocks |
| 4. Predicts the next token |
| |
| Architecture: |
| Token IDs [batch, seq] |
| ↓ |
| Embeddings [batch, seq, n_embd] |
| ↓ |
| Transformer Block × N |
| ↓ |
| Final LayerNorm |
| ↓ |
| Output Head → [batch, seq, vocab_size] |
| ↓ |
| Logits (unnormalized probabilities for each vocab word) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .config import ModelConfig |
| from .embeddings import Embeddings |
| from .block import TransformerBlock |
|
|
|
|
| class FrawdLLM(nn.Module): |
| """ |
| The complete FrawdLLM model. |
| |
| Input: token_ids [batch_size, seq_len] |
| Output: logits [batch_size, seq_len, vocab_size] |
| """ |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
|
|
| self.config = config |
|
|
| |
| self.embeddings = Embeddings(config) |
|
|
| |
| self.blocks = nn.ModuleList([ |
| TransformerBlock(config) for _ in range(config.n_layer) |
| ]) |
|
|
| |
| self.ln_f = nn.LayerNorm(config.n_embd) |
|
|
| |
| |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
| |
| |
| |
| |
| self.lm_head.weight = self.embeddings.token_emb.weight |
|
|
| |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| """Initialize weights for better training.""" |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward( |
| self, |
| token_ids: torch.Tensor, |
| targets: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| """ |
| Forward pass through the model. |
| |
| Args: |
| token_ids: [batch_size, seq_len] - input token IDs |
| targets: [batch_size, seq_len] - target token IDs (for computing loss) |
| |
| Returns: |
| logits: [batch_size, seq_len, vocab_size] - prediction scores |
| loss: scalar tensor if targets provided, else None |
| """ |
| |
| |
| x = self.embeddings(token_ids) |
|
|
| |
| for block in self.blocks: |
| x = block(x) |
|
|
| |
| x = self.ln_f(x) |
|
|
| |
| |
| logits = self.lm_head(x) |
|
|
| |
| loss = None |
| if targets is not None: |
| |
| |
| |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1), |
| ignore_index=self.config.pad_token_id, |
| ) |
|
|
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| token_ids: torch.Tensor, |
| max_new_tokens: int = 100, |
| temperature: float = 1.0, |
| top_k: int | None = None, |
| ) -> torch.Tensor: |
| """ |
| Generate new tokens autoregressively. |
| |
| Args: |
| token_ids: [batch_size, seq_len] - starting tokens (prompt) |
| max_new_tokens: How many new tokens to generate |
| temperature: Higher = more random, lower = more deterministic |
| top_k: If set, only sample from top k most likely tokens |
| |
| Returns: |
| [batch_size, seq_len + max_new_tokens] - original + generated tokens |
| """ |
| for _ in range(max_new_tokens): |
| |
| context = token_ids[:, -self.config.context_length:] |
|
|
| |
| logits, _ = self.forward(context) |
|
|
| |
| |
| logits = logits[:, -1, :] |
|
|
| |
| logits = logits / temperature |
|
|
| |
| if top_k is not None: |
| |
| top_values, _ = torch.topk(logits, top_k, dim=-1) |
| min_top_value = top_values[:, -1].unsqueeze(-1) |
| logits = torch.where( |
| logits < min_top_value, |
| torch.full_like(logits, float('-inf')), |
| logits, |
| ) |
|
|
| |
| probs = F.softmax(logits, dim=-1) |
|
|
| |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| |
| token_ids = torch.cat([token_ids, next_token], dim=1) |
|
|
| |
| if (next_token == self.config.eos_token_id).all(): |
| break |
|
|
| return token_ids |
|
|
| def count_parameters(self) -> int: |
| """Count total trainable parameters.""" |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
| if __name__ == "__main__": |
| from .config import get_config |
|
|
| print("Testing FrawdLLM...") |
| print("=" * 50) |
|
|
| config = get_config("tiny") |
| print(f"Config: vocab={config.vocab_size}, n_embd={config.n_embd}, " |
| f"n_layer={config.n_layer}, n_head={config.n_head}") |
|
|
| model = FrawdLLM(config) |
|
|
| |
| num_params = model.count_parameters() |
| print(f"Total parameters: {num_params:,} ({num_params/1e6:.1f}M)") |
|
|
| |
| batch_size, seq_len = 2, 16 |
| token_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len)) |
| targets = torch.randint(0, config.vocab_size, (batch_size, seq_len)) |
|
|
| print(f"\nInput shape: {token_ids.shape}") |
|
|
| logits, loss = model(token_ids, targets) |
|
|
| print(f"Output logits shape: {logits.shape}") |
| print(f"Loss: {loss.item():.4f}") |
|
|
| |
| prompt = torch.tensor([[config.bos_token_id]]) |
| generated = model.generate(prompt, max_new_tokens=10) |
| print(f"\nGenerated shape: {generated.shape}") |
| print(f"Generated tokens: {generated[0].tolist()}") |
|
|
| print("\nFrawdLLM working!") |
|
|