""" generation_task.py — Text generation quality evaluation tasks. Top-level functions for ProcessPoolExecutor (spawn) compatibility: - eval_generation(device) -> dict - eval_repetition_grid(device) -> dict Helper functions (also top-level, used internally): - top_p_filtering(logits, top_p, top_k) - generate_one(model, tokenizer, prompt, temperature, ...) - compute_ngram_rep(text, n) """ from __future__ import annotations import logging import os import sys import time from pathlib import Path import numpy as np import torch import torch.nn.functional as F logger = logging.getLogger(__name__) _PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent if str(_PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(_PROJECT_ROOT)) _DEFAULT_CHECKPOINT = str(_PROJECT_ROOT / "checkpoints" / "korean_3b_fp8_run1" / "checkpoint-0057000") CHECKPOINT = os.environ.get("EVAL_CHECKPOINT", _DEFAULT_CHECKPOINT) TOKENIZER_PATH = os.environ.get("EVAL_TOKENIZER", str(_PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json")) # Chat template support for SFT models USE_CHAT_TEMPLATE = os.environ.get("USE_CHAT_TEMPLATE", "0") == "1" CHAT_TEMPLATE_FMT = "<|user|>\n{prompt}\n<|assistant|>\n" DATA_DIR = _PROJECT_ROOT / "data" SEQ_LEN = 2048 STRIDE = 512 BATCH_SIZE = 32 # --------------------------------------------------------------------------- # Prompt / temperature constants # --------------------------------------------------------------------------- PROMPTS = [ "대한민국의 수도는", "인공지능이란", "한국의 전통 음식 중에서", "지구 온난화의 주요 원인은", "프로그래밍을 배우려면", "조선시대에는", "물리학에서 에너지란", "한국어는 세계에서", "경제 성장을 위해서는", "우주 탐사의 역사를 보면", "머신러닝과 딥러닝의 차이는", "한국 문학의 대표적인 작품으로는", "양자 컴퓨터란", "건강한 식습관을 위해서는", "세계 2차 대전 이후", ] TEMPERATURES = [0.0, 0.5, 0.8, 1.0] REP_GRID = [ {"name": "greedy", "temperature": 0.0, "repetition_penalty": 1.0}, {"name": "t0.5", "temperature": 0.5, "repetition_penalty": 1.0}, {"name": "t0.5_rep1.1", "temperature": 0.5, "repetition_penalty": 1.1}, {"name": "t0.7", "temperature": 0.7, "repetition_penalty": 1.0}, {"name": "t0.7_rep1.1", "temperature": 0.7, "repetition_penalty": 1.1}, {"name": "t0.7_rep1.2", "temperature": 0.7, "repetition_penalty": 1.2}, {"name": "t0.7_rep1.3", "temperature": 0.7, "repetition_penalty": 1.3}, {"name": "t0.9", "temperature": 0.9, "repetition_penalty": 1.0}, {"name": "t0.9_rep1.1", "temperature": 0.9, "repetition_penalty": 1.1}, {"name": "t0.9_rep1.2", "temperature": 0.9, "repetition_penalty": 1.2}, {"name": "t1.0", "temperature": 1.0, "repetition_penalty": 1.0}, {"name": "t1.0_rep1.1", "temperature": 1.0, "repetition_penalty": 1.1}, ] # --------------------------------------------------------------------------- # Shared model utilities # --------------------------------------------------------------------------- def _load_model(device: str): """Load FRANKENSTALLM 3B from checkpoint onto the given device.""" from model.transformer import LLM # type: ignore[import] model = LLM.from_pretrained(CHECKPOINT) model = model.to(device=device, dtype=torch.bfloat16) model.eval() return model def _load_tokenizer(): """Load the Korean SentencePiece tokenizer.""" from tokenizers import Tokenizer # type: ignore[import] return Tokenizer.from_file(TOKENIZER_PATH) # --------------------------------------------------------------------------- # Generation helpers (top-level for pickle compatibility) # --------------------------------------------------------------------------- def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9, top_k: int = 0) -> torch.Tensor: """Apply top-p (nucleus) and/or top-k filtering to a logits tensor. Args: logits: Shape (..., vocab_size). top_p: Nucleus probability threshold in (0, 1). 0 or 1 disables. top_k: Keep only the top-k tokens. 0 disables. Returns: Filtered logits tensor of the same shape. """ if logits.dim() == 1: logits = logits.unsqueeze(0) squeeze = True else: squeeze = False if top_k > 0: k = min(top_k, logits.size(-1)) kth = torch.topk(logits, k, dim=-1).values[:, -1, None] logits = logits.masked_fill(logits < kth, float("-inf")) if 0.0 < top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, dim=-1, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) remove = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p sorted_logits[remove] = float("-inf") logits = torch.zeros_like(logits).scatter_(-1, sorted_idx, sorted_logits) if squeeze: logits = logits.squeeze(0) return logits def generate_one( model, tokenizer, prompt: str, temperature: float, top_p: float = 0.9, top_k: int = 50, max_new_tokens: int = 256, device: str = "cuda:0", repetition_penalty: float = 1.0, ) -> tuple[str, int, bool]: """Generate a single continuation for a prompt using the given model. Args: model: Pre-loaded language model (eval mode). tokenizer: Tokenizer with encode/decode methods. prompt: Input prompt string. temperature: Sampling temperature. 0.0 = greedy. top_p: Nucleus filtering threshold. top_k: Top-k filtering count. max_new_tokens: Maximum number of tokens to generate. device: CUDA device string. repetition_penalty: Penalty > 1.0 discourages token repetition. Returns: Tuple of (generated_text, num_new_tokens, hit_eos). """ input_ids = torch.tensor( [tokenizer.encode(prompt).ids], dtype=torch.long, device=device ) eos_id = tokenizer.token_to_id("") generated = input_ids new_ids: list[int] = [] hit_eos = False for _ in range(max_new_tokens): logits_all, _ = model(generated) logits = logits_all[:, -1, :].clone() if repetition_penalty != 1.0: for tid in set(generated[0].tolist()): if logits[0, tid] > 0: logits[0, tid] /= repetition_penalty else: logits[0, tid] *= repetition_penalty if temperature == 0.0: next_id = logits.argmax(dim=-1, keepdim=True) else: logits = logits / max(temperature, 1e-8) logits = top_p_filtering(logits, top_p=top_p, top_k=top_k) probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) generated = torch.cat([generated, next_id], dim=-1) new_ids.append(next_id.item()) if eos_id is not None and next_id.item() == eos_id: hit_eos = True break text = tokenizer.decode(new_ids) return text, len(new_ids), hit_eos def compute_ngram_rep(text: str, n: int) -> float: """Compute n-gram repetition rate for a whitespace-tokenized string. Repetition rate = 1 - (unique n-grams / total n-grams). A value of 0 means no repeated n-grams; 1 means all n-grams are repeated. Args: text: Input text (whitespace-tokenized). n: N-gram order (1, 2, 3, 4, ...). Returns: Float in [0, 1]. """ tokens = text.split() if len(tokens) < n: return 0.0 ngrams = [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)] if not ngrams: return 0.0 return 1.0 - len(set(ngrams)) / len(ngrams) def compute_diversity_metrics(text: str) -> dict: """N-gram 반복률을 보완하는 어휘 다양성 메트릭. - Distinct-n (Li et al., 2016): 고유 n-gram 비율 - Type-Token Ratio: 어휘 풍부도 """ tokens = text.split() n = len(tokens) if n == 0: return {"distinct_1": 0.0, "distinct_2": 0.0, "distinct_3": 0.0, "type_token_ratio": 0.0, "vocab_size": 0, "total_tokens": 0} unigrams = set(tokens) bigrams = set(zip(tokens, tokens[1:])) if n > 1 else set() trigrams = set(zip(tokens, tokens[1:], tokens[2:])) if n > 2 else set() return { "distinct_1": len(unigrams) / n, "distinct_2": len(bigrams) / max(n - 1, 1), "distinct_3": len(trigrams) / max(n - 2, 1), "type_token_ratio": len(unigrams) / n, "vocab_size": len(unigrams), "total_tokens": n, } # --------------------------------------------------------------------------- # Main task functions (must be top-level for pickle / spawn compatibility) # --------------------------------------------------------------------------- def eval_generation(device: str) -> dict: """Evaluate generation quality: 15 prompts x 4 temperatures. For each (prompt, temperature) combination: - Generates up to 256 new tokens - Computes 1-gram through 4-gram repetition rates Args: device: CUDA device string, e.g. "cuda:4". Returns: Dict with keys: - summary: aggregate statistics across all generations - samples: list of per-generation result dicts """ torch.cuda.set_device(int(device.split(":")[-1])) print(f"[GEN {device}] Loading model...") model = _load_model(device) tokenizer = _load_tokenizer() t0 = time.time() results: list[dict] = [] total_combinations = len(PROMPTS) * len(TEMPERATURES) done = 0 if USE_CHAT_TEMPLATE: print(f"[GEN {device}] Chat template ENABLED", flush=True) for prompt in PROMPTS: effective_prompt = CHAT_TEMPLATE_FMT.format(prompt=prompt) if USE_CHAT_TEMPLATE else prompt for temp in TEMPERATURES: with torch.inference_mode(): text, n_tokens, hit_eos = generate_one( model, tokenizer, effective_prompt, temp, device=device ) rep1 = compute_ngram_rep(text, 1) rep2 = compute_ngram_rep(text, 2) rep3 = compute_ngram_rep(text, 3) rep4 = compute_ngram_rep(text, 4) diversity = compute_diversity_metrics(text) entry = { "prompt": prompt, "chat_template": USE_CHAT_TEMPLATE, "effective_prompt": effective_prompt if USE_CHAT_TEMPLATE else prompt, "temperature": temp, "generated_tokens": n_tokens, "hit_eos": hit_eos, "1gram_rep": round(rep1, 4), "2gram_rep": round(rep2, 4), "3gram_rep": round(rep3, 4), "4gram_rep": round(rep4, 4), "distinct_1": round(diversity["distinct_1"], 4), "distinct_2": round(diversity["distinct_2"], 4), "distinct_3": round(diversity["distinct_3"], 4), "type_token_ratio": round(diversity["type_token_ratio"], 4), "text": text[:500], # truncate for readability } results.append(entry) done += 1 label = "greedy" if temp == 0.0 else f"t={temp}" print( f"[GEN {device}] ({done}/{total_combinations}) " f"{prompt[:15]}... ({label}): " f"{n_tokens}tok, 3gram_rep={rep3:.2%}, eos={hit_eos}" ) elapsed = time.time() - t0 # Aggregate stats per temperature group greedy = [r for r in results if r["temperature"] == 0.0] sampled = [r for r in results if r["temperature"] > 0.0] if not greedy: logger.warning("No greedy generation results — all prompts may have failed") if not sampled: logger.warning("No sampled generation results") summary = { "total_generations": len(results), "n_prompts": len(PROMPTS), "temperatures": TEMPERATURES, "greedy_avg_1gram_rep": round(np.mean([r["1gram_rep"] for r in greedy]), 4) if greedy else 0.0, "greedy_avg_2gram_rep": round(np.mean([r["2gram_rep"] for r in greedy]), 4) if greedy else 0.0, "greedy_avg_3gram_rep": round(np.mean([r["3gram_rep"] for r in greedy]), 4) if greedy else 0.0, "greedy_avg_4gram_rep": round(np.mean([r["4gram_rep"] for r in greedy]), 4) if greedy else 0.0, "greedy_eos_rate": round(np.mean([r["hit_eos"] for r in greedy]), 4) if greedy else 0.0, "greedy_avg_tokens": round(np.mean([r["generated_tokens"] for r in greedy]), 1) if greedy else 0.0, "sampled_avg_3gram_rep": round(np.mean([r["3gram_rep"] for r in sampled]), 4) if sampled else 0.0, "sampled_eos_rate": round(np.mean([r["hit_eos"] for r in sampled]), 4) if sampled else 0.0, "sampled_avg_tokens": round(np.mean([r["generated_tokens"] for r in sampled]), 1) if sampled else 0.0, "greedy_avg_distinct_1": round(float(np.mean([r["distinct_1"] for r in greedy])), 4) if greedy else 0.0, "greedy_avg_distinct_2": round(float(np.mean([r["distinct_2"] for r in greedy])), 4) if greedy else 0.0, "greedy_avg_distinct_3": round(float(np.mean([r["distinct_3"] for r in greedy])), 4) if greedy else 0.0, "sampled_avg_distinct_2": round(float(np.mean([r["distinct_2"] for r in sampled])), 4) if sampled else 0.0, "token_count_min": int(np.min([r["generated_tokens"] for r in results])) if results else 0, "token_count_max": int(np.max([r["generated_tokens"] for r in results])) if results else 0, "token_count_p25": int(np.percentile([r["generated_tokens"] for r in results], 25)) if results else 0, "token_count_p75": int(np.percentile([r["generated_tokens"] for r in results], 75)) if results else 0, "elapsed_sec": round(elapsed, 1), } print( f"[GEN {device}] DONE greedy 3gram_rep={summary['greedy_avg_3gram_rep']:.4f}, " f"eos_rate={summary['greedy_eos_rate']:.2%}, {elapsed:.1f}s" ) return {"summary": summary, "samples": results} def eval_repetition_grid(device: str) -> dict: """Grid search over 12 generation parameter combinations x 5 prompts. Evaluates each config (temperature x repetition_penalty) on the first 5 prompts and returns results sorted by average 3-gram repetition rate. Args: device: CUDA device string, e.g. "cuda:5". Returns: Dict with keys: - grid_results: list of per-config dicts, sorted by avg_3gram_rep - best: config with lowest avg_3gram_rep - elapsed_sec: wall-clock time """ torch.cuda.set_device(int(device.split(":")[-1])) print(f"[REP {device}] Loading model...") model = _load_model(device) tokenizer = _load_tokenizer() t0 = time.time() rep_prompts = PROMPTS[:5] # first 5 prompts results: list[dict] = [] total = len(REP_GRID) * len(rep_prompts) done = 0 if USE_CHAT_TEMPLATE: print(f"[REP {device}] Chat template ENABLED", flush=True) for params in REP_GRID: combo_results: list[dict] = [] for prompt in rep_prompts: effective_prompt = CHAT_TEMPLATE_FMT.format(prompt=prompt) if USE_CHAT_TEMPLATE else prompt with torch.inference_mode(): text, n_tokens, hit_eos = generate_one( model, tokenizer, effective_prompt, temperature=params["temperature"], repetition_penalty=params["repetition_penalty"], device=device, max_new_tokens=256, ) combo_results.append( { "prompt": prompt, "n_tokens": n_tokens, "hit_eos": hit_eos, "1gram_rep": compute_ngram_rep(text, 1), "2gram_rep": compute_ngram_rep(text, 2), "3gram_rep": compute_ngram_rep(text, 3), "4gram_rep": compute_ngram_rep(text, 4), } ) done += 1 if not combo_results: logger.warning("All prompts failed for config %s — skipping", params.get("name", "unknown")) continue avg_3gram = float(np.mean([r["3gram_rep"] for r in combo_results])) avg_4gram = float(np.mean([r["4gram_rep"] for r in combo_results])) eos_rate = float(np.mean([r["hit_eos"] for r in combo_results])) avg_tokens = float(np.mean([r["n_tokens"] for r in combo_results])) entry = { "params": params["name"], "temperature": params["temperature"], "repetition_penalty": params["repetition_penalty"], "avg_3gram_rep": round(avg_3gram, 4), "avg_4gram_rep": round(avg_4gram, 4), "eos_rate": round(eos_rate, 4), "avg_tokens": round(avg_tokens, 1), "per_prompt": combo_results, } results.append(entry) print( f"[REP {device}] {params['name']}: " f"3gram={avg_3gram:.2%}, 4gram={avg_4gram:.2%}, " f"eos={eos_rate:.0%}, {avg_tokens:.0f}tok" ) elapsed = time.time() - t0 # Sort by avg 3-gram repetition (ascending = better) sorted_results = sorted(results, key=lambda r: r["avg_3gram_rep"]) best = sorted_results[0] print( f"[REP {device}] DONE best={best['params']} " f"(3gram={best['avg_3gram_rep']:.2%}), {elapsed:.1f}s" ) return { "grid_results": sorted_results, "best": { "params": best["params"], "temperature": best["temperature"], "repetition_penalty": best["repetition_penalty"], "avg_3gram_rep": best["avg_3gram_rep"], "avg_4gram_rep": best["avg_4gram_rep"], }, "elapsed_sec": round(elapsed, 1), }