bbkdevops's picture
download
raw
12.6 kB
"""
Phase 4: Training Loop — Optimized สำหรับ RTX 3090 (24GB VRAM)
Features:
- Mixed precision (BF16/FP16)
- Gradient checkpointing
- Gradient accumulation
- Cosine LR schedule + warmup
- Checkpoint saving + resuming
- WandB logging (optional)
- Flash Attention (ถ้ามี)
"""
import json
import math
import os
import time
from pathlib import Path
from functools import partial
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tokenizers import Tokenizer
from model.config import OmegaConfig, small_config
from model.architecture import OmegaModel
from train.dataset import QADataset, collate_fn
# ─── Config ───────────────────────────────────────────────────────────────────
TRAIN_CFG = {
# Data — รองรับทั้ง plain QA และ CoT data (cot_qa.jsonl)
"data_path": "data/filtered/clean_qa.jsonl",
"cot_data_path": "data/filtered/cot_qa.jsonl", # เพิ่ม CoT data (ถ้ามี)
"tokenizer_path": "data/tokenizer/tokenizer.json",
"out_dir": "checkpoints",
# Context — ขยายจาก 1024 → 2048 สำหรับ reasoning traces
"max_seq_len": 2048,
# Batch — ปรับลงเพราะ seq len ยาวขึ้น (VRAM budget คงเดิม)
"batch_size": 2,
"grad_accum": 16, # effective batch = 2 * 16 = 32
# LR — ลด initial LR เล็กน้อยสำหรับ longer context
"lr": 2e-4,
"min_lr": 2e-5,
"warmup_steps": 1_000, # warmup นานขึ้นสำหรับ stability
# Steps
"max_steps": 80_000, # train นานขึ้น เพราะ data หลากหลายกว่า
"save_every": 1_000,
"eval_every": 500,
"eval_steps": 50,
# Precision
"dtype": "bfloat16",
"grad_clip": 1.0,
"weight_decay": 0.1,
# Extras
"use_wandb": False,
"compile": True,
}
# ─── LR Schedule ──────────────────────────────────────────────────────────────
def get_lr(step: int, cfg: dict) -> float:
warmup = cfg["warmup_steps"]
max_steps = cfg["max_steps"]
max_lr = cfg["lr"]
min_lr = cfg["min_lr"]
if step < warmup:
return max_lr * step / warmup
if step > max_steps:
return min_lr
decay = (step - warmup) / (max_steps - warmup)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay))
return min_lr + coeff * (max_lr - min_lr)
# ─── Trainer ──────────────────────────────────────────────────────────────────
class Trainer:
def __init__(self, train_cfg: dict = TRAIN_CFG, model_cfg: OmegaConfig | None = None):
self.cfg = train_cfg
self.model_cfg = model_cfg or small_config()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = torch.bfloat16 if train_cfg["dtype"] == "bfloat16" else torch.float16
self.use_amp = self.device.type == "cuda"
Path(train_cfg["out_dir"]).mkdir(exist_ok=True)
self.step = 0
self.best_val_loss = float("inf")
def setup(self):
print(f"Device: {self.device} | dtype: {self.dtype}")
self._load_tokenizer()
self._load_data()
self._build_model()
self._build_optimizer()
if self.cfg.get("use_wandb"):
self._init_wandb()
def _load_tokenizer(self):
tok_path = self.cfg["tokenizer_path"]
if not Path(tok_path).exists():
raise FileNotFoundError(f"Tokenizer not found: {tok_path}\nรัน: python data/tokenizer/train_tokenizer.py")
self.tokenizer = Tokenizer.from_file(tok_path)
print(f"Tokenizer loaded: {self.tokenizer.get_vocab_size():,} tokens")
def _load_data(self):
data_path = self.cfg["data_path"]
if not Path(data_path).exists():
raise FileNotFoundError(f"Dataset not found: {data_path}\nรัน: python run_phase1.py")
datasets_to_merge = [QADataset(data_path, self.tokenizer, self.cfg["max_seq_len"])]
# รวม CoT data ถ้ามี
cot_path = self.cfg.get("cot_data_path", "")
if cot_path and Path(cot_path).exists():
cot_ds = QADataset(cot_path, self.tokenizer, self.cfg["max_seq_len"])
datasets_to_merge.append(cot_ds)
print(f"CoT data: +{len(cot_ds):,} examples")
from torch.utils.data import ConcatDataset
full_ds: QADataset | ConcatDataset = (
datasets_to_merge[0] if len(datasets_to_merge) == 1
else ConcatDataset(datasets_to_merge)
)
n_val = min(max(1, int(len(full_ds) * 0.02)), max(len(full_ds) - 1, 1))
n_train = len(full_ds) - n_val
train_ds, val_ds = random_split(full_ds, [n_train, n_val])
pad_id = self.model_cfg.pad_token_id
collate = partial(collate_fn, pad_id=pad_id)
self.train_loader = DataLoader(
train_ds, batch_size=self.cfg["batch_size"],
shuffle=True, collate_fn=collate, num_workers=4, pin_memory=True,
)
self.val_loader = DataLoader(
val_ds, batch_size=self.cfg["batch_size"],
shuffle=False, collate_fn=collate, num_workers=2, pin_memory=True,
)
print(f"Data: {len(train_ds):,} train | {len(val_ds):,} val")
def _build_model(self):
self.model = OmegaModel(self.model_cfg).to(self.device)
print(f"Model: {self.model.count_params()}")
# Gradient checkpointing — ประหยัด VRAM ~40%
self.model.enable_grad_checkpointing()
if self.cfg.get("compile") and hasattr(torch, "compile"):
print("Compiling model with torch.compile ...")
self.model = torch.compile(self.model) # type: ignore
def _build_optimizer(self):
# Separate weight decay groups
decay_params = [p for n, p in self.model.named_parameters()
if p.requires_grad and p.dim() >= 2]
no_decay_params = [p for n, p in self.model.named_parameters()
if p.requires_grad and p.dim() < 2]
param_groups = [
{"params": decay_params, "weight_decay": self.cfg["weight_decay"]},
{"params": no_decay_params, "weight_decay": 0.0},
]
self.optimizer = torch.optim.AdamW(param_groups, lr=self.cfg["lr"],
betas=(0.9, 0.95), eps=1e-8)
self.scaler = torch.amp.GradScaler(enabled=(self.use_amp and self.dtype == torch.float16))
def _init_wandb(self):
try:
import wandb
wandb.init(project="tinymind-omega", config={**self.cfg, **vars(self.model_cfg)})
except ImportError:
print("wandb not installed — skipping")
# ─── Training Step ────────────────────────────────────────────────────────
def train_step(self, batch: dict) -> float:
self.model.train()
input_ids = batch["input_ids"].to(self.device)
labels = batch["labels"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype, enabled=self.use_amp):
out = self.model(input_ids, attention_mask=attention_mask, labels=labels)
loss = out["loss"] / self.cfg["grad_accum"]
self.scaler.scale(loss).backward()
return loss.item() * self.cfg["grad_accum"]
@torch.no_grad()
def evaluate(self) -> float:
self.model.eval()
total_loss = 0.0
count = 0
for i, batch in enumerate(self.val_loader):
if i >= self.cfg["eval_steps"]:
break
input_ids = batch["input_ids"].to(self.device)
labels = batch["labels"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype, enabled=self.use_amp):
out = self.model(input_ids, attention_mask=attention_mask, labels=labels)
total_loss += out["loss"].item()
count += 1
return total_loss / max(count, 1)
def save_checkpoint(self, tag: str = "latest"):
path = Path(self.cfg["out_dir"]) / f"omega_{tag}.pt"
# unwrap compiled model if needed
raw_model = getattr(self.model, "_orig_mod", self.model)
torch.save({
"step": self.step,
"model_state": raw_model.state_dict(),
"optimizer_state": self.optimizer.state_dict(),
"model_cfg": self.model_cfg,
"best_val_loss": self.best_val_loss,
}, path)
print(f" Checkpoint saved → {path}")
def load_checkpoint(self, path: str):
ckpt = torch.load(path, map_location=self.device)
raw_model = getattr(self.model, "_orig_mod", self.model)
raw_model.load_state_dict(ckpt["model_state"])
self.optimizer.load_state_dict(ckpt["optimizer_state"])
self.step = ckpt["step"]
self.best_val_loss = ckpt.get("best_val_loss", float("inf"))
print(f"Resumed from step {self.step}")
# ─── Main Loop ────────────────────────────────────────────────────────────
def train(self):
self.setup()
data_iter = iter(self.train_loader)
self.optimizer.zero_grad()
t0 = time.time()
running_loss = 0.0
print(f"\nTraining for {self.cfg['max_steps']:,} steps")
print(f"Effective batch size: {self.cfg['batch_size'] * self.cfg['grad_accum']}\n")
while self.step < self.cfg["max_steps"]:
# Update LR
lr = get_lr(self.step, self.cfg)
for pg in self.optimizer.param_groups:
pg["lr"] = lr
# Gradient accumulation
for micro_step in range(self.cfg["grad_accum"]):
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(self.train_loader)
batch = next(data_iter)
running_loss += self.train_step(batch)
# Clip gradients
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg["grad_clip"])
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
self.step += 1
# Logging
if self.step % 10 == 0:
dt = time.time() - t0
avg_loss = running_loss / 10
running_loss = 0.0
tok_per_sec = (10 * self.cfg["batch_size"] * self.cfg["grad_accum"]
* self.cfg["max_seq_len"] / dt)
print(f"step {self.step:5d} | loss {avg_loss:.4f} | lr {lr:.2e} "
f"| {tok_per_sec:,.0f} tok/s | {dt:.1f}s")
t0 = time.time()
if self.cfg.get("use_wandb"):
try:
import wandb
wandb.log({"train_loss": avg_loss, "lr": lr, "step": self.step})
except Exception:
pass
# Eval
if self.step % self.cfg["eval_every"] == 0:
val_loss = self.evaluate()
print(f"\n Val loss: {val_loss:.4f}")
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.save_checkpoint("best")
print(f" New best! val_loss={val_loss:.4f}")
# Save
if self.step % self.cfg["save_every"] == 0:
self.save_checkpoint(f"step{self.step}")
self.save_checkpoint("final")
print(f"\nTraining complete! Best val loss: {self.best_val_loss:.4f}")
if __name__ == "__main__":
trainer = Trainer()
trainer.train()

Xet Storage Details

Size:
12.6 kB
·
Xet hash:
1b2c2c7810908f6d14fbd51442dfb9d7ab39dbfcd69e1eaf19cccaa6382a0257

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.