waghmareps12 commited on
Commit
81c887b
·
verified ·
1 Parent(s): 0c721a5

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +79 -2
  2. inference.py +86 -0
  3. model.py +200 -0
  4. train.py +201 -0
README.md CHANGED
@@ -1,3 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- license: mit
3
- ---
 
 
1
+ # 🚀 Small Language Model Training
2
+
3
+ This project implements a **125M parameter language model** optimized for training on consumer hardware with limited VRAM (4GB+). It includes efficient training with gradient accumulation and length-based batch scheduling.
4
+
5
+ ## 📂 Project Structure
6
+ ```
7
+ │── model.py # Transformer-based language model (125M params)
8
+ │── train.py # Training script with memory optimizations
9
+ │── inference.py # Text generation script
10
+ │── requirements.txt # Required dependencies
11
+ │── README.md # Project documentation
12
+ ```
13
+
14
+ ## 📌 Features
15
+ - **Memory-Efficient Transformer Model** (~125M parameters)
16
+ - **Length-Based Batch Scheduling** for efficient training
17
+ - **Gradient Accumulation** for effective larger batch sizes
18
+ - **Autoregressive Text Generation**
19
+ - **Wikitext-2 Dataset Integration**
20
+
21
+ ## 🛠 Installation
22
+ Install dependencies:
23
+ ```bash
24
+ pip install -r requirements.txt
25
+ ```
26
+
27
+ ## 🎯 Training the Model
28
+ Run the training script:
29
+ ```bash
30
+ python train.py
31
+ ```
32
+
33
+ The training process includes:
34
+ - Automatic GPU/CPU device selection
35
+ - Dynamic batch scheduling by sequence length
36
+ - Gradient accumulation (effective batch size: 16)
37
+ - Automatic checkpointing
38
+ - Cosine learning rate scheduling
39
+
40
+ ## 📝 Inference
41
+ Generate text using the trained model:
42
+ ```bash
43
+ python inference.py
44
+ ```
45
+
46
+ ## 🏗 Model Architecture
47
+ - **Layers:** 12 transformer blocks
48
+ - **Attention Heads:** 12 heads
49
+ - **Embedding Dimension:** 768
50
+ - **Context Window:** 512 tokens
51
+ - **Total Parameters:** ~125M
52
+ - **Activation:** GELU
53
+ - **Layer Normalization:** Pre-norm architecture
54
+
55
+ ## ⚡ Performance Optimizations
56
+ - ✅ Length-based batch scheduling
57
+ - ✅ Gradient accumulation (4 steps)
58
+ - ✅ Efficient memory usage
59
+ - ✅ Optimized for 4GB VRAM GPUs
60
+ - ✅ Pre-padded sequences for faster training
61
+
62
+ ## 🔧 Training Configuration
63
+ - **Batch Size:** 4 (16 with gradient accumulation)
64
+ - **Learning Rate:** 3e-4 with cosine decay
65
+ - **Weight Decay:** 0.1
66
+ - **Training Data:** Wikitext-2
67
+ - **Epochs:** 3
68
+
69
+ ## 📊 Memory Usage
70
+ - **GPU VRAM:** ~3.5GB peak
71
+ - **Recommended GPU:** 4GB+ VRAM
72
+ - **CPU RAM:** ~8GB recommended
73
+
74
+ ## 📜 License
75
+ This project is licensed under the MIT License.
76
+
77
  ---
