| | """ |
| | generation_task.py — Text generation quality evaluation tasks. |
| | |
| | Top-level functions for ProcessPoolExecutor (spawn) compatibility: |
| | - eval_generation(device) -> dict |
| | - eval_repetition_grid(device) -> dict |
| | |
| | Helper functions (also top-level, used internally): |
| | - top_p_filtering(logits, top_p, top_k) |
| | - generate_one(model, tokenizer, prompt, temperature, ...) |
| | - compute_ngram_rep(text, n) |
| | """ |
| | from __future__ import annotations |
| |
|
| | import logging |
| | import os |
| | import sys |
| | import time |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | _PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent |
| | if str(_PROJECT_ROOT) not in sys.path: |
| | sys.path.insert(0, str(_PROJECT_ROOT)) |
| |
|
| | _DEFAULT_CHECKPOINT = str(_PROJECT_ROOT / "checkpoints" / "korean_3b_fp8_run1" / "checkpoint-0057000") |
| | CHECKPOINT = os.environ.get("EVAL_CHECKPOINT", _DEFAULT_CHECKPOINT) |
| | TOKENIZER_PATH = os.environ.get("EVAL_TOKENIZER", str(_PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json")) |
| |
|
| | |
| | USE_CHAT_TEMPLATE = os.environ.get("USE_CHAT_TEMPLATE", "0") == "1" |
| | CHAT_TEMPLATE_FMT = "<|user|>\n{prompt}\n<|assistant|>\n" |
| | DATA_DIR = _PROJECT_ROOT / "data" |
| | SEQ_LEN = 2048 |
| | STRIDE = 512 |
| | BATCH_SIZE = 32 |
| |
|
| | |
| | |
| | |
| |
|
| | PROMPTS = [ |
| | "대한민국의 수도는", |
| | "인공지능이란", |
| | "한국의 전통 음식 중에서", |
| | "지구 온난화의 주요 원인은", |
| | "프로그래밍을 배우려면", |
| | "조선시대에는", |
| | "물리학에서 에너지란", |
| | "한국어는 세계에서", |
| | "경제 성장을 위해서는", |
| | "우주 탐사의 역사를 보면", |
| | "머신러닝과 딥러닝의 차이는", |
| | "한국 문학의 대표적인 작품으로는", |
| | "양자 컴퓨터란", |
| | "건강한 식습관을 위해서는", |
| | "세계 2차 대전 이후", |
| | ] |
| |
|
| | TEMPERATURES = [0.0, 0.5, 0.8, 1.0] |
| |
|
| | REP_GRID = [ |
| | {"name": "greedy", "temperature": 0.0, "repetition_penalty": 1.0}, |
| | {"name": "t0.5", "temperature": 0.5, "repetition_penalty": 1.0}, |
| | {"name": "t0.5_rep1.1", "temperature": 0.5, "repetition_penalty": 1.1}, |
| | {"name": "t0.7", "temperature": 0.7, "repetition_penalty": 1.0}, |
| | {"name": "t0.7_rep1.1", "temperature": 0.7, "repetition_penalty": 1.1}, |
| | {"name": "t0.7_rep1.2", "temperature": 0.7, "repetition_penalty": 1.2}, |
| | {"name": "t0.7_rep1.3", "temperature": 0.7, "repetition_penalty": 1.3}, |
| | {"name": "t0.9", "temperature": 0.9, "repetition_penalty": 1.0}, |
| | {"name": "t0.9_rep1.1", "temperature": 0.9, "repetition_penalty": 1.1}, |
| | {"name": "t0.9_rep1.2", "temperature": 0.9, "repetition_penalty": 1.2}, |
| | {"name": "t1.0", "temperature": 1.0, "repetition_penalty": 1.0}, |
| | {"name": "t1.0_rep1.1", "temperature": 1.0, "repetition_penalty": 1.1}, |
| | ] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _load_model(device: str): |
| | """Load FRANKENSTALLM 3B from checkpoint onto the given device.""" |
| | from model.transformer import LLM |
| |
|
| | model = LLM.from_pretrained(CHECKPOINT) |
| | model = model.to(device=device, dtype=torch.bfloat16) |
| | model.eval() |
| | return model |
| |
|
| |
|
| | def _load_tokenizer(): |
| | """Load the Korean SentencePiece tokenizer.""" |
| | from tokenizers import Tokenizer |
| |
|
| | return Tokenizer.from_file(TOKENIZER_PATH) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9, top_k: int = 0) -> torch.Tensor: |
| | """Apply top-p (nucleus) and/or top-k filtering to a logits tensor. |
| | |
| | Args: |
| | logits: Shape (..., vocab_size). |
| | top_p: Nucleus probability threshold in (0, 1). 0 or 1 disables. |
| | top_k: Keep only the top-k tokens. 0 disables. |
| | |
| | Returns: |
| | Filtered logits tensor of the same shape. |
| | """ |
| | if logits.dim() == 1: |
| | logits = logits.unsqueeze(0) |
| | squeeze = True |
| | else: |
| | squeeze = False |
| |
|
| | if top_k > 0: |
| | k = min(top_k, logits.size(-1)) |
| | kth = torch.topk(logits, k, dim=-1).values[:, -1, None] |
| | logits = logits.masked_fill(logits < kth, float("-inf")) |
| |
|
| | if 0.0 < top_p < 1.0: |
| | sorted_logits, sorted_idx = torch.sort(logits, dim=-1, descending=True) |
| | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | remove = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p |
| | sorted_logits[remove] = float("-inf") |
| | logits = torch.zeros_like(logits).scatter_(-1, sorted_idx, sorted_logits) |
| |
|
| | if squeeze: |
| | logits = logits.squeeze(0) |
| | return logits |
| |
|
| |
|
| | def generate_one( |
| | model, |
| | tokenizer, |
| | prompt: str, |
| | temperature: float, |
| | top_p: float = 0.9, |
| | top_k: int = 50, |
| | max_new_tokens: int = 256, |
| | device: str = "cuda:0", |
| | repetition_penalty: float = 1.0, |
| | ) -> tuple[str, int, bool]: |
| | """Generate a single continuation for a prompt using the given model. |
| | |
| | Args: |
| | model: Pre-loaded language model (eval mode). |
| | tokenizer: Tokenizer with encode/decode methods. |
| | prompt: Input prompt string. |
| | temperature: Sampling temperature. 0.0 = greedy. |
| | top_p: Nucleus filtering threshold. |
| | top_k: Top-k filtering count. |
| | max_new_tokens: Maximum number of tokens to generate. |
| | device: CUDA device string. |
| | repetition_penalty: Penalty > 1.0 discourages token repetition. |
| | |
| | Returns: |
| | Tuple of (generated_text, num_new_tokens, hit_eos). |
| | """ |
| | input_ids = torch.tensor( |
| | [tokenizer.encode(prompt).ids], dtype=torch.long, device=device |
| | ) |
| | eos_id = tokenizer.token_to_id("</s>") |
| | generated = input_ids |
| | new_ids: list[int] = [] |
| | hit_eos = False |
| |
|
| | for _ in range(max_new_tokens): |
| | logits_all, _ = model(generated) |
| | logits = logits_all[:, -1, :].clone() |
| |
|
| | if repetition_penalty != 1.0: |
| | for tid in set(generated[0].tolist()): |
| | if logits[0, tid] > 0: |
| | logits[0, tid] /= repetition_penalty |
| | else: |
| | logits[0, tid] *= repetition_penalty |
| |
|
| | if temperature == 0.0: |
| | next_id = logits.argmax(dim=-1, keepdim=True) |
| | else: |
| | logits = logits / max(temperature, 1e-8) |
| | logits = top_p_filtering(logits, top_p=top_p, top_k=top_k) |
| | probs = F.softmax(logits, dim=-1) |
| | next_id = torch.multinomial(probs, num_samples=1) |
| |
|
| | generated = torch.cat([generated, next_id], dim=-1) |
| | new_ids.append(next_id.item()) |
| |
|
| | if eos_id is not None and next_id.item() == eos_id: |
| | hit_eos = True |
| | break |
| |
|
| | text = tokenizer.decode(new_ids) |
| | return text, len(new_ids), hit_eos |
| |
|
| |
|
| | def compute_ngram_rep(text: str, n: int) -> float: |
| | """Compute n-gram repetition rate for a whitespace-tokenized string. |
| | |
| | Repetition rate = 1 - (unique n-grams / total n-grams). |
| | A value of 0 means no repeated n-grams; 1 means all n-grams are repeated. |
| | |
| | Args: |
| | text: Input text (whitespace-tokenized). |
| | n: N-gram order (1, 2, 3, 4, ...). |
| | |
| | Returns: |
| | Float in [0, 1]. |
| | """ |
| | tokens = text.split() |
| | if len(tokens) < n: |
| | return 0.0 |
| | ngrams = [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)] |
| | if not ngrams: |
| | return 0.0 |
| | return 1.0 - len(set(ngrams)) / len(ngrams) |
| |
|
| |
|
| | def compute_diversity_metrics(text: str) -> dict: |
| | """N-gram 반복률을 보완하는 어휘 다양성 메트릭. |
| | |
| | - Distinct-n (Li et al., 2016): 고유 n-gram 비율 |
| | - Type-Token Ratio: 어휘 풍부도 |
| | """ |
| | tokens = text.split() |
| | n = len(tokens) |
| | if n == 0: |
| | return {"distinct_1": 0.0, "distinct_2": 0.0, "distinct_3": 0.0, |
| | "type_token_ratio": 0.0, "vocab_size": 0, "total_tokens": 0} |
| |
|
| | unigrams = set(tokens) |
| | bigrams = set(zip(tokens, tokens[1:])) if n > 1 else set() |
| | trigrams = set(zip(tokens, tokens[1:], tokens[2:])) if n > 2 else set() |
| |
|
| | return { |
| | "distinct_1": len(unigrams) / n, |
| | "distinct_2": len(bigrams) / max(n - 1, 1), |
| | "distinct_3": len(trigrams) / max(n - 2, 1), |
| | "type_token_ratio": len(unigrams) / n, |
| | "vocab_size": len(unigrams), |
| | "total_tokens": n, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def eval_generation(device: str) -> dict: |
| | """Evaluate generation quality: 15 prompts x 4 temperatures. |
| | |
| | For each (prompt, temperature) combination: |
| | - Generates up to 256 new tokens |
| | - Computes 1-gram through 4-gram repetition rates |
| | |
| | Args: |
| | device: CUDA device string, e.g. "cuda:4". |
| | |
| | Returns: |
| | Dict with keys: |
| | - summary: aggregate statistics across all generations |
| | - samples: list of per-generation result dicts |
| | """ |
| | torch.cuda.set_device(int(device.split(":")[-1])) |
| | print(f"[GEN {device}] Loading model...") |
| | model = _load_model(device) |
| | tokenizer = _load_tokenizer() |
| | t0 = time.time() |
| |
|
| | results: list[dict] = [] |
| | total_combinations = len(PROMPTS) * len(TEMPERATURES) |
| | done = 0 |
| |
|
| | if USE_CHAT_TEMPLATE: |
| | print(f"[GEN {device}] Chat template ENABLED", flush=True) |
| |
|
| | for prompt in PROMPTS: |
| | effective_prompt = CHAT_TEMPLATE_FMT.format(prompt=prompt) if USE_CHAT_TEMPLATE else prompt |
| | for temp in TEMPERATURES: |
| | with torch.inference_mode(): |
| | text, n_tokens, hit_eos = generate_one( |
| | model, tokenizer, effective_prompt, temp, device=device |
| | ) |
| | rep1 = compute_ngram_rep(text, 1) |
| | rep2 = compute_ngram_rep(text, 2) |
| | rep3 = compute_ngram_rep(text, 3) |
| | rep4 = compute_ngram_rep(text, 4) |
| | diversity = compute_diversity_metrics(text) |
| |
|
| | entry = { |
| | "prompt": prompt, |
| | "chat_template": USE_CHAT_TEMPLATE, |
| | "effective_prompt": effective_prompt if USE_CHAT_TEMPLATE else prompt, |
| | "temperature": temp, |
| | "generated_tokens": n_tokens, |
| | "hit_eos": hit_eos, |
| | "1gram_rep": round(rep1, 4), |
| | "2gram_rep": round(rep2, 4), |
| | "3gram_rep": round(rep3, 4), |
| | "4gram_rep": round(rep4, 4), |
| | "distinct_1": round(diversity["distinct_1"], 4), |
| | "distinct_2": round(diversity["distinct_2"], 4), |
| | "distinct_3": round(diversity["distinct_3"], 4), |
| | "type_token_ratio": round(diversity["type_token_ratio"], 4), |
| | "text": text[:500], |
| | } |
| | results.append(entry) |
| | done += 1 |
| |
|
| | label = "greedy" if temp == 0.0 else f"t={temp}" |
| | print( |
| | f"[GEN {device}] ({done}/{total_combinations}) " |
| | f"{prompt[:15]}... ({label}): " |
| | f"{n_tokens}tok, 3gram_rep={rep3:.2%}, eos={hit_eos}" |
| | ) |
| |
|
| | elapsed = time.time() - t0 |
| |
|
| | |
| | greedy = [r for r in results if r["temperature"] == 0.0] |
| | sampled = [r for r in results if r["temperature"] > 0.0] |
| |
|
| | if not greedy: |
| | logger.warning("No greedy generation results — all prompts may have failed") |
| | if not sampled: |
| | logger.warning("No sampled generation results") |
| |
|
| | summary = { |
| | "total_generations": len(results), |
| | "n_prompts": len(PROMPTS), |
| | "temperatures": TEMPERATURES, |
| | "greedy_avg_1gram_rep": round(np.mean([r["1gram_rep"] for r in greedy]), 4) if greedy else 0.0, |
| | "greedy_avg_2gram_rep": round(np.mean([r["2gram_rep"] for r in greedy]), 4) if greedy else 0.0, |
| | "greedy_avg_3gram_rep": round(np.mean([r["3gram_rep"] for r in greedy]), 4) if greedy else 0.0, |
| | "greedy_avg_4gram_rep": round(np.mean([r["4gram_rep"] for r in greedy]), 4) if greedy else 0.0, |
| | "greedy_eos_rate": round(np.mean([r["hit_eos"] for r in greedy]), 4) if greedy else 0.0, |
| | "greedy_avg_tokens": round(np.mean([r["generated_tokens"] for r in greedy]), 1) if greedy else 0.0, |
| | "sampled_avg_3gram_rep": round(np.mean([r["3gram_rep"] for r in sampled]), 4) if sampled else 0.0, |
| | "sampled_eos_rate": round(np.mean([r["hit_eos"] for r in sampled]), 4) if sampled else 0.0, |
| | "sampled_avg_tokens": round(np.mean([r["generated_tokens"] for r in sampled]), 1) if sampled else 0.0, |
| | "greedy_avg_distinct_1": round(float(np.mean([r["distinct_1"] for r in greedy])), 4) if greedy else 0.0, |
| | "greedy_avg_distinct_2": round(float(np.mean([r["distinct_2"] for r in greedy])), 4) if greedy else 0.0, |
| | "greedy_avg_distinct_3": round(float(np.mean([r["distinct_3"] for r in greedy])), 4) if greedy else 0.0, |
| | "sampled_avg_distinct_2": round(float(np.mean([r["distinct_2"] for r in sampled])), 4) if sampled else 0.0, |
| | "token_count_min": int(np.min([r["generated_tokens"] for r in results])) if results else 0, |
| | "token_count_max": int(np.max([r["generated_tokens"] for r in results])) if results else 0, |
| | "token_count_p25": int(np.percentile([r["generated_tokens"] for r in results], 25)) if results else 0, |
| | "token_count_p75": int(np.percentile([r["generated_tokens"] for r in results], 75)) if results else 0, |
| | "elapsed_sec": round(elapsed, 1), |
| | } |
| |
|
| | print( |
| | f"[GEN {device}] DONE greedy 3gram_rep={summary['greedy_avg_3gram_rep']:.4f}, " |
| | f"eos_rate={summary['greedy_eos_rate']:.2%}, {elapsed:.1f}s" |
| | ) |
| | return {"summary": summary, "samples": results} |
| |
|
| |
|
| | def eval_repetition_grid(device: str) -> dict: |
| | """Grid search over 12 generation parameter combinations x 5 prompts. |
| | |
| | Evaluates each config (temperature x repetition_penalty) on the first 5 |
| | prompts and returns results sorted by average 3-gram repetition rate. |
| | |
| | Args: |
| | device: CUDA device string, e.g. "cuda:5". |
| | |
| | Returns: |
| | Dict with keys: |
| | - grid_results: list of per-config dicts, sorted by avg_3gram_rep |
| | - best: config with lowest avg_3gram_rep |
| | - elapsed_sec: wall-clock time |
| | """ |
| | torch.cuda.set_device(int(device.split(":")[-1])) |
| | print(f"[REP {device}] Loading model...") |
| | model = _load_model(device) |
| | tokenizer = _load_tokenizer() |
| | t0 = time.time() |
| |
|
| | rep_prompts = PROMPTS[:5] |
| | results: list[dict] = [] |
| |
|
| | total = len(REP_GRID) * len(rep_prompts) |
| | done = 0 |
| |
|
| | if USE_CHAT_TEMPLATE: |
| | print(f"[REP {device}] Chat template ENABLED", flush=True) |
| |
|
| | for params in REP_GRID: |
| | combo_results: list[dict] = [] |
| | for prompt in rep_prompts: |
| | effective_prompt = CHAT_TEMPLATE_FMT.format(prompt=prompt) if USE_CHAT_TEMPLATE else prompt |
| | with torch.inference_mode(): |
| | text, n_tokens, hit_eos = generate_one( |
| | model, |
| | tokenizer, |
| | effective_prompt, |
| | temperature=params["temperature"], |
| | repetition_penalty=params["repetition_penalty"], |
| | device=device, |
| | max_new_tokens=256, |
| | ) |
| | combo_results.append( |
| | { |
| | "prompt": prompt, |
| | "n_tokens": n_tokens, |
| | "hit_eos": hit_eos, |
| | "1gram_rep": compute_ngram_rep(text, 1), |
| | "2gram_rep": compute_ngram_rep(text, 2), |
| | "3gram_rep": compute_ngram_rep(text, 3), |
| | "4gram_rep": compute_ngram_rep(text, 4), |
| | } |
| | ) |
| | done += 1 |
| |
|
| | if not combo_results: |
| | logger.warning("All prompts failed for config %s — skipping", params.get("name", "unknown")) |
| | continue |
| |
|
| | avg_3gram = float(np.mean([r["3gram_rep"] for r in combo_results])) |
| | avg_4gram = float(np.mean([r["4gram_rep"] for r in combo_results])) |
| | eos_rate = float(np.mean([r["hit_eos"] for r in combo_results])) |
| | avg_tokens = float(np.mean([r["n_tokens"] for r in combo_results])) |
| |
|
| | entry = { |
| | "params": params["name"], |
| | "temperature": params["temperature"], |
| | "repetition_penalty": params["repetition_penalty"], |
| | "avg_3gram_rep": round(avg_3gram, 4), |
| | "avg_4gram_rep": round(avg_4gram, 4), |
| | "eos_rate": round(eos_rate, 4), |
| | "avg_tokens": round(avg_tokens, 1), |
| | "per_prompt": combo_results, |
| | } |
| | results.append(entry) |
| | print( |
| | f"[REP {device}] {params['name']}: " |
| | f"3gram={avg_3gram:.2%}, 4gram={avg_4gram:.2%}, " |
| | f"eos={eos_rate:.0%}, {avg_tokens:.0f}tok" |
| | ) |
| |
|
| | elapsed = time.time() - t0 |
| |
|
| | |
| | sorted_results = sorted(results, key=lambda r: r["avg_3gram_rep"]) |
| | best = sorted_results[0] |
| |
|
| | print( |
| | f"[REP {device}] DONE best={best['params']} " |
| | f"(3gram={best['avg_3gram_rep']:.2%}), {elapsed:.1f}s" |
| | ) |
| | return { |
| | "grid_results": sorted_results, |
| | "best": { |
| | "params": best["params"], |
| | "temperature": best["temperature"], |
| | "repetition_penalty": best["repetition_penalty"], |
| | "avg_3gram_rep": best["avg_3gram_rep"], |
| | "avg_4gram_rep": best["avg_4gram_rep"], |
| | }, |
| | "elapsed_sec": round(elapsed, 1), |
| | } |
| |
|