TaoNet-mini-T2 / code /TaoTrain /scripts /diagnostics /sft_sanity_check.py
StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
"""Small SFT diagnostics for checkpoint quality and trainability.
This script intentionally bypasses the full trainer so it can answer one narrow
question quickly: can the checkpoint reduce response-only SFT loss on a tiny,
fixed batch?
"""
from __future__ import annotations
import argparse
import json
import math
from pathlib import Path
from typing import Any
import torch
from taoTrain.checkpointing.checkpoint import CheckpointManager
from taoTrain.config import TrainingModeEnum, load_config
from taoTrain.core import create_model
from taoTrain.data.sft_utils import build_sft_sequence_tokens, parse_sft_record
try:
from taoTrain.data.sft_utils import build_response_only_next_token_labels
except ImportError:
def build_response_only_next_token_labels(input_ids: list[int], mask: list[int]) -> list[int]:
labels = [token_id if mask_value else -100 for token_id, mask_value in zip(input_ids, mask)]
return labels[1:] + [-100]
from taoTrain.data.tokenizer import SentencePieceTokenizerWrapper
from taoTrain.utils import set_seed
def load_tokenizer(tokenizer_path: str):
path = Path(tokenizer_path)
if path.suffix == ".model":
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(str(path))
return SentencePieceTokenizerWrapper(sp)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
if getattr(tokenizer, "pad_token", None) is None and getattr(tokenizer, "eos_token", None):
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def read_jsonl_records(path: str, limit: int) -> list[dict[str, Any]]:
records = []
with open(path, "r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
records.append(json.loads(line))
if len(records) >= limit:
break
return records
def build_batch(config, tokenizer, records: list[dict[str, Any]], device: torch.device) -> dict[str, torch.Tensor]:
input_rows = []
attention_rows = []
label_rows = []
train_tokens = []
for record in records:
turns, _ = parse_sft_record(record, config)
if not turns:
continue
input_ids, attention_mask, mask = build_sft_sequence_tokens(
turns=turns,
tokenizer=tokenizer,
user_token=getattr(config, "user_token", "<user>"),
assistant_token=getattr(config, "assistant_token", "<assistant>"),
max_seq_length=config.model.max_seq_length,
)
labels = build_response_only_next_token_labels(input_ids, mask)
input_rows.append(input_ids)
attention_rows.append(attention_mask)
label_rows.append(labels)
train_tokens.append(sum(1 for value in labels if value != -100))
if not input_rows:
raise ValueError("No valid SFT records found for the diagnostic batch")
return {
"input_ids": torch.tensor(input_rows, dtype=torch.long, device=device),
"attention_mask": torch.tensor(attention_rows, dtype=torch.long, device=device),
"labels": torch.tensor(label_rows, dtype=torch.long, device=device),
"train_tokens": torch.tensor(train_tokens, dtype=torch.long),
}
@torch.no_grad()
def score_batch(model, batch: dict[str, torch.Tensor], dtype: torch.dtype) -> float:
model.eval()
device_type = "cuda" if batch["input_ids"].is_cuda else "cpu"
enabled = device_type == "cuda" and dtype in (torch.float16, torch.bfloat16)
with torch.autocast(device_type=device_type, dtype=dtype, enabled=enabled):
outputs = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
)
return float(outputs["loss"].detach().cpu())
def grad_l2_norm(parameters) -> float:
total = 0.0
for parameter in parameters:
if parameter.grad is None:
continue
grad = parameter.grad.detach()
total += float(torch.sum(grad.float() * grad.float()).cpu())
return math.sqrt(total)
def grad_summary(named_parameters, max_items: int = 12) -> dict[str, Any]:
groups: dict[str, dict[str, Any]] = {}
worst = []
nonfinite = []
for name, parameter in named_parameters:
if parameter.grad is None:
continue
grad = parameter.grad.detach().float()
finite = torch.isfinite(grad)
finite_count = int(finite.sum().cpu())
numel = grad.numel()
finite_abs_max = float(grad[finite].abs().max().cpu()) if finite_count else float("inf")
has_nonfinite = finite_count != numel
if has_nonfinite:
nonfinite.append(name)
if ".layers." in name:
parts = name.split(".")
try:
idx = parts.index("layers")
group = "layer_" + parts[idx + 1]
except (ValueError, IndexError):
group = "layers"
else:
group = name.split(".", 1)[0]
entry = groups.setdefault(group, {
"numel": 0,
"finite": 0,
"nonfinite_tensors": 0,
"max_abs_grad": 0.0,
})
entry["numel"] += numel
entry["finite"] += finite_count
entry["nonfinite_tensors"] += int(has_nonfinite)
entry["max_abs_grad"] = max(entry["max_abs_grad"], finite_abs_max)
worst.append((finite_abs_max, name))
worst.sort(reverse=True, key=lambda item: item[0])
return {
"groups": groups,
"worst_tensors": [{"name": name, "max_abs_grad": value} for value, name in worst[:max_items]],
"nonfinite_tensors": nonfinite[:max_items],
"nonfinite_tensor_count": len(nonfinite),
}
def freeze_ssm_core_parameters(model) -> int:
frozen = 0
markers = (
".ssm_lanes.",
".ssm.",
)
for name, parameter in model.named_parameters():
if any(marker in name for marker in markers):
parameter.requires_grad_(False)
frozen += parameter.numel()
return frozen
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True)
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--output", required=True)
parser.add_argument("--samples", type=int, default=2)
parser.add_argument("--steps", type=int, default=80)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--log-every", type=int, default=10)
parser.add_argument("--device", default="cuda")
parser.add_argument("--dtype", choices=["config", "float32", "float16", "bfloat16"], default="config")
parser.add_argument("--no-clip", action="store_true")
parser.add_argument("--freeze-ssm-core", action="store_true")
parser.add_argument("--ssm-branch-rms-norm", action="store_true")
parser.add_argument("--ssm-branch-clip-value", type=float, default=None)
parser.add_argument("--block-residual-rms-norm", action="store_true")
parser.add_argument("--block-residual-rms-target", type=float, default=None)
parser.add_argument("--seed", type=int, default=123)
args = parser.parse_args()
set_seed(args.seed)
config = load_config(args.config, TrainingModeEnum.SFT)
if args.ssm_branch_rms_norm:
config.model.ssm_branch_rms_norm = True
if args.ssm_branch_clip_value is not None:
config.model.ssm_branch_clip_value = args.ssm_branch_clip_value
if args.block_residual_rms_norm:
config.model.block_residual_rms_norm = True
if args.block_residual_rms_target is not None:
config.model.block_residual_rms_target = args.block_residual_rms_target
device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu")
if args.dtype == "float32":
dtype = torch.float32
elif args.dtype == "float16":
dtype = torch.float16
elif args.dtype == "bfloat16":
dtype = torch.bfloat16
else:
dtype = torch.bfloat16 if str(config.dtype) == "DataTypeEnum.BFLOAT16" or str(config.dtype) == "bfloat16" else torch.float32
tokenizer = load_tokenizer(config.dataset.tokenizer_path)
records = read_jsonl_records(config.dataset.jsonl_path, args.samples)
batch = build_batch(config, tokenizer, records, device)
model = create_model(config, device)
checkpoint = CheckpointManager(config.checkpoint_dir).load(args.checkpoint, device=device)
model.load_state_dict(checkpoint["model_state"], strict=False)
frozen_params = freeze_ssm_core_parameters(model) if args.freeze_ssm_core else 0
initial_loss = score_batch(model, batch, dtype)
trainable_params = [parameter for parameter in model.parameters() if parameter.requires_grad]
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=0.0)
history = []
device_type = "cuda" if device.type == "cuda" else "cpu"
autocast_enabled = device_type == "cuda" and dtype in (torch.float16, torch.bfloat16)
model.train()
for step in range(1, args.steps + 1):
optimizer.zero_grad(set_to_none=True)
with torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled):
outputs = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
)
loss = outputs["loss"]
loss.backward()
grad_norm = grad_l2_norm(trainable_params)
stats = None
if step == 1 or step % args.log_every == 0 or step == args.steps:
stats = grad_summary(model.named_parameters())
if not args.no_clip:
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
optimizer.step()
if step == 1 or step % args.log_every == 0 or step == args.steps:
item = {
"step": step,
"loss": float(loss.detach().cpu()),
"grad_l2_norm": grad_norm,
}
if stats is not None:
item["grad_summary"] = stats
history.append(item)
final_loss = score_batch(model, batch, dtype)
result = {
"checkpoint": str(Path(args.checkpoint)),
"config": str(Path(args.config)),
"dataset": config.dataset.jsonl_path,
"samples": len(records),
"sequence_length": config.model.max_seq_length,
"train_tokens_per_sample": batch["train_tokens"].tolist(),
"lr": args.lr,
"steps": args.steps,
"clip_grad_norm": not args.no_clip,
"freeze_ssm_core": args.freeze_ssm_core,
"ssm_branch_rms_norm": config.model.ssm_branch_rms_norm,
"ssm_branch_clip_value": config.model.ssm_branch_clip_value,
"block_residual_rms_norm": config.model.block_residual_rms_norm,
"block_residual_rms_target": config.model.block_residual_rms_target,
"frozen_params": frozen_params,
"trainable_params": sum(parameter.numel() for parameter in trainable_params),
"initial_loss": initial_loss,
"final_loss": final_loss,
"loss_delta": final_loss - initial_loss,
"history": history,
"device": str(device),
"dtype": str(dtype),
}
output = Path(args.output)
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(json.dumps(result, indent=2), encoding="utf-8")
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()