78
+ 🚀 Happy Training! Feel free to contribute or raise issues. 🎯
79
+
80
+
inference.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from model import SmallLanguageModel, ModelConfig
4
+
5
+ def create_model_config(vocab_size):
6
+ """Create model configuration matching training"""
7
+ return ModelConfig(
8
+ vocab_size=vocab_size,
9
+ block_size=512,
10
+ n_layer=12,
11
+ n_head=12,
12
+ n_embd=768,
13
+ dropout=0.1,
14
+ bias=True
15
+ )
16
+
17
+ def generate_text(prompt, model, tokenizer, max_length=100, temperature=0.8, top_k=50):
18
+ model.eval()
19
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
20
+
21
+ with torch.no_grad():
22
+ for _ in range(max_length):
23
+ # Get model predictions
24
+ outputs = model(input_ids)
25
+ next_token_logits = outputs[:, -1, :] / temperature
26
+
27
+ # Apply top-k filtering
28
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
29
+ next_token_logits[0, :] = float('-inf')
30
+ next_token_logits[0, top_k_indices[0]] = top_k_logits[0]
31
+
32
+ # Sample from the filtered distribution
33
+ probs = torch.softmax(next_token_logits, dim=-1)
34
+ next_token = torch.multinomial(probs, num_samples=1)
35
+
36
+ # Append to input_ids
37
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
38
+
39
+ # Stop if we generate the EOS token
40
+ if next_token[0].item() == tokenizer.eos_token_id:
41
+ break
42
+
43
+ return tokenizer.decode(input_ids[0], skip_special_tokens=True)
44
+
45
+ if __name__ == "__main__":
46
+ # Load tokenizer
47
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+
50
+ # Setup device
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ print(f"Using device: {device}")
53
+
54
+ # Create and load model
55
+ config = create_model_config(tokenizer.vocab_size)
56
+ model = SmallLanguageModel(config).to(device)
57
+
58
+ # Load trained weights
59
+ try:
60
+ checkpoint = torch.load("small_language_model.pt", map_location=device)
61
+ model.load_state_dict(checkpoint)
62
+ print("Loaded model from small_language_model.pt")
63
+ except FileNotFoundError:
64
+ print("No saved model found. Please train the model first.")
65
+ exit(1)
66
+
67
+ # Generate some example texts
68
+ prompts = [
69
+ "Once upon a time",
70
+ "The meaning of life is",
71
+ "In the distant future",
72
+ "The best way to learn programming is"
73
+ ]
74
+
75
+ print("\nGenerating text samples:\n")
76
+ for prompt in prompts:
77
+ print(f"Prompt: {prompt}")
78
+ generated_text = generate_text(
79
+ prompt,
80
+ model,
81
+ tokenizer,
82
+ max_length=100,
83
+ temperature=0.8,
84
+ top_k=50
85
+ )
86
+ print(f"Generated: {generated_text}\n")
model.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class LayerNorm(nn.Module):
7
+ def __init__(self, ndim, bias=True):
8
+ super().__init__()
9
+ self.weight = nn.Parameter(torch.ones(ndim))
10
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
11
+
12
+ def forward(self, x):
13
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
14
+
15
+ class MultiHeadAttention(nn.Module):
16
+ def __init__(self, config):
17
+ super().__init__()
18
+ self.config = config
19
+ self.n_head = config.n_head
20
+ self.head_dim = config.n_embd // config.n_head
21
+
22
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
23
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
24
+ self.attn_dropout = nn.Dropout(config.dropout)
25
+ self.resid_dropout = nn.Dropout(config.dropout)
26
+
27
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
28
+ .view(1, 1, config.block_size, config.block_size))
29
+
30
+ def forward(self, x):
31
+ B, T, C = x.size() # batch, sequence length, embedding dim
32
+
33
+ # calculate query, key, values
34
+ q, k, v = self.c_attn(x).split(self.config.n_embd, dim=2)
35
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
36
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
37
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
38
+
39
+ # causal self-attention
40
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
41
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
42
+ att = F.softmax(att, dim=-1)
43
+ att = self.attn_dropout(att)
44
+ y = att @ v
45
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
46
+
47
+ return self.resid_dropout(self.c_proj(y))
48
+
49
+ class MLP(nn.Module):
50
+ def __init__(self, config):
51
+ super().__init__()
52
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
53
+ self.gelu = nn.GELU()
54
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
55
+ self.dropout = nn.Dropout(config.dropout)
56
+
57
+ def forward(self, x):
58
+ x = self.c_fc(x)
59
+ x = self.gelu(x)
60
+ x = self.c_proj(x)
61
+ x = self.dropout(x)
62
+ return x
63
+
64
+ class Block(nn.Module):
65
+ def __init__(self, config):
66
+ super().__init__()
67
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
68
+ self.attn = MultiHeadAttention(config)
69
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
70
+ self.mlp = MLP(config)
71
+
72
+ def forward(self, x):
73
+ x = x + self.attn(self.ln_1(x))
74
+ x = x + self.mlp(self.ln_2(x))
75
+ return x
76
+
77
+ class ModelConfig:
78
+ def __init__(self, vocab_size=50257, block_size=1024, n_layer=24, n_head=16,
79
+ n_embd=1024, dropout=0.1, bias=True):
80
+ self.vocab_size = vocab_size
81
+ self.block_size = block_size
82
+ self.n_layer = n_layer
83
+ self.n_head = n_head
84
+ self.n_embd = n_embd
85
+ self.dropout = dropout
86
+ self.bias = bias
87
+ def count_parameters(model):
88
+ """Count number of trainable parameters in the model"""
89
+ total = sum(p.numel() for p in model.parameters() if p.requires_grad)
90
+
91
+ # Calculate parameters for each component
92
+ embedding_params = model.transformer.wte.weight.numel() + model.transformer.wpe.weight.numel()
93
+
94
+ attention_params = 0
95
+ mlp_params = 0
96
+ layer_norm_params = 0
97
+
98
+ for block in model.transformer.h:
99
+ # Attention parameters
100
+ attention_params += (
101
+ block.attn.c_attn.weight.numel() +
102
+ (block.attn.c_attn.bias.numel() if block.attn.c_attn.bias is not None else 0) +
103
+ block.attn.c_proj.weight.numel() +
104
+ (block.attn.c_proj.bias.numel() if block.attn.c_proj.bias is not None else 0)
105
+ )
106
+
107
+ # MLP parameters
108
+ mlp_params += (
109
+ block.mlp.c_fc.weight.numel() +
110
+ (block.mlp.c_fc.bias.numel() if block.mlp.c_fc.bias is not None else 0) +
111
+ block.mlp.c_proj.weight.numel() +
112
+ (block.mlp.c_proj.bias.numel() if block.mlp.c_proj.bias is not None else 0)
113
+ )
114
+
115
+ # Layer norm parameters
116
+ layer_norm_params += (
117
+ block.ln_1.weight.numel() +
118
+ (block.ln_1.bias.numel() if block.ln_1.bias is not None else 0) +
119
+ block.ln_2.weight.numel() +
120
+ (block.ln_2.bias.numel() if block.ln_2.bias is not None else 0)
121
+ )
122
+
123
+ # Final layer norm
124
+ layer_norm_params += (
125
+ model.transformer.ln_f.weight.numel() +
126
+ (model.transformer.ln_f.bias.numel() if model.transformer.ln_f.bias is not None else 0)
127
+ )
128
+
129
+ # Print detailed breakdown
130
+ print(f"\nParameter Count Breakdown:")
131
+ print(f"Embeddings: {embedding_params:,} parameters")
132
+ print(f"Attention Layers: {attention_params:,} parameters")
133
+ print(f"MLP Layers: {mlp_params:,} parameters")
134
+ print(f"Layer Normalization: {layer_norm_params:,} parameters")
135
+ print(f"Total: {total:,} parameters")
136
+
137
+ return total
138
+ class SmallLanguageModel(nn.Module):
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ self.config = config
142
+
143
+ self.transformer = nn.ModuleDict(dict(
144
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
145
+ wpe = nn.Embedding(config.block_size, config.n_embd),
146
+ drop = nn.Dropout(config.dropout),
147
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
148
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
149
+ ))
150
+
151
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
152
+ self.transformer.wte.weight = self.lm_head.weight
153
+
154
+ # Initialize weights
155
+ self.apply(self._init_weights)
156
+
157
+ print("\nModel Configuration:")
158
+
159
+ print(f"Layers: {config.n_layer}")
160
+
161
+ print(f"Heads: {config.n_head}")
162
+
163
+ print(f"Embedding Dimension: {config.n_embd}")
164
+
165
+ print(f"Context Window: {config.block_size}")
166
+
167
+ count_parameters(self)
168
+
169
+ def _init_weights(self, module):
170
+ if isinstance(module, nn.Linear):
171
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
172
+ if module.bias is not None:
173
+ torch.nn.init.zeros_(module.bias)
174
+ elif isinstance(module, nn.Embedding):
175
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
176
+
177
+ def forward(self, input_ids, targets=None):
178
+ device = input_ids.device
179
+ b, t = input_ids.size()
180
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
181
+
182
+ # forward the model
183
+ tok_emb = self.transformer.wte(input_ids)
184
+ pos_emb = self.transformer.wpe(pos)
185
+ x = self.transformer.drop(tok_emb + pos_emb)
186
+
187
+ for block in self.transformer.h:
188
+ x = block(x)
189
+
190
+ x = self.transformer.ln_f(x)
191
+ logits = self.lm_head(x)
192
+
193
+ if targets is not None:
194
+ # Reshape logits and targets for loss calculation
195
+ logits = logits.reshape(-1, logits.size(-1))
196
+ targets = targets.reshape(-1)
197
+ loss = F.cross_entropy(logits, targets)
198
+ return logits, loss
199
+
200
+ return logits
train.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from transformers import AutoTokenizer
5
+ from datasets import load_dataset
6
+ from model import SmallLanguageModel, ModelConfig
7
+ import random
8
+
9
+ def create_model_config(vocab_size):
10
+ """Create a ~125M parameter model configuration"""
11
+ return ModelConfig(
12
+ vocab_size=vocab_size,
13
+ block_size=512, # Reduced from 1024
14
+ n_layer=12, # Reduced from 24
15
+ n_head=12, # Reduced from 16
16
+ n_embd=768, # Reduced from 1024
17
+ dropout=0.1,
18
+ bias=True
19
+ )
20
+
21
+ def setup_training():
22
+ # Load tokenizer
23
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
24
+ tokenizer.pad_token = tokenizer.eos_token
25
+
26
+ # Create model configuration
27
+ config = create_model_config(tokenizer.vocab_size)
28
+
29
+ # Initialize model
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ model = SmallLanguageModel(config).to(device)
32
+
33
+ return model, tokenizer, device
34
+
35
+ class TextDataset(Dataset):
36
+ def __init__(self, tokenized_texts, block_size, tokenizer):
37
+ self.examples = []
38
+ self.block_size = block_size
39
+ self.tokenizer = tokenizer
40
+
41
+ # Group texts by exact length
42
+ self.length_groups = {} # Keep as instance variable
43
+
44
+ for text in tokenized_texts["input_ids"]:
45
+ if len(text) > 1: # Ensure text is at least 2 tokens
46
+ # Truncate if longer than block_size + 1
47
+ if len(text) > block_size + 1:
48
+ text = text[:block_size + 1]
49
+
50
+ length = len(text)
51
+ if length not in self.length_groups:
52
+ self.length_groups[length] = []
53
+ self.length_groups[length].append(torch.tensor(text, dtype=torch.long))
54
+
55
+ # Sort lengths for more efficient batching
56
+ self.lengths = sorted(self.length_groups.keys())
57
+
58
+ # Create index mapping
59
+ self.length_to_idx = {}
60
+ start_idx = 0
61
+ for length in self.lengths:
62
+ group = self.length_groups[length]
63
+ self.length_to_idx[length] = (start_idx, start_idx + len(group))
64
+ start_idx += len(group)
65
+ self.examples.extend(group)
66
+
67
+ print(f"Created {len(self.examples)} sequences across {len(self.lengths)} different lengths")
68
+
69
+ def __len__(self):
70
+ return len(self.examples)
71
+
72
+ def __getitem__(self, idx):
73
+ return self.examples[idx]
74
+
75
+ class BatchSchedulerSampler(torch.utils.data.Sampler):
76
+ """Samples batches according to sequence length"""
77
+ def __init__(self, dataset, batch_size):
78
+ super().__init__(dataset)
79
+ self.dataset = dataset
80
+ self.batch_size = batch_size
81
+
82
+ # Create batches for each length
83
+ self.batches = []
84
+ for length in dataset.lengths:
85
+ start_idx, end_idx = dataset.length_to_idx[length]
86
+ # Create batches of indices for this length
87
+ indices = list(range(start_idx, end_idx))
88
+ for i in range(0, len(indices), batch_size):
89
+ self.batches.append(indices[i:i + batch_size])
90
+
91
+ def __iter__(self):
92
+ # Shuffle batches
93
+ random.shuffle(self.batches)
94
+ for batch in self.batches:
95
+ yield batch
96
+
97
+ def __len__(self):
98
+ return len(self.batches)
99
+
100
+ def prepare_dataset(tokenizer, block_size):
101
+ # Load and tokenize dataset
102
+ dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
103
+
104
+ def tokenize_function(examples):
105
+ # Remove empty strings and concatenate all texts
106
+ texts = [text for text in examples["text"] if len(text.strip()) > 0]
107
+ return tokenizer(texts, truncation=False, padding=False)
108
+
109
+ tokenized_dataset = dataset.map(
110
+ tokenize_function,
111
+ batched=True,
112
+ remove_columns=dataset["train"].column_names,
113
+ desc="Tokenizing texts"
114
+ )
115
+
116
+ # Create training dataset with tokenizer
117
+ train_dataset = TextDataset(tokenized_dataset["train"], block_size=block_size, tokenizer=tokenizer)
118
+ print(f"Created dataset with {len(train_dataset)} examples")
119
+ return train_dataset
120
+
121
+ def collate_batch(batch):
122
+ # All tensors in a batch should be the same length
123
+ return torch.stack(batch)
124
+
125
+ def train_model(model, train_loader, optimizer, scheduler, device, num_epochs=3, gradient_accumulation_steps=4):
126
+ model.train()
127
+ for epoch in range(num_epochs):
128
+ total_loss = 0
129
+ optimizer.zero_grad() # Zero gradients at start of epoch
130
+
131
+ for batch_idx, batch in enumerate(train_loader):
132
+ batch = batch.to(device)
133
+
134
+ # Get input_ids and targets
135
+ input_ids = batch[:, :-1].contiguous()
136
+ targets = batch[:, 1:].contiguous()
137
+
138
+ # Forward pass
139
+ logits, loss = model(input_ids, targets)
140
+
141
+ # Scale loss for gradient accumulation
142
+ loss = loss / gradient_accumulation_steps
143
+ loss.backward()
144
+
145
+ # Update weights every gradient_accumulation_steps
146
+ if (batch_idx + 1) % gradient_accumulation_steps == 0:
147
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
148
+ optimizer.step()
149
+ scheduler.step()
150
+ optimizer.zero_grad()
151
+
152
+ total_loss += loss.item() * gradient_accumulation_steps
153
+
154
+ if batch_idx % 10 == 0:
155
+ print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item() * gradient_accumulation_steps:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")
156
+
157
+ avg_loss = total_loss / len(train_loader)
158
+ print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")
159
+
160
+ # Save checkpoint
161
+ torch.save({
162
+ 'epoch': epoch,
163
+ 'model_state_dict': model.state_dict(),
164
+ 'optimizer_state_dict': optimizer.state_dict(),
165
+ 'loss': avg_loss,
166
+ }, f'checkpoint_epoch_{epoch+1}.pt')
167
+
168
+ def main():
169
+ # Setup
170
+ model, tokenizer, device = setup_training()
171
+
172
+ # Prepare dataset
173
+ train_dataset = prepare_dataset(tokenizer, model.config.block_size)
174
+
175
+ # Use custom sampler instead of shuffle
176
+ train_loader = DataLoader(
177
+ train_dataset,
178
+ batch_sampler=BatchSchedulerSampler(train_dataset, batch_size=4), # Reduced batch size from 8 to 4
179
+ num_workers=4
180
+ )
181
+
182
+ # Training setup with gradient accumulation
183
+ optimizer = optim.AdamW(model.parameters(),
184
+ lr=3e-4,
185
+ weight_decay=0.1)
186
+
187
+ # Learning rate scheduler
188
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
189
+ optimizer,
190
+ T_max=len(train_loader) * 3, # 3 epochs
191
+ eta_min=1e-5
192
+ )
193
+
194
+ # Train the model
195
+ train_model(model, train_loader, optimizer, scheduler, device)
196
+
197
+ # Save the final model
198
+ torch.save(model.state_dict(), "small_language_model.pt")
199
+
200
+ if __name__ == "__main__":
201
+ main()