lemms commited on
Commit
d253f09
·
verified ·
1 Parent(s): ea3c33b

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +665 -0
model.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024 Louis Chua Bean Chong
3
+ #
4
+ # This file is part of OpenLLM.
5
+ #
6
+ # OpenLLM is dual-licensed:
7
+ # 1. For open source use: GNU General Public License v3.0
8
+ # 2. For commercial use: Commercial License (contact for details)
9
+ #
10
+ # See LICENSE and docs/LICENSES.md for full license information.
11
+
12
+ """
13
+ GPT-style Language Model Architecture
14
+
15
+ This module implements a standard GPT (Generative Pre-trained Transformer) architecture
16
+ using pure PyTorch. The model is a decoder-only transformer designed for autoregressive
17
+ language modeling (next-token prediction).
18
+
19
+ ARCHITECTURE OVERVIEW:
20
+ - Token Embedding: Maps token IDs to dense vectors
21
+ - Positional Embedding: Adds position information to token embeddings
22
+ - Transformer Blocks: Stack of multi-head attention + feed-forward layers
23
+ - Layer Normalization: Pre-norm placement for training stability
24
+ - Output Head: Linear projection to vocabulary for next-token prediction
25
+
26
+ FEATURES:
27
+ - Configurable model size (small/medium/large)
28
+ - Dropout for regularization
29
+ - Causal (autoregressive) attention masking
30
+ - Compatible with our SentencePiece tokenizer
31
+ - Memory-efficient implementation for training on limited hardware
32
+
33
+ Usage:
34
+ from model import GPTConfig, GPTModel
35
+
36
+ config = GPTConfig(vocab_size=32000, n_layer=12, n_head=12, n_embd=768)
37
+ model = GPTModel(config)
38
+
39
+ # Forward pass
40
+ logits = model(input_ids) # Shape: (batch_size, seq_len, vocab_size)
41
+
42
+ Hardware Requirements:
43
+ - Small Model (25M params): 4-8GB RAM, CPU/integrated GPU
44
+ - Medium Model (117M params): 8-16GB RAM, dedicated GPU recommended
45
+ - Large Model (350M params): 16GB+ RAM, high-end GPU required
46
+
47
+ Author: Louis Chua Bean Chong
48
+ License: GPLv3
49
+ """
50
+
51
+ import math
52
+ from dataclasses import dataclass
53
+ from typing import Optional, Tuple
54
+
55
+ import torch
56
+ import torch.nn as nn
57
+ import torch.nn.functional as F
58
+
59
+
60
+ @dataclass
61
+ class GPTConfig:
62
+ """
63
+ Configuration class for GPT model hyperparameters.
64
+
65
+ This class defines all the architectural parameters needed to instantiate
66
+ a GPT model. Use the provided class methods to get pre-configured setups
67
+ for different model sizes.
68
+ """
69
+
70
+ # Model architecture
71
+ vocab_size: int = 32000 # Vocabulary size (from tokenizer)
72
+ n_layer: int = 12 # Number of transformer layers
73
+ n_head: int = 12 # Number of attention heads
74
+ n_embd: int = 768 # Embedding dimension
75
+
76
+ # Sequence and context
77
+ block_size: int = 1024 # Maximum sequence length
78
+
79
+ # Training hyperparameters
80
+ dropout: float = 0.1 # Dropout probability
81
+ bias: bool = True # Use bias in linear layers
82
+
83
+ # Model size identifier
84
+ model_name: str = "gpt-medium" # Human-readable model identifier
85
+
86
+ @classmethod
87
+ def small(cls) -> "GPTConfig":
88
+ """Small model configuration (~25M parameters) - Good for CPU training"""
89
+ return cls(
90
+ vocab_size=32000,
91
+ n_layer=6,
92
+ n_head=8,
93
+ n_embd=512,
94
+ block_size=1024,
95
+ dropout=0.1,
96
+ model_name="gpt-small",
97
+ )
98
+
99
+ @classmethod
100
+ def medium(cls) -> "GPTConfig":
101
+ """Medium model configuration (~117M parameters) - Balanced performance"""
102
+ return cls(
103
+ vocab_size=32000,
104
+ n_layer=12,
105
+ n_head=12,
106
+ n_embd=768,
107
+ block_size=2048,
108
+ dropout=0.1,
109
+ model_name="gpt-medium",
110
+ )
111
+
112
+ @classmethod
113
+ def large(cls) -> "GPTConfig":
114
+ """Large model configuration (~350M parameters) - High performance"""
115
+ return cls(
116
+ vocab_size=32000,
117
+ n_layer=24,
118
+ n_head=16,
119
+ n_embd=1024,
120
+ block_size=2048,
121
+ dropout=0.1,
122
+ model_name="gpt-large",
123
+ )
124
+
125
+ def estimate_parameters(self) -> int:
126
+ """
127
+ Estimate the total number of trainable parameters.
128
+
129
+ Returns:
130
+ int: Estimated parameter count
131
+ """
132
+ # Token embeddings
133
+ token_emb = self.vocab_size * self.n_embd
134
+
135
+ # Position embeddings
136
+ pos_emb = self.block_size * self.n_embd
137
+
138
+ # Transformer layers
139
+ # Each layer: attention (4 * n_embd^2) + mlp (8 * n_embd^2) + layer_norms
140
+ layer_params = self.n_layer * (12 * self.n_embd**2 + 4 * self.n_embd)
141
+
142
+ # Output head
143
+ output_head = self.vocab_size * self.n_embd
144
+
145
+ total = token_emb + pos_emb + layer_params + output_head
146
+ return total
147
+
148
+
149
+ class CausalSelfAttention(nn.Module):
150
+ """
151
+ Multi-head causal self-attention mechanism.
152
+
153
+ This implements the core attention mechanism of the transformer, with causal
154
+ masking to ensure autoregressive behavior (tokens can only attend to previous
155
+ tokens, not future ones).
156
+ """
157
+
158
+ def __init__(self, config: GPTConfig):
159
+ super().__init__()
160
+ assert (
161
+ config.n_embd % config.n_head == 0
162
+ ), "Embedding dim must be divisible by number of heads"
163
+
164
+ self.config = config
165
+ self.n_head = config.n_head
166
+ self.n_embd = config.n_embd
167
+ self.head_dim = self.n_embd // self.n_head
168
+
169
+ # Key, query, value projections for all heads (batched)
170
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
171
+
172
+ # Output projection
173
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
174
+
175
+ # Dropout
176
+ self.attn_dropout = nn.Dropout(config.dropout)
177
+ self.resid_dropout = nn.Dropout(config.dropout)
178
+
179
+ # Causal mask - lower triangular matrix
180
+ self.register_buffer(
181
+ "bias",
182
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
183
+ 1, 1, config.block_size, config.block_size
184
+ ),
185
+ )
186
+
187
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
188
+ """
189
+ Forward pass of causal self-attention.
190
+
191
+ This method implements the scaled dot-product attention mechanism with causal masking.
192
+ The attention mechanism allows each token to attend to all previous tokens in the sequence,
193
+ but not to future tokens, maintaining the autoregressive property essential for language modeling.
194
+
195
+ Mathematical formulation:
196
+ Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
197
+ where Q, K, V are query, key, value matrices derived from input x
198
+
199
+ Implementation details:
200
+ - Uses batch matrix multiplication for efficiency
201
+ - Applies causal mask to prevent future token attention
202
+ - Implements multi-head attention by reshaping and parallel processing
203
+ - Applies dropout for regularization during training
204
+
205
+ Args:
206
+ x: Input tensor of shape (batch_size, seq_len, n_embd)
207
+ Contains embedded token representations from previous layer
208
+
209
+ Returns:
210
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
211
+ """
212
+ # Extract tensor dimensions for clear variable naming and validation
213
+ # B = batch size (number of sequences processed in parallel)
214
+ # T = sequence length (number of tokens in each sequence)
215
+ # C = embedding dimensionality (n_embd from config)
216
+ B, T, C = x.size()
217
+
218
+ # Generate query, key, and value projections for all attention heads
219
+ # The c_attn linear layer outputs 3 * n_embd features, which we split into Q, K, V
220
+ # This batched approach is more efficient than separate linear layers
221
+ # Input shape: (B, T, C) -> Output shape: (B, T, 3*C) -> Split to 3x (B, T, C)
222
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
223
+
224
+ # Reshape tensors for multi-head attention computation
225
+ # Transform from (B, T, C) to (B, nh, T, hs) where:
226
+ # - nh = number of heads (self.n_head)
227
+ # - hs = head size (self.head_dim = C // nh)
228
+ # The transpose(1, 2) moves the head dimension before sequence dimension for efficient computation
229
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
230
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
231
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
232
+
233
+ # Compute scaled dot-product attention scores
234
+ # Matrix multiplication: Q @ K^T gives attention affinities between all token pairs
235
+ # Scaling by 1/sqrt(head_dim) prevents softmax saturation for large embedding dimensions
236
+ # Shape: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nh, T, T)
237
+ # The resulting (T, T) matrix represents attention weights from each token to every other token
238
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
239
+
240
+ # Apply causal masking to enforce autoregressive property
241
+ # The causal mask ensures that token i can only attend to tokens j where j <= i
242
+ # This prevents the model from "cheating" by looking at future tokens during training
243
+ # We use -inf for masked positions so they become 0 after softmax
244
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
245
+
246
+ # Convert attention scores to probabilities using softmax
247
+ # Each row of the attention matrix now sums to 1, representing a probability distribution
248
+ # over which tokens to attend to for each query position
249
+ att = F.softmax(att, dim=-1)
250
+
251
+ # Apply dropout to attention weights for regularization
252
+ # This randomly zeros some attention connections during training to prevent overfitting
253
+ att = self.attn_dropout(att)
254
+
255
+ # Apply attention weights to value vectors
256
+ # This weighted combination produces the actual output of the attention mechanism
257
+ # Shape: (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
258
+ # Each output position is a weighted sum of all value vectors, with weights from attention
259
+ y = att @ v
260
+
261
+ # Concatenate multi-head outputs back to original embedding dimension
262
+ # Transform from (B, nh, T, hs) back to (B, T, C) where C = nh * hs
263
+ # The transpose moves head dimension back, and contiguous() ensures memory layout efficiency
264
+ # This combines information from all attention heads into a single representation
265
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
266
+
267
+ # Apply final output projection and residual dropout
268
+ # The output projection allows the model to learn how to best combine multi-head information
269
+ # Residual dropout provides additional regularization before the residual connection
270
+ y = self.resid_dropout(self.c_proj(y))
271
+ return y
272
+
273
+
274
+ class MLP(nn.Module):
275
+ """
276
+ Multi-Layer Perceptron (Feed-Forward Network) for Transformer.
277
+
278
+ This implements the position-wise feed-forward network that appears in each transformer layer.
279
+ The MLP provides additional non-linear transformation capacity beyond what attention provides.
280
+
281
+ Architecture:
282
+ Input -> Linear(n_embd -> 4*n_embd) -> GELU -> Linear(4*n_embd -> n_embd) -> Dropout -> Output
283
+
284
+ Design rationale:
285
+ - 4x expansion is standard in transformers (from "Attention Is All You Need")
286
+ - GELU activation provides smoother gradients than ReLU for language modeling
287
+ - Dropout prevents overfitting in the feed-forward layers
288
+ - Two linear layers allow complex non-linear transformations of attention outputs
289
+
290
+ Parameters:
291
+ - First linear layer: n_embd * 4*n_embd parameters (expansion)
292
+ - Second linear layer: 4*n_embd * n_embd parameters (projection back)
293
+ - Total: 8 * n_embd^2 parameters (significant portion of model size)
294
+ """
295
+
296
+ def __init__(self, config: GPTConfig):
297
+ super().__init__()
298
+
299
+ # First linear layer: expand embedding dimension by 4x
300
+ # This expansion gives the network more representational capacity
301
+ # The 4x factor is a standard choice that balances capacity vs efficiency
302
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
303
+
304
+ # GELU (Gaussian Error Linear Unit) activation function
305
+ # GELU provides smoother gradients compared to ReLU and works better for language modeling
306
+ # It's approximately: GELU(x) = x * Φ(x) where Φ is the CDF of standard normal distribution
307
+ self.gelu = nn.GELU()
308
+
309
+ # Second linear layer: project back to original embedding dimension
310
+ # This projection allows the network to combine information from the expanded representation
311
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
312
+
313
+ # Dropout for regularization in the feed-forward network
314
+ # Applied after the final projection to prevent overfitting
315
+ self.dropout = nn.Dropout(config.dropout)
316
+
317
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
318
+ """
319
+ Forward pass of the feed-forward network.
320
+
321
+ This method applies a two-layer MLP with GELU activation to transform
322
+ the attention outputs. The MLP operates independently on each position
323
+ in the sequence, providing position-wise non-linear transformations.
324
+
325
+ Mathematical operation:
326
+ MLP(x) = Dropout(Linear₂(GELU(Linear₁(x))))
327
+ where Linear₁: R^n_embd -> R^4*n_embd and Linear₂: R^4*n_embd -> R^n_embd
328
+
329
+ Args:
330
+ x: Input tensor of shape (batch_size, seq_len, n_embd)
331
+ Contains attended representations from the attention layer
332
+
333
+ Returns:
334
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
335
+ Contains transformed representations ready for residual connection
336
+ """
337
+ # First linear transformation: expand from n_embd to 4*n_embd dimensions
338
+ # This expansion provides the network with a higher-dimensional space for computation
339
+ # Shape: (batch_size, seq_len, n_embd) -> (batch_size, seq_len, 4*n_embd)
340
+ x = self.c_fc(x)
341
+
342
+ # Apply GELU activation function for non-linearity
343
+ # GELU is smoother than ReLU and provides better gradients for language modeling
344
+ # It introduces non-linearity while maintaining differentiability everywhere
345
+ x = self.gelu(x)
346
+
347
+ # Second linear transformation: project back to original n_embd dimensions
348
+ # This projection combines information from the expanded representation
349
+ # Shape: (batch_size, seq_len, 4*n_embd) -> (batch_size, seq_len, n_embd)
350
+ x = self.c_proj(x)
351
+
352
+ # Apply dropout for regularization before residual connection
353
+ # Dropout randomly zeros some neurons during training to prevent overfitting
354
+ # This is particularly important in the feed-forward layers which have many parameters
355
+ x = self.dropout(x)
356
+
357
+ return x
358
+
359
+
360
+ class Block(nn.Module):
361
+ """
362
+ Single Transformer block.
363
+
364
+ Consists of:
365
+ 1. Layer normalization
366
+ 2. Multi-head causal self-attention
367
+ 3. Residual connection
368
+ 4. Layer normalization
369
+ 5. MLP (feed-forward network)
370
+ 6. Residual connection
371
+
372
+ Uses pre-norm architecture for better training stability.
373
+ """
374
+
375
+ def __init__(self, config: GPTConfig):
376
+ super().__init__()
377
+ self.ln_1 = nn.LayerNorm(config.n_embd)
378
+ self.attn = CausalSelfAttention(config)
379
+ self.ln_2 = nn.LayerNorm(config.n_embd)
380
+ self.mlp = MLP(config)
381
+
382
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
383
+ """
384
+ Forward pass of transformer block.
385
+
386
+ Args:
387
+ x: Input tensor of shape (batch_size, seq_len, n_embd)
388
+
389
+ Returns:
390
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
391
+ """
392
+ # Pre-norm attention with residual connection
393
+ x = x + self.attn(self.ln_1(x))
394
+
395
+ # Pre-norm MLP with residual connection
396
+ x = x + self.mlp(self.ln_2(x))
397
+
398
+ return x
399
+
400
+
401
+ class GPTModel(nn.Module):
402
+ """
403
+ Complete GPT Language Model.
404
+
405
+ This is the main model class that combines all components:
406
+ - Token and positional embeddings
407
+ - Stack of transformer blocks
408
+ - Final layer normalization
409
+ - Language modeling head
410
+
411
+ The model can be used for:
412
+ - Training from scratch on text data
413
+ - Fine-tuning on downstream tasks
414
+ - Text generation (inference)
415
+ """
416
+
417
+ def __init__(self, config: GPTConfig, use_checkpoint=True):
418
+ super().__init__()
419
+ assert config.vocab_size is not None, "vocab_size must be specified"
420
+ assert config.block_size is not None, "block_size must be specified"
421
+
422
+ self.config = config
423
+ self.use_checkpoint = use_checkpoint
424
+
425
+ # Embeddings
426
+ self.transformer = nn.ModuleDict(
427
+ dict(
428
+ wte=nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
429
+ wpe=nn.Embedding(config.block_size, config.n_embd), # Position embeddings
430
+ drop=nn.Dropout(config.dropout),
431
+ h=nn.ModuleList(
432
+ [Block(config) for _ in range(config.n_layer)]
433
+ ), # Transformer blocks
434
+ ln_f=nn.LayerNorm(config.n_embd), # Final layer norm
435
+ )
436
+ )
437
+
438
+ # Language modeling head (maps hidden states to vocabulary)
439
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
440
+
441
+ # Tie weights between token embeddings and output head (common practice)
442
+ self.transformer.wte.weight = self.lm_head.weight
443
+
444
+ # Initialize weights
445
+ self.apply(self._init_weights)
446
+
447
+ # Report parameter count
448
+ print(f"Model initialized: {self.config.model_name}")
449
+ print(f"Parameters: {self.get_num_params():,}")
450
+ print(f"Estimated: {self.config.estimate_parameters():,}")
451
+
452
+ def _init_weights(self, module):
453
+ """Initialize model weights using standard practices."""
454
+ if isinstance(module, nn.Linear):
455
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
456
+ if module.bias is not None:
457
+ torch.nn.init.zeros_(module.bias)
458
+ elif isinstance(module, nn.Embedding):
459
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
460
+
461
+ def get_num_params(self, non_embedding: bool = False) -> int:
462
+ """
463
+ Count the number of parameters in the model.
464
+
465
+ Args:
466
+ non_embedding: If True, subtract embedding parameters
467
+
468
+ Returns:
469
+ int: Number of parameters
470
+ """
471
+ n_params = sum(p.numel() for p in self.parameters())
472
+ if non_embedding:
473
+ n_params -= self.transformer.wpe.weight.numel()
474
+ n_params -= self.transformer.wte.weight.numel()
475
+ return n_params
476
+
477
+ def forward(
478
+ self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None
479
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
480
+ """
481
+ Forward pass of the GPT model.
482
+
483
+ Args:
484
+ idx: Input token indices of shape (batch_size, seq_len)
485
+ targets: Optional target tokens for loss calculation (batch_size, seq_len)
486
+
487
+ Returns:
488
+ Tuple containing:
489
+ - logits: Output logits of shape (batch_size, seq_len, vocab_size)
490
+ - loss: Cross-entropy loss if targets provided, None otherwise
491
+ """
492
+ device = idx.device
493
+ b, t = idx.size()
494
+ assert (
495
+ t <= self.config.block_size
496
+ ), f"Sequence length {t} exceeds block size {self.config.block_size}"
497
+
498
+ # Token embeddings
499
+ tok_emb = self.transformer.wte(idx) # (b, t, n_embd)
500
+
501
+ # Position embeddings
502
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # (t,)
503
+ pos_emb = self.transformer.wpe(pos) # (t, n_embd)
504
+
505
+ # Combine embeddings and apply dropout
506
+ x = self.transformer.drop(tok_emb + pos_emb)
507
+
508
+ # Pass through transformer blocks with optional gradient checkpointing
509
+ if self.use_checkpoint and self.training:
510
+ # Use gradient checkpointing to save memory during training
511
+ for block in self.transformer.h:
512
+ x = torch.utils.checkpoint.checkpoint(block, x)
513
+ else:
514
+ # Standard forward pass
515
+ for block in self.transformer.h:
516
+ x = block(x)
517
+
518
+ # Final layer normalization
519
+ x = self.transformer.ln_f(x)
520
+
521
+ # Language modeling head
522
+ # Always compute full logits for training and evaluation
523
+ logits = self.lm_head(x)
524
+
525
+ if targets is not None:
526
+ # If we have targets, compute loss
527
+ loss = F.cross_entropy(
528
+ logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
529
+ )
530
+ else:
531
+ # If no targets, no loss computation
532
+ loss = None
533
+
534
+ return logits, loss
535
+
536
+ def generate(
537
+ self,
538
+ idx: torch.Tensor,
539
+ max_new_tokens: int = 100,
540
+ temperature: float = 1.0,
541
+ top_k: Optional[int] = None,
542
+ ) -> torch.Tensor:
543
+ """
544
+ Generate new tokens autoregressively.
545
+
546
+ Args:
547
+ idx: Starting token indices (batch_size, seq_len)
548
+ max_new_tokens: Maximum number of new tokens to generate
549
+ temperature: Sampling temperature (higher = more random)
550
+ top_k: If set, only sample from top-k most likely tokens
551
+
552
+ Returns:
553
+ torch.Tensor: Generated sequence (batch_size, seq_len + max_new_tokens)
554
+ """
555
+ self.eval()
556
+ with torch.no_grad():
557
+ for _ in range(max_new_tokens):
558
+ # Crop sequence if it exceeds block size
559
+ idx_cond = (
560
+ idx
561
+ if idx.size(1) <= self.config.block_size
562
+ else idx[:, -self.config.block_size :]
563
+ )
564
+
565
+ # Forward pass
566
+ logits, _ = self(idx_cond)
567
+
568
+ # Get logits for the last token and apply temperature
569
+ logits = logits[:, -1, :] / temperature
570
+
571
+ # Optionally crop to top-k most likely tokens
572
+ if top_k is not None:
573
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
574
+ logits[logits < v[:, [-1]]] = -float("inf")
575
+
576
+ # Apply softmax and sample
577
+ probs = F.softmax(logits, dim=-1)
578
+ idx_next = torch.multinomial(probs, num_samples=1)
579
+
580
+ # Append to sequence
581
+ idx = torch.cat((idx, idx_next), dim=1)
582
+
583
+ self.train() # Return to training mode
584
+ return idx
585
+
586
+ def estimate_memory_usage(self, batch_size: int = 1, seq_len: int = None) -> dict:
587
+ """
588
+ Estimate memory usage for training and inference.
589
+
590
+ Args:
591
+ batch_size: Batch size for estimation
592
+ seq_len: Sequence length (defaults to block_size)
593
+
594
+ Returns:
595
+ dict: Memory usage estimates in MB
596
+ """
597
+ if seq_len is None:
598
+ seq_len = self.config.block_size
599
+
600
+ # Model parameters (weights)
601
+ param_memory = self.get_num_params() * 4 / (1024**2) # 4 bytes per float32
602
+
603
+ # Activations (rough estimate)
604
+ activation_memory = (
605
+ batch_size * seq_len * self.config.n_embd * self.config.n_layer * 8 # Rough estimate
606
+ ) / (1024**2)
607
+
608
+ # Gradients (same size as parameters during training)
609
+ gradient_memory = param_memory
610
+
611
+ return {
612
+ "parameters_mb": param_memory,
613
+ "activations_mb": activation_memory,
614
+ "gradients_mb": gradient_memory,
615
+ "total_training_mb": param_memory + activation_memory + gradient_memory,
616
+ "total_inference_mb": param_memory + activation_memory * 0.5, # No gradients needed
617
+ }
618
+
619
+
620
+ def create_model(model_size: str = "medium") -> GPTModel:
621
+ """
622
+ Factory function to create a GPT model with predefined configurations.
623
+
624
+ Args:
625
+ model_size: Size of model to create ("small", "medium", "large")
626
+
627
+ Returns:
628
+ GPTModel: Initialized model
629
+ """
630
+ configs = {
631
+ "small": GPTConfig.small(),
632
+ "medium": GPTConfig.medium(),
633
+ "large": GPTConfig.large(),
634
+ }
635
+
636
+ if model_size not in configs:
637
+ raise ValueError(f"Unknown model size: {model_size}. Choose from {list(configs.keys())}")
638
+
639
+ config = configs[model_size]
640
+ model = GPTModel(config)
641
+
642
+ return model
643
+
644
+
645
+ if __name__ == "__main__":
646
+ # Example usage
647
+ print("🧠 GPT Model Architecture")
648
+ print("=" * 50)
649
+
650
+ # Create models of different sizes
651
+ for size in ["small", "medium", "large"]:
652
+ print(f"\n{size.upper()} MODEL:")
653
+ model = create_model(size)
654
+
655
+ # Show memory estimates
656
+ memory = model.estimate_memory_usage(batch_size=4, seq_len=512)
657
+ print(
658
+ f"Memory (4 batch, 512 seq): {memory['total_training_mb']:.1f}MB training, {memory['total_inference_mb']:.1f}MB inference"
659
+ )
660
+
661
+ # Test forward pass
662
+ x = torch.randint(0, 32000, (2, 64)) # Batch size 2, sequence length 64
663
+ with torch.no_grad():
664
+ logits, _ = model(x)
665
+ print(f"Test forward pass: {x.shape} -> {logits.shape} ✓")