Spaces:
Runtime error
Runtime error
| """HYDRA SFT — instruction fine-tune the pretrained 7.5M-param base. | |
| Mode selection: | |
| MODE=resume_from_pretrain iff ~/.cache/autoresearch/pretrain_final.pt | |
| exists AND loads cleanly into a fresh model. | |
| MODE=from_scratch otherwise (degraded fallback). | |
| Data: int16 shards written by `scripts/download_sft_data.py`, paired with | |
| uint8 loss-mask shards (1 on assistant tokens, 0 on user-prompt tokens). | |
| At runtime we pack consecutive examples into fixed-length rows; prompt | |
| positions get target=-1 so CE's `ignore_index=-1` drops them. | |
| Env vars (with defaults tuned for RTX 3060 6GB, 7.5M params): | |
| HYDRA_SFT_TIME_BUDGET 10800 SFT wall-clock budget (3h) | |
| HYDRA_SFT_SEQ_LEN 512 sequence length during SFT | |
| HYDRA_BATCH_SIZE 4 per-step device batch | |
| HYDRA_TOTAL_BATCH 8192 effective batch (grad-accum derived) | |
| HYDRA_SFT_LR_MULT 0.10 multiply pretrain LRs by this | |
| HYDRA_SFT_EVAL_INTERVAL 500 steps between sample generations | |
| HYDRA_SFT_CKPT_INTERVAL 2000 steps between interim checkpoints | |
| CLI: | |
| --dry-run load model+data, run 1 step, exit (validation path) | |
| --eval-only load `sft_final.pt`, run sample gen, exit | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import math | |
| import os | |
| import sys | |
| import time | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| # Repo root on path | |
| _REPO_ROOT = Path(__file__).resolve().parent.parent | |
| if str(_REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| # Must import hydra.config BEFORE touching torch.cuda for CUDA env setup | |
| from hydra.config import ( | |
| ADAM_BETAS, D_MODEL, D_STATE, DEVICE_BATCH_SIZE, EMBEDDING_LR, | |
| ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND, | |
| FINAL_LR_FRAC, GPU_BF16_PEAK_FLOPS, HEADDIM, MATRIX_LR, N_HEADS, | |
| N_LAYER, PostSemClawConfig, SCALAR_LR, SEED, TOTAL_BATCH_SIZE, | |
| UNEMBEDDING_LR, WARMUP_RATIO, WEIGHT_DECAY, | |
| ) | |
| from hydra.model import PostSemClawModel | |
| from prepare import Tokenizer | |
| # Use line-buffered stdout | |
| try: | |
| sys.stdout.reconfigure(line_buffering=True) | |
| except Exception: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # Paths | |
| # --------------------------------------------------------------------------- | |
| CACHE_DIR = Path.home() / ".cache" / "autoresearch" | |
| PRETRAIN_CKPT = CACHE_DIR / "pretrain_final.pt" | |
| SFT_FINAL_CKPT = CACHE_DIR / "sft_final.pt" | |
| SFT_INTERIM_CKPT = CACHE_DIR / "sft_interim.pt" | |
| SFT_DATA_DIR = _REPO_ROOT / "data" / "sft" | |
| # --------------------------------------------------------------------------- | |
| # Env vars for SFT | |
| # --------------------------------------------------------------------------- | |
| SFT_TIME_BUDGET = int(os.environ.get("HYDRA_SFT_TIME_BUDGET", "10800")) | |
| SFT_SEQ_LEN = int(os.environ.get("HYDRA_SFT_SEQ_LEN", "512")) | |
| SFT_LR_MULT = float(os.environ.get("HYDRA_SFT_LR_MULT", "0.10")) | |
| SFT_EVAL_INTERVAL = int(os.environ.get("HYDRA_SFT_EVAL_INTERVAL", "500")) | |
| SFT_CKPT_INTERVAL = int(os.environ.get("HYDRA_SFT_CKPT_INTERVAL", "2000")) | |
| # --------------------------------------------------------------------------- | |
| # Data loading | |
| # --------------------------------------------------------------------------- | |
| def _load_meta() -> dict: | |
| meta_path = SFT_DATA_DIR / "meta.json" | |
| if not meta_path.exists(): | |
| raise FileNotFoundError( | |
| f"SFT meta not found at {meta_path}. Run " | |
| f"`python scripts/download_sft_data.py` first." | |
| ) | |
| with open(meta_path) as f: | |
| return json.load(f) | |
| def _load_shards(): | |
| """Load all shard_XXX.bin + mask_XXX.bin as big flat arrays. | |
| Returns: (tokens: np.int64, mask: np.uint8) | |
| Both arrays are 1-D and the same length. Total len ~= target_tokens. | |
| """ | |
| tok_shards = sorted(SFT_DATA_DIR.glob("shard_*.bin")) | |
| mask_shards = sorted(SFT_DATA_DIR.glob("mask_*.bin")) | |
| if not tok_shards: | |
| raise FileNotFoundError(f"No SFT shards in {SFT_DATA_DIR}") | |
| assert len(tok_shards) == len(mask_shards), ( | |
| f"shard/mask count mismatch: {len(tok_shards)} vs {len(mask_shards)}" | |
| ) | |
| tok_parts = [] | |
| mask_parts = [] | |
| for t, m in zip(tok_shards, mask_shards): | |
| tok_parts.append(np.fromfile(str(t), dtype=np.int16).astype(np.int64)) | |
| mask_parts.append(np.fromfile(str(m), dtype=np.uint8)) | |
| tokens = np.concatenate(tok_parts) | |
| mask = np.concatenate(mask_parts) | |
| assert tokens.shape == mask.shape | |
| # Guard against negative int16 values (unlikely with vocab=8192 but defensive) | |
| if tokens.min() < 0 or tokens.max() >= 8192: | |
| raise ValueError( | |
| f"Token IDs out of range: min={tokens.min()} max={tokens.max()}" | |
| ) | |
| return tokens, mask | |
| def make_sft_dataloader(tokens: np.ndarray, mask: np.ndarray, B: int, T: int, | |
| device: torch.device, seed: int = 0): | |
| """Yield (x, y, epoch) forever. | |
| Each row is a slice of length T+1 sampled at a random start. We produce: | |
| x = slice[:-1] (B, T) int64 on device | |
| y = slice[1:] with mask=0 -> -1 (B, T) int64 on device | |
| The mask applies to target positions (y), not inputs. This way the | |
| chunked CE loss in model.forward sees ignore_index=-1 for prompt tokens. | |
| """ | |
| N = tokens.shape[0] | |
| rng = np.random.default_rng(seed) | |
| # Pin CPU tensors; copy to GPU non-blocking. | |
| cpu_x = torch.empty(B, T, dtype=torch.long, pin_memory=True) | |
| cpu_y = torch.empty(B, T, dtype=torch.long, pin_memory=True) | |
| epoch = 1 | |
| samples_drawn = 0 | |
| samples_per_epoch = max(1, N // (T + 1)) | |
| # Minimum loss-positions per window. If a sampled window has fewer than | |
| # this many assistant tokens, resample. Guards against all-prompt windows | |
| # producing NaN from 0/0 in the chunked CE loss. | |
| min_loss_positions = max(1, T // 32) | |
| max_resample = 8 | |
| while True: | |
| for b in range(B): | |
| # Sample a starting index with a light rejection filter to ensure | |
| # the window contains enough assistant (mask=1) positions. | |
| if N <= T + 1: | |
| start = 0 | |
| else: | |
| start = int(rng.integers(0, N - T - 1)) | |
| for _ in range(max_resample): | |
| loss_in_window = int(mask[start + 1:start + T + 1].sum()) | |
| if loss_in_window >= min_loss_positions: | |
| break | |
| start = int(rng.integers(0, N - T - 1)) | |
| window_tok = tokens[start:start + T + 1] | |
| window_mask = mask[start:start + T + 1] | |
| # x = window[:-1], y = window[1:] | |
| cpu_x[b].copy_(torch.from_numpy(window_tok[:-1].astype(np.int64))) | |
| y_slice = window_tok[1:].astype(np.int64).copy() | |
| # Apply mask to targets: mask=0 (prompt) -> target=-1 (ignore) | |
| y_slice[window_mask[1:] == 0] = -1 | |
| # Final guard: if no loss positions survived, force at least 1 | |
| # valid target so the batch doesn't produce NaN (it's rare with | |
| # the rejection filter but defensive is cheap). | |
| if (y_slice != -1).sum() == 0: | |
| y_slice[-1] = int(window_tok[-1]) | |
| cpu_y[b].copy_(torch.from_numpy(y_slice)) | |
| x = cpu_x.to(device, non_blocking=True) | |
| y = cpu_y.to(device, non_blocking=True) | |
| samples_drawn += B | |
| if samples_drawn >= samples_per_epoch: | |
| epoch += 1 | |
| samples_drawn = 0 | |
| yield x, y, epoch | |
| # --------------------------------------------------------------------------- | |
| # Model init + checkpoint load | |
| # --------------------------------------------------------------------------- | |
| def _peek_pretrain_config(vocab_size: int) -> PostSemClawConfig | None: | |
| """If pretrain checkpoint exists, return its saved config so we build | |
| the SFT model with matching architecture. Returns None if unavailable.""" | |
| if not PRETRAIN_CKPT.exists(): | |
| return None | |
| try: | |
| ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cpu", | |
| weights_only=False) | |
| cfg_dict = ckpt.get("config") | |
| if cfg_dict is None: | |
| return None | |
| # Override sequence_len to SFT's (shorter context) — architecture | |
| # is independent of sequence_len since Mamba3 is recurrent. | |
| cfg_dict = dict(cfg_dict) | |
| cfg_dict["sequence_len"] = SFT_SEQ_LEN | |
| cfg_dict["vocab_size"] = vocab_size | |
| cfg = PostSemClawConfig(**cfg_dict) | |
| return cfg | |
| except Exception as e: | |
| print(f"[model] could not peek pretrain config: {type(e).__name__}: {e}", | |
| flush=True) | |
| return None | |
| def build_model(vocab_size: int, device: torch.device) -> PostSemClawModel: | |
| # Prefer checkpoint-derived config if available (guards against env-var drift) | |
| config = _peek_pretrain_config(vocab_size) | |
| if config is None: | |
| config = PostSemClawConfig( | |
| sequence_len=SFT_SEQ_LEN, | |
| vocab_size=vocab_size, | |
| n_layer=N_LAYER, | |
| d_model=D_MODEL, | |
| d_state=D_STATE, | |
| headdim=HEADDIM, | |
| n_heads=N_HEADS, | |
| expand=EXPAND, | |
| engram_n_columns=ENGRAM_N_COLUMNS, | |
| engram_key_dim=ENGRAM_KEY_DIM, | |
| engram_layer_idx=ENGRAM_LAYER_IDX, | |
| ) | |
| print(f"[model] config (from env, no ckpt): {asdict(config)}", flush=True) | |
| else: | |
| print(f"[model] config (from pretrain ckpt): {asdict(config)}", flush=True) | |
| with torch.device("meta"): | |
| model = PostSemClawModel(config) | |
| model.to_empty(device=device) | |
| model.init_weights() | |
| return model | |
| def try_load_pretrain(model: PostSemClawModel) -> tuple[bool, str]: | |
| """Attempt to load pretrain checkpoint into model. Returns (loaded, msg).""" | |
| if not PRETRAIN_CKPT.exists(): | |
| return False, f"no checkpoint at {PRETRAIN_CKPT}" | |
| try: | |
| ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cuda", | |
| weights_only=False) | |
| state = ckpt.get("model_state_dict", ckpt) | |
| # Use strict=False in case SDR/HTM params are excluded from state_dict | |
| # by torch.compile wrappers or similar. | |
| missing, unexpected = model.load_state_dict(state, strict=False) | |
| msg = (f"loaded {PRETRAIN_CKPT} — missing={len(missing)} " | |
| f"unexpected={len(unexpected)}") | |
| if missing: | |
| # Log first few missing keys to help diagnose architecture skew | |
| msg += f" first_missing={missing[:3]}" | |
| return True, msg | |
| except Exception as e: | |
| return False, f"load failed: {type(e).__name__}: {e}" | |
| # --------------------------------------------------------------------------- | |
| # Sample generation (for in-training eval prints) | |
| # --------------------------------------------------------------------------- | |
| _SAMPLE_PROMPTS = [ | |
| "What is the capital of France?", | |
| "Write a haiku about winter.", | |
| "List three colors.", | |
| "How are you?", | |
| "Explain why the sky is blue in one sentence.", | |
| ] | |
| def sample_once(model, tokenizer, meta: dict, prompt: str, | |
| max_new: int = 64, temperature: float = 0.8, | |
| top_k: int = 40) -> str: | |
| """Generate a chat-formatted reply. Stops on <|end|> or max_new tokens.""" | |
| bos = meta["special_tokens"]["bos"] | |
| user = meta["special_tokens"]["user"] | |
| assistant = meta["special_tokens"]["assistant"] | |
| end = meta["special_tokens"]["end"] | |
| prompt_ids = [bos, user] + tokenizer.encode("\n" + prompt.strip()) | |
| prompt_ids += tokenizer.encode("\n") | |
| prompt_ids.append(assistant) | |
| prompt_ids += tokenizer.encode("\n") | |
| ctx = torch.tensor([prompt_ids], device="cuda", dtype=torch.long) | |
| generated: list[int] = [] | |
| for _ in range(max_new): | |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| logits = model(ctx, targets=None) | |
| last = logits[0, -1].float() | |
| if top_k and top_k < last.shape[-1]: | |
| kth = torch.topk(last, top_k).values[-1] | |
| last = torch.where(last < kth, torch.full_like(last, -1e9), last) | |
| probs = torch.softmax(last / max(temperature, 1e-6), dim=-1) | |
| next_id = int(torch.multinomial(probs, num_samples=1).item()) | |
| generated.append(next_id) | |
| if next_id == end: | |
| break | |
| ctx = torch.cat( | |
| [ctx, torch.tensor([[next_id]], device="cuda", dtype=torch.long)], | |
| dim=1, | |
| ) | |
| # Hard cap on ctx length (model was trained at 2048, SFT at 512, | |
| # but inference could theoretically go longer) | |
| if ctx.size(1) >= 2048: | |
| break | |
| try: | |
| text = tokenizer.decode(generated) | |
| except Exception: | |
| text = "<decode error>" | |
| return text | |
| def run_samples(model, tokenizer, meta: dict, step: int): | |
| model.eval() | |
| print(f"\n=== SFT samples @ step {step} ===", flush=True) | |
| for p in _SAMPLE_PROMPTS: | |
| try: | |
| resp = sample_once(model, tokenizer, meta, p) | |
| except Exception as e: | |
| resp = f"<sample failed: {type(e).__name__}: {e}>" | |
| # Sanitize newlines for log readability | |
| resp_clean = resp.replace("\n", " ⏎ ").replace("\r", " ") | |
| print(f" prompt: {p!r}") | |
| print(f" reply: {resp_clean!r}") | |
| print("=== end samples ===\n", flush=True) | |
| model.train() | |
| # --------------------------------------------------------------------------- | |
| # Checkpoint save | |
| # --------------------------------------------------------------------------- | |
| def save_ckpt(model, step: int, smoothed_loss: float, path: Path, | |
| mode: str, meta: dict): | |
| try: | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| payload = { | |
| "model_state_dict": model.state_dict(), | |
| "step": step, | |
| "smoothed_loss": smoothed_loss, | |
| "mode": mode, | |
| "sft_meta": meta, | |
| } | |
| torch.save(payload, str(path)) | |
| print(f"[ckpt] saved {path} (step={step})", flush=True) | |
| except Exception as e: | |
| print(f"[ckpt] SAVE FAILED {path}: {type(e).__name__}: {e}", flush=True) | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--dry-run", action="store_true", | |
| help="Load model+data, run 1 step, exit.") | |
| ap.add_argument("--eval-only", action="store_true", | |
| help="Load sft_final.pt and run sample gen.") | |
| args = ap.parse_args() | |
| t_start = time.time() | |
| torch.manual_seed(SEED + 1) # +1 so SFT draws different RNG than pretrain | |
| torch.cuda.manual_seed(SEED + 1) | |
| torch.set_float32_matmul_precision("high") | |
| device = torch.device("cuda") | |
| autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) | |
| # --- Tokenizer --- | |
| tokenizer = Tokenizer.from_directory() | |
| vocab_size = tokenizer.get_vocab_size() | |
| print(f"[init] vocab: {vocab_size}", flush=True) | |
| # --- Data meta --- | |
| meta = _load_meta() | |
| print(f"[data] meta: {meta}", flush=True) | |
| # --- Model --- | |
| model = build_model(vocab_size, device) | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| print(f"[model] params: {n_params:,}", flush=True) | |
| loaded, msg = try_load_pretrain(model) | |
| mode = "resume_from_pretrain" if loaded else "from_scratch" | |
| print(f"[init] MODE={mode} :: {msg}", flush=True) | |
| # --- Eval-only path --- | |
| if args.eval_only: | |
| if SFT_FINAL_CKPT.exists(): | |
| ckpt = torch.load(str(SFT_FINAL_CKPT), map_location=device, | |
| weights_only=False) | |
| state = ckpt.get("model_state_dict", ckpt) | |
| model.load_state_dict(state, strict=False) | |
| print(f"[eval-only] loaded {SFT_FINAL_CKPT}", flush=True) | |
| else: | |
| print(f"[eval-only] no SFT checkpoint — running on current weights", | |
| flush=True) | |
| run_samples(model, tokenizer, meta, step=-1) | |
| return | |
| # --- Dataloader --- | |
| print(f"[data] loading shards ...", flush=True) | |
| tokens, mask = _load_shards() | |
| print(f"[data] tokens: {len(tokens):,} loss-positions: {int(mask.sum()):,}", | |
| flush=True) | |
| B = DEVICE_BATCH_SIZE | |
| T = SFT_SEQ_LEN | |
| tokens_per_fwdbwd = B * T | |
| assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0, ( | |
| f"TOTAL_BATCH_SIZE={TOTAL_BATCH_SIZE} not divisible by B*T={tokens_per_fwdbwd}" | |
| ) | |
| grad_accum = TOTAL_BATCH_SIZE // tokens_per_fwdbwd | |
| print(f"[train] B={B} T={T} accum={grad_accum} effective_batch={TOTAL_BATCH_SIZE}", | |
| flush=True) | |
| loader = make_sft_dataloader(tokens, mask, B, T, device, seed=SEED + 7) | |
| x, y, epoch = next(loader) | |
| # --- Optimizer (scaled LRs) --- | |
| matrix_lr = MATRIX_LR * SFT_LR_MULT | |
| embed_lr = EMBEDDING_LR * SFT_LR_MULT | |
| unembed_lr = UNEMBEDDING_LR * SFT_LR_MULT | |
| scalar_lr = SCALAR_LR * SFT_LR_MULT | |
| print(f"[opt] LRs scaled by {SFT_LR_MULT}: matrix={matrix_lr:.5f} " | |
| f"embed={embed_lr:.5f} unembed={unembed_lr:.6f}", flush=True) | |
| optimizer = model.setup_optimizer( | |
| unembedding_lr=unembed_lr, | |
| embedding_lr=embed_lr, | |
| scalar_lr=scalar_lr, | |
| adam_betas=ADAM_BETAS, | |
| matrix_lr=matrix_lr, | |
| weight_decay=WEIGHT_DECAY, | |
| ) | |
| # --- Dry-run path (validation) --- | |
| if args.dry_run: | |
| print("[dry-run] running 1 step ...", flush=True) | |
| with autocast_ctx: | |
| loss = model(x, y) | |
| loss_f = float(loss.item()) | |
| print(f"[dry-run] step0 loss={loss_f:.4f}", flush=True) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| model.zero_grad(set_to_none=True) | |
| if math.isnan(loss_f) or loss_f > 100: | |
| print("[dry-run] FAILED (NaN / huge loss)", flush=True) | |
| sys.exit(1) | |
| print("[dry-run] OK", flush=True) | |
| return | |
| # --- Training loop --- | |
| print(f"[train] budget={SFT_TIME_BUDGET}s eval_every={SFT_EVAL_INTERVAL} " | |
| f"ckpt_every={SFT_CKPT_INTERVAL}", flush=True) | |
| t_loop_start = time.time() | |
| smooth_loss = 0.0 | |
| step = 0 | |
| total_train_secs = 0.0 | |
| # Warmup schedule for SFT: linear 0->1 over first 5% of budget, then cosine. | |
| sft_warmup_frac = 0.05 | |
| def lr_mult(progress: float) -> float: | |
| if progress < sft_warmup_frac: | |
| return progress / sft_warmup_frac if sft_warmup_frac > 0 else 1.0 | |
| decay = (progress - sft_warmup_frac) / (1.0 - sft_warmup_frac) | |
| return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * \ | |
| (1 + math.cos(math.pi * decay)) | |
| while True: | |
| torch.cuda.synchronize() | |
| t0 = time.time() | |
| for _ in range(grad_accum): | |
| with autocast_ctx: | |
| loss = model(x, y) | |
| train_loss_val = loss.detach() | |
| (loss / grad_accum).backward() | |
| x, y, epoch = next(loader) | |
| progress = min(total_train_secs / SFT_TIME_BUDGET, 1.0) | |
| mult = lr_mult(progress) | |
| for group in optimizer.param_groups: | |
| group["lr"] = group["initial_lr"] * mult | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| model.zero_grad(set_to_none=True) | |
| loss_f = float(train_loss_val.item()) | |
| if math.isnan(loss_f) or loss_f > 100: | |
| print(f"[FAIL] step={step} loss={loss_f} — aborting", flush=True) | |
| save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta) | |
| sys.exit(1) | |
| torch.cuda.synchronize() | |
| dt = time.time() - t0 | |
| if step > 3: | |
| total_train_secs += dt | |
| # EMA loss (debiased) | |
| beta = 0.9 | |
| smooth_loss = beta * smooth_loss + (1 - beta) * loss_f | |
| debiased = smooth_loss / (1 - beta ** (step + 1)) | |
| bpt = debiased / math.log(2) | |
| tps = int(TOTAL_BATCH_SIZE / dt) if dt > 0 else 0 | |
| vram_mib = torch.cuda.memory_allocated() / 1024 / 1024 | |
| lr_now = optimizer.param_groups[0]["lr"] | |
| remaining = max(0, SFT_TIME_BUDGET - total_train_secs) | |
| print( | |
| f"sft_step={step:05d} loss={debiased:.4f} bpt={bpt:.3f} " | |
| f"tps={tps} dt_ms={dt*1000:.0f} lr={lr_now:.2e} " | |
| f"vram={vram_mib:.0f}MiB pct={100*progress:.1f} " | |
| f"epoch={epoch} remaining={remaining:.0f}s", | |
| flush=True, | |
| ) | |
| if step > 0 and step % SFT_EVAL_INTERVAL == 0: | |
| run_samples(model, tokenizer, meta, step) | |
| if step > 0 and step % SFT_CKPT_INTERVAL == 0: | |
| save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta) | |
| step += 1 | |
| if step > 5 and total_train_secs >= SFT_TIME_BUDGET: | |
| break | |
| # Final samples + save | |
| run_samples(model, tokenizer, meta, step) | |
| save_ckpt(model, step, smooth_loss, SFT_FINAL_CKPT, mode, meta) | |
| total_secs = time.time() - t_start | |
| print("---", flush=True) | |
| print(f"SFT_COMPLETE mode={mode} step={step} " | |
| f"smoothed_loss={smooth_loss:.4f} total_seconds={total_secs:.0f} " | |
| f"train_seconds={total_train_secs:.0f}", flush=True) | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except SystemExit: | |
| raise | |
| except Exception as e: | |
| import traceback | |
| print(f"SFT_FAILED {type(e).__name__}: {e}", flush=True) | |
| traceback.print_exc() | |
| sys.exit(1) | |