"""Small SFT diagnostics for checkpoint quality and trainability. This script intentionally bypasses the full trainer so it can answer one narrow question quickly: can the checkpoint reduce response-only SFT loss on a tiny, fixed batch? """ from __future__ import annotations import argparse import json import math from pathlib import Path from typing import Any import torch from taoTrain.checkpointing.checkpoint import CheckpointManager from taoTrain.config import TrainingModeEnum, load_config from taoTrain.core import create_model from taoTrain.data.sft_utils import build_sft_sequence_tokens, parse_sft_record try: from taoTrain.data.sft_utils import build_response_only_next_token_labels except ImportError: def build_response_only_next_token_labels(input_ids: list[int], mask: list[int]) -> list[int]: labels = [token_id if mask_value else -100 for token_id, mask_value in zip(input_ids, mask)] return labels[1:] + [-100] from taoTrain.data.tokenizer import SentencePieceTokenizerWrapper from taoTrain.utils import set_seed def load_tokenizer(tokenizer_path: str): path = Path(tokenizer_path) if path.suffix == ".model": import sentencepiece as spm sp = spm.SentencePieceProcessor() sp.Load(str(path)) return SentencePieceTokenizerWrapper(sp) from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) if getattr(tokenizer, "pad_token", None) is None and getattr(tokenizer, "eos_token", None): tokenizer.pad_token = tokenizer.eos_token return tokenizer def read_jsonl_records(path: str, limit: int) -> list[dict[str, Any]]: records = [] with open(path, "r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue records.append(json.loads(line)) if len(records) >= limit: break return records def build_batch(config, tokenizer, records: list[dict[str, Any]], device: torch.device) -> dict[str, torch.Tensor]: input_rows = [] attention_rows = [] label_rows = [] train_tokens = [] for record in records: turns, _ = parse_sft_record(record, config) if not turns: continue input_ids, attention_mask, mask = build_sft_sequence_tokens( turns=turns, tokenizer=tokenizer, user_token=getattr(config, "user_token", ""), assistant_token=getattr(config, "assistant_token", ""), max_seq_length=config.model.max_seq_length, ) labels = build_response_only_next_token_labels(input_ids, mask) input_rows.append(input_ids) attention_rows.append(attention_mask) label_rows.append(labels) train_tokens.append(sum(1 for value in labels if value != -100)) if not input_rows: raise ValueError("No valid SFT records found for the diagnostic batch") return { "input_ids": torch.tensor(input_rows, dtype=torch.long, device=device), "attention_mask": torch.tensor(attention_rows, dtype=torch.long, device=device), "labels": torch.tensor(label_rows, dtype=torch.long, device=device), "train_tokens": torch.tensor(train_tokens, dtype=torch.long), } @torch.no_grad() def score_batch(model, batch: dict[str, torch.Tensor], dtype: torch.dtype) -> float: model.eval() device_type = "cuda" if batch["input_ids"].is_cuda else "cpu" enabled = device_type == "cuda" and dtype in (torch.float16, torch.bfloat16) with torch.autocast(device_type=device_type, dtype=dtype, enabled=enabled): outputs = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"], ) return float(outputs["loss"].detach().cpu()) def grad_l2_norm(parameters) -> float: total = 0.0 for parameter in parameters: if parameter.grad is None: continue grad = parameter.grad.detach() total += float(torch.sum(grad.float() * grad.float()).cpu()) return math.sqrt(total) def grad_summary(named_parameters, max_items: int = 12) -> dict[str, Any]: groups: dict[str, dict[str, Any]] = {} worst = [] nonfinite = [] for name, parameter in named_parameters: if parameter.grad is None: continue grad = parameter.grad.detach().float() finite = torch.isfinite(grad) finite_count = int(finite.sum().cpu()) numel = grad.numel() finite_abs_max = float(grad[finite].abs().max().cpu()) if finite_count else float("inf") has_nonfinite = finite_count != numel if has_nonfinite: nonfinite.append(name) if ".layers." in name: parts = name.split(".") try: idx = parts.index("layers") group = "layer_" + parts[idx + 1] except (ValueError, IndexError): group = "layers" else: group = name.split(".", 1)[0] entry = groups.setdefault(group, { "numel": 0, "finite": 0, "nonfinite_tensors": 0, "max_abs_grad": 0.0, }) entry["numel"] += numel entry["finite"] += finite_count entry["nonfinite_tensors"] += int(has_nonfinite) entry["max_abs_grad"] = max(entry["max_abs_grad"], finite_abs_max) worst.append((finite_abs_max, name)) worst.sort(reverse=True, key=lambda item: item[0]) return { "groups": groups, "worst_tensors": [{"name": name, "max_abs_grad": value} for value, name in worst[:max_items]], "nonfinite_tensors": nonfinite[:max_items], "nonfinite_tensor_count": len(nonfinite), } def freeze_ssm_core_parameters(model) -> int: frozen = 0 markers = ( ".ssm_lanes.", ".ssm.", ) for name, parameter in model.named_parameters(): if any(marker in name for marker in markers): parameter.requires_grad_(False) frozen += parameter.numel() return frozen def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--config", required=True) parser.add_argument("--checkpoint", required=True) parser.add_argument("--output", required=True) parser.add_argument("--samples", type=int, default=2) parser.add_argument("--steps", type=int, default=80) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--log-every", type=int, default=10) parser.add_argument("--device", default="cuda") parser.add_argument("--dtype", choices=["config", "float32", "float16", "bfloat16"], default="config") parser.add_argument("--no-clip", action="store_true") parser.add_argument("--freeze-ssm-core", action="store_true") parser.add_argument("--ssm-branch-rms-norm", action="store_true") parser.add_argument("--ssm-branch-clip-value", type=float, default=None) parser.add_argument("--block-residual-rms-norm", action="store_true") parser.add_argument("--block-residual-rms-target", type=float, default=None) parser.add_argument("--seed", type=int, default=123) args = parser.parse_args() set_seed(args.seed) config = load_config(args.config, TrainingModeEnum.SFT) if args.ssm_branch_rms_norm: config.model.ssm_branch_rms_norm = True if args.ssm_branch_clip_value is not None: config.model.ssm_branch_clip_value = args.ssm_branch_clip_value if args.block_residual_rms_norm: config.model.block_residual_rms_norm = True if args.block_residual_rms_target is not None: config.model.block_residual_rms_target = args.block_residual_rms_target device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu") if args.dtype == "float32": dtype = torch.float32 elif args.dtype == "float16": dtype = torch.float16 elif args.dtype == "bfloat16": dtype = torch.bfloat16 else: dtype = torch.bfloat16 if str(config.dtype) == "DataTypeEnum.BFLOAT16" or str(config.dtype) == "bfloat16" else torch.float32 tokenizer = load_tokenizer(config.dataset.tokenizer_path) records = read_jsonl_records(config.dataset.jsonl_path, args.samples) batch = build_batch(config, tokenizer, records, device) model = create_model(config, device) checkpoint = CheckpointManager(config.checkpoint_dir).load(args.checkpoint, device=device) model.load_state_dict(checkpoint["model_state"], strict=False) frozen_params = freeze_ssm_core_parameters(model) if args.freeze_ssm_core else 0 initial_loss = score_batch(model, batch, dtype) trainable_params = [parameter for parameter in model.parameters() if parameter.requires_grad] optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=0.0) history = [] device_type = "cuda" if device.type == "cuda" else "cpu" autocast_enabled = device_type == "cuda" and dtype in (torch.float16, torch.bfloat16) model.train() for step in range(1, args.steps + 1): optimizer.zero_grad(set_to_none=True) with torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled): outputs = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"], ) loss = outputs["loss"] loss.backward() grad_norm = grad_l2_norm(trainable_params) stats = None if step == 1 or step % args.log_every == 0 or step == args.steps: stats = grad_summary(model.named_parameters()) if not args.no_clip: torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) optimizer.step() if step == 1 or step % args.log_every == 0 or step == args.steps: item = { "step": step, "loss": float(loss.detach().cpu()), "grad_l2_norm": grad_norm, } if stats is not None: item["grad_summary"] = stats history.append(item) final_loss = score_batch(model, batch, dtype) result = { "checkpoint": str(Path(args.checkpoint)), "config": str(Path(args.config)), "dataset": config.dataset.jsonl_path, "samples": len(records), "sequence_length": config.model.max_seq_length, "train_tokens_per_sample": batch["train_tokens"].tolist(), "lr": args.lr, "steps": args.steps, "clip_grad_norm": not args.no_clip, "freeze_ssm_core": args.freeze_ssm_core, "ssm_branch_rms_norm": config.model.ssm_branch_rms_norm, "ssm_branch_clip_value": config.model.ssm_branch_clip_value, "block_residual_rms_norm": config.model.block_residual_rms_norm, "block_residual_rms_target": config.model.block_residual_rms_target, "frozen_params": frozen_params, "trainable_params": sum(parameter.numel() for parameter in trainable_params), "initial_loss": initial_loss, "final_loss": final_loss, "loss_delta": final_loss - initial_loss, "history": history, "device": str(device), "dtype": str(dtype), } output = Path(args.output) output.parent.mkdir(parents=True, exist_ok=True) output.write_text(json.dumps(result, indent=2), encoding="utf-8") print(json.dumps(result, indent=2)) if __name__ == "__main__": main()