import argparse import gc import math import os import time from pathlib import Path import torch import torch.nn as nn import yaml from huggingface_hub import hf_hub_download from torch.amp import GradScaler, autocast from sft_train import LUNAModel, SFTDataset, cosine_lr, probe_hardware, run_eval_prompts SEP = "=" * 72 class LoRALinear(nn.Module): def __init__(self, base_layer, rank=16, alpha=32, dropout=0.05): super().__init__() if not isinstance(base_layer, nn.Linear): raise TypeError("LoRALinear expects a torch.nn.Linear base layer") self.base = base_layer self.rank = rank self.alpha = alpha self.scale = alpha / max(rank, 1) self.dropout = nn.Dropout(dropout) self.lora_a = nn.Linear(base_layer.in_features, rank, bias=False) self.lora_b = nn.Linear(rank, base_layer.out_features, bias=False) nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) nn.init.zeros_(self.lora_b.weight) for parameter in self.base.parameters(): parameter.requires_grad = False def forward(self, x): base_out = self.base(x) lora_out = self.lora_b(self.lora_a(self.dropout(x))) * self.scale return base_out + lora_out def load_config(config_path): with open(config_path, encoding="utf-8") as handle: raw = yaml.safe_load(handle) cfg = { "auto_config": raw.get("auto_config", True), "hf_model_repo": raw.get("hf_model_repo", "ASTERIZER/LUNA-100M"), "hf_model_file": raw.get("hf_model_file", "sft_v1/final/model.pth"), "pretrained_ckpt": raw.get("pretrained_ckpt", "Base/out/input_models/luna_sft_v1/model.pth"), "train_json": raw.get("train_json", "Base/Datasets/rag_mcp_sft/train.json"), "val_json": raw.get("val_json", "Base/Datasets/rag_mcp_sft/val.json"), "out_dir": raw.get("out_dir", "Base/out/sft/rag_mcp_lora"), "tokenizer_dir": raw.get("tokenizer_dir", "Base/checkpoints/EleutherAI/pythia-160m"), "vocab_size": raw["model"]["vocab_size"], "seq_len": raw["model"]["seq_len"], "n_layer": raw["model"]["n_layer"], "n_embd": raw["model"]["n_embd"], "n_head": raw["model"]["n_head"], "epochs": raw["train"]["epochs"], "lr_warmup_steps": raw["train"]["lr_warmup_steps"], "save_interval": raw["train"]["save_interval"], "log_interval": raw["train"]["log_interval"], "eval_interval": raw["train"]["eval_interval"], "max_norm": raw["train"]["max_norm"], "lr": raw["optimizer"]["lr"], "min_lr": raw["optimizer"]["min_lr"], "weight_decay": raw["optimizer"]["weight_decay"], "betas": tuple(raw["optimizer"]["betas"]), "eps": raw["optimizer"]["eps"], "global_batch": raw["batch"]["global_batch"], "micro_batch": raw["batch"]["micro_batch"], "grad_accum": raw["batch"]["grad_accum"], "auto_probe_batch": raw["batch"].get("auto_probe_batch", True), "probe_safety": raw["batch"].get("probe_safety", 0.94), "num_workers": raw["dataloader"]["num_workers"], "pin_memory": raw["dataloader"]["pin_memory"], "precision": raw["hardware"]["precision"], "eval_prompts": raw.get("eval_prompts", []), "lora_rank": raw["lora"]["rank"], "lora_alpha": raw["lora"]["alpha"], "lora_dropout": raw["lora"]["dropout"], "target_modules": list(raw["lora"]["target_modules"]), } return cfg def resolve_checkpoint(cfg): ckpt_path = Path(cfg["pretrained_ckpt"]) if ckpt_path.exists(): return ckpt_path ckpt_path.parent.mkdir(parents=True, exist_ok=True) hf_hub_download( repo_id=cfg["hf_model_repo"], filename=cfg["hf_model_file"], local_dir=str(ckpt_path.parent), token=os.environ.get("HF_TOKEN"), ) downloaded = ckpt_path.parent / cfg["hf_model_file"] if not downloaded.exists(): raise FileNotFoundError(f"Expected downloaded checkpoint at {downloaded}") return downloaded def inject_lora(model, target_modules, rank, alpha, dropout): replaced = [] for module_name, module in list(model.named_modules()): if not isinstance(module, nn.Linear): continue if not any(module_name.endswith(target) for target in target_modules): continue parent_name, _, child_name = module_name.rpartition(".") parent_module = model.get_submodule(parent_name) if parent_name else model wrapped = LoRALinear(module, rank=rank, alpha=alpha, dropout=dropout) wrapped = wrapped.to(device=module.weight.device, dtype=module.weight.dtype) setattr(parent_module, child_name, wrapped) replaced.append(module_name) if not replaced: raise RuntimeError("No target modules matched for LoRA injection") return replaced def get_lora_state_dict(model): state_dict = model.state_dict() return { name: tensor.cpu() for name, tensor in state_dict.items() if "lora_a.weight" in name or "lora_b.weight" in name } def count_trainable_parameters(model): return sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad) def probe_max_micro_batch_lora(model, trainable_parameters, device, dtype, seq_len, vocab_size, safety=0.94, grad_accum_sim=2): if device.type != "cuda": return 1 optimizer = torch.optim.AdamW(trainable_parameters, lr=1e-4) lo, hi, best = 1, 512, 1 while lo <= hi: mid = (lo + hi) // 2 try: torch.cuda.empty_cache() gc.collect() optimizer.zero_grad(set_to_none=True) for _ in range(grad_accum_sim): input_ids = torch.randint(0, vocab_size, (mid, seq_len), device=device) loss_mask = torch.ones_like(input_ids) with autocast(device_type="cuda", dtype=dtype): _, loss = model(input_ids, targets=input_ids, loss_mask=loss_mask, return_logits=False) loss = loss / grad_accum_sim loss.backward() del input_ids, loss_mask, loss optimizer.step() optimizer.zero_grad(set_to_none=True) best = mid lo = mid + 1 except (torch.cuda.OutOfMemoryError, RuntimeError) as error: if "out of memory" not in str(error).lower() and not isinstance(error, torch.cuda.OutOfMemoryError): raise optimizer.zero_grad(set_to_none=True) torch.cuda.empty_cache() gc.collect() hi = mid - 1 del optimizer torch.cuda.empty_cache() gc.collect() safe = max(1, int(best * safety)) print(f" LoRA batch probe: max_micro_batch={best}, using {safe} ({int(safety * 100)}% safety)") return safe def load_base_weights(model, checkpoint_path, device): checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) state_dict = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint model.load_state_dict(state_dict, strict=True) def train(cfg): hw = probe_hardware() device = torch.device(hw["device"]) dtype = hw.get("dtype", torch.float32) if cfg["auto_config"] else { "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32, }.get(cfg["precision"], torch.float32) from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(cfg["tokenizer_dir"]) ckpt_path = resolve_checkpoint(cfg) model = LUNAModel( vocab_size=cfg["vocab_size"], block_size=cfg["seq_len"], n_layer=cfg["n_layer"], n_embd=cfg["n_embd"], n_head=cfg["n_head"], ).to(device) load_base_weights(model, ckpt_path, device) for parameter in model.parameters(): parameter.requires_grad = False replaced = inject_lora( model, target_modules=cfg["target_modules"], rank=cfg["lora_rank"], alpha=cfg["lora_alpha"], dropout=cfg["lora_dropout"], ) trainable_params = count_trainable_parameters(model) total_params = sum(parameter.numel() for parameter in model.parameters()) trainable_parameters = [parameter for parameter in model.parameters() if parameter.requires_grad] if cfg["auto_config"] and device.type == "cuda" and cfg["auto_probe_batch"]: print(" Probing LoRA micro_batch against available VRAM...") cfg["micro_batch"] = probe_max_micro_batch_lora( model, trainable_parameters=trainable_parameters, device=device, dtype=dtype, seq_len=cfg["seq_len"], vocab_size=cfg["vocab_size"], safety=cfg["probe_safety"], ) cfg["grad_accum"] = max(1, math.ceil(cfg["global_batch"] / cfg["micro_batch"])) torch.cuda.reset_peak_memory_stats(device) effective_batch = cfg["micro_batch"] * cfg["grad_accum"] train_dataset = SFTDataset(cfg["train_json"], tokenizer, max_len=cfg["seq_len"]) val_dataset = SFTDataset(cfg["val_json"], tokenizer, max_len=cfg["seq_len"]) if Path(cfg["val_json"]).exists() else None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=cfg["micro_batch"], shuffle=True, num_workers=cfg["num_workers"], pin_memory=cfg["pin_memory"], drop_last=True, prefetch_factor=4 if cfg["num_workers"] > 0 else None, persistent_workers=cfg["num_workers"] > 0, ) val_loader = None if val_dataset is not None: val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=cfg["micro_batch"], shuffle=False, num_workers=min(2, cfg["num_workers"]), pin_memory=cfg["pin_memory"], drop_last=False, ) optimizer = torch.optim.AdamW( trainable_parameters, lr=cfg["lr"], weight_decay=cfg["weight_decay"], betas=cfg["betas"], eps=cfg["eps"], ) scaler = GradScaler(enabled=(device.type == "cuda" and dtype == torch.float16)) steps_per_epoch = max(1, len(train_loader) // cfg["grad_accum"]) total_steps = steps_per_epoch * cfg["epochs"] warmup_steps = min(cfg["lr_warmup_steps"], max(1, total_steps // 5)) out_dir = Path(cfg["out_dir"]) out_dir.mkdir(parents=True, exist_ok=True) best_val_loss = float("inf") step = 0 latest_path = out_dir / "latest.pt" if latest_path.exists(): checkpoint = torch.load(latest_path, map_location=device, weights_only=True) model.load_state_dict(checkpoint["adapter"], strict=False) optimizer.load_state_dict(checkpoint["optimizer"]) step = checkpoint["step"] print(SEP) print(" LUNA 100M - LoRA SFT") print(SEP) print(f" Base checkpoint : {ckpt_path}") print(f" Train dataset : {cfg['train_json']}") print(f" Val dataset : {cfg['val_json']}") print(f" Output dir : {out_dir}") print(f" Device : {hw['gpu_name']} ({hw['vram_gb']:.1f} GB)") print(f" Precision : {cfg['precision']} dtype={dtype}") print(f" LoRA modules : {', '.join(replaced)}") print(f" Trainable params: {trainable_params:,} / {total_params:,}") print(f" micro_batch : {cfg['micro_batch']}") print(f" grad_accum : {cfg['grad_accum']}") print(f" effective_batch : {effective_batch}") print(f" Train samples : {len(train_dataset):,}") print(f" Val samples : {len(val_dataset):,}" if val_dataset is not None else " Val samples : 0") print(SEP) if cfg["eval_prompts"] and step == 0: run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, 0, out_dir) model.train() run_t0 = time.perf_counter() for epoch in range(cfg["epochs"]): micro_step = 0 for input_ids, loss_mask in train_loader: current_global_step = epoch * steps_per_epoch + (micro_step // cfg["grad_accum"]) if current_global_step < step and (micro_step % cfg["grad_accum"] == cfg["grad_accum"] - 1): micro_step += 1 continue if current_global_step >= total_steps: break input_ids = input_ids.to(device, non_blocking=True) loss_mask = loss_mask.to(device, non_blocking=True) step_start = time.perf_counter() with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")): _, loss = model(input_ids, targets=input_ids, loss_mask=loss_mask, return_logits=False) loss = loss / cfg["grad_accum"] scaler.scale(loss).backward() micro_step += 1 if micro_step % cfg["grad_accum"] != 0: continue scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(trainable_parameters, cfg["max_norm"]) lr_now = cosine_lr(step, warmup_steps, total_steps, cfg["lr"], cfg["min_lr"]) for param_group in optimizer.param_groups: param_group["lr"] = lr_now scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) if device.type == "cuda": torch.cuda.synchronize() dt = time.perf_counter() - step_start step += 1 if step % cfg["log_interval"] == 0 or step <= 3: tokens_step = effective_batch * cfg["seq_len"] tps = tokens_step / max(dt, 1e-6) vram = torch.cuda.max_memory_allocated() / 1024**3 if device.type == "cuda" else 0 eta_h = (total_steps - step) * dt / 3600 print( f" step {step:6d}/{total_steps} | epoch {epoch + 1}/{cfg['epochs']} | " f"loss {loss.item() * cfg['grad_accum']:.4f} | lr {lr_now:.2e} | " f"{tps:,.0f} tok/s | VRAM {vram:.1f}GB | ETA {eta_h:.1f}h" ) if step % cfg["save_interval"] == 0 or step == total_steps: step_dir = out_dir / f"step-{step:06d}" step_dir.mkdir(parents=True, exist_ok=True) adapter_state = get_lora_state_dict(model) torch.save(adapter_state, step_dir / "adapter_model.pt") torch.save( { "step": step, "adapter": adapter_state, "optimizer": optimizer.state_dict(), "epoch": epoch, "loss": loss.item() * cfg["grad_accum"], }, latest_path, ) print(f" Saved -> {step_dir}") if step % cfg["eval_interval"] == 0 or step == total_steps: if val_loader is not None: model.eval() val_loss_sum = 0.0 val_count = 0 with torch.no_grad(): for val_ids, val_mask in val_loader: val_ids = val_ids.to(device, non_blocking=True) val_mask = val_mask.to(device, non_blocking=True) with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")): _, val_loss = model(val_ids, targets=val_ids, loss_mask=val_mask, return_logits=False) val_loss_sum += val_loss.item() val_count += 1 if val_count >= 50: break avg_val = val_loss_sum / max(val_count, 1) print(f" Val loss: {avg_val:.4f}") if avg_val < best_val_loss: best_val_loss = avg_val torch.save(get_lora_state_dict(model), out_dir / "best_adapter_model.pt") print(" New best! Saved best_adapter_model.pt") model.train() if cfg["eval_prompts"]: run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, step, out_dir) final_dir = out_dir / "final" final_dir.mkdir(parents=True, exist_ok=True) torch.save(get_lora_state_dict(model), final_dir / "adapter_model.pt") torch.save( { "step": step, "adapter": get_lora_state_dict(model), "lora_rank": cfg["lora_rank"], "lora_alpha": cfg["lora_alpha"], "lora_dropout": cfg["lora_dropout"], "target_modules": cfg["target_modules"], "base_checkpoint": str(ckpt_path), }, final_dir / "adapter_bundle.pt", ) total_h = (time.perf_counter() - run_t0) / 3600 print(SEP) print(f" LoRA SFT complete in {total_h:.2f}h -> {final_dir}") print(f" Best val loss: {best_val_loss:.4f}") print(SEP) def parse_args(): parser = argparse.ArgumentParser(description="LUNA 100M - LoRA SFT") parser.add_argument("--config", default="rag_mcp_lora_config.yaml") parser.add_argument("--pretrained_ckpt", default=None) parser.add_argument("--train_json", default=None) parser.add_argument("--val_json", default=None) parser.add_argument("--out_dir", default=None) parser.add_argument("--epochs", type=int, default=None) return parser.parse_args() def main(): args = parse_args() cfg = load_config(args.config) for key in ("pretrained_ckpt", "train_json", "val_json", "out_dir"): value = getattr(args, key) if value: cfg[key] = value if args.epochs is not None: cfg["epochs"] = args.epochs train(cfg) if __name__ == "__main__": main()