AndreCosta commited on
Commit
7bcb88f
Β·
verified Β·
1 Parent(s): 4915795

Upload training_loop.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training_loop.py +1164 -0
training_loop.py ADDED
@@ -0,0 +1,1164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ training_loop.py
3
+ ================
4
+ Custom training loop for the MiniLM model.
5
+
6
+ This module is part of the project:
7
+ "A bilingual PT+EN LLM with BPE tokenizer and training loop
8
+ implemented from scratch, with didactic and documented code"
9
+
10
+ Author : AndrΓ© Costa
11
+ License : MIT
12
+
13
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
14
+ THEORETICAL BACKGROUND
15
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
16
+
17
+ The training objective
18
+ -----------------------
19
+ Training an LLM is an optimization problem: we want to find the
20
+ weights ΞΈ that minimize the average loss over the corpus:
21
+
22
+ L(ΞΈ) = -1/N Ξ£ log P(t_i | t_1, ..., t_{i-1}; ΞΈ)
23
+
24
+ In other words: maximize the probability the model assigns to the
25
+ correct next token given the previous context. This is called
26
+ "Language Modeling" or "next-token prediction".
27
+
28
+ The standard metric is Perplexity (PPL):
29
+ PPL = exp(L)
30
+
31
+ Intuitively, perplexity measures "how many words the model considers
32
+ equally likely at each step". PPL = 10 means the model is, on average,
33
+ as uncertain as if it were choosing between 10 equally probable options.
34
+
35
+ Stochastic Gradient Descent (SGD)
36
+ -----------------------------------
37
+ Instead of computing the gradient over the entire corpus (infeasible),
38
+ we use mini-batches: random samples of B sequences per step.
39
+
40
+ ΞΈ ← ΞΈ - Ξ· Γ— βˆ‡_ΞΈ L(batch)
41
+
42
+ where Ξ· is the learning rate.
43
+
44
+ AdamW Optimizer (Loshchilov & Hutter, 2019)
45
+ ---------------------------------------------
46
+ AdamW combines two insights:
47
+ 1. Adam: adaptive per-parameter learning rate using first and
48
+ second order gradient moments
49
+ 2. Decoupled weight decay: L2 regularization applied directly
50
+ to weights, without interfering with Adam's adaptation
51
+
52
+ m_t = Ξ²1 Γ— m_{t-1} + (1-Ξ²1) Γ— g_t (1st order moment)
53
+ v_t = Ξ²2 Γ— v_{t-1} + (1-Ξ²2) Γ— g_tΒ² (2nd order moment)
54
+ ΞΈ_t = ΞΈ_{t-1} - Ξ· Γ— mΜ‚_t / (√vΜ‚_t + Ξ΅) - Ξ· Γ— Ξ» Γ— ΞΈ_{t-1}
55
+
56
+ Typical values: Ξ²1=0.9, Ξ²2=0.95, Ξ΅=1e-8, Ξ»=0.1
57
+
58
+ Cosine Learning Rate Schedule with Warmup
59
+ -------------------------------------------
60
+ The learning rate is not constant β€” it varies throughout training:
61
+
62
+ Phase 1 β€” Linear warmup (first ~2% of steps):
63
+ lr grows linearly from 0 to lr_max
64
+ Avoids instability at the start when weights are random
65
+
66
+ Phase 2 β€” Cosine decay:
67
+ lr decays smoothly from lr_max to lr_min
68
+ lr(t) = lr_min + 0.5 Γ— (lr_max - lr_min) Γ— (1 + cos(Ο€ Γ— t/T))
69
+
70
+ Cosine decay is preferable to linear because:
71
+ - Decays slowly at the start (still much to learn)
72
+ - Decays faster in the middle
73
+ - Stabilizes near the end (fine-grained refinement)
74
+
75
+ Gradient Clipping
76
+ ------------------
77
+ Limits the gradient norm to a maximum value (typically 1.0):
78
+ if ||g|| > max_norm:
79
+ g ← g Γ— max_norm / ||g||
80
+
81
+ Prevents "gradient explosion" β€” situations where the gradient grows
82
+ uncontrollably, causing destructive weight updates.
83
+ Especially important at the start of training.
84
+
85
+ Gradient Accumulation
86
+ ----------------------
87
+ Simulates larger batch sizes without increasing VRAM usage:
88
+ - Instead of one step with batch=32, do 4 steps with batch=8
89
+ - Accumulate gradients across the 4 steps (without optimizer.step())
90
+ - Apply the update after the 4th step
91
+
92
+ effective_batch_size = batch_size Γ— accumulation_steps
93
+
94
+ Useful for the RTX 4060 Ti (16GB), where physical batch size is limited.
95
+
96
+ Mixed Precision Training (bf16)
97
+ ---------------------------------
98
+ Uses bfloat16 (16 bits) instead of float32 to:
99
+ - Reduce VRAM usage by half
100
+ - Speed up computation (bf16 ops are ~2x faster on modern GPUs)
101
+
102
+ bf16 vs fp16:
103
+ - fp16: range 6Γ—10⁻⁡ to 65504 β†’ risk of overflow/underflow
104
+ - bf16: same range as fp32 β†’ more stable, no grad scaling needed
105
+
106
+ The RTX 4060 Ti natively supports bf16 β€” always use it.
107
+
108
+ References:
109
+ - Loshchilov, I., & Hutter, F. (2019). Decoupled weight decay
110
+ regularization. ICLR 2019.
111
+ - Loshchilov, I., & Hutter, F. (2017). SGDR: Stochastic gradient
112
+ descent with warm restarts. ICLR 2017.
113
+ - Micikevicius, P. et al. (2018). Mixed precision training. ICLR 2018.
114
+ """
115
+
116
+ import os
117
+ import math
118
+ import time
119
+ import json
120
+ import logging
121
+ from pathlib import Path
122
+ from dataclasses import dataclass, field
123
+ from typing import Optional
124
+
125
+ import torch
126
+ import torch.nn as nn
127
+ from torch.utils.data import DataLoader
128
+
129
+ # Project modules
130
+ from transformer import MiniLM, ModelConfig
131
+ from data_pipeline import CorpusDataset
132
+
133
+
134
+ # ─────────────────────────────────────────────────────────────
135
+ # Training configuration
136
+ # ��────────────────────────────────────────────────────────────
137
+
138
+ @dataclass
139
+ class TrainingConfig:
140
+ """
141
+ Training hyperparameters and settings.
142
+
143
+ Separating training configuration from model configuration
144
+ allows experimenting with different optimization regimes using
145
+ the same architecture, and vice versa.
146
+
147
+ Fields:
148
+ # Paths
149
+ corpus_dir: Directory of the pre-processed corpus.
150
+ checkpoint_dir: Where to save checkpoints during training.
151
+ model_config_path: Path to save/load the model config.
152
+
153
+ # Optimization
154
+ lr_max: Maximum (peak) learning rate.
155
+ Typical values for LLMs: 3e-4 to 6e-4.
156
+ lr_min: Minimum learning rate (end of cosine decay).
157
+ Typically lr_max / 10.
158
+ weight_decay: Decoupled L2 regularization in AdamW.
159
+ beta1, beta2: Adam moments. Ξ²2=0.95 is more conservative
160
+ than the default 0.999 β€” more stable for LLMs.
161
+ grad_clip: Maximum gradient norm.
162
+
163
+ # Batch and accumulation
164
+ batch_size: Sequences per GPU step.
165
+ accum_steps: Gradient accumulation steps.
166
+ effective_batch = batch_size Γ— accum_steps.
167
+
168
+ # Schedule
169
+ warmup_steps: Linear warmup steps.
170
+ max_steps: Total optimization steps.
171
+ None = train for 1 full epoch.
172
+
173
+ # Logging and checkpoints
174
+ log_interval: How often (in steps) to log metrics.
175
+ eval_interval: How often (in steps) to evaluate on val set.
176
+ save_interval: How often (in steps) to save a checkpoint.
177
+ eval_steps: How many batches to use for evaluation.
178
+
179
+ # Hardware
180
+ dtype: Data type for mixed precision.
181
+ "bfloat16" for RTX 4060 Ti (recommended).
182
+ compile_model: If True, uses torch.compile() for ~20% speedup.
183
+ num_workers: DataLoader workers for parallel data loading.
184
+ """
185
+ # Paths
186
+ corpus_dir: str = "./corpus"
187
+ checkpoint_dir: str = "./checkpoints"
188
+ model_config_path: str = "./model_config.json"
189
+
190
+ # Optimization
191
+ lr_max: float = 3e-4
192
+ lr_min: float = 3e-5
193
+ weight_decay: float = 0.1
194
+ beta1: float = 0.9
195
+ beta2: float = 0.95
196
+ grad_clip: float = 1.0
197
+
198
+ # Batch
199
+ batch_size: int = 8 # adjust according to available VRAM
200
+ accum_steps: int = 4 # effective_batch = 32
201
+
202
+ # Schedule
203
+ warmup_steps: int = 200
204
+ max_steps: Optional[int] = None # None = 1 full epoch
205
+
206
+ # Logging
207
+ log_interval: int = 10
208
+ eval_interval: int = 200
209
+ save_interval: int = 500
210
+ eval_steps: int = 50
211
+
212
+ # Hardware
213
+ dtype: str = "bfloat16"
214
+ compile_model: bool = True
215
+ num_workers: int = 4
216
+
217
+ @property
218
+ def effective_batch_size(self) -> int:
219
+ """Effective batch size after gradient accumulation."""
220
+ return self.batch_size * self.accum_steps
221
+
222
+ def save(self, path: str) -> None:
223
+ with open(path, "w", encoding="utf-8") as f:
224
+ json.dump(self.__dict__, f, indent=2)
225
+
226
+ @classmethod
227
+ def load(cls, path: str) -> "TrainingConfig":
228
+ with open(path, "r", encoding="utf-8") as f:
229
+ data = json.load(f)
230
+ config = cls()
231
+ for key, value in data.items():
232
+ setattr(config, key, value)
233
+ return config
234
+
235
+
236
+ # ─────────────────────────────────────────────────────────────
237
+ # Learning Rate Schedule
238
+ # ─────────────────────────────────────────────────────────────
239
+
240
+ def get_lr(
241
+ step: int,
242
+ warmup_steps: int,
243
+ max_steps: int,
244
+ lr_max: float,
245
+ lr_min: float,
246
+ ) -> float:
247
+ """
248
+ Compute the learning rate for the current step.
249
+
250
+ Implements the standard LLM schedule:
251
+ - Linear warmup from 0 β†’ lr_max over the first `warmup_steps`
252
+ - Cosine decay from lr_max β†’ lr_min until `max_steps`
253
+
254
+ Cosine decay is derived from the work of Loshchilov & Hutter (2017)
255
+ on SGDR (Stochastic Gradient Descent with Restarts).
256
+ Here we use only half a cycle (no restarts).
257
+
258
+ Args:
259
+ step: Current optimization step (starts at 0).
260
+ warmup_steps: Duration of the linear warmup.
261
+ max_steps: Total training steps.
262
+ lr_max: Maximum learning rate (warmup peak).
263
+ lr_min: Minimum learning rate (cosine end).
264
+
265
+ Returns:
266
+ Learning rate for the current step.
267
+
268
+ Example curve (warmup=100, max=1000, lr_max=3e-4, lr_min=3e-5):
269
+ step=0: lr = 0.0
270
+ step=50: lr = 1.5e-4 (midpoint of warmup)
271
+ step=100: lr = 3e-4 (peak)
272
+ step=550: lr = 1.65e-4 (midpoint of cosine)
273
+ step=1000: lr = 3e-5 (end)
274
+ """
275
+ # Phase 1: linear warmup
276
+ if step < warmup_steps:
277
+ return lr_max * (step + 1) / warmup_steps
278
+
279
+ # Beyond max_steps: hold lr_min
280
+ if step >= max_steps:
281
+ return lr_min
282
+
283
+ # Phase 2: cosine decay
284
+ # Normalize progress after warmup to [0, 1]
285
+ progress = (step - warmup_steps) / (max_steps - warmup_steps)
286
+
287
+ # Half-cosine decay formula
288
+ cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
289
+
290
+ return lr_min + cosine_decay * (lr_max - lr_min)
291
+
292
+
293
+ # ─────────────────────────────────────────────────────────────
294
+ # Metrics and logging
295
+ # ─────────────────────────────────────────────────────────────
296
+
297
+ class MetricsTracker:
298
+ """
299
+ Track and record training metrics.
300
+
301
+ Maintains a full history of loss and perplexity for
302
+ post-training analysis and learning curve generation.
303
+
304
+ Perplexity (PPL) is the main metric for LLMs:
305
+ PPL = exp(cross_entropy_loss)
306
+
307
+ Interpretation:
308
+ PPL = 1: perfect model (impossible in practice)
309
+ PPL = 10: good for small models on general text
310
+ PPL = 50: reasonable for very small models
311
+ PPL = 100+: model still learning / difficult corpus
312
+ """
313
+
314
+ def __init__(self, log_dir: str):
315
+ """
316
+ Initialize the tracker and configure the logger.
317
+
318
+ Args:
319
+ log_dir: Directory where logs and metrics will be saved.
320
+ """
321
+ os.makedirs(log_dir, exist_ok=True)
322
+ self.log_dir = log_dir
323
+
324
+ # Full history for post-training analysis
325
+ self.history: list[dict] = []
326
+
327
+ # Accumulators for moving average
328
+ self._loss_accum = 0.0
329
+ self._accum_count = 0
330
+
331
+ # Configure logger to write to both file and console
332
+ self.logger = logging.getLogger("MiniLM")
333
+ self.logger.setLevel(logging.INFO)
334
+
335
+ # File handler
336
+ fh = logging.FileHandler(os.path.join(log_dir, "training.log"))
337
+ fh.setFormatter(logging.Formatter("%(asctime)s | %(message)s"))
338
+
339
+ # Console handler
340
+ ch = logging.StreamHandler()
341
+ ch.setFormatter(logging.Formatter("%(message)s"))
342
+
343
+ self.logger.addHandler(fh)
344
+ self.logger.addHandler(ch)
345
+
346
+ def update(self, loss: float) -> None:
347
+ """Accumulate loss for average computation."""
348
+ self._loss_accum += loss
349
+ self._accum_count += 1
350
+
351
+ def log_step(
352
+ self,
353
+ step: int,
354
+ lr: float,
355
+ tokens_per_sec: float,
356
+ split: str = "train",
357
+ ) -> dict:
358
+ """
359
+ Record metrics for the current step.
360
+
361
+ Args:
362
+ step: Current step.
363
+ lr: Current learning rate.
364
+ tokens_per_sec: Token throughput per second.
365
+ split: "train" or "val".
366
+
367
+ Returns:
368
+ Dictionary with the recorded metrics.
369
+ """
370
+ avg_loss = self._loss_accum / max(self._accum_count, 1)
371
+ ppl = math.exp(min(avg_loss, 20)) # clamp to avoid overflow
372
+
373
+ metrics = {
374
+ "step": step,
375
+ "split": split,
376
+ "loss": round(avg_loss, 4),
377
+ "perplexity": round(ppl, 2),
378
+ "lr": f"{lr:.2e}",
379
+ "tokens_per_sec": int(tokens_per_sec),
380
+ }
381
+
382
+ self.history.append(metrics)
383
+
384
+ # Format log line
385
+ self.logger.info(
386
+ f"step {step:>6} | {split:<5} | "
387
+ f"loss {avg_loss:.4f} | ppl {ppl:.2f} | "
388
+ f"lr {lr:.2e} | {tokens_per_sec:.0f} tok/s"
389
+ )
390
+
391
+ # Reset accumulators
392
+ self._loss_accum = 0.0
393
+ self._accum_count = 0
394
+
395
+ return metrics
396
+
397
+ def save_history(self) -> None:
398
+ """Save the full history to JSON."""
399
+ path = os.path.join(self.log_dir, "metrics_history.json")
400
+ with open(path, "w", encoding="utf-8") as f:
401
+ json.dump(self.history, f, indent=2)
402
+
403
+
404
+ # ─────────────────────────────────────────────────────────────
405
+ # Checkpoint
406
+ # ─────────────────────────────────────────────────────────────
407
+
408
+ def save_checkpoint(
409
+ model: MiniLM,
410
+ optimizer: torch.optim.Optimizer,
411
+ step: int,
412
+ loss: float,
413
+ config: TrainingConfig,
414
+ model_config: ModelConfig,
415
+ is_best: bool = False,
416
+ ) -> None:
417
+ """
418
+ Save a full training state checkpoint.
419
+
420
+ A checkpoint includes everything needed to resume training
421
+ exactly where it left off:
422
+ - Model weights (state_dict)
423
+ - Optimizer state (accumulated Adam moments)
424
+ - Current step and best loss (for comparison)
425
+ - Model and training configurations
426
+
427
+ Checkpoint strategy:
428
+ - Saves a periodic checkpoint every `save_interval` steps
429
+ - Keeps only the 3 most recent checkpoints (saves disk space)
430
+ - Separately saves the "best checkpoint" (lowest val loss)
431
+
432
+ Args:
433
+ model: Model to save.
434
+ optimizer: Optimizer with its internal state.
435
+ step: Current step.
436
+ loss: Current validation loss.
437
+ config: Training configuration.
438
+ model_config: Architecture configuration.
439
+ is_best: If True, also saves as "best_model.pt".
440
+ """
441
+ os.makedirs(config.checkpoint_dir, exist_ok=True)
442
+
443
+ checkpoint = {
444
+ "step": step,
445
+ "loss": loss,
446
+ "model_state": model.state_dict(),
447
+ "optim_state": optimizer.state_dict(),
448
+ "model_config": model_config.__dict__,
449
+ "train_config": {k: v for k, v in config.__dict__.items()
450
+ if not callable(v)},
451
+ }
452
+
453
+ # Periodic checkpoint
454
+ ckpt_path = os.path.join(config.checkpoint_dir, f"ckpt_step_{step:07d}.pt")
455
+ torch.save(checkpoint, ckpt_path)
456
+
457
+ # Keep only the 3 most recent
458
+ ckpts = sorted(Path(config.checkpoint_dir).glob("ckpt_step_*.pt"))
459
+ for old_ckpt in ckpts[:-3]:
460
+ old_ckpt.unlink()
461
+
462
+ # Save as best model if applicable
463
+ if is_best:
464
+ best_path = os.path.join(config.checkpoint_dir, "best_model.pt")
465
+ torch.save(checkpoint, best_path)
466
+ print(f" β†’ New best model saved (loss={loss:.4f})")
467
+
468
+
469
+ def load_checkpoint(
470
+ path: str,
471
+ model: MiniLM,
472
+ optimizer: Optional[torch.optim.Optimizer] = None,
473
+ ) -> dict:
474
+ """
475
+ Load a saved checkpoint.
476
+
477
+ Args:
478
+ path: Path to the checkpoint .pt file.
479
+ model: Model to load weights into.
480
+ optimizer: Optimizer to load state into (optional).
481
+
482
+ Returns:
483
+ Dictionary with checkpoint metadata (step, loss, configs).
484
+ """
485
+ checkpoint = torch.load(path, map_location="cpu", weights_only=True)
486
+
487
+ model.load_state_dict(checkpoint["model_state"])
488
+
489
+ if optimizer is not None and "optim_state" in checkpoint:
490
+ optimizer.load_state_dict(checkpoint["optim_state"])
491
+
492
+ print(f"Checkpoint loaded: step={checkpoint['step']}, "
493
+ f"loss={checkpoint['loss']:.4f}")
494
+
495
+ return checkpoint
496
+
497
+
498
+ # ─────────────────────────────────────────────────────────────
499
+ # Evaluation
500
+ # ─────────────────────────────────────────────────────────────
501
+
502
+ @torch.no_grad()
503
+ def evaluate(
504
+ model: MiniLM,
505
+ val_loader: DataLoader,
506
+ device: torch.device,
507
+ dtype: torch.dtype,
508
+ eval_steps: int,
509
+ ) -> float:
510
+ """
511
+ Evaluate the model on the validation set.
512
+
513
+ Disables gradient computation (@torch.no_grad) to save memory
514
+ and speed up evaluation β€” during evaluation we only need the
515
+ forward pass, not the backward pass.
516
+
517
+ Loss is computed over `eval_steps` random batches from the val
518
+ set, which is sufficient for a reliable estimate without running
519
+ the full val set (which would be slow).
520
+
521
+ Args:
522
+ model: Model to evaluate.
523
+ val_loader: DataLoader for the validation set.
524
+ device: Device (cuda/cpu).
525
+ dtype: Data type for autocast.
526
+ eval_steps: How many batches to evaluate.
527
+
528
+ Returns:
529
+ Average validation loss.
530
+ """
531
+ model.eval()
532
+
533
+ total_loss = 0.0
534
+ steps_done = 0
535
+
536
+ for batch in val_loader:
537
+ if steps_done >= eval_steps:
538
+ break
539
+
540
+ # Prepare input and targets
541
+ # input_ids: all tokens except the last
542
+ # targets: all tokens except the first (shift of 1)
543
+ input_ids = batch[:, :-1].to(device)
544
+ targets = batch[:, 1:].to(device)
545
+
546
+ # Forward pass with autocast
547
+ with torch.autocast(device_type=device.type, dtype=dtype):
548
+ _, loss = model(input_ids, targets)
549
+
550
+ total_loss += loss.item()
551
+ steps_done += 1
552
+
553
+ model.train()
554
+ return total_loss / max(steps_done, 1)
555
+
556
+
557
+ # ─────────────────────────────────────────────────────────────
558
+ # Trainer β€” main class
559
+ # ─────────────────────────────────────────────────────────────
560
+
561
+ class Trainer:
562
+ """
563
+ Orchestrates the full training of MiniLM.
564
+
565
+ Responsibilities:
566
+ - Set up device, dtype and compilation
567
+ - Initialize model, optimizer and LR schedule
568
+ - Run the training loop with gradient accumulation
569
+ - Periodically evaluate on the val set
570
+ - Save checkpoints and metrics
571
+ - Resume training from a checkpoint
572
+
573
+ Basic usage:
574
+ >>> model_config = ModelConfig()
575
+ >>> train_config = TrainingConfig()
576
+ >>> trainer = Trainer(model_config, train_config)
577
+ >>> trainer.train()
578
+
579
+ Resuming training:
580
+ >>> trainer = Trainer(model_config, train_config)
581
+ >>> trainer.train(resume_from="./checkpoints/ckpt_step_0005000.pt")
582
+ """
583
+
584
+ def __init__(self, model_config: ModelConfig, train_config: TrainingConfig):
585
+ """
586
+ Initialize the Trainer.
587
+
588
+ Args:
589
+ model_config: Model architecture configuration.
590
+ train_config: Training configuration.
591
+ """
592
+ self.model_config = model_config
593
+ self.config = train_config
594
+
595
+ # ── Device ────────────────────────────────────────────────────────
596
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
597
+ print(f"Device: {self.device}")
598
+
599
+ if self.device.type == "cuda":
600
+ print(f" GPU: {torch.cuda.get_device_name(0)}")
601
+ print(f" Total VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
602
+
603
+ # ── Data type for mixed precision ──────────────────────────────────
604
+ # bf16 for RTX 4060 Ti (Ampere+), fp16 for older GPUs
605
+ if train_config.dtype == "bfloat16" and torch.cuda.is_bf16_supported():
606
+ self.dtype = torch.bfloat16
607
+ print(" Mixed precision: bfloat16 βœ“")
608
+ elif train_config.dtype == "float16":
609
+ self.dtype = torch.float16
610
+ print(" Mixed precision: float16 βœ“")
611
+ else:
612
+ self.dtype = torch.float32
613
+ print(" Mixed precision: disabled (float32)")
614
+
615
+ # ── Model ──────────────────────────────────────────────────────────
616
+ self.model = MiniLM(model_config).to(self.device)
617
+ print(f"\nModel: {self.model.count_parameters()['total'] / 1e6:.1f}M parameters")
618
+
619
+ # torch.compile() β€” JIT compilation for ~20% speedup
620
+ # Requires PyTorch 2.0+ and may take a few minutes the first time
621
+ if train_config.compile_model and hasattr(torch, "compile"):
622
+ print(" Compiling model with torch.compile()...")
623
+ self.model = torch.compile(self.model)
624
+ print(" torch.compile() βœ“")
625
+
626
+ # ── Optimizer ──────────────────────────────────────────────────────
627
+ # Weight decay should NOT be applied to:
628
+ # - Embeddings (weight decay collapses them)
629
+ # - Bias terms
630
+ # - Normalization parameters (RMSNorm.weight)
631
+ decay_params = []
632
+ no_decay_params = []
633
+
634
+ for name, param in self.model.named_parameters():
635
+ if not param.requires_grad:
636
+ continue
637
+ if param.ndim < 2 or "norm" in name or "bias" in name:
638
+ no_decay_params.append(param)
639
+ else:
640
+ decay_params.append(param)
641
+
642
+ optimizer_groups = [
643
+ {"params": decay_params, "weight_decay": train_config.weight_decay},
644
+ {"params": no_decay_params, "weight_decay": 0.0},
645
+ ]
646
+
647
+ self.optimizer = torch.optim.AdamW(
648
+ optimizer_groups,
649
+ lr=train_config.lr_max,
650
+ betas=(train_config.beta1, train_config.beta2),
651
+ eps=1e-8,
652
+ fused=True if self.device.type == "cuda" else False,
653
+ # fused=True: CUDA fused implementation, ~10% faster
654
+ )
655
+
656
+ # ── DataLoaders ────────────────────────────────────────────────────
657
+ train_dataset = CorpusDataset(
658
+ os.path.join(train_config.corpus_dir, "train")
659
+ )
660
+ val_dataset = CorpusDataset(
661
+ os.path.join(train_config.corpus_dir, "val")
662
+ )
663
+
664
+ self.train_loader = DataLoader(
665
+ train_dataset,
666
+ batch_size=train_config.batch_size,
667
+ shuffle=True,
668
+ num_workers=train_config.num_workers,
669
+ pin_memory=True, # speeds up CPU→GPU transfer
670
+ drop_last=True, # discard incomplete batch at the end
671
+ )
672
+
673
+ self.val_loader = DataLoader(
674
+ val_dataset,
675
+ batch_size=train_config.batch_size,
676
+ shuffle=False,
677
+ num_workers=train_config.num_workers,
678
+ pin_memory=True,
679
+ )
680
+
681
+ # ── Max steps ──────────────────────────────────────────────────────
682
+ if train_config.max_steps is None:
683
+ # 1 epoch = iterate through the full dataset once
684
+ self.max_steps = len(self.train_loader) // train_config.accum_steps
685
+ else:
686
+ self.max_steps = train_config.max_steps
687
+
688
+ print(f" Max steps: {self.max_steps:,}")
689
+ print(f" Effective batch size: {train_config.effective_batch_size}")
690
+ print(f" Steps per epoch: {len(self.train_loader) // train_config.accum_steps:,}")
691
+
692
+ # ── Metrics ────────────────────────────────────────────────────────
693
+ self.metrics = MetricsTracker(train_config.checkpoint_dir)
694
+
695
+ # ── Internal state ─────────────────────────────────────────────────
696
+ self.step = 0
697
+ self.best_loss = float("inf")
698
+
699
+ def _set_lr(self, step: int) -> float:
700
+ """
701
+ Update the learning rate for all optimizer parameter groups.
702
+
703
+ Args:
704
+ step: Current step.
705
+
706
+ Returns:
707
+ Computed learning rate.
708
+ """
709
+ lr = get_lr(
710
+ step=step,
711
+ warmup_steps=self.config.warmup_steps,
712
+ max_steps=self.max_steps,
713
+ lr_max=self.config.lr_max,
714
+ lr_min=self.config.lr_min,
715
+ )
716
+ for group in self.optimizer.param_groups:
717
+ group["lr"] = lr
718
+ return lr
719
+
720
+ def train(self, resume_from: Optional[str] = None) -> None:
721
+ """
722
+ Run the full training loop.
723
+
724
+ Main loop:
725
+ For each batch from train_loader:
726
+ 1. Forward pass β†’ loss
727
+ 2. loss /= accum_steps (scale for accumulation)
728
+ 3. Backward pass (accumulate gradients)
729
+ 4. Every accum_steps:
730
+ a. Gradient clipping
731
+ b. Update weights (optimizer.step)
732
+ c. Zero gradients (optimizer.zero_grad)
733
+ 5. Log metrics periodically
734
+ 6. Evaluate on val set periodically
735
+ 7. Save checkpoint periodically
736
+
737
+ Args:
738
+ resume_from: Path to a checkpoint to resume from (optional).
739
+ """
740
+ # Resume from checkpoint if provided
741
+ if resume_from is not None:
742
+ ckpt = load_checkpoint(resume_from, self.model, self.optimizer)
743
+ self.step = ckpt["step"]
744
+ self.best_loss = ckpt.get("loss", float("inf"))
745
+ print(f"Resuming from step {self.step}")
746
+
747
+ self.model.train()
748
+ self.metrics.logger.info("=" * 60)
749
+ self.metrics.logger.info("Training started")
750
+ self.metrics.logger.info(
751
+ f"max_steps={self.max_steps} | "
752
+ f"batch={self.config.batch_size} | "
753
+ f"accum={self.config.accum_steps} | "
754
+ f"effective_batch={self.config.effective_batch_size}"
755
+ )
756
+ self.metrics.logger.info("=" * 60)
757
+
758
+ # Time tracking for throughput computation
759
+ t_start = time.time()
760
+ tokens_seen = 0
761
+
762
+ # Infinite iterator over the dataset
763
+ # (needed since max_steps may span more than 1 epoch)
764
+ def infinite_loader():
765
+ while True:
766
+ for batch in self.train_loader:
767
+ yield batch
768
+
769
+ loader_iter = infinite_loader()
770
+ accumulated_loss = 0.0
771
+
772
+ while self.step < self.max_steps:
773
+
774
+ # ── Update learning rate ───────────────────────────────────────
775
+ lr = self._set_lr(self.step)
776
+
777
+ # ── Gradient Accumulation Loop ─────────────────────────────────
778
+ # Accumulate gradients over `accum_steps` micro-batches
779
+ # before applying the weight update
780
+ self.optimizer.zero_grad(set_to_none=True)
781
+ # set_to_none=True frees memory instead of zeroing β€” more efficient
782
+
783
+ for _ in range(self.config.accum_steps):
784
+ batch = next(loader_iter)
785
+
786
+ # Prepare input and targets (shift of 1 token)
787
+ input_ids = batch[:, :-1].to(self.device, non_blocking=True)
788
+ targets = batch[:, 1:].to(self.device, non_blocking=True)
789
+
790
+ tokens_seen += input_ids.numel()
791
+
792
+ # Forward with autocast (mixed precision)
793
+ with torch.autocast(
794
+ device_type=self.device.type,
795
+ dtype=self.dtype,
796
+ ):
797
+ _, loss = self.model(input_ids, targets)
798
+
799
+ # Scale the loss by the number of micro-steps so that
800
+ # the accumulated gradient is equivalent to the gradient
801
+ # of a batch of size effective_batch
802
+ loss = loss / self.config.accum_steps
803
+ accumulated_loss += loss.item()
804
+
805
+ # Backward: accumulate gradients (do not zero yet)
806
+ loss.backward()
807
+
808
+ # ── Weight update ──────────────────────────────────────────────
809
+
810
+ # Gradient clipping: prevents gradient explosion
811
+ # Returns the norm before clipping (useful for monitoring)
812
+ grad_norm = nn.utils.clip_grad_norm_(
813
+ self.model.parameters(),
814
+ self.config.grad_clip,
815
+ )
816
+
817
+ # Optimization step
818
+ self.optimizer.step()
819
+
820
+ self.step += 1
821
+
822
+ # ── Logging ────────────────────────────────────────────────────
823
+ self.metrics.update(accumulated_loss)
824
+ accumulated_loss = 0.0
825
+
826
+ if self.step % self.config.log_interval == 0:
827
+ elapsed = time.time() - t_start
828
+ tok_per_sec = tokens_seen / elapsed
829
+ lr_now = self.optimizer.param_groups[0]["lr"]
830
+
831
+ self.metrics.log_step(
832
+ step=self.step,
833
+ lr=lr_now,
834
+ tokens_per_sec=tok_per_sec,
835
+ split="train",
836
+ )
837
+
838
+ # Reset throughput counters
839
+ tokens_seen = 0
840
+ t_start = time.time()
841
+
842
+ # ── Evaluation ─────────────────────────────────────────────────
843
+ if self.step % self.config.eval_interval == 0:
844
+ val_loss = evaluate(
845
+ model=self.model,
846
+ val_loader=self.val_loader,
847
+ device=self.device,
848
+ dtype=self.dtype,
849
+ eval_steps=self.config.eval_steps,
850
+ )
851
+
852
+ self.metrics._loss_accum = val_loss
853
+ self.metrics._accum_count = 1
854
+ self.metrics.log_step(
855
+ step=self.step,
856
+ lr=self.optimizer.param_groups[0]["lr"],
857
+ tokens_per_sec=0,
858
+ split="val",
859
+ )
860
+
861
+ is_best = val_loss < self.best_loss
862
+ if is_best:
863
+ self.best_loss = val_loss
864
+
865
+ save_checkpoint(
866
+ model=self.model,
867
+ optimizer=self.optimizer,
868
+ step=self.step,
869
+ loss=val_loss,
870
+ config=self.config,
871
+ model_config=self.model_config,
872
+ is_best=is_best,
873
+ )
874
+
875
+ # ── Periodic checkpoint ────────────────────────────────────────
876
+ elif self.step % self.config.save_interval == 0:
877
+ save_checkpoint(
878
+ model=self.model,
879
+ optimizer=self.optimizer,
880
+ step=self.step,
881
+ loss=self.best_loss,
882
+ config=self.config,
883
+ model_config=self.model_config,
884
+ is_best=False,
885
+ )
886
+
887
+ # ── End of training ────────────────────────────────────────────────
888
+ self.metrics.logger.info("=" * 60)
889
+ self.metrics.logger.info(
890
+ f"Training complete. "
891
+ f"Best val loss: {self.best_loss:.4f} | "
892
+ f"PPL: {math.exp(self.best_loss):.2f}"
893
+ )
894
+ self.metrics.logger.info("=" * 60)
895
+ self.metrics.save_history()
896
+
897
+ print(f"\nBest model saved to: "
898
+ f"{os.path.join(self.config.checkpoint_dir, 'best_model.pt')}")
899
+
900
+
901
+ # ─────────────────────────────────────────────────────────────
902
+ # HuggingFace export
903
+ # ─────────────────────────────────────────────────────────────
904
+
905
+ def export_to_huggingface(
906
+ checkpoint_path: str,
907
+ output_dir: str,
908
+ tokenizer_path: str,
909
+ ) -> None:
910
+ """
911
+ Export the trained model to HuggingFace format.
912
+
913
+ Saves the model in a format compatible with AutoModel.from_pretrained(),
914
+ allowing anyone to load the model with:
915
+ model = AutoModel.from_pretrained("your-username/your-model")
916
+
917
+ The process:
918
+ 1. Load the trained checkpoint
919
+ 2. Save weights in safetensors (safe and efficient format)
920
+ 3. Create config.json in HuggingFace format
921
+ 4. Copy tokenizer files
922
+ 5. Create the model card (README.md)
923
+
924
+ After this step, use the HuggingFace CLI to publish:
925
+ huggingface-cli upload your-username/minilm ./hf_export
926
+
927
+ Args:
928
+ checkpoint_path: Path to best_model.pt.
929
+ output_dir: Output directory for HF files.
930
+ tokenizer_path: Directory with BPE tokenizer files.
931
+ """
932
+ import shutil
933
+
934
+ os.makedirs(output_dir, exist_ok=True)
935
+ print(f"Exporting to HuggingFace format in '{output_dir}'...")
936
+
937
+ # Load checkpoint
938
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
939
+ model_cfg_dict = ckpt["model_config"]
940
+ # d_head is derived automatically in ModelConfig.__post_init__
941
+ # and must not be passed as a constructor argument
942
+ model_cfg_dict.pop("d_head", None)
943
+ model_config = ModelConfig(**model_cfg_dict)
944
+
945
+ # Instantiate and load weights
946
+ model = MiniLM(model_config)
947
+
948
+ # If the model was trained with torch.compile(), the state_dict keys
949
+ # will have a '_orig_mod.' prefix β€” strip it before loading
950
+ state_dict = ckpt["model_state"]
951
+ if any(k.startswith("_orig_mod.") for k in state_dict):
952
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
953
+
954
+ model.load_state_dict(state_dict)
955
+ model.eval()
956
+
957
+ # Save weights in safetensors (safer than .bin)
958
+ # Note: weight tying means lm_head.weight and token_emb.weight share
959
+ # the same tensor in memory. safetensors does not allow shared tensors,
960
+ # so we save lm_head.weight as an independent copy.
961
+ try:
962
+ from safetensors.torch import save_file
963
+ tensors = {}
964
+ for k, v in model.state_dict().items():
965
+ # Skip complex tensors (e.g. freqs_complex from RoPE) β€”
966
+ # safetensors does not support complex dtypes.
967
+ # These buffers are recomputed automatically on model load.
968
+ if v.is_complex():
969
+ continue
970
+ tensors[k] = v.clone() # clone breaks shared memory references
971
+ save_file(tensors, os.path.join(output_dir, "model.safetensors"))
972
+ print(" Weights saved to model.safetensors")
973
+ except ImportError:
974
+ # Fallback to pytorch_model.bin β€” supports complex tensors
975
+ state_dict = {k: v for k, v in model.state_dict().items()
976
+ if not v.is_complex()}
977
+ torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
978
+ print(" Weights saved to pytorch_model.bin")
979
+ print(" (install safetensors for the recommended format: pip install safetensors)")
980
+
981
+ # Save config.json in HuggingFace format
982
+ hf_config = {
983
+ "model_type": "minilm",
984
+ "architectures": ["MiniLM"],
985
+ "vocab_size": model_config.vocab_size,
986
+ "hidden_size": model_config.d_model,
987
+ "num_hidden_layers": model_config.n_layers,
988
+ "num_attention_heads": model_config.n_heads,
989
+ "intermediate_size": model_config.d_ff,
990
+ "max_position_embeddings": model_config.seq_len,
991
+ "hidden_dropout_prob": model_config.dropout,
992
+ "torch_dtype": "bfloat16",
993
+ "transformers_version": "4.0.0",
994
+ }
995
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
996
+ json.dump(hf_config, f, indent=2)
997
+ print(" config.json saved")
998
+
999
+ # Copy tokenizer files
1000
+ for fname in ["tokenizer.json", "vocab.json"]:
1001
+ src = os.path.join(tokenizer_path, fname)
1002
+ if os.path.exists(src):
1003
+ shutil.copy(src, os.path.join(output_dir, fname))
1004
+ print(" Tokenizer files copied")
1005
+
1006
+ # Create model card (README.md)
1007
+ params_m = model_config.n_params / 1e6
1008
+ readme = f"""---
1009
+ language:
1010
+ - pt
1011
+ - en
1012
+ license: mit
1013
+ tags:
1014
+ - language-model
1015
+ - bilingual
1016
+ - portuguese
1017
+ - english
1018
+ - from-scratch
1019
+ ---
1020
+
1021
+ # MiniLM β€” Bilingual PT+EN Language Model
1022
+
1023
+ A decoder-only Transformer language model trained from scratch,
1024
+ including a BPE tokenizer and training loop implemented without
1025
+ high-level frameworks.
1026
+
1027
+ ## Specifications
1028
+
1029
+ | Attribute | Value |
1030
+ |----------------------|------------------------|
1031
+ | Parameters | {params_m:.0f}M |
1032
+ | Architecture | Transformer Decoder-only |
1033
+ | Normalization | RMSNorm |
1034
+ | Positional Encoding | RoPE |
1035
+ | FFN Activation | SwiGLU |
1036
+ | Vocabulary | {model_config.vocab_size:,} tokens (BPE) |
1037
+ | Max context | {model_config.seq_len} tokens |
1038
+ | Languages | Brazilian Portuguese + English |
1039
+
1040
+ ## Training corpus
1041
+
1042
+ - **TinyStories** (EN): short synthetic stories ~60%
1043
+ - **CulturaX PT** (PT-BR): curated Portuguese web ~40%
1044
+
1045
+ ## How to use
1046
+
1047
+ ```python
1048
+ from bpe_tokenizer import BPETokenizer
1049
+ from transformer import MiniLM, ModelConfig
1050
+ import torch, json
1051
+
1052
+ tokenizer = BPETokenizer.load("./")
1053
+
1054
+ with open("config.json") as f:
1055
+ cfg = json.load(f)
1056
+
1057
+ model_config = ModelConfig(
1058
+ vocab_size=cfg["vocab_size"],
1059
+ d_model=cfg["hidden_size"],
1060
+ n_layers=cfg["num_hidden_layers"],
1061
+ n_heads=cfg["num_attention_heads"],
1062
+ d_ff=cfg["intermediate_size"],
1063
+ seq_len=cfg["max_position_embeddings"],
1064
+ )
1065
+ model = MiniLM(model_config)
1066
+ model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu"))
1067
+ model.eval()
1068
+
1069
+ prompt = "Once upon a time"
1070
+ input_ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long)
1071
+ output = model.generate(input_ids, max_new_tokens=100, temperature=0.8, top_k=50)
1072
+ print(tokenizer.decode(output[0].tolist()))
1073
+ ```
1074
+
1075
+ ## Development
1076
+
1077
+ All training code is available in the repository:
1078
+ - `bpe_tokenizer.py` β€” BPE tokenizer from scratch
1079
+ - `data_pipeline.py` β€” Corpus preparation pipeline
1080
+ - `transformer.py` β€” Model architecture
1081
+ - `training_loop.py` β€” Custom training loop
1082
+
1083
+ ## Citation
1084
+
1085
+ ```
1086
+ @misc{{minilm2025,
1087
+ title={{MiniLM: A bilingual PT+EN language model built from scratch}},
1088
+ author={{AndrΓ© Costa}},
1089
+ year={{2026}},
1090
+ url={{https://huggingface.co/AndreCosta/minilm}}
1091
+ }}
1092
+ ```
1093
+ """
1094
+ with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f:
1095
+ f.write(readme)
1096
+ print(" README.md (model card) created")
1097
+
1098
+ print(f"\nExport complete!")
1099
+ print(f"To publish on HuggingFace:")
1100
+ print(f" huggingface-cli login")
1101
+ print(f" huggingface-cli upload [your-username]/minilm {output_dir}")
1102
+
1103
+
1104
+ # ─────────────────────────────────────────────────────────────
1105
+ # Entry point
1106
+ # ─────────────────────────────────────────────────────────────
1107
+
1108
+ if __name__ == "__main__":
1109
+ import argparse
1110
+
1111
+ parser = argparse.ArgumentParser(description="MiniLM Training")
1112
+ parser.add_argument("--mode", choices=["train", "export"],
1113
+ default="train", help="Execution mode")
1114
+ parser.add_argument("--resume", type=str, default=None,
1115
+ help="Path to checkpoint to resume from")
1116
+ parser.add_argument("--checkpoint", type=str, default=None,
1117
+ help="Checkpoint to export (export mode)")
1118
+ parser.add_argument("--output-dir", type=str, default="./hf_export",
1119
+ help="Output directory for HF export")
1120
+ parser.add_argument("--tokenizer-path", type=str, default="./tokenizer",
1121
+ help="Path to the BPE tokenizer")
1122
+ parser.add_argument("--small", action="store_true",
1123
+ help="Use Tiny config (~15M params) for quick tests")
1124
+ args = parser.parse_args()
1125
+
1126
+ if args.mode == "train":
1127
+ # Model configuration
1128
+ if args.small:
1129
+ print("Using Tiny configuration (~15M params) for quick test")
1130
+ model_config = ModelConfig(
1131
+ vocab_size=16384,
1132
+ seq_len=512, # must match the seq_len used in data_pipeline.py
1133
+ d_model=256,
1134
+ n_heads=4,
1135
+ n_layers=4,
1136
+ d_ff=768,
1137
+ dropout=0.1,
1138
+ )
1139
+ train_config = TrainingConfig(
1140
+ batch_size=4,
1141
+ accum_steps=2,
1142
+ max_steps=100,
1143
+ log_interval=10,
1144
+ eval_interval=50,
1145
+ save_interval=50,
1146
+ )
1147
+ else:
1148
+ model_config = ModelConfig() # Small (~85M) by default
1149
+ train_config = TrainingConfig()
1150
+
1151
+ print("\nModel configuration:")
1152
+ print(f" {model_config.n_params / 1e6:.1f}M parameters")
1153
+
1154
+ trainer = Trainer(model_config, train_config)
1155
+ trainer.train(resume_from=args.resume)
1156
+
1157
+ elif args.mode == "export":
1158
+ if args.checkpoint is None:
1159
+ args.checkpoint = "./checkpoints/best_model.pt"
1160
+ export_to_huggingface(
1161
+ checkpoint_path=args.checkpoint,
1162
+ output_dir=args.output_dir,
1163
+ tokenizer_path=args.tokenizer_path,
1164
+ )