| | """ |
| | token_nll_task.py — Token-level NLL distribution analysis. |
| | |
| | Top-level function for ProcessPoolExecutor (spawn) compatibility: |
| | - eval_token_nll(device, n_tokens=50000) -> dict |
| | """ |
| | from __future__ import annotations |
| |
|
| | import sys |
| | import time |
| | from pathlib import Path |
| |
|
| | import os |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader, Dataset |
| |
|
| | _PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent |
| | if str(_PROJECT_ROOT) not in sys.path: |
| | sys.path.insert(0, str(_PROJECT_ROOT)) |
| |
|
| | _DEFAULT_CHECKPOINT = str(_PROJECT_ROOT / "checkpoints" / "korean_3b_fp8_run1" / "checkpoint-0057000") |
| | CHECKPOINT = os.environ.get("EVAL_CHECKPOINT", _DEFAULT_CHECKPOINT) |
| | TOKENIZER_PATH = os.environ.get("EVAL_TOKENIZER", str(_PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json")) |
| | DATA_DIR = _PROJECT_ROOT / "data" |
| | SEQ_LEN = 2048 |
| | STRIDE = 512 |
| | BATCH_SIZE = 32 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class SlidingWindowDataset(Dataset): |
| | """Sliding-window tokenized dataset for evaluation.""" |
| |
|
| | def __init__(self, tokens: np.ndarray, seq_len: int, stride: int) -> None: |
| | self.tokens = tokens |
| | self.seq_len = seq_len |
| | self.stride = stride |
| | self.n_windows = max(0, (len(tokens) - seq_len + stride - 1) // stride) |
| |
|
| | def __len__(self) -> int: |
| | return self.n_windows |
| |
|
| | def __getitem__(self, idx: int): |
| | start = idx * self.stride |
| | end = start + self.seq_len |
| | actual_end = min(end, len(self.tokens)) |
| | chunk_len = actual_end - start |
| |
|
| | input_ids = torch.zeros(self.seq_len, dtype=torch.long) |
| | targets = torch.full((self.seq_len,), fill_value=-100, dtype=torch.long) |
| | loss_mask = torch.zeros(self.seq_len, dtype=torch.bool) |
| |
|
| | if chunk_len > 1: |
| | toks = torch.from_numpy(self.tokens[start:actual_end].astype(np.int64)) |
| | input_ids[:chunk_len] = toks |
| | targets[:chunk_len - 1] = toks[1:] |
| |
|
| | new_start = 0 if idx == 0 else self.stride |
| | if chunk_len > 1: |
| | for pos in range(new_start, chunk_len - 1): |
| | loss_mask[pos] = True |
| |
|
| | return input_ids, targets, loss_mask |
| |
|
| |
|
| | def _load_model(device: str): |
| | """Load FRANKENSTALLM 3B from checkpoint onto the given device.""" |
| | from model.transformer import LLM |
| |
|
| | model = LLM.from_pretrained(CHECKPOINT) |
| | model = model.to(device=device, dtype=torch.bfloat16) |
| | model.eval() |
| | return model |
| |
|
| |
|
| | def _load_tokenizer(): |
| | """Load the Korean SentencePiece tokenizer.""" |
| | from tokenizers import Tokenizer |
| |
|
| | return Tokenizer.from_file(TOKENIZER_PATH) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def eval_token_nll(device: str, n_tokens: int = 50000) -> dict: |
| | """Analyse the per-token NLL distribution on 3b_val.bin. |
| | |
| | Collects the NLL of every valid (unmasked) token and computes summary |
| | statistics and percentile breakdowns, as well as the fraction of |
| | "high-loss" tokens that may indicate out-of-distribution content. |
| | |
| | Args: |
| | device: CUDA device string, e.g. "cuda:6". |
| | n_tokens: Number of tokens to process (first n_tokens of 3b_val.bin). |
| | |
| | Returns: |
| | Dict with keys: |
| | - n_eval_tokens: number of tokens included in stats |
| | - nll_mean: mean token NLL |
| | - nll_std: standard deviation of token NLL |
| | - nll_median: 50th-percentile NLL |
| | - nll_percentiles: dict mapping percentile label to value |
| | (keys: p5, p25, p75, p95, p99) |
| | - high_loss_fraction_5: fraction of tokens with NLL > 5.0 |
| | - high_loss_fraction_10: fraction of tokens with NLL > 10.0 |
| | - elapsed_sec: wall-clock time |
| | """ |
| | torch.cuda.set_device(int(device.split(":")[-1])) |
| | print(f"[NLL {device}] Loading model...") |
| | model = _load_model(device) |
| |
|
| | val_path = DATA_DIR / "3b_val.bin" |
| | if not val_path.exists(): |
| | raise FileNotFoundError(f"Validation file not found: {val_path}") |
| | tokens = np.fromfile(str(val_path), dtype=np.uint16) |
| | if len(tokens) == 0: |
| | raise ValueError(f"Validation file is empty (0 tokens): {val_path}") |
| | tokens = tokens[: min(n_tokens, len(tokens))] |
| | print(f"[NLL {device}] Using {len(tokens):,} tokens from 3b_val.bin") |
| |
|
| | ds = SlidingWindowDataset(tokens, SEQ_LEN, STRIDE) |
| | dl = DataLoader( |
| | ds, |
| | batch_size=BATCH_SIZE, |
| | shuffle=False, |
| | num_workers=4, |
| | pin_memory=True, |
| | ) |
| |
|
| | all_nlls: list[np.ndarray] = [] |
| | t0 = time.time() |
| |
|
| | with torch.inference_mode(): |
| | for batch_idx, (inp, tgt, mask) in enumerate(dl): |
| | inp = inp.to(device) |
| | tgt = tgt.to(device) |
| | mask = mask.to(device) |
| |
|
| | logits, _ = model(inp) |
| |
|
| | |
| | per_token_nll = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), |
| | tgt.view(-1), |
| | reduction="none", |
| | ignore_index=-100, |
| | ).view(mask.shape) |
| |
|
| | |
| | valid_nll = per_token_nll[mask].float().cpu().numpy() |
| | if len(valid_nll) > 0: |
| | all_nlls.append(valid_nll) |
| |
|
| | if (batch_idx + 1) % 50 == 0: |
| | n_collected = sum(len(a) for a in all_nlls) |
| | elapsed = time.time() - t0 |
| | print( |
| | f"[NLL {device}] batch {batch_idx + 1}/{len(dl)}, " |
| | f"tokens collected={n_collected:,}, {elapsed:.0f}s" |
| | ) |
| |
|
| | elapsed = time.time() - t0 |
| |
|
| | if all_nlls: |
| | nll_arr = np.concatenate(all_nlls) |
| | else: |
| | nll_arr = np.array([], dtype=np.float32) |
| |
|
| | n_eval = len(nll_arr) |
| |
|
| | if n_eval > 0: |
| | nll_mean = float(np.mean(nll_arr)) |
| | nll_std = float(np.std(nll_arr)) |
| | nll_median = float(np.median(nll_arr)) |
| | percentiles = { |
| | "p5": round(float(np.percentile(nll_arr, 5)), 4), |
| | "p25": round(float(np.percentile(nll_arr, 25)), 4), |
| | "p75": round(float(np.percentile(nll_arr, 75)), 4), |
| | "p95": round(float(np.percentile(nll_arr, 95)), 4), |
| | "p99": round(float(np.percentile(nll_arr, 99)), 4), |
| | } |
| | high_loss_5 = float(np.mean(nll_arr > 5.0)) |
| | high_loss_10 = float(np.mean(nll_arr > 10.0)) |
| | else: |
| | nll_mean = nll_std = nll_median = 0.0 |
| | percentiles = {"p5": 0.0, "p25": 0.0, "p75": 0.0, "p95": 0.0, "p99": 0.0} |
| | high_loss_5 = high_loss_10 = 0.0 |
| |
|
| | result: dict = { |
| | "n_eval_tokens": int(n_eval), |
| | "nll_mean": round(nll_mean, 4), |
| | "nll_std": round(nll_std, 4), |
| | "nll_median": round(nll_median, 4), |
| | "nll_percentiles": {k: round(v, 4) for k, v in percentiles.items()}, |
| | "high_loss_fraction_5": round(high_loss_5, 6), |
| | "high_loss_fraction_10": round(high_loss_10, 6), |
| | "elapsed_sec": round(elapsed, 1), |
| | } |
| |
|
| | print( |
| | f"[NLL {device}] DONE n={n_eval:,}, " |
| | f"mean={nll_mean:.4f}, std={nll_std:.4f}, " |
| | f"median={nll_median:.4f}, " |
| | f"high_loss(>5)={high_loss_5:.2%}, " |
| | f"high_loss(>10)={high_loss_10:.2%}, " |
| | f"{elapsed:.1f}s" |
| | ) |
| | return result |
| |
|