TaoNet-mini-T2 / code /TaoTrain /scripts /diagnostics /activation_probe.py
StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
"""Probe residual activation scale for a saved TaoTrain checkpoint."""
from __future__ import annotations
import argparse
import json
import re
import sys
from pathlib import Path
from typing import Any
import torch
REPO_ROOT = Path(__file__).resolve().parents[2]
SRC_ROOT = REPO_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
sys.path.insert(0, str(SRC_ROOT))
from taoTrain.checkpointing.checkpoint import CheckpointManager
from taoTrain.config import ModelConfig
from taoTrain.models import get_model
def load_sentencepiece(path: Path):
import sentencepiece as spm
processor = spm.SentencePieceProcessor()
processor.load(str(path))
return processor
def load_tokens(args: argparse.Namespace) -> tuple[torch.Tensor, int]:
tokenizer = load_sentencepiece(Path(args.tokenizer_path))
tokens: list[int] = []
with Path(args.data_path).open("r", encoding="utf-8", errors="replace") as handle:
for line in handle:
if len(tokens) >= args.max_tokens:
break
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
except json.JSONDecodeError:
continue
text = record.get(args.text_field)
if not isinstance(text, str) or not text:
continue
ids = list(tokenizer.encode(text, out_type=int))
eos_id = tokenizer.eos_id()
if eos_id >= 0:
ids.append(eos_id)
tokens.extend(ids)
if len(tokens) < args.seq_len + 2:
raise ValueError(f"Need at least {args.seq_len + 2} tokens, got {len(tokens)}")
return torch.tensor(tokens[: args.max_tokens], dtype=torch.long), int(tokenizer.vocab_size())
def sample_batch(tokens: torch.Tensor, *, batch_size: int, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
max_start = tokens.numel() - seq_len - 1
starts = torch.linspace(0, max_start - 1, steps=batch_size).long()
rows = [tokens[int(start) : int(start) + seq_len + 1] for start in starts]
batch = torch.stack(rows, dim=0).to(device=device)
return batch[:, :-1].contiguous(), batch[:, 1:].contiguous()
def tensor_stats(value: torch.Tensor) -> dict[str, float | int]:
data = value.detach().float()
finite = torch.isfinite(data)
finite_count = int(finite.sum().cpu())
numel = data.numel()
if finite_count:
finite_data = data[finite]
rms = float(torch.sqrt(torch.mean(finite_data * finite_data)).cpu())
max_abs = float(finite_data.abs().max().cpu())
else:
rms = float("inf")
max_abs = float("inf")
return {
"numel": numel,
"finite": finite_count,
"rms": rms,
"max_abs": max_abs,
}
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--tokenizer-path", required=True)
parser.add_argument("--data-path", required=True)
parser.add_argument("--text-field", default="text")
parser.add_argument("--output", required=True)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--max-tokens", type=int, default=200_000)
parser.add_argument("--device", default="cuda")
parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16")
args = parser.parse_args()
device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu")
dtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[args.dtype]
tokens, _ = load_tokens(args)
input_ids, labels = sample_batch(tokens, batch_size=args.batch_size, seq_len=args.seq_len, device=device)
attention_mask = torch.ones_like(input_ids)
checkpoint_path = Path(args.checkpoint)
checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device)
config_dict = checkpoint.get("config", {})
model_config = ModelConfig(**config_dict.get("model", {}))
model = get_model(model_config, device=device)
model.load_state_dict(checkpoint["model_state"], strict=False)
model.eval()
layer_stats: dict[str, dict[str, float | int]] = {}
handles = []
layer_pattern = re.compile(r"^(?:model\.)?(?:layers|blocks)\.\d+$")
def make_hook(name: str):
def hook(_module, _inputs, output):
value = output[0] if isinstance(output, tuple) else output
if torch.is_tensor(value):
layer_stats[name] = tensor_stats(value)
return hook
for name, module in model.named_modules():
if layer_pattern.match(name):
handles.append(module.register_forward_hook(make_hook(name)))
device_type = "cuda" if device.type == "cuda" else "cpu"
autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
with torch.no_grad(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled):
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
for handle in handles:
handle.remove()
result: dict[str, Any] = {
"checkpoint": str(checkpoint_path),
"loss": float(outputs["loss"].detach().cpu()),
"batch_size": args.batch_size,
"seq_len": args.seq_len,
"device": str(device),
"dtype": str(dtype),
"layers": layer_stats,
}
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()