File size: 6,013 Bytes
e2bfccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Probe residual activation scale for a saved TaoTrain checkpoint."""

from __future__ import annotations

import argparse
import json
import re
import sys
from pathlib import Path
from typing import Any

import torch

REPO_ROOT = Path(__file__).resolve().parents[2]
SRC_ROOT = REPO_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
    sys.path.insert(0, str(SRC_ROOT))

from taoTrain.checkpointing.checkpoint import CheckpointManager
from taoTrain.config import ModelConfig
from taoTrain.models import get_model


def load_sentencepiece(path: Path):
    import sentencepiece as spm

    processor = spm.SentencePieceProcessor()
    processor.load(str(path))
    return processor


def load_tokens(args: argparse.Namespace) -> tuple[torch.Tensor, int]:
    tokenizer = load_sentencepiece(Path(args.tokenizer_path))
    tokens: list[int] = []
    with Path(args.data_path).open("r", encoding="utf-8", errors="replace") as handle:
        for line in handle:
            if len(tokens) >= args.max_tokens:
                break
            line = line.strip()
            if not line:
                continue
            try:
                record = json.loads(line)
            except json.JSONDecodeError:
                continue
            text = record.get(args.text_field)
            if not isinstance(text, str) or not text:
                continue
            ids = list(tokenizer.encode(text, out_type=int))
            eos_id = tokenizer.eos_id()
            if eos_id >= 0:
                ids.append(eos_id)
            tokens.extend(ids)
    if len(tokens) < args.seq_len + 2:
        raise ValueError(f"Need at least {args.seq_len + 2} tokens, got {len(tokens)}")
    return torch.tensor(tokens[: args.max_tokens], dtype=torch.long), int(tokenizer.vocab_size())


def sample_batch(tokens: torch.Tensor, *, batch_size: int, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
    max_start = tokens.numel() - seq_len - 1
    starts = torch.linspace(0, max_start - 1, steps=batch_size).long()
    rows = [tokens[int(start) : int(start) + seq_len + 1] for start in starts]
    batch = torch.stack(rows, dim=0).to(device=device)
    return batch[:, :-1].contiguous(), batch[:, 1:].contiguous()


def tensor_stats(value: torch.Tensor) -> dict[str, float | int]:
    data = value.detach().float()
    finite = torch.isfinite(data)
    finite_count = int(finite.sum().cpu())
    numel = data.numel()
    if finite_count:
        finite_data = data[finite]
        rms = float(torch.sqrt(torch.mean(finite_data * finite_data)).cpu())
        max_abs = float(finite_data.abs().max().cpu())
    else:
        rms = float("inf")
        max_abs = float("inf")
    return {
        "numel": numel,
        "finite": finite_count,
        "rms": rms,
        "max_abs": max_abs,
    }


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", required=True)
    parser.add_argument("--tokenizer-path", required=True)
    parser.add_argument("--data-path", required=True)
    parser.add_argument("--text-field", default="text")
    parser.add_argument("--output", required=True)
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--seq-len", type=int, default=512)
    parser.add_argument("--max-tokens", type=int, default=200_000)
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16")
    args = parser.parse_args()

    device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu")
    dtype = {
        "float32": torch.float32,
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
    }[args.dtype]

    tokens, _ = load_tokens(args)
    input_ids, labels = sample_batch(tokens, batch_size=args.batch_size, seq_len=args.seq_len, device=device)
    attention_mask = torch.ones_like(input_ids)

    checkpoint_path = Path(args.checkpoint)
    checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device)
    config_dict = checkpoint.get("config", {})
    model_config = ModelConfig(**config_dict.get("model", {}))
    model = get_model(model_config, device=device)
    model.load_state_dict(checkpoint["model_state"], strict=False)
    model.eval()

    layer_stats: dict[str, dict[str, float | int]] = {}
    handles = []
    layer_pattern = re.compile(r"^(?:model\.)?(?:layers|blocks)\.\d+$")

    def make_hook(name: str):
        def hook(_module, _inputs, output):
            value = output[0] if isinstance(output, tuple) else output
            if torch.is_tensor(value):
                layer_stats[name] = tensor_stats(value)

        return hook

    for name, module in model.named_modules():
        if layer_pattern.match(name):
            handles.append(module.register_forward_hook(make_hook(name)))

    device_type = "cuda" if device.type == "cuda" else "cpu"
    autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
    with torch.no_grad(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled):
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    for handle in handles:
        handle.remove()

    result: dict[str, Any] = {
        "checkpoint": str(checkpoint_path),
        "loss": float(outputs["loss"].detach().cpu()),
        "batch_size": args.batch_size,
        "seq_len": args.seq_len,
        "device": str(device),
        "dtype": str(dtype),
        "layers": layer_stats,
    }
    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
    print(json.dumps(result, indent=2))


if __name__ == "__main__":
    main()