| """ |
| 반복 퇴화 문제 해결을 위한 생성 파라미터 그리드 서치. |
| |
| 다양한 디코딩 전략을 테스트하고 반복률을 측정한다. |
| - Sampling (temperature, top_p, top_k, repetition_penalty) |
| - no_repeat_ngram_size |
| - Contrastive Search |
| - Stop sequence (### 답변:, ### 질문:) |
| |
| Usage: |
| cd /PROJECT/0325120031_A/ghong/taketimes/llm-bang |
| python eval/test_generation_params.py \ |
| --checkpoint checkpoints/korean_1b_sft/checkpoint-0005000 \ |
| --device cuda:0 |
| """ |
|
|
| from __future__ import annotations |
| import argparse |
| import json |
| import sys |
| import time |
| from pathlib import Path |
| from collections import Counter |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| _PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| if str(_PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(_PROJECT_ROOT)) |
|
|
| from model.transformer import LLM |
| from tokenizers import Tokenizer |
|
|
|
|
| |
| |
| |
| SFT_PROMPTS = [ |
| "<|user|>\n한국의 수도는 어디인가요?\n<|assistant|>\n", |
| "<|user|>\n파이썬에서 리스트를 정렬하는 방법을 설명해주세요.\n<|assistant|>\n", |
| "<|user|>\n지구온난화의 주요 원인을 설명하세요.\n<|assistant|>\n", |
| "<|user|>\n좋은 수면 습관을 만들기 위한 팁을 알려주세요.\n<|assistant|>\n", |
| "<|user|>\n한국 전통 음식 중 김치에 대해 설명해주세요.\n<|assistant|>\n", |
| ] |
|
|
| |
| WRONG_FORMAT_PROMPTS = [ |
| "### 질문: 한국의 수도는 어디인가요?\n### 답변:", |
| "### 질문: 파이썬에서 리스트를 정렬하는 방법을 설명해주세요.\n### 답변:", |
| "### 질문: 지구온난화의 주요 원인을 설명하세요.\n### 답변:", |
| ] |
|
|
|
|
| |
| |
| |
| def find_stop_token_ids(tokenizer: Tokenizer, stop_strings: list[str]) -> list[list[int]]: |
| """Find token IDs for stop sequences.""" |
| results = [] |
| for s in stop_strings: |
| ids = tokenizer.encode(s).ids |
| results.append(ids) |
| print(f" Stop sequence '{s}' → token IDs: {ids}") |
| return results |
|
|
|
|
| def check_stop_sequences(generated_ids: list[int], stop_sequences: list[list[int]]) -> int | None: |
| """Check if generated_ids ends with any stop sequence. Returns index to truncate at, or None.""" |
| for seq in stop_sequences: |
| seq_len = len(seq) |
| if len(generated_ids) >= seq_len: |
| if generated_ids[-seq_len:] == seq: |
| return len(generated_ids) - seq_len |
| return None |
|
|
|
|
| |
| |
| |
| def compute_ngram_repetition(text: str, n: int) -> float: |
| tokens = text.split() |
| if len(tokens) < n: |
| return 0.0 |
| ngrams = [tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)] |
| if not ngrams: |
| return 0.0 |
| return 1.0 - len(set(ngrams)) / len(ngrams) |
|
|
|
|
| def compute_all_repetition_metrics(text: str) -> dict: |
| return { |
| f"{n}gram_rep": compute_ngram_repetition(text, n) |
| for n in [1, 2, 3, 4] |
| } |
|
|
|
|
| |
| |
| |
| def top_p_filtering(logits, top_p=0.9, top_k=0): |
| if logits.dim() == 1: |
| logits = logits.unsqueeze(0) |
| squeeze = True |
| else: |
| squeeze = False |
|
|
| if top_k > 0: |
| k = min(top_k, logits.size(-1)) |
| kth = torch.topk(logits, k, dim=-1).values[:, -1, None] |
| logits = logits.masked_fill(logits < kth, float("-inf")) |
|
|
| if 0.0 < top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, dim=-1, descending=True) |
| cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| remove = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p |
| sorted_logits[remove] = float("-inf") |
| logits = torch.zeros_like(logits).scatter_(-1, sorted_idx, sorted_logits) |
|
|
| if squeeze: |
| logits = logits.squeeze(0) |
| return logits |
|
|
|
|
| @torch.inference_mode() |
| def generate_with_params( |
| model, tokenizer, prompt, params, device="cuda:0", max_new_tokens=200 |
| ): |
| """Generate with flexible parameter set.""" |
| model.eval() |
| input_ids = torch.tensor([tokenizer.encode(prompt).ids], dtype=torch.long, device=device) |
| eos_id = tokenizer.token_to_id("</s>") |
|
|
| |
| temperature = params.get("temperature", 0.8) |
| top_p = params.get("top_p", 0.9) |
| top_k = params.get("top_k", 50) |
| repetition_penalty = params.get("repetition_penalty", 1.0) |
| no_repeat_ngram = params.get("no_repeat_ngram_size", 0) |
| use_contrastive = params.get("contrastive_search", False) |
| penalty_alpha = params.get("penalty_alpha", 0.6) |
| contrastive_k = params.get("contrastive_k", 4) |
|
|
| |
| stop_strings = params.get("stop_strings", []) |
| stop_seqs = [] |
| for s in stop_strings: |
| stop_seqs.append(tokenizer.encode(s).ids) |
|
|
| generated_ids = input_ids |
| new_token_ids = [] |
|
|
| for step in range(max_new_tokens): |
| logits_all, _ = model(generated_ids) |
| logits = logits_all[:, -1, :].clone() |
|
|
| |
| if repetition_penalty != 1.0: |
| for token_id in set(generated_ids[0].tolist()): |
| if logits[0, token_id] > 0: |
| logits[0, token_id] /= repetition_penalty |
| else: |
| logits[0, token_id] *= repetition_penalty |
|
|
| |
| if no_repeat_ngram > 0 and len(new_token_ids) >= no_repeat_ngram - 1: |
| all_ids = generated_ids[0].tolist() |
| for i in range(len(all_ids) - no_repeat_ngram + 1): |
| ngram = tuple(all_ids[i:i + no_repeat_ngram - 1]) |
| last_ngram = tuple(all_ids[-(no_repeat_ngram - 1):]) |
| if ngram == last_ngram: |
| logits[0, all_ids[i + no_repeat_ngram - 1]] = float("-inf") |
|
|
| if use_contrastive: |
| |
| |
| |
| top_k_logits, top_k_ids = torch.topk(logits[0], contrastive_k) |
| probs = F.softmax(top_k_logits, dim=-1) |
|
|
| if step > 0: |
| |
| |
| |
| best_idx = 0 |
| best_score = float("-inf") |
| for ki in range(contrastive_k): |
| confidence = probs[ki].item() |
| |
| token = top_k_ids[ki].item() |
| penalty = 1.0 if token in set(new_token_ids[-20:]) else 0.0 |
| score = (1 - penalty_alpha) * confidence - penalty_alpha * penalty |
| if score > best_score: |
| best_score = score |
| best_idx = ki |
| next_token_id = top_k_ids[best_idx].unsqueeze(0).unsqueeze(0) |
| else: |
| next_token_id = top_k_ids[0].unsqueeze(0).unsqueeze(0) |
| else: |
| |
| if temperature == 0.0: |
| next_token_id = logits.argmax(dim=-1, keepdim=True) |
| else: |
| logits = logits / max(temperature, 1e-8) |
| logits = top_p_filtering(logits, top_p=top_p, top_k=top_k) |
| probs = F.softmax(logits, dim=-1) |
| next_token_id = torch.multinomial(probs, num_samples=1) |
|
|
| generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
| new_token_ids.append(next_token_id.item()) |
|
|
| |
| if eos_id is not None and next_token_id.item() == eos_id: |
| break |
|
|
| |
| for seq in stop_seqs: |
| if len(new_token_ids) >= len(seq) and new_token_ids[-len(seq):] == seq: |
| new_token_ids = new_token_ids[:-len(seq)] |
| return tokenizer.decode(new_token_ids) |
|
|
| return tokenizer.decode(new_token_ids) |
|
|
|
|
| |
| |
| |
| PARAM_GRID = [ |
| |
| {"name": "baseline", "temperature": 0.8, "top_p": 0.9, "top_k": 50}, |
|
|
| |
| {"name": "rep_1.1", "temperature": 0.8, "top_p": 0.9, "top_k": 50, "repetition_penalty": 1.1}, |
| {"name": "rep_1.2", "temperature": 0.8, "top_p": 0.9, "top_k": 50, "repetition_penalty": 1.2}, |
| {"name": "rep_1.3", "temperature": 0.8, "top_p": 0.9, "top_k": 50, "repetition_penalty": 1.3}, |
| {"name": "rep_1.5", "temperature": 0.8, "top_p": 0.9, "top_k": 50, "repetition_penalty": 1.5}, |
|
|
| |
| {"name": "no_rep_3gram", "temperature": 0.8, "top_p": 0.9, "top_k": 50, "no_repeat_ngram_size": 3}, |
| {"name": "no_rep_4gram", "temperature": 0.8, "top_p": 0.9, "top_k": 50, "no_repeat_ngram_size": 4}, |
|
|
| |
| {"name": "rep1.2+no3gram", "temperature": 0.8, "top_p": 0.9, "top_k": 50, |
| "repetition_penalty": 1.2, "no_repeat_ngram_size": 3}, |
|
|
| |
| {"name": "temp_0.5", "temperature": 0.5, "top_p": 0.9, "top_k": 50}, |
| {"name": "temp_0.7", "temperature": 0.7, "top_p": 0.9, "top_k": 50}, |
| {"name": "temp_1.0", "temperature": 1.0, "top_p": 0.9, "top_k": 50}, |
|
|
| |
| {"name": "contrastive_a0.6_k4", "contrastive_search": True, "penalty_alpha": 0.6, "contrastive_k": 4}, |
| {"name": "contrastive_a0.4_k6", "contrastive_search": True, "penalty_alpha": 0.4, "contrastive_k": 6}, |
|
|
| |
| {"name": "stop_seq", "temperature": 0.8, "top_p": 0.9, "top_k": 50, |
| "stop_strings": ["### 답변:", "### 질문:", "\n\n###"]}, |
| {"name": "rep1.2+stop", "temperature": 0.8, "top_p": 0.9, "top_k": 50, |
| "repetition_penalty": 1.2, "stop_strings": ["### 답변:", "### 질문:", "\n\n###"]}, |
|
|
| |
| {"name": "best_combo", "temperature": 0.7, "top_p": 0.9, "top_k": 50, |
| "repetition_penalty": 1.2, "no_repeat_ngram_size": 3, |
| "stop_strings": ["### 답변:", "### 질문:", "\n\n###", "<|user|>"]}, |
|
|
| |
| {"name": "sft_format_stop", "temperature": 0.7, "top_p": 0.9, "top_k": 50, |
| "repetition_penalty": 1.2, "no_repeat_ngram_size": 3, |
| "stop_strings": ["<|user|>", "</s>"]}, |
| ] |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", default="checkpoints/korean_1b_sft/checkpoint-0005000") |
| parser.add_argument("--device", default="cuda:0") |
| parser.add_argument("--max_new_tokens", type=int, default=200) |
| parser.add_argument("--output", default="eval/repetition_param_search_results.json") |
| args = parser.parse_args() |
|
|
| ckpt = Path(args.checkpoint) |
| if not ckpt.is_absolute(): |
| ckpt = _PROJECT_ROOT / ckpt |
|
|
| print(f"Loading model from {ckpt}...") |
| model = LLM.from_pretrained(str(ckpt)).to(device=args.device, dtype=torch.bfloat16) |
| model.eval() |
|
|
| tok_path = ckpt / "tokenizer.json" |
| if not tok_path.exists(): |
| tok_path = _PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json" |
| tokenizer = Tokenizer.from_file(str(tok_path)) |
|
|
| print(f"Model loaded. Params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M") |
|
|
| |
| print("\n=== Stop Sequence Token IDs ===") |
| for s in ["### 답변:", "### 질문:", "<|user|>", "<|assistant|>", "</s>", "\n\n###"]: |
| ids = tokenizer.encode(s).ids |
| print(f" '{s}' → {ids}") |
|
|
| |
| all_results = {} |
|
|
| for format_name, prompts in [("sft_format", SFT_PROMPTS), ("wrong_format", WRONG_FORMAT_PROMPTS)]: |
| print(f"\n{'='*70}") |
| print(f" Testing with {format_name}") |
| print(f"{'='*70}") |
|
|
| for params in PARAM_GRID: |
| name = params["name"] |
| key = f"{format_name}/{name}" |
| print(f"\n--- {key} ---") |
|
|
| rep_scores = [] |
| generations = [] |
| for prompt in prompts: |
| t0 = time.time() |
| text = generate_with_params( |
| model, tokenizer, prompt, params, |
| device=args.device, max_new_tokens=args.max_new_tokens, |
| ) |
| elapsed = time.time() - t0 |
| metrics = compute_all_repetition_metrics(text) |
| rep_scores.append(metrics["3gram_rep"]) |
| generations.append({ |
| "prompt": prompt[:50] + "...", |
| "generation": text[:200], |
| "3gram_rep": metrics["3gram_rep"], |
| "time": round(elapsed, 2), |
| }) |
|
|
| avg_rep = sum(rep_scores) / len(rep_scores) if rep_scores else 0 |
| print(f" Avg 3-gram repetition: {avg_rep*100:.1f}%") |
|
|
| all_results[key] = { |
| "params": {k: v for k, v in params.items() if k != "name"}, |
| "avg_3gram_rep": round(avg_rep, 4), |
| "generations": generations, |
| } |
|
|
| |
| print(f"\n{'='*70}") |
| print(" RESULTS RANKED BY REPETITION RATE") |
| print(f"{'='*70}") |
| print(f" {'Config':<35} {'Avg 3gram Rep':>15}") |
| print(f" {'-'*35} {'-'*15}") |
| for key, res in sorted(all_results.items(), key=lambda x: x[1]["avg_3gram_rep"]): |
| print(f" {key:<35} {res['avg_3gram_rep']*100:>14.1f}%") |
|
|
| |
| out_path = _PROJECT_ROOT / args.output |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(out_path, "w", encoding="utf-8") as f: |
| json.dump(all_results, f, ensure_ascii=False, indent=2) |
| print(f"\nResults saved to {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|