frankenstallm / source /eval /tasks /token_nll_task.py
pathcosmos's picture
Upload folder using huggingface_hub (#29)
5b1ff4d
"""
token_nll_task.py — Token-level NLL distribution analysis.
Top-level function for ProcessPoolExecutor (spawn) compatibility:
- eval_token_nll(device, n_tokens=50000) -> dict
"""
from __future__ import annotations
import sys
import time
from pathlib import Path
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
_DEFAULT_CHECKPOINT = str(_PROJECT_ROOT / "checkpoints" / "korean_3b_fp8_run1" / "checkpoint-0057000")
CHECKPOINT = os.environ.get("EVAL_CHECKPOINT", _DEFAULT_CHECKPOINT)
TOKENIZER_PATH = os.environ.get("EVAL_TOKENIZER", str(_PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json"))
DATA_DIR = _PROJECT_ROOT / "data"
SEQ_LEN = 2048
STRIDE = 512
BATCH_SIZE = 32
# ---------------------------------------------------------------------------
# Shared dataset / model utilities
# ---------------------------------------------------------------------------
class SlidingWindowDataset(Dataset):
"""Sliding-window tokenized dataset for evaluation."""
def __init__(self, tokens: np.ndarray, seq_len: int, stride: int) -> None:
self.tokens = tokens
self.seq_len = seq_len
self.stride = stride
self.n_windows = max(0, (len(tokens) - seq_len + stride - 1) // stride)
def __len__(self) -> int:
return self.n_windows
def __getitem__(self, idx: int):
start = idx * self.stride
end = start + self.seq_len
actual_end = min(end, len(self.tokens))
chunk_len = actual_end - start
input_ids = torch.zeros(self.seq_len, dtype=torch.long)
targets = torch.full((self.seq_len,), fill_value=-100, dtype=torch.long)
loss_mask = torch.zeros(self.seq_len, dtype=torch.bool)
if chunk_len > 1:
toks = torch.from_numpy(self.tokens[start:actual_end].astype(np.int64))
input_ids[:chunk_len] = toks
targets[:chunk_len - 1] = toks[1:]
new_start = 0 if idx == 0 else self.stride
if chunk_len > 1:
for pos in range(new_start, chunk_len - 1):
loss_mask[pos] = True
return input_ids, targets, loss_mask
def _load_model(device: str):
"""Load FRANKENSTALLM 3B from checkpoint onto the given device."""
from model.transformer import LLM # type: ignore[import]
model = LLM.from_pretrained(CHECKPOINT)
model = model.to(device=device, dtype=torch.bfloat16)
model.eval()
return model
def _load_tokenizer():
"""Load the Korean SentencePiece tokenizer."""
from tokenizers import Tokenizer # type: ignore[import]
return Tokenizer.from_file(TOKENIZER_PATH)
# ---------------------------------------------------------------------------
# Main task function (must be top-level for pickle / spawn compatibility)
# ---------------------------------------------------------------------------
def eval_token_nll(device: str, n_tokens: int = 50000) -> dict:
"""Analyse the per-token NLL distribution on 3b_val.bin.
Collects the NLL of every valid (unmasked) token and computes summary
statistics and percentile breakdowns, as well as the fraction of
"high-loss" tokens that may indicate out-of-distribution content.
Args:
device: CUDA device string, e.g. "cuda:6".
n_tokens: Number of tokens to process (first n_tokens of 3b_val.bin).
Returns:
Dict with keys:
- n_eval_tokens: number of tokens included in stats
- nll_mean: mean token NLL
- nll_std: standard deviation of token NLL
- nll_median: 50th-percentile NLL
- nll_percentiles: dict mapping percentile label to value
(keys: p5, p25, p75, p95, p99)
- high_loss_fraction_5: fraction of tokens with NLL > 5.0
- high_loss_fraction_10: fraction of tokens with NLL > 10.0
- elapsed_sec: wall-clock time
"""
torch.cuda.set_device(int(device.split(":")[-1]))
print(f"[NLL {device}] Loading model...")
model = _load_model(device)
val_path = DATA_DIR / "3b_val.bin"
if not val_path.exists():
raise FileNotFoundError(f"Validation file not found: {val_path}")
tokens = np.fromfile(str(val_path), dtype=np.uint16)
if len(tokens) == 0:
raise ValueError(f"Validation file is empty (0 tokens): {val_path}")
tokens = tokens[: min(n_tokens, len(tokens))]
print(f"[NLL {device}] Using {len(tokens):,} tokens from 3b_val.bin")
ds = SlidingWindowDataset(tokens, SEQ_LEN, STRIDE)
dl = DataLoader(
ds,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=4,
pin_memory=True,
)
all_nlls: list[np.ndarray] = []
t0 = time.time()
with torch.inference_mode():
for batch_idx, (inp, tgt, mask) in enumerate(dl):
inp = inp.to(device)
tgt = tgt.to(device)
mask = mask.to(device)
logits, _ = model(inp)
# Per-token NLL — shape (batch, seq_len)
per_token_nll = F.cross_entropy(
logits.view(-1, logits.size(-1)),
tgt.view(-1),
reduction="none",
ignore_index=-100,
).view(mask.shape)
# Collect only valid (unmasked) positions
valid_nll = per_token_nll[mask].float().cpu().numpy()
if len(valid_nll) > 0:
all_nlls.append(valid_nll)
if (batch_idx + 1) % 50 == 0:
n_collected = sum(len(a) for a in all_nlls)
elapsed = time.time() - t0
print(
f"[NLL {device}] batch {batch_idx + 1}/{len(dl)}, "
f"tokens collected={n_collected:,}, {elapsed:.0f}s"
)
elapsed = time.time() - t0
if all_nlls:
nll_arr = np.concatenate(all_nlls)
else:
nll_arr = np.array([], dtype=np.float32)
n_eval = len(nll_arr)
if n_eval > 0:
nll_mean = float(np.mean(nll_arr))
nll_std = float(np.std(nll_arr))
nll_median = float(np.median(nll_arr))
percentiles = {
"p5": round(float(np.percentile(nll_arr, 5)), 4),
"p25": round(float(np.percentile(nll_arr, 25)), 4),
"p75": round(float(np.percentile(nll_arr, 75)), 4),
"p95": round(float(np.percentile(nll_arr, 95)), 4),
"p99": round(float(np.percentile(nll_arr, 99)), 4),
}
high_loss_5 = float(np.mean(nll_arr > 5.0))
high_loss_10 = float(np.mean(nll_arr > 10.0))
else:
nll_mean = nll_std = nll_median = 0.0
percentiles = {"p5": 0.0, "p25": 0.0, "p75": 0.0, "p95": 0.0, "p99": 0.0}
high_loss_5 = high_loss_10 = 0.0
result: dict = {
"n_eval_tokens": int(n_eval),
"nll_mean": round(nll_mean, 4),
"nll_std": round(nll_std, 4),
"nll_median": round(nll_median, 4),
"nll_percentiles": {k: round(v, 4) for k, v in percentiles.items()},
"high_loss_fraction_5": round(high_loss_5, 6),
"high_loss_fraction_10": round(high_loss_10, 6),
"elapsed_sec": round(elapsed, 1),
}
print(
f"[NLL {device}] DONE n={n_eval:,}, "
f"mean={nll_mean:.4f}, std={nll_std:.4f}, "
f"median={nll_median:.4f}, "
f"high_loss(>5)={high_loss_5:.2%}, "
f"high_loss(>10)={high_loss_10:.2%}, "
f"{elapsed:.1f}s"
)
return result