| """ |
| Component 5: Training pipeline for the 420M code model. |
| |
| Features: |
| - FP16 mixed precision |
| - Gradient checkpointing |
| - Gradient accumulation |
| - 8-bit optimizer attempt with safe fallback |
| - Checkpoint save every N steps |
| - Resume from checkpoint |
| - Early stopping |
| - Live progress with loss, LR, ETA, VRAM |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Any, Dict, Optional, Tuple |
|
|
| import torch |
| import yaml |
| from torch.optim import AdamW |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from src.model_architecture.code_transformer import CodeTransformerLM, ModelConfig, get_model_presets |
| from src.training_pipeline.tokenized_dataset import CausalCollator, TokenizedJsonlDataset |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run Component 5 training.") |
| parser.add_argument("--config", default="configs/component5_training_config.yaml") |
| return parser.parse_args() |
|
|
|
|
| def load_yaml(path: Path) -> Dict[str, Any]: |
| if not path.exists(): |
| raise FileNotFoundError(f"Config not found: {path}") |
| with path.open("r", encoding="utf-8") as f: |
| data = yaml.safe_load(f) |
| if not isinstance(data, dict): |
| raise ValueError("Invalid YAML format.") |
| return data |
|
|
|
|
| def load_model_config(path: Path) -> ModelConfig: |
| cfg = load_yaml(path) |
| preset = cfg.get("preset") |
| model_cfg = cfg.get("model", {}) |
| if preset: |
| presets = get_model_presets() |
| if preset not in presets: |
| raise ValueError(f"Unknown model preset: {preset}") |
| base = presets[preset].__dict__.copy() |
| base.update(model_cfg) |
| return ModelConfig(**base) |
| return ModelConfig(**model_cfg) |
|
|
|
|
| def make_optimizer(model: torch.nn.Module, train_cfg: Dict[str, Any]) -> Tuple[torch.optim.Optimizer, str]: |
| lr = float(train_cfg["learning_rate"]) |
| wd = float(train_cfg["weight_decay"]) |
| betas = tuple(float(x) for x in train_cfg.get("betas", [0.9, 0.95])) |
| prefer_8bit = bool(train_cfg.get("prefer_8bit_adam", True)) |
|
|
| if prefer_8bit: |
| try: |
| import bitsandbytes as bnb |
|
|
| optimizer = bnb.optim.Adam8bit(model.parameters(), lr=lr, betas=betas, weight_decay=wd) |
| return optimizer, "Adam8bit" |
| except Exception: |
| pass |
|
|
| optimizer = AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd) |
| return optimizer, "AdamW" |
|
|
|
|
| def cosine_lr(base_lr: float, step: int, warmup_steps: int, max_steps: int, min_lr_ratio: float) -> float: |
| if step < warmup_steps: |
| return base_lr * (step / max(1, warmup_steps)) |
| progress = (step - warmup_steps) / max(1, max_steps - warmup_steps) |
| progress = min(1.0, max(0.0, progress)) |
| cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| min_lr = base_lr * min_lr_ratio |
| return min_lr + (base_lr - min_lr) * cosine |
|
|
|
|
| def set_optimizer_lr(optimizer: torch.optim.Optimizer, lr: float) -> None: |
| for pg in optimizer.param_groups: |
| pg["lr"] = lr |
|
|
|
|
| def get_vram_gb() -> float: |
| if not torch.cuda.is_available(): |
| return 0.0 |
| return torch.cuda.memory_allocated() / (1024**3) |
|
|
|
|
| def save_checkpoint( |
| ckpt_dir: Path, |
| step: int, |
| model: CodeTransformerLM, |
| optimizer: torch.optim.Optimizer, |
| scaler: Optional[torch.cuda.amp.GradScaler], |
| best_val: float, |
| no_improve_evals: int, |
| config: Dict[str, Any], |
| ) -> Path: |
| ckpt_dir.mkdir(parents=True, exist_ok=True) |
| ckpt_path = ckpt_dir / f"step_{step}.pt" |
| payload = { |
| "step": step, |
| "model_state": model.state_dict(), |
| "optimizer_state": optimizer.state_dict(), |
| "scaler_state": scaler.state_dict() if scaler is not None else None, |
| "best_val": best_val, |
| "no_improve_evals": no_improve_evals, |
| "config": config, |
| } |
| torch.save(payload, ckpt_path) |
| latest = ckpt_dir / "latest.pt" |
| torch.save(payload, latest) |
| return ckpt_path |
|
|
|
|
| def load_checkpoint( |
| ckpt_path: Path, |
| model: CodeTransformerLM, |
| optimizer: torch.optim.Optimizer, |
| scaler: Optional[torch.cuda.amp.GradScaler], |
| device: torch.device, |
| ) -> Tuple[int, float, int]: |
| payload = torch.load(ckpt_path, map_location=device) |
| model.load_state_dict(payload["model_state"]) |
| optimizer.load_state_dict(payload["optimizer_state"]) |
| if scaler is not None and payload.get("scaler_state") is not None: |
| scaler.load_state_dict(payload["scaler_state"]) |
| step = int(payload.get("step", 0)) |
| best_val = float(payload.get("best_val", 1e9)) |
| no_improve = int(payload.get("no_improve_evals", 0)) |
| return step, best_val, no_improve |
|
|
|
|
| @torch.no_grad() |
| def evaluate_loss( |
| model: CodeTransformerLM, |
| val_loader: DataLoader, |
| device: torch.device, |
| use_fp16: bool, |
| max_batches: int = 50, |
| ) -> float: |
| model.eval() |
| losses = [] |
| amp_enabled = use_fp16 and device.type == "cuda" |
| for i, (input_ids, labels) in enumerate(val_loader): |
| if i >= max_batches: |
| break |
| input_ids = input_ids.to(device, non_blocking=True) |
| labels = labels.to(device, non_blocking=True) |
| with torch.amp.autocast("cuda", enabled=amp_enabled, dtype=torch.float16): |
| out = model(input_ids=input_ids, labels=labels) |
| losses.append(float(out["loss"].item())) |
| model.train() |
| if not losses: |
| return 1e9 |
| return sum(losses) / len(losses) |
|
|
|
|
| def train() -> None: |
| args = parse_args() |
| cfg = load_yaml(Path(args.config)) |
| train_cfg = cfg["training"] |
| data_cfg = cfg["data"] |
| resume_cfg = cfg.get("resume", {}) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| if device.type != "cuda": |
| raise RuntimeError("CUDA GPU is required for this training setup.") |
|
|
| model_cfg = load_model_config(Path(cfg["model"]["model_config_path"])) |
| model_cfg.max_seq_len = int(train_cfg["max_seq_len"]) |
| model_cfg.gradient_checkpointing = bool(train_cfg.get("use_gradient_checkpointing", True)) |
|
|
| model = CodeTransformerLM(model_cfg) |
| model.enable_gradient_checkpointing(model_cfg.gradient_checkpointing) |
| model = model.to(device) |
|
|
| use_fp16 = bool(train_cfg.get("use_fp16", True)) |
| scaler = torch.amp.GradScaler("cuda", enabled=use_fp16) |
|
|
| optimizer, optimizer_name = make_optimizer(model, train_cfg) |
|
|
| tokenized_path = str(data_cfg["tokenized_jsonl_path"]) |
| train_ds = TokenizedJsonlDataset( |
| path=tokenized_path, |
| split="train", |
| val_ratio=float(data_cfg.get("val_ratio", 0.02)), |
| split_seed=int(data_cfg.get("split_seed", 17)), |
| ) |
| val_ds = TokenizedJsonlDataset( |
| path=tokenized_path, |
| split="val", |
| val_ratio=float(data_cfg.get("val_ratio", 0.02)), |
| split_seed=int(data_cfg.get("split_seed", 17)), |
| ) |
|
|
| collator = CausalCollator(pad_token_id=0, max_seq_len=int(train_cfg["max_seq_len"])) |
| train_loader = DataLoader( |
| train_ds, |
| batch_size=int(train_cfg["micro_batch_size"]), |
| shuffle=True, |
| num_workers=int(data_cfg.get("num_workers", 0)), |
| pin_memory=True, |
| collate_fn=collator, |
| ) |
| val_loader = DataLoader( |
| val_ds, |
| batch_size=int(train_cfg["micro_batch_size"]), |
| shuffle=False, |
| num_workers=0, |
| pin_memory=True, |
| collate_fn=collator, |
| ) |
|
|
| out_dir = Path(train_cfg["output_dir"]) |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| global_step = 0 |
| best_val = 1e9 |
| no_improve = 0 |
|
|
| resume_from = str(resume_cfg.get("resume_from", "none")).strip().lower() |
| if resume_from != "none": |
| if resume_from == "latest": |
| ckpt_path = out_dir / "latest.pt" |
| else: |
| ckpt_path = Path(resume_cfg["resume_from"]) |
| if ckpt_path.exists(): |
| global_step, best_val, no_improve = load_checkpoint( |
| ckpt_path=ckpt_path, |
| model=model, |
| optimizer=optimizer, |
| scaler=scaler, |
| device=device, |
| ) |
| print(f"[resume] loaded checkpoint {ckpt_path} at step {global_step}") |
| else: |
| print(f"[resume] checkpoint not found, starting fresh: {ckpt_path}") |
|
|
| max_steps = int(train_cfg["max_steps"]) |
| grad_accum = int(train_cfg["grad_accum_steps"]) |
| log_every = int(train_cfg["log_every"]) |
| eval_every = int(train_cfg["eval_every"]) |
| save_every = int(train_cfg["save_every"]) |
| warmup_steps = int(train_cfg["warmup_steps"]) |
| min_lr_ratio = float(train_cfg["min_lr_ratio"]) |
| grad_clip = float(train_cfg["grad_clip_norm"]) |
| max_vram_gb = float(train_cfg.get("max_vram_gb", 7.0)) |
| patience = int(train_cfg.get("early_stopping_patience_evals", 20)) |
| min_delta = float(train_cfg.get("early_stopping_min_delta", 5e-4)) |
| base_lr = float(train_cfg["learning_rate"]) |
|
|
| model.train() |
| start_time = time.time() |
| running_loss = 0.0 |
| running_count = 0 |
|
|
| pbar = tqdm(total=max_steps, initial=global_step, desc="train", dynamic_ncols=True) |
|
|
| while global_step < max_steps: |
| for input_ids, labels in train_loader: |
| if global_step >= max_steps: |
| break |
|
|
| current_lr = cosine_lr(base_lr, global_step, warmup_steps, max_steps, min_lr_ratio) |
| set_optimizer_lr(optimizer, current_lr) |
|
|
| input_ids = input_ids.to(device, non_blocking=True) |
| labels = labels.to(device, non_blocking=True) |
|
|
| amp_enabled = use_fp16 and device.type == "cuda" |
| with torch.amp.autocast("cuda", enabled=amp_enabled, dtype=torch.float16): |
| out = model(input_ids=input_ids, labels=labels) |
| loss = out["loss"] / grad_accum |
|
|
| scaler.scale(loss).backward() |
|
|
| running_loss += float(loss.item()) * grad_accum |
| running_count += 1 |
|
|
| if running_count % grad_accum == 0: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| global_step += 1 |
| pbar.update(1) |
|
|
| elapsed = time.time() - start_time |
| steps_done = max(1, global_step) |
| steps_left = max(0, max_steps - global_step) |
| eta_sec = (elapsed / steps_done) * steps_left |
| avg_loss = running_loss / max(1, running_count) |
| vram = get_vram_gb() |
|
|
| if vram > max_vram_gb: |
| raise RuntimeError( |
| f"VRAM safety threshold exceeded: {vram:.2f} GB > {max_vram_gb:.2f} GB. " |
| "Reduce max_seq_len or grad_accum/micro_batch settings." |
| ) |
|
|
| if global_step % log_every == 0: |
| pbar.set_postfix( |
| { |
| "loss": f"{avg_loss:.4f}", |
| "lr": f"{current_lr:.2e}", |
| "vram_gb": f"{vram:.2f}", |
| "eta_min": f"{eta_sec/60.0:.1f}", |
| } |
| ) |
|
|
| if global_step % save_every == 0: |
| ckpt_path = save_checkpoint( |
| ckpt_dir=out_dir, |
| step=global_step, |
| model=model, |
| optimizer=optimizer, |
| scaler=scaler, |
| best_val=best_val, |
| no_improve_evals=no_improve, |
| config=cfg, |
| ) |
| print(f"\n[checkpoint] saved {ckpt_path}") |
|
|
| if global_step % eval_every == 0: |
| val_loss = evaluate_loss(model, val_loader, device, use_fp16=use_fp16) |
| print(f"\n[eval] step={global_step} val_loss={val_loss:.4f} best={best_val:.4f}") |
| if val_loss < (best_val - min_delta): |
| best_val = val_loss |
| no_improve = 0 |
| else: |
| no_improve += 1 |
| if no_improve >= patience: |
| print( |
| f"\n[early_stop] no improvement for {no_improve} evals " |
| f"(patience={patience}). Stopping training." |
| ) |
| global_step = max_steps |
| break |
|
|
| pbar.close() |
| final_ckpt = save_checkpoint( |
| ckpt_dir=out_dir, |
| step=global_step, |
| model=model, |
| optimizer=optimizer, |
| scaler=scaler, |
| best_val=best_val, |
| no_improve_evals=no_improve, |
| config=cfg, |
| ) |
| print("Training completed.") |
| print(f"Optimizer used: {optimizer_name}") |
| print(f"Final checkpoint: {final_ckpt}") |
|
|
|
|
| def main() -> None: |
| try: |
| train() |
| except Exception as exc: |
| print("Component 5 training failed.") |
| print(f"What went wrong: {exc}") |
| print( |
| "Fix suggestion: lower max_seq_len, keep micro_batch_size=1, " |
| "increase grad_accum_steps, and verify checkpoint/output paths." |
| ) |
| raise SystemExit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|