| | """ |
| | Comprehensive evaluation script for a trained 1B Korean language model. |
| | |
| | Covers: |
| | 1. Multi-source sliding-window perplexity (4 val sets) |
| | 2. Token-level NLL distribution + top-50 highest/lowest-loss tokens |
| | 3. Multi-prompt generation quality (10 diverse prompts) |
| | 4. Repetition analysis (unigram..4-gram repetition ratio) |
| | 5. Greedy vs. sampling comparison (3 prompts × 4 temperature settings) |
| | 6. Calibration check (accuracy@1/5/10, mean prob, mean entropy) |
| | |
| | Usage: |
| | python eval/comprehensive_eval.py \ |
| | --checkpoint checkpoints/korean_1b_fp8_run1/checkpoint-0034000 \ |
| | --device cuda:0 |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import math |
| | import sys |
| | import time |
| | from collections import Counter, defaultdict |
| | from pathlib import Path |
| | from typing import Dict, List, Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader, Dataset |
| |
|
| | |
| | |
| | |
| | _THIS_FILE = Path(__file__).resolve() |
| | _PROJECT_ROOT = _THIS_FILE.parent.parent |
| | if str(_PROJECT_ROOT) not in sys.path: |
| | sys.path.insert(0, str(_PROJECT_ROOT)) |
| |
|
| | from model.transformer import LLM |
| | from tokenizers import Tokenizer |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser( |
| | description="Comprehensive evaluation for a trained Korean LLM." |
| | ) |
| | parser.add_argument( |
| | "--checkpoint", |
| | default="checkpoints/korean_1b_fp8_run1/checkpoint-0034000", |
| | help="Path to the checkpoint directory (default: korean_1b_fp8_run1/checkpoint-0034000).", |
| | ) |
| | parser.add_argument( |
| | "--device", |
| | default="cuda:0", |
| | help="Torch device string (default: cuda:0).", |
| | ) |
| | parser.add_argument( |
| | "--tokenizer", |
| | default=None, |
| | help="Path to tokenizer.json. Defaults to <checkpoint>/tokenizer.json, " |
| | "then tokenizer/korean_sp/tokenizer.json.", |
| | ) |
| | parser.add_argument( |
| | "--data_dir", |
| | default=None, |
| | help="Directory containing val .bin files. Defaults to <project>/data/.", |
| | ) |
| | parser.add_argument( |
| | "--seq_len", |
| | type=int, |
| | default=2048, |
| | help="Sliding-window sequence length for PPL (default: 2048).", |
| | ) |
| | parser.add_argument( |
| | "--stride", |
| | type=int, |
| | default=512, |
| | help="Stride for sliding-window PPL (default: 512).", |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | type=int, |
| | default=4, |
| | help="Batch size for PPL evaluation (default: 4).", |
| | ) |
| | parser.add_argument( |
| | "--max_new_tokens", |
| | type=int, |
| | default=200, |
| | help="Max new tokens for generation (default: 200).", |
| | ) |
| | parser.add_argument( |
| | "--calib_tokens", |
| | type=int, |
| | default=10000, |
| | help="Number of tokens used for calibration check (default: 10000).", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_model(checkpoint_dir: str, device: str) -> LLM: |
| | """Load LLM from checkpoint directory in BF16.""" |
| | ckpt_path = Path(checkpoint_dir) |
| | if not ckpt_path.exists(): |
| | raise FileNotFoundError(f"Checkpoint directory not found: {ckpt_path}") |
| |
|
| | print(f" Loading model weights from: {ckpt_path}") |
| | model = LLM.from_pretrained(str(ckpt_path)) |
| | model = model.to(device=device, dtype=torch.bfloat16) |
| | model.eval() |
| | num_params = sum(p.numel() for p in model.parameters()) |
| | print(f" Model parameters: {num_params / 1e6:.1f}M | dtype: {next(model.parameters()).dtype}") |
| | return model |
| |
|
| |
|
| | def load_tokenizer(checkpoint_dir: str, tokenizer_override: Optional[str]) -> Tokenizer: |
| | """Resolve and load tokenizer.""" |
| | ckpt_path = Path(checkpoint_dir) |
| | candidates = [] |
| | if tokenizer_override: |
| | candidates.append(Path(tokenizer_override)) |
| | candidates += [ |
| | ckpt_path / "tokenizer.json", |
| | _PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json", |
| | ] |
| | for p in candidates: |
| | if p.exists(): |
| | print(f" Loading tokenizer from: {p}") |
| | return Tokenizer.from_file(str(p)) |
| | raise FileNotFoundError( |
| | f"tokenizer.json not found. Tried: {[str(c) for c in candidates]}" |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class SlidingWindowDataset(Dataset): |
| | """Sliding-window dataset yielding (input_ids, targets, loss_mask).""" |
| |
|
| | def __init__(self, tokens: np.ndarray, seq_len: int, stride: int) -> None: |
| | self.tokens = tokens |
| | self.seq_len = seq_len |
| | self.stride = stride |
| | self.n_windows = max(0, (len(tokens) - seq_len + stride - 1) // stride) |
| |
|
| | def __len__(self) -> int: |
| | return self.n_windows |
| |
|
| | def __getitem__(self, idx: int): |
| | start = idx * self.stride |
| | end = start + self.seq_len |
| | actual_end = min(end, len(self.tokens)) |
| | chunk_len = actual_end - start |
| |
|
| | input_ids = torch.zeros(self.seq_len, dtype=torch.long) |
| | targets = torch.full((self.seq_len,), fill_value=-100, dtype=torch.long) |
| | loss_mask = torch.zeros(self.seq_len, dtype=torch.bool) |
| |
|
| | if chunk_len > 1: |
| | toks = torch.from_numpy(self.tokens[start:actual_end].astype(np.int64)) |
| | input_ids[:chunk_len] = toks |
| | targets[:chunk_len - 1] = toks[1:] |
| |
|
| | new_start = 0 if idx == 0 else self.stride |
| | if chunk_len > 1: |
| | for pos in range(new_start, chunk_len - 1): |
| | loss_mask[pos] = True |
| |
|
| | return input_ids, targets, loss_mask |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def top_p_filtering( |
| | logits: torch.Tensor, |
| | top_p: float = 0.9, |
| | top_k: int = 0, |
| | filter_value: float = float("-inf"), |
| | ) -> torch.Tensor: |
| | """Apply top-k and top-p (nucleus) filtering to logits.""" |
| | if logits.dim() == 1: |
| | logits = logits.unsqueeze(0) |
| | squeeze_output = True |
| | else: |
| | squeeze_output = False |
| |
|
| | if top_k > 0: |
| | k = min(top_k, logits.size(-1)) |
| | kth_values = torch.topk(logits, k, dim=-1).values[:, -1, None] |
| | logits = logits.masked_fill(logits < kth_values, filter_value) |
| |
|
| | if 0.0 < top_p < 1.0: |
| | sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) |
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | sorted_indices_to_remove = ( |
| | cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p |
| | ) |
| | sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, filter_value) |
| | logits = torch.zeros_like(logits).scatter_(-1, sorted_indices, sorted_logits) |
| |
|
| | if squeeze_output: |
| | logits = logits.squeeze(0) |
| | return logits |
| |
|
| |
|
| | @torch.inference_mode() |
| | def generate_text( |
| | model: LLM, |
| | tokenizer: Tokenizer, |
| | prompt: str, |
| | max_new_tokens: int = 200, |
| | temperature: float = 0.8, |
| | top_p: float = 0.9, |
| | top_k: int = 50, |
| | device: str = "cuda:0", |
| | ) -> str: |
| | """Generate text and return the full string (prompt + generated).""" |
| | model.eval() |
| | input_ids = torch.tensor( |
| | [tokenizer.encode(prompt).ids], dtype=torch.long, device=device |
| | ) |
| | eos_token_id: Optional[int] = tokenizer.token_to_id("</s>") |
| | generated_ids = input_ids |
| |
|
| | for _ in range(max_new_tokens): |
| | logits_all, _ = model(generated_ids) |
| | logits: torch.Tensor = logits_all[:, -1, :] |
| |
|
| | if temperature == 0.0: |
| | |
| | next_token_id = logits.argmax(dim=-1, keepdim=True) |
| | else: |
| | logits = logits / max(temperature, 1e-8) |
| | logits = top_p_filtering(logits, top_p=top_p, top_k=top_k) |
| | probs = F.softmax(logits, dim=-1) |
| | next_token_id = torch.multinomial(probs, num_samples=1) |
| |
|
| | generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
| |
|
| | if eos_token_id is not None and next_token_id.item() == eos_token_id: |
| | break |
| |
|
| | |
| | all_ids = generated_ids[0].tolist() |
| | new_ids = all_ids[len(tokenizer.encode(prompt).ids):] |
| | generated = tokenizer.decode(new_ids) |
| | return generated |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @torch.inference_mode() |
| | def eval_perplexity_on_file( |
| | model: LLM, |
| | data_path: Path, |
| | seq_len: int, |
| | stride: int, |
| | batch_size: int, |
| | device: str, |
| | ) -> Tuple[float, float, int]: |
| | """ |
| | Sliding-window PPL on one .bin file. |
| | |
| | Returns: |
| | (perplexity, bits_per_token, n_tokens_evaluated) |
| | """ |
| | if not data_path.exists(): |
| | raise FileNotFoundError(f"Data file not found: {data_path}") |
| |
|
| | tokens = np.memmap(str(data_path), dtype="uint16", mode="r") |
| | n_total = len(tokens) |
| | |
| | MAX_EVAL_TOKENS = 2_000_000 |
| | if n_total > MAX_EVAL_TOKENS: |
| | tokens = tokens[:MAX_EVAL_TOKENS] |
| | print(f" {data_path.name}: {n_total:,} tokens (using {len(tokens):,})") |
| |
|
| | dataset = SlidingWindowDataset(tokens, seq_len=seq_len, stride=stride) |
| | if len(dataset) == 0: |
| | raise ValueError(f"No windows fit: {n_total} tokens, seq_len={seq_len}") |
| |
|
| | loader = DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | num_workers=0, |
| | pin_memory=True, |
| | ) |
| |
|
| | total_nll = 0.0 |
| | total_count = 0 |
| |
|
| | for batch_input_ids, batch_targets, batch_loss_mask in loader: |
| | batch_input_ids = batch_input_ids.to(device) |
| | batch_targets = batch_targets.to(device) |
| | batch_loss_mask = batch_loss_mask.to(device) |
| |
|
| | logits, _ = model(batch_input_ids) |
| | B, S, V = logits.shape |
| |
|
| | ce = F.cross_entropy( |
| | logits.reshape(B * S, V), |
| | batch_targets.reshape(B * S), |
| | ignore_index=-100, |
| | reduction="none", |
| | ).reshape(B, S) |
| |
|
| | masked_ce = ce * batch_loss_mask.float() |
| | total_nll += masked_ce.sum().item() |
| | total_count += batch_loss_mask.sum().item() |
| |
|
| | if total_count == 0: |
| | raise RuntimeError("No valid positions evaluated.") |
| |
|
| | avg_nll = total_nll / total_count |
| | ppl = math.exp(avg_nll) |
| | bpt = avg_nll / math.log(2) |
| | return ppl, bpt, total_count |
| |
|
| |
|
| | def section_perplexity( |
| | model: LLM, |
| | data_dir: Path, |
| | seq_len: int, |
| | stride: int, |
| | batch_size: int, |
| | device: str, |
| | ) -> Dict[str, Tuple[float, float, int]]: |
| | """Run PPL on all 4 val sets. Returns {name: (ppl, bpt, n_tokens)}.""" |
| | print_header("1. MULTI-SOURCE PERPLEXITY") |
| | val_files = [ |
| | "3b_val.bin", |
| | "korean_wiki_val.bin", |
| | "korean_c4_val.bin", |
| | "korean_namuwiki_val.bin", |
| | ] |
| | results: Dict[str, Tuple[float, float, int]] = {} |
| | for fname in val_files: |
| | path = data_dir / fname |
| | name = fname.replace(".bin", "") |
| | print(f" Evaluating {fname} ...") |
| | try: |
| | ppl, bpt, n_tok = eval_perplexity_on_file( |
| | model, path, seq_len, stride, batch_size, device |
| | ) |
| | results[name] = (ppl, bpt, n_tok) |
| | print(f" PPL = {ppl:.4f} | bits/token = {bpt:.4f} | tokens = {n_tok:,}") |
| | except Exception as exc: |
| | print(f" [SKIPPED] {exc}") |
| | results[name] = (float("nan"), float("nan"), 0) |
| |
|
| | print() |
| | print(f" {'Dataset':<30} {'PPL':>10} {'bits/tok':>10} {'tokens':>12}") |
| | print(f" {'-'*30} {'-'*10} {'-'*10} {'-'*12}") |
| | for name, (ppl, bpt, n_tok) in results.items(): |
| | ppl_s = f"{ppl:.4f}" if math.isfinite(ppl) else "N/A" |
| | bpt_s = f"{bpt:.4f}" if math.isfinite(bpt) else "N/A" |
| | n_s = f"{n_tok:,}" if n_tok else "N/A" |
| | print(f" {name:<30} {ppl_s:>10} {bpt_s:>10} {n_s:>12}") |
| | return results |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @torch.inference_mode() |
| | def section_token_analysis( |
| | model: LLM, |
| | tokenizer: Tokenizer, |
| | data_dir: Path, |
| | seq_len: int, |
| | batch_size: int, |
| | device: str, |
| | max_batches: int = 50, |
| | ) -> None: |
| | """Compute per-token NLL distribution and identify hardest/easiest tokens.""" |
| | print_header("2. TOKEN-LEVEL NLL ANALYSIS") |
| |
|
| | val_path = data_dir / "3b_val.bin" |
| | if not val_path.exists(): |
| | print(" [SKIPPED] 3b_val.bin not found.") |
| | return |
| |
|
| | tokens = np.memmap(str(val_path), dtype="uint16", mode="r") |
| | dataset = SlidingWindowDataset(tokens, seq_len=seq_len, stride=seq_len) |
| | loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) |
| |
|
| | |
| | vocab_size = model.config.vocab_size |
| | token_nll_sum = torch.zeros(vocab_size, dtype=torch.float64) |
| | token_nll_count = torch.zeros(vocab_size, dtype=torch.long) |
| |
|
| | |
| | all_nll_values: List[float] = [] |
| |
|
| | n_batches = 0 |
| | for batch_input_ids, batch_targets, batch_loss_mask in loader: |
| | if n_batches >= max_batches: |
| | break |
| |
|
| | batch_input_ids = batch_input_ids.to(device) |
| | batch_targets_dev = batch_targets.to(device) |
| | batch_loss_mask_dev = batch_loss_mask.to(device) |
| |
|
| | logits, _ = model(batch_input_ids) |
| | B, S, V = logits.shape |
| |
|
| | |
| | nll = F.cross_entropy( |
| | logits.reshape(B * S, V), |
| | batch_targets_dev.reshape(B * S), |
| | ignore_index=-100, |
| | reduction="none", |
| | ).reshape(B, S) |
| |
|
| | |
| | mask = batch_loss_mask_dev & (batch_targets_dev != -100) |
| | valid_nll = nll[mask].float() |
| | valid_tok = batch_targets_dev[mask].long() |
| |
|
| | |
| | all_nll_values.extend(valid_nll.cpu().tolist()) |
| |
|
| | |
| | for tok_id, nll_val in zip(valid_tok.tolist(), valid_nll.cpu().tolist()): |
| | if 0 <= tok_id < vocab_size: |
| | token_nll_sum[tok_id] += nll_val |
| | token_nll_count[tok_id] += 1 |
| |
|
| | n_batches += 1 |
| |
|
| | if not all_nll_values: |
| | print(" [SKIPPED] No valid NLL values collected.") |
| | return |
| |
|
| | all_nll = torch.tensor(all_nll_values, dtype=torch.float32) |
| |
|
| | |
| | bins = [0, 1, 2, 3, 5, 10, float("inf")] |
| | labels = ["<1", "1-2", "2-3", "3-5", "5-10", ">10"] |
| | total = len(all_nll) |
| | print(f" Total token positions analysed: {total:,}") |
| | print() |
| | print(f" {'NLL range':<10} {'count':>10} {'percentage':>12}") |
| | print(f" {'-'*10} {'-'*10} {'-'*12}") |
| | for i, label in enumerate(labels): |
| | lo = bins[i] |
| | hi = bins[i + 1] |
| | if hi == float("inf"): |
| | cnt = int((all_nll >= lo).sum().item()) |
| | else: |
| | cnt = int(((all_nll >= lo) & (all_nll < hi)).sum().item()) |
| | pct = 100.0 * cnt / total if total > 0 else 0.0 |
| | print(f" {label:<10} {cnt:>10,} {pct:>11.2f}%") |
| |
|
| | print() |
| | print(f" Mean NLL: {all_nll.mean().item():.4f} Std: {all_nll.std().item():.4f}") |
| | print(f" Median NLL: {all_nll.median().item():.4f}") |
| |
|
| | |
| | has_data = token_nll_count > 0 |
| | avg_nll_per_token = torch.where( |
| | has_data, |
| | token_nll_sum / token_nll_count.clamp(min=1).float(), |
| | torch.full_like(token_nll_sum, float("nan")), |
| | ) |
| |
|
| | |
| | valid_mask = ~torch.isnan(avg_nll_per_token) |
| | valid_ids = valid_mask.nonzero(as_tuple=True)[0] |
| | valid_avgs = avg_nll_per_token[valid_ids] |
| |
|
| | if len(valid_ids) == 0: |
| | print(" [WARNING] No per-token averages computed.") |
| | return |
| |
|
| | |
| | sorted_idx = valid_avgs.argsort(descending=True) |
| | top50_hard = valid_ids[sorted_idx[:50]] |
| | top50_easy = valid_ids[sorted_idx[-50:].flip(0)] |
| |
|
| | def decode_token(tid: int) -> str: |
| | try: |
| | return repr(tokenizer.decode([tid])) |
| | except Exception: |
| | return f"<id={tid}>" |
| |
|
| | print() |
| | print(" Top-50 HIGHEST-loss tokens (model struggles with):") |
| | print(f" {'rank':<5} {'token_id':<10} {'avg_nll':>8} {'count':>8} {'decoded'}") |
| | print(f" {'-'*5} {'-'*10} {'-'*8} {'-'*8} {'-'*30}") |
| | for rank, tid in enumerate(top50_hard[:50].tolist(), start=1): |
| | avg = avg_nll_per_token[tid].item() |
| | cnt = token_nll_count[tid].item() |
| | text = decode_token(tid) |
| | print(f" {rank:<5} {tid:<10} {avg:>8.3f} {cnt:>8,} {text}") |
| |
|
| | print() |
| | print(" Top-50 LOWEST-loss tokens (model handles well):") |
| | print(f" {'rank':<5} {'token_id':<10} {'avg_nll':>8} {'count':>8} {'decoded'}") |
| | print(f" {'-'*5} {'-'*10} {'-'*8} {'-'*8} {'-'*30}") |
| | for rank, tid in enumerate(top50_easy[:50].tolist(), start=1): |
| | avg = avg_nll_per_token[tid].item() |
| | cnt = token_nll_count[tid].item() |
| | text = decode_token(tid) |
| | print(f" {rank:<5} {tid:<10} {avg:>8.3f} {cnt:>8,} {text}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | GENERATION_PROMPTS = [ |
| | "한국의 수도는", |
| | "인공지능이란", |
| | "오늘 날씨가 좋아서", |
| | "대한민국의 역사에서 가장 중요한 사건은", |
| | "서울에서 부산까지 가는 방법은", |
| | "다음은 파이썬 코드입니다:\ndef hello():", |
| | "1 + 1 = 2이고, 2 + 2 =", |
| | "봄이 오면 꽃이 피고", |
| | "맛있는 김치찌개를 만들려면", |
| | "세종대왕은", |
| | ] |
| |
|
| |
|
| | def compute_ngram_repetition(text: str, n: int) -> float: |
| | """Compute n-gram repetition ratio = 1 - unique_ngrams / total_ngrams. |
| | |
| | Returns a value in [0, 1] where 0 = no repetition, 1 = all repeated. |
| | """ |
| | tokens = text.split() |
| | if len(tokens) < n: |
| | return 0.0 |
| | ngrams = [tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)] |
| | if not ngrams: |
| | return 0.0 |
| | total = len(ngrams) |
| | unique = len(set(ngrams)) |
| | return 1.0 - unique / total |
| |
|
| |
|
| | def section_generation( |
| | model: LLM, |
| | tokenizer: Tokenizer, |
| | max_new_tokens: int, |
| | device: str, |
| | ) -> Dict[str, str]: |
| | """Generate text for each prompt and return {prompt: generated}.""" |
| | print_header("3. MULTI-PROMPT GENERATION") |
| | generated: Dict[str, str] = {} |
| |
|
| | for i, prompt in enumerate(GENERATION_PROMPTS, start=1): |
| | print(f"\n [{i:02d}/{len(GENERATION_PROMPTS)}] Prompt: {prompt!r}") |
| | print(" " + "-" * 70) |
| | try: |
| | t0 = time.time() |
| | text = generate_text( |
| | model, tokenizer, prompt, |
| | max_new_tokens=max_new_tokens, |
| | temperature=0.8, |
| | top_p=0.9, |
| | top_k=50, |
| | device=device, |
| | ) |
| | elapsed = time.time() - t0 |
| | generated[prompt] = text |
| | |
| | full_output = prompt + text |
| | print(f" {full_output}") |
| | print(f"\n [generated {len(text.split()):,} words in {elapsed:.1f}s]") |
| | except Exception as exc: |
| | print(f" [FAILED] {exc}") |
| | generated[prompt] = "" |
| |
|
| | return generated |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | REPETITION_THRESHOLD = 0.30 |
| |
|
| |
|
| | def section_repetition(generated: Dict[str, str]) -> Dict[str, Dict[str, float]]: |
| | """Analyse n-gram repetition for each generated text.""" |
| | print_header("4. REPETITION ANALYSIS") |
| |
|
| | ns = [1, 2, 3, 4] |
| | header = f" {'Prompt (truncated)':<35}" |
| | for n in ns: |
| | header += f" {'%rep-{n}gram':>12}" |
| | header += f" {'FLAG':>6}" |
| | print(header) |
| | print(" " + "-" * (35 + 12 * len(ns) + 10)) |
| |
|
| | results: Dict[str, Dict[str, float]] = {} |
| | for prompt, text in generated.items(): |
| | if not text.strip(): |
| | continue |
| | row_results: Dict[str, float] = {} |
| | for n in ns: |
| | ratio = compute_ngram_repetition(text, n) |
| | row_results[f"{n}gram"] = ratio |
| | results[prompt] = row_results |
| |
|
| | prompt_short = (prompt[:32] + "..") if len(prompt) > 34 else prompt |
| | row = f" {prompt_short:<35}" |
| | for n in ns: |
| | pct = row_results[f"{n}gram"] * 100 |
| | row += f" {pct:>11.1f}%" |
| | flag = "[DEGENERATE]" if row_results.get("3gram", 0.0) > REPETITION_THRESHOLD else "" |
| | row += f" {flag}" |
| | print(row) |
| |
|
| | |
| | degenerate = [ |
| | p for p, r in results.items() |
| | if r.get("3gram", 0.0) > REPETITION_THRESHOLD |
| | ] |
| | print() |
| | if degenerate: |
| | print(f" WARNING: {len(degenerate)} generation(s) exceed {REPETITION_THRESHOLD*100:.0f}% trigram repetition:") |
| | for p in degenerate: |
| | print(f" - {p!r}") |
| | else: |
| | print(f" All generations are below the {REPETITION_THRESHOLD*100:.0f}% trigram repetition threshold.") |
| |
|
| | return results |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | COMPARISON_PROMPTS = [ |
| | "한국의 수도는", |
| | "인공지능이란", |
| | "봄이 오면 꽃이 피고", |
| | ] |
| |
|
| | TEMPERATURE_CONFIGS = [ |
| | ("Greedy (T=0.0)", 0.0, 1, 0.0), |
| | ("Low (T=0.3)", 0.3, 50, 0.9), |
| | ("Normal (T=0.8)", 0.8, 50, 0.9), |
| | ("High (T=1.2)", 1.2, 50, 0.9), |
| | ] |
| |
|
| |
|
| | def section_comparison( |
| | model: LLM, |
| | tokenizer: Tokenizer, |
| | max_new_tokens: int, |
| | device: str, |
| | ) -> None: |
| | """Generate each comparison prompt at 4 temperature settings.""" |
| | print_header("5. GREEDY vs. SAMPLING COMPARISON") |
| |
|
| | for prompt in COMPARISON_PROMPTS: |
| | print(f"\n Prompt: {prompt!r}") |
| | print(" " + "=" * 74) |
| | for label, temp, top_k, top_p in TEMPERATURE_CONFIGS: |
| | try: |
| | text = generate_text( |
| | model, tokenizer, prompt, |
| | max_new_tokens=min(max_new_tokens, 100), |
| | temperature=temp, |
| | top_p=top_p, |
| | top_k=top_k, |
| | device=device, |
| | ) |
| | print(f"\n [{label}]") |
| | print(f" {prompt + text}") |
| | except Exception as exc: |
| | print(f"\n [{label}] FAILED: {exc}") |
| | print() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @torch.inference_mode() |
| | def section_calibration( |
| | model: LLM, |
| | data_dir: Path, |
| | device: str, |
| | calib_tokens: int = 10000, |
| | seq_len: int = 512, |
| | ) -> Dict[str, float]: |
| | """ |
| | Calibration check on first `calib_tokens` tokens of korean_val.bin. |
| | |
| | Computes: |
| | - mean predicted probability of correct token |
| | - mean entropy of predicted distributions |
| | - accuracy@1, @5, @10 |
| | """ |
| | print_header("6. CALIBRATION CHECK") |
| |
|
| | val_path = data_dir / "3b_val.bin" |
| | if not val_path.exists(): |
| | print(" [SKIPPED] 3b_val.bin not found.") |
| | return {} |
| |
|
| | tokens_all = np.memmap(str(val_path), dtype="uint16", mode="r") |
| | n_use = min(calib_tokens + seq_len, len(tokens_all)) |
| | tokens = tokens_all[:n_use] |
| | print(f" Using first {n_use:,} tokens for calibration.") |
| |
|
| | |
| | mean_correct_prob = 0.0 |
| | mean_entropy = 0.0 |
| | acc1 = acc5 = acc10 = 0 |
| | n_positions = 0 |
| |
|
| | n_chunks = (n_use - 1) // seq_len |
| | if n_chunks == 0: |
| | print(" [SKIPPED] Not enough tokens for calibration.") |
| | return {} |
| |
|
| | for chunk_idx in range(n_chunks): |
| | start = chunk_idx * seq_len |
| | end = start + seq_len + 1 |
| | if end > len(tokens): |
| | break |
| |
|
| | chunk = torch.from_numpy(tokens[start:end].astype(np.int64)) |
| | input_ids = chunk[:-1].unsqueeze(0).to(device) |
| | target = chunk[1:].to(device) |
| |
|
| | logits, _ = model(input_ids) |
| | logits_2d = logits[0] |
| |
|
| | |
| | probs = F.softmax(logits_2d.float(), dim=-1) |
| |
|
| | |
| | correct_probs = probs[torch.arange(seq_len, device=device), target] |
| | mean_correct_prob += correct_probs.sum().item() |
| |
|
| | |
| | log_probs = torch.log(probs.clamp(min=1e-10)) |
| | entropy = -(probs * log_probs).sum(dim=-1) |
| | mean_entropy += entropy.sum().item() |
| |
|
| | |
| | top10 = logits_2d.topk(10, dim=-1).indices |
| | target_col = target.unsqueeze(1) |
| | in_top10 = (top10 == target_col) |
| | acc1 += in_top10[:, :1].any(dim=1).sum().item() |
| | acc5 += in_top10[:, :5].any(dim=1).sum().item() |
| | acc10 += in_top10[:, :10].any(dim=1).sum().item() |
| | n_positions += seq_len |
| |
|
| | if n_positions == 0: |
| | print(" [SKIPPED] No positions evaluated.") |
| | return {} |
| |
|
| | metrics = { |
| | "mean_correct_prob": mean_correct_prob / n_positions, |
| | "mean_entropy_nats": mean_entropy / n_positions, |
| | "accuracy_at_1": acc1 / n_positions, |
| | "accuracy_at_5": acc5 / n_positions, |
| | "accuracy_at_10": acc10 / n_positions, |
| | } |
| |
|
| | print(f" Positions evaluated: {n_positions:,}") |
| | print(f" Mean correct-token prob: {metrics['mean_correct_prob']:.4f}") |
| | print(f" Mean predicted entropy: {metrics['mean_entropy_nats']:.4f} nats") |
| | print(f" Accuracy @1: {metrics['accuracy_at_1']*100:.2f}%") |
| | print(f" Accuracy @5: {metrics['accuracy_at_5']*100:.2f}%") |
| | print(f" Accuracy @10: {metrics['accuracy_at_10']*100:.2f}%") |
| | return metrics |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def print_summary( |
| | ppl_results: Dict[str, Tuple[float, float, int]], |
| | rep_results: Dict[str, Dict[str, float]], |
| | calib_results: Dict[str, float], |
| | ) -> None: |
| | print_header("SUMMARY TABLE") |
| |
|
| | |
| | print(" [Perplexity]") |
| | print(f" {'Dataset':<30} {'PPL':>10} {'bits/tok':>10}") |
| | print(f" {'-'*30} {'-'*10} {'-'*10}") |
| | for name, (ppl, bpt, _) in ppl_results.items(): |
| | ppl_s = f"{ppl:.4f}" if math.isfinite(ppl) else "N/A" |
| | bpt_s = f"{bpt:.4f}" if math.isfinite(bpt) else "N/A" |
| | print(f" {name:<30} {ppl_s:>10} {bpt_s:>10}") |
| |
|
| | |
| | if rep_results: |
| | mean_tri = np.mean([r.get("3gram", 0.0) for r in rep_results.values()]) |
| | degenerate_count = sum( |
| | 1 for r in rep_results.values() if r.get("3gram", 0.0) > REPETITION_THRESHOLD |
| | ) |
| | print() |
| | print(" [Repetition (avg over all prompts)]") |
| | for n in [1, 2, 3, 4]: |
| | vals = [r.get(f"{n}gram", 0.0) for r in rep_results.values()] |
| | if vals: |
| | print(f" {n}-gram avg rep ratio: {np.mean(vals)*100:.1f}%") |
| | print(f" Degenerate outputs (>30% trigram): {degenerate_count}/{len(rep_results)}") |
| |
|
| | |
| | if calib_results: |
| | print() |
| | print(" [Calibration]") |
| | for key, val in calib_results.items(): |
| | if "accuracy" in key: |
| | print(f" {key:<30} {val*100:.2f}%") |
| | else: |
| | print(f" {key:<30} {val:.4f}") |
| |
|
| | print() |
| | print(" " + "=" * 60) |
| | print(" Evaluation complete.") |
| | print(" " + "=" * 60) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def print_header(title: str) -> None: |
| | bar = "=" * 72 |
| | print() |
| | print(bar) |
| | print(f" {title}") |
| | print(bar) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main() -> None: |
| | args = parse_args() |
| |
|
| | |
| | ckpt_path = Path(args.checkpoint) |
| | if not ckpt_path.is_absolute(): |
| | ckpt_path = _PROJECT_ROOT / ckpt_path |
| |
|
| | data_dir = Path(args.data_dir) if args.data_dir else _PROJECT_ROOT / "data" |
| |
|
| | print_header("COMPREHENSIVE EVAL — Korean 1B LLM") |
| | print(f" Checkpoint : {ckpt_path}") |
| | print(f" Device : {args.device}") |
| | print(f" Data dir : {data_dir}") |
| | print(f" seq_len : {args.seq_len} stride={args.stride} batch={args.batch_size}") |
| |
|
| | |
| | |
| | |
| | print_header("LOADING MODEL & TOKENIZER") |
| | try: |
| | model = load_model(str(ckpt_path), args.device) |
| | except Exception as exc: |
| | print(f" [FATAL] Could not load model: {exc}") |
| | sys.exit(1) |
| |
|
| | try: |
| | tokenizer = load_tokenizer(str(ckpt_path), args.tokenizer) |
| | except Exception as exc: |
| | print(f" [FATAL] Could not load tokenizer: {exc}") |
| | sys.exit(1) |
| |
|
| | |
| | ppl_results: Dict[str, Tuple[float, float, int]] = {} |
| | rep_results: Dict[str, Dict[str, float]] = {} |
| | calib_results: Dict[str, float] = {} |
| |
|
| | |
| | |
| | |
| | try: |
| | ppl_results = section_perplexity( |
| | model, data_dir, |
| | seq_len=args.seq_len, |
| | stride=args.stride, |
| | batch_size=args.batch_size, |
| | device=args.device, |
| | ) |
| | except Exception as exc: |
| | print(f" [SECTION 1 FAILED] {exc}") |
| |
|
| | |
| | |
| | |
| | try: |
| | section_token_analysis( |
| | model, tokenizer, data_dir, |
| | seq_len=args.seq_len, |
| | batch_size=args.batch_size, |
| | device=args.device, |
| | ) |
| | except Exception as exc: |
| | print(f" [SECTION 2 FAILED] {exc}") |
| |
|
| | |
| | |
| | |
| | generated: Dict[str, str] = {} |
| | try: |
| | generated = section_generation( |
| | model, tokenizer, |
| | max_new_tokens=args.max_new_tokens, |
| | device=args.device, |
| | ) |
| | except Exception as exc: |
| | print(f" [SECTION 3 FAILED] {exc}") |
| |
|
| | |
| | |
| | |
| | if generated: |
| | try: |
| | rep_results = section_repetition(generated) |
| | except Exception as exc: |
| | print(f" [SECTION 4 FAILED] {exc}") |
| | else: |
| | print_header("4. REPETITION ANALYSIS") |
| | print(" [SKIPPED] No generated texts available.") |
| |
|
| | |
| | |
| | |
| | try: |
| | section_comparison( |
| | model, tokenizer, |
| | max_new_tokens=args.max_new_tokens, |
| | device=args.device, |
| | ) |
| | except Exception as exc: |
| | print(f" [SECTION 5 FAILED] {exc}") |
| |
|
| | |
| | |
| | |
| | try: |
| | calib_results = section_calibration( |
| | model, data_dir, |
| | device=args.device, |
| | calib_tokens=args.calib_tokens, |
| | seq_len=min(args.seq_len, 512), |
| | ) |
| | except Exception as exc: |
| | print(f" [SECTION 6 FAILED] {exc}") |
| |
|
| | |
| | |
| | |
| | try: |
| | print_summary(ppl_results, rep_results, calib_results) |
| | except Exception as exc: |
| | print(f" [SUMMARY FAILED] {exc}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|