"""Probe residual activation scale for a saved TaoTrain checkpoint.""" from __future__ import annotations import argparse import json import re import sys from pathlib import Path from typing import Any import torch REPO_ROOT = Path(__file__).resolve().parents[2] SRC_ROOT = REPO_ROOT / "src" if str(SRC_ROOT) not in sys.path: sys.path.insert(0, str(SRC_ROOT)) from taoTrain.checkpointing.checkpoint import CheckpointManager from taoTrain.config import ModelConfig from taoTrain.models import get_model def load_sentencepiece(path: Path): import sentencepiece as spm processor = spm.SentencePieceProcessor() processor.load(str(path)) return processor def load_tokens(args: argparse.Namespace) -> tuple[torch.Tensor, int]: tokenizer = load_sentencepiece(Path(args.tokenizer_path)) tokens: list[int] = [] with Path(args.data_path).open("r", encoding="utf-8", errors="replace") as handle: for line in handle: if len(tokens) >= args.max_tokens: break line = line.strip() if not line: continue try: record = json.loads(line) except json.JSONDecodeError: continue text = record.get(args.text_field) if not isinstance(text, str) or not text: continue ids = list(tokenizer.encode(text, out_type=int)) eos_id = tokenizer.eos_id() if eos_id >= 0: ids.append(eos_id) tokens.extend(ids) if len(tokens) < args.seq_len + 2: raise ValueError(f"Need at least {args.seq_len + 2} tokens, got {len(tokens)}") return torch.tensor(tokens[: args.max_tokens], dtype=torch.long), int(tokenizer.vocab_size()) def sample_batch(tokens: torch.Tensor, *, batch_size: int, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: max_start = tokens.numel() - seq_len - 1 starts = torch.linspace(0, max_start - 1, steps=batch_size).long() rows = [tokens[int(start) : int(start) + seq_len + 1] for start in starts] batch = torch.stack(rows, dim=0).to(device=device) return batch[:, :-1].contiguous(), batch[:, 1:].contiguous() def tensor_stats(value: torch.Tensor) -> dict[str, float | int]: data = value.detach().float() finite = torch.isfinite(data) finite_count = int(finite.sum().cpu()) numel = data.numel() if finite_count: finite_data = data[finite] rms = float(torch.sqrt(torch.mean(finite_data * finite_data)).cpu()) max_abs = float(finite_data.abs().max().cpu()) else: rms = float("inf") max_abs = float("inf") return { "numel": numel, "finite": finite_count, "rms": rms, "max_abs": max_abs, } def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", required=True) parser.add_argument("--tokenizer-path", required=True) parser.add_argument("--data-path", required=True) parser.add_argument("--text-field", default="text") parser.add_argument("--output", required=True) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--seq-len", type=int, default=512) parser.add_argument("--max-tokens", type=int, default=200_000) parser.add_argument("--device", default="cuda") parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16") args = parser.parse_args() device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu") dtype = { "float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16, }[args.dtype] tokens, _ = load_tokens(args) input_ids, labels = sample_batch(tokens, batch_size=args.batch_size, seq_len=args.seq_len, device=device) attention_mask = torch.ones_like(input_ids) checkpoint_path = Path(args.checkpoint) checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device) config_dict = checkpoint.get("config", {}) model_config = ModelConfig(**config_dict.get("model", {})) model = get_model(model_config, device=device) model.load_state_dict(checkpoint["model_state"], strict=False) model.eval() layer_stats: dict[str, dict[str, float | int]] = {} handles = [] layer_pattern = re.compile(r"^(?:model\.)?(?:layers|blocks)\.\d+$") def make_hook(name: str): def hook(_module, _inputs, output): value = output[0] if isinstance(output, tuple) else output if torch.is_tensor(value): layer_stats[name] = tensor_stats(value) return hook for name, module in model.named_modules(): if layer_pattern.match(name): handles.append(module.register_forward_hook(make_hook(name))) device_type = "cuda" if device.type == "cuda" else "cpu" autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} with torch.no_grad(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled): outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) for handle in handles: handle.remove() result: dict[str, Any] = { "checkpoint": str(checkpoint_path), "loss": float(outputs["loss"].detach().cpu()), "batch_size": args.batch_size, "seq_len": args.seq_len, "device": str(device), "dtype": str(dtype), "layers": layer_stats, } output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(json.dumps(result, indent=2), encoding="utf-8") print(json.dumps(result, indent=2)) if __name__ == "__main__": main()