| | """ |
| | calibration_task.py — Top-k accuracy and entropy calibration evaluation. |
| | |
| | Top-level function for ProcessPoolExecutor (spawn) compatibility: |
| | - eval_calibration(device, n_tokens=50000) -> dict |
| | """ |
| | from __future__ import annotations |
| |
|
| | import sys |
| | import time |
| | from pathlib import Path |
| |
|
| | import os |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader, Dataset |
| |
|
| | _PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent |
| | if str(_PROJECT_ROOT) not in sys.path: |
| | sys.path.insert(0, str(_PROJECT_ROOT)) |
| |
|
| | _DEFAULT_CHECKPOINT = str(_PROJECT_ROOT / "checkpoints" / "korean_3b_fp8_run1" / "checkpoint-0057000") |
| | CHECKPOINT = os.environ.get("EVAL_CHECKPOINT", _DEFAULT_CHECKPOINT) |
| | TOKENIZER_PATH = os.environ.get("EVAL_TOKENIZER", str(_PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json")) |
| | DATA_DIR = _PROJECT_ROOT / "data" |
| | SEQ_LEN = 2048 |
| | STRIDE = 512 |
| | BATCH_SIZE = 32 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class SlidingWindowDataset(Dataset): |
| | """Sliding-window tokenized dataset for evaluation.""" |
| |
|
| | def __init__(self, tokens: np.ndarray, seq_len: int, stride: int) -> None: |
| | self.tokens = tokens |
| | self.seq_len = seq_len |
| | self.stride = stride |
| | self.n_windows = max(0, (len(tokens) - seq_len + stride - 1) // stride) |
| |
|
| | def __len__(self) -> int: |
| | return self.n_windows |
| |
|
| | def __getitem__(self, idx: int): |
| | start = idx * self.stride |
| | end = start + self.seq_len |
| | actual_end = min(end, len(self.tokens)) |
| | chunk_len = actual_end - start |
| |
|
| | input_ids = torch.zeros(self.seq_len, dtype=torch.long) |
| | targets = torch.full((self.seq_len,), fill_value=-100, dtype=torch.long) |
| | loss_mask = torch.zeros(self.seq_len, dtype=torch.bool) |
| |
|
| | if chunk_len > 1: |
| | toks = torch.from_numpy(self.tokens[start:actual_end].astype(np.int64)) |
| | input_ids[:chunk_len] = toks |
| | targets[:chunk_len - 1] = toks[1:] |
| |
|
| | new_start = 0 if idx == 0 else self.stride |
| | if chunk_len > 1: |
| | for pos in range(new_start, chunk_len - 1): |
| | loss_mask[pos] = True |
| |
|
| | return input_ids, targets, loss_mask |
| |
|
| |
|
| | def _load_model(device: str): |
| | """Load FRANKENSTALLM 3B from checkpoint onto the given device.""" |
| | from model.transformer import LLM |
| |
|
| | model = LLM.from_pretrained(CHECKPOINT) |
| | model = model.to(device=device, dtype=torch.bfloat16) |
| | model.eval() |
| | return model |
| |
|
| |
|
| | def _load_tokenizer(): |
| | """Load the Korean SentencePiece tokenizer.""" |
| | from tokenizers import Tokenizer |
| |
|
| | return Tokenizer.from_file(TOKENIZER_PATH) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def eval_calibration(device: str, n_tokens: int = 50000) -> dict: |
| | """Compute top-k accuracy and entropy calibration on 3b_val.bin. |
| | |
| | Measures how well the model's probability distribution is calibrated: |
| | - Top-1/5/10 next-token prediction accuracy |
| | - Mean probability assigned to the correct next token |
| | - Mean Shannon entropy of the predictive distribution |
| | |
| | Args: |
| | device: CUDA device string, e.g. "cuda:3". |
| | n_tokens: Number of tokens to evaluate (first n_tokens of 3b_val.bin). |
| | |
| | Returns: |
| | Dict with keys: n_eval_tokens, top1_accuracy, top5_accuracy, |
| | top10_accuracy, mean_correct_prob, mean_entropy, elapsed_sec. |
| | """ |
| | torch.cuda.set_device(int(device.split(":")[-1])) |
| | print(f"[CALIB {device}] Loading model...") |
| | model = _load_model(device) |
| |
|
| | val_path = DATA_DIR / "3b_val.bin" |
| | if not val_path.exists(): |
| | raise FileNotFoundError(f"Validation file not found: {val_path}") |
| | tokens = np.fromfile(str(val_path), dtype=np.uint16) |
| | if len(tokens) == 0: |
| | raise ValueError(f"Validation file is empty (0 tokens): {val_path}") |
| | tokens = tokens[: min(n_tokens, len(tokens))] |
| | print(f"[CALIB {device}] Using {len(tokens):,} tokens from 3b_val.bin") |
| |
|
| | ds = SlidingWindowDataset(tokens, SEQ_LEN, STRIDE) |
| | dl = DataLoader( |
| | ds, |
| | batch_size=BATCH_SIZE, |
| | shuffle=False, |
| | num_workers=2, |
| | pin_memory=True, |
| | ) |
| |
|
| | top1_correct = 0 |
| | top5_correct = 0 |
| | top10_correct = 0 |
| | total_entropy = 0.0 |
| | total_prob = 0.0 |
| | total_count = 0 |
| | t0 = time.time() |
| |
|
| | with torch.inference_mode(): |
| | for batch_idx, (inp, tgt, mask) in enumerate(dl): |
| | inp = inp.to(device) |
| | tgt = tgt.to(device) |
| | mask = mask.to(device) |
| |
|
| | logits, _ = model(inp) |
| | probs = F.softmax(logits, dim=-1) |
| |
|
| | valid = mask & (tgt != -100) |
| | if valid.sum() == 0: |
| | continue |
| |
|
| | flat_logits = logits[valid] |
| | flat_tgt = tgt[valid] |
| | flat_probs = probs[valid] |
| |
|
| | |
| | _, top1_pred = flat_logits.topk(1, dim=-1) |
| | _, top5_pred = flat_logits.topk(5, dim=-1) |
| | _, top10_pred = flat_logits.topk(10, dim=-1) |
| |
|
| | top1_correct += (top1_pred.squeeze(-1) == flat_tgt).sum().item() |
| | top5_correct += ( |
| | (top5_pred == flat_tgt.unsqueeze(-1)).any(dim=-1).sum().item() |
| | ) |
| | top10_correct += ( |
| | (top10_pred == flat_tgt.unsqueeze(-1)).any(dim=-1).sum().item() |
| | ) |
| |
|
| | |
| | correct_probs = flat_probs[torch.arange(len(flat_tgt), device=device), flat_tgt] |
| | total_prob += correct_probs.sum().item() |
| |
|
| | |
| | log_probs = torch.log(torch.clamp(flat_probs, min=1e-7)) |
| | entropy = -(flat_probs * log_probs).sum(dim=-1) |
| | total_entropy += entropy.sum().item() |
| |
|
| | total_count += valid.sum().item() |
| |
|
| | if (batch_idx + 1) % 50 == 0: |
| | elapsed = time.time() - t0 |
| | print( |
| | f"[CALIB {device}] batch {batch_idx + 1}/{len(dl)}, " |
| | f"tokens so far={total_count:,}, {elapsed:.0f}s" |
| | ) |
| |
|
| | elapsed = time.time() - t0 |
| | result: dict = { |
| | "n_eval_tokens": int(total_count), |
| | "top1_accuracy": round(top1_correct / total_count, 4) if total_count > 0 else 0.0, |
| | "top5_accuracy": round(top5_correct / total_count, 4) if total_count > 0 else 0.0, |
| | "top10_accuracy": round(top10_correct / total_count, 4) if total_count > 0 else 0.0, |
| | "mean_correct_prob": round(total_prob / total_count, 4) if total_count > 0 else 0.0, |
| | "mean_entropy": round(total_entropy / total_count, 4) if total_count > 0 else 0.0, |
| | "elapsed_sec": round(elapsed, 1), |
| | } |
| | print( |
| | f"[CALIB {device}] DONE top1={result['top1_accuracy']:.4f}, " |
| | f"top5={result['top5_accuracy']:.4f}, " |
| | f"top10={result['top10_accuracy']:.4f}, " |
| | f"entropy={result['mean_entropy']:.4f}, {elapsed:.1f}s" |
| | ) |
| | return result |
| |
|