Buckets:
| """ | |
| Phase 4: Training Loop — Optimized สำหรับ RTX 3090 (24GB VRAM) | |
| Features: | |
| - Mixed precision (BF16/FP16) | |
| - Gradient checkpointing | |
| - Gradient accumulation | |
| - Cosine LR schedule + warmup | |
| - Checkpoint saving + resuming | |
| - WandB logging (optional) | |
| - Flash Attention (ถ้ามี) | |
| """ | |
| import json | |
| import math | |
| import os | |
| import time | |
| from pathlib import Path | |
| from functools import partial | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, random_split | |
| from tokenizers import Tokenizer | |
| from model.config import OmegaConfig, small_config | |
| from model.architecture import OmegaModel | |
| from train.dataset import QADataset, collate_fn | |
| # ─── Config ─────────────────────────────────────────────────────────────────── | |
| TRAIN_CFG = { | |
| # Data — รองรับทั้ง plain QA และ CoT data (cot_qa.jsonl) | |
| "data_path": "data/filtered/clean_qa.jsonl", | |
| "cot_data_path": "data/filtered/cot_qa.jsonl", # เพิ่ม CoT data (ถ้ามี) | |
| "tokenizer_path": "data/tokenizer/tokenizer.json", | |
| "out_dir": "checkpoints", | |
| # Context — ขยายจาก 1024 → 2048 สำหรับ reasoning traces | |
| "max_seq_len": 2048, | |
| # Batch — ปรับลงเพราะ seq len ยาวขึ้น (VRAM budget คงเดิม) | |
| "batch_size": 2, | |
| "grad_accum": 16, # effective batch = 2 * 16 = 32 | |
| # LR — ลด initial LR เล็กน้อยสำหรับ longer context | |
| "lr": 2e-4, | |
| "min_lr": 2e-5, | |
| "warmup_steps": 1_000, # warmup นานขึ้นสำหรับ stability | |
| # Steps | |
| "max_steps": 80_000, # train นานขึ้น เพราะ data หลากหลายกว่า | |
| "save_every": 1_000, | |
| "eval_every": 500, | |
| "eval_steps": 50, | |
| # Precision | |
| "dtype": "bfloat16", | |
| "grad_clip": 1.0, | |
| "weight_decay": 0.1, | |
| # Extras | |
| "use_wandb": False, | |
| "compile": True, | |
| } | |
| # ─── LR Schedule ────────────────────────────────────────────────────────────── | |
| def get_lr(step: int, cfg: dict) -> float: | |
| warmup = cfg["warmup_steps"] | |
| max_steps = cfg["max_steps"] | |
| max_lr = cfg["lr"] | |
| min_lr = cfg["min_lr"] | |
| if step < warmup: | |
| return max_lr * step / warmup | |
| if step > max_steps: | |
| return min_lr | |
| decay = (step - warmup) / (max_steps - warmup) | |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay)) | |
| return min_lr + coeff * (max_lr - min_lr) | |
| # ─── Trainer ────────────────────────────────────────────────────────────────── | |
| class Trainer: | |
| def __init__(self, train_cfg: dict = TRAIN_CFG, model_cfg: OmegaConfig | None = None): | |
| self.cfg = train_cfg | |
| self.model_cfg = model_cfg or small_config() | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.dtype = torch.bfloat16 if train_cfg["dtype"] == "bfloat16" else torch.float16 | |
| self.use_amp = self.device.type == "cuda" | |
| Path(train_cfg["out_dir"]).mkdir(exist_ok=True) | |
| self.step = 0 | |
| self.best_val_loss = float("inf") | |
| def setup(self): | |
| print(f"Device: {self.device} | dtype: {self.dtype}") | |
| self._load_tokenizer() | |
| self._load_data() | |
| self._build_model() | |
| self._build_optimizer() | |
| if self.cfg.get("use_wandb"): | |
| self._init_wandb() | |
| def _load_tokenizer(self): | |
| tok_path = self.cfg["tokenizer_path"] | |
| if not Path(tok_path).exists(): | |
| raise FileNotFoundError(f"Tokenizer not found: {tok_path}\nรัน: python data/tokenizer/train_tokenizer.py") | |
| self.tokenizer = Tokenizer.from_file(tok_path) | |
| print(f"Tokenizer loaded: {self.tokenizer.get_vocab_size():,} tokens") | |
| def _load_data(self): | |
| data_path = self.cfg["data_path"] | |
| if not Path(data_path).exists(): | |
| raise FileNotFoundError(f"Dataset not found: {data_path}\nรัน: python run_phase1.py") | |
| datasets_to_merge = [QADataset(data_path, self.tokenizer, self.cfg["max_seq_len"])] | |
| # รวม CoT data ถ้ามี | |
| cot_path = self.cfg.get("cot_data_path", "") | |
| if cot_path and Path(cot_path).exists(): | |
| cot_ds = QADataset(cot_path, self.tokenizer, self.cfg["max_seq_len"]) | |
| datasets_to_merge.append(cot_ds) | |
| print(f"CoT data: +{len(cot_ds):,} examples") | |
| from torch.utils.data import ConcatDataset | |
| full_ds: QADataset | ConcatDataset = ( | |
| datasets_to_merge[0] if len(datasets_to_merge) == 1 | |
| else ConcatDataset(datasets_to_merge) | |
| ) | |
| n_val = min(max(1, int(len(full_ds) * 0.02)), max(len(full_ds) - 1, 1)) | |
| n_train = len(full_ds) - n_val | |
| train_ds, val_ds = random_split(full_ds, [n_train, n_val]) | |
| pad_id = self.model_cfg.pad_token_id | |
| collate = partial(collate_fn, pad_id=pad_id) | |
| self.train_loader = DataLoader( | |
| train_ds, batch_size=self.cfg["batch_size"], | |
| shuffle=True, collate_fn=collate, num_workers=4, pin_memory=True, | |
| ) | |
| self.val_loader = DataLoader( | |
| val_ds, batch_size=self.cfg["batch_size"], | |
| shuffle=False, collate_fn=collate, num_workers=2, pin_memory=True, | |
| ) | |
| print(f"Data: {len(train_ds):,} train | {len(val_ds):,} val") | |
| def _build_model(self): | |
| self.model = OmegaModel(self.model_cfg).to(self.device) | |
| print(f"Model: {self.model.count_params()}") | |
| # Gradient checkpointing — ประหยัด VRAM ~40% | |
| self.model.enable_grad_checkpointing() | |
| if self.cfg.get("compile") and hasattr(torch, "compile"): | |
| print("Compiling model with torch.compile ...") | |
| self.model = torch.compile(self.model) # type: ignore | |
| def _build_optimizer(self): | |
| # Separate weight decay groups | |
| decay_params = [p for n, p in self.model.named_parameters() | |
| if p.requires_grad and p.dim() >= 2] | |
| no_decay_params = [p for n, p in self.model.named_parameters() | |
| if p.requires_grad and p.dim() < 2] | |
| param_groups = [ | |
| {"params": decay_params, "weight_decay": self.cfg["weight_decay"]}, | |
| {"params": no_decay_params, "weight_decay": 0.0}, | |
| ] | |
| self.optimizer = torch.optim.AdamW(param_groups, lr=self.cfg["lr"], | |
| betas=(0.9, 0.95), eps=1e-8) | |
| self.scaler = torch.amp.GradScaler(enabled=(self.use_amp and self.dtype == torch.float16)) | |
| def _init_wandb(self): | |
| try: | |
| import wandb | |
| wandb.init(project="tinymind-omega", config={**self.cfg, **vars(self.model_cfg)}) | |
| except ImportError: | |
| print("wandb not installed — skipping") | |
| # ─── Training Step ──────────────────────────────────────────────────────── | |
| def train_step(self, batch: dict) -> float: | |
| self.model.train() | |
| input_ids = batch["input_ids"].to(self.device) | |
| labels = batch["labels"].to(self.device) | |
| attention_mask = batch["attention_mask"].to(self.device) | |
| with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype, enabled=self.use_amp): | |
| out = self.model(input_ids, attention_mask=attention_mask, labels=labels) | |
| loss = out["loss"] / self.cfg["grad_accum"] | |
| self.scaler.scale(loss).backward() | |
| return loss.item() * self.cfg["grad_accum"] | |
| def evaluate(self) -> float: | |
| self.model.eval() | |
| total_loss = 0.0 | |
| count = 0 | |
| for i, batch in enumerate(self.val_loader): | |
| if i >= self.cfg["eval_steps"]: | |
| break | |
| input_ids = batch["input_ids"].to(self.device) | |
| labels = batch["labels"].to(self.device) | |
| attention_mask = batch["attention_mask"].to(self.device) | |
| with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype, enabled=self.use_amp): | |
| out = self.model(input_ids, attention_mask=attention_mask, labels=labels) | |
| total_loss += out["loss"].item() | |
| count += 1 | |
| return total_loss / max(count, 1) | |
| def save_checkpoint(self, tag: str = "latest"): | |
| path = Path(self.cfg["out_dir"]) / f"omega_{tag}.pt" | |
| # unwrap compiled model if needed | |
| raw_model = getattr(self.model, "_orig_mod", self.model) | |
| torch.save({ | |
| "step": self.step, | |
| "model_state": raw_model.state_dict(), | |
| "optimizer_state": self.optimizer.state_dict(), | |
| "model_cfg": self.model_cfg, | |
| "best_val_loss": self.best_val_loss, | |
| }, path) | |
| print(f" Checkpoint saved → {path}") | |
| def load_checkpoint(self, path: str): | |
| ckpt = torch.load(path, map_location=self.device) | |
| raw_model = getattr(self.model, "_orig_mod", self.model) | |
| raw_model.load_state_dict(ckpt["model_state"]) | |
| self.optimizer.load_state_dict(ckpt["optimizer_state"]) | |
| self.step = ckpt["step"] | |
| self.best_val_loss = ckpt.get("best_val_loss", float("inf")) | |
| print(f"Resumed from step {self.step}") | |
| # ─── Main Loop ──────────────────────────────────────────────────────────── | |
| def train(self): | |
| self.setup() | |
| data_iter = iter(self.train_loader) | |
| self.optimizer.zero_grad() | |
| t0 = time.time() | |
| running_loss = 0.0 | |
| print(f"\nTraining for {self.cfg['max_steps']:,} steps") | |
| print(f"Effective batch size: {self.cfg['batch_size'] * self.cfg['grad_accum']}\n") | |
| while self.step < self.cfg["max_steps"]: | |
| # Update LR | |
| lr = get_lr(self.step, self.cfg) | |
| for pg in self.optimizer.param_groups: | |
| pg["lr"] = lr | |
| # Gradient accumulation | |
| for micro_step in range(self.cfg["grad_accum"]): | |
| try: | |
| batch = next(data_iter) | |
| except StopIteration: | |
| data_iter = iter(self.train_loader) | |
| batch = next(data_iter) | |
| running_loss += self.train_step(batch) | |
| # Clip gradients | |
| self.scaler.unscale_(self.optimizer) | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg["grad_clip"]) | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| self.optimizer.zero_grad() | |
| self.step += 1 | |
| # Logging | |
| if self.step % 10 == 0: | |
| dt = time.time() - t0 | |
| avg_loss = running_loss / 10 | |
| running_loss = 0.0 | |
| tok_per_sec = (10 * self.cfg["batch_size"] * self.cfg["grad_accum"] | |
| * self.cfg["max_seq_len"] / dt) | |
| print(f"step {self.step:5d} | loss {avg_loss:.4f} | lr {lr:.2e} " | |
| f"| {tok_per_sec:,.0f} tok/s | {dt:.1f}s") | |
| t0 = time.time() | |
| if self.cfg.get("use_wandb"): | |
| try: | |
| import wandb | |
| wandb.log({"train_loss": avg_loss, "lr": lr, "step": self.step}) | |
| except Exception: | |
| pass | |
| # Eval | |
| if self.step % self.cfg["eval_every"] == 0: | |
| val_loss = self.evaluate() | |
| print(f"\n Val loss: {val_loss:.4f}") | |
| if val_loss < self.best_val_loss: | |
| self.best_val_loss = val_loss | |
| self.save_checkpoint("best") | |
| print(f" New best! val_loss={val_loss:.4f}") | |
| # Save | |
| if self.step % self.cfg["save_every"] == 0: | |
| self.save_checkpoint(f"step{self.step}") | |
| self.save_checkpoint("final") | |
| print(f"\nTraining complete! Best val loss: {self.best_val_loss:.4f}") | |
| if __name__ == "__main__": | |
| trainer = Trainer() | |
| trainer.train() | |
Xet Storage Details
- Size:
- 12.6 kB
- Xet hash:
- 1b2c2c7810908f6d14fbd51442dfb9d7ab39dbfcd69e1eaf19cccaa6382a0257
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.