| | """SFT v2 comprehensive evaluation script.""" |
| | import sys, json, os |
| | sys.path.insert(0, "/PROJECT/0325120031_A/ghong/taketimes/llm-bang") |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from pathlib import Path |
| | from eval.generate import load_model_and_tokenizer, generate |
| |
|
| | CKPT = "/PROJECT/0325120031_A/ghong/taketimes/llm-bang/checkpoints/korean_1b_sft/checkpoint-best" |
| | DEVICE = "cuda:0" |
| | OUTPUT_DIR = "/PROJECT/0325120031_A/ghong/taketimes/llm-bang/eval/sft_v2_eval" |
| |
|
| | QUESTIONS = [ |
| | "ํ๊ตญ์ ์๋๋ ์ด๋์ธ๊ฐ์?", |
| | "ํ์ด์ฌ์์ ๋ฆฌ์คํธ๋ฅผ ์ ๋ ฌํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช
ํด์ฃผ์ธ์.", |
| | "์ง๊ตฌ์จ๋ํ์ ์ฃผ์ ์์ธ์ ์ค๋ช
ํ์ธ์.", |
| | "์ข์ ์๋ฉด ์ต๊ด์ ๋ง๋ค๊ธฐ ์ํ ํ์ ์๋ ค์ฃผ์ธ์.", |
| | "ํ๊ตญ ์ ํต ์์ ์ค ๊น์น์ ๋ํด ์ค๋ช
ํด์ฃผ์ธ์.", |
| | "๋จธ์ ๋ฌ๋๊ณผ ๋ฅ๋ฌ๋์ ์ฐจ์ด์ ์ ๋ฌด์์ธ๊ฐ์?", |
| | "์คํธ๋ ์ค ํด์ ๋ฐฉ๋ฒ์ ์๋ ค์ฃผ์ธ์.", |
| | "ํจ๊ณผ์ ์ธ ๊ณต๋ถ ๋ฐฉ๋ฒ์ ์ค๋ช
ํด์ฃผ์ธ์.", |
| | "์ธ๊ณต์ง๋ฅ์ ๋ฏธ๋์ ๋ํด ์ด๋ป๊ฒ ์๊ฐํ์๋์?", |
| | "๊ฑด๊ฐํ ์์ต๊ด์ ์ ์งํ๋ ๋ฐฉ๋ฒ์ ์๋ ค์ฃผ์ธ์.", |
| | ] |
| |
|
| | def calc_repetition_rate(text, n=3): |
| | tokens = list(text) |
| | if len(tokens) < n: |
| | return 0.0 |
| | ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)] |
| | if not ngrams: |
| | return 0.0 |
| | unique = set(ngrams) |
| | return 1.0 - len(unique) / len(ngrams) |
| |
|
| | def main(): |
| | print("Loading model...") |
| | model, tokenizer = load_model_and_tokenizer(CKPT, DEVICE) |
| | eos_id = tokenizer.token_to_id("</s>") |
| | |
| | |
| | user_token_id = tokenizer.token_to_id("<|user|>") |
| | |
| | results = [] |
| | |
| | print("\n=== Generation Evaluation ===\n") |
| | for i, q in enumerate(QUESTIONS): |
| | prompt = f"<|user|>\n{q}\n<|assistant|>\n" |
| | |
| | |
| | gen_tokens = [] |
| | full_text = "" |
| | stopped_eos = False |
| | |
| | |
| | input_ids = torch.tensor([tokenizer.encode(prompt).ids], dtype=torch.long, device=DEVICE) |
| | generated_ids = input_ids.clone() |
| | |
| | for step in range(200): |
| | logits_all, _ = model(generated_ids) |
| | logits = logits_all[:, -1, :].float() |
| | |
| | |
| | logits = logits / 0.7 |
| | |
| | |
| | for token_id in set(generated_ids[0].tolist()): |
| | if logits[0, token_id] > 0: |
| | logits[0, token_id] /= 1.1 |
| | else: |
| | logits[0, token_id] *= 1.1 |
| | |
| | |
| | if generated_ids.shape[1] >= 3: |
| | last_2 = tuple(generated_ids[0, -2:].tolist()) |
| | for j in range(generated_ids.shape[1] - 2): |
| | if tuple(generated_ids[0, j:j+2].tolist()) == last_2: |
| | blocked = generated_ids[0, j+2].item() |
| | logits[0, blocked] = float('-inf') |
| | |
| | |
| | probs = F.softmax(logits, dim=-1) |
| | sorted_probs, sorted_idx = torch.sort(probs, descending=True) |
| | cumsum = torch.cumsum(sorted_probs, dim=-1) |
| | mask = cumsum - sorted_probs >= 0.9 |
| | sorted_probs[mask] = 0.0 |
| | sorted_probs = sorted_probs / sorted_probs.sum() |
| | |
| | next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)] |
| | generated_ids = torch.cat([generated_ids, next_token.view(1, 1)], dim=-1) |
| | |
| | tid = next_token.item() |
| | gen_tokens.append(tid) |
| | |
| | if tid == eos_id: |
| | stopped_eos = True |
| | break |
| | if user_token_id and tid == user_token_id: |
| | stopped_eos = True |
| | break |
| | |
| | full_text = tokenizer.decode(gen_tokens) |
| | |
| | if "<|user|>" in full_text: |
| | full_text = full_text[:full_text.index("<|user|>")] |
| | full_text = full_text.replace("</s>", "").strip() |
| | |
| | rep_rate = calc_repetition_rate(full_text) |
| | |
| | result = { |
| | "question": q, |
| | "answer": full_text, |
| | "repetition_rate": rep_rate, |
| | "stopped_eos": stopped_eos, |
| | "num_tokens": len(gen_tokens), |
| | } |
| | results.append(result) |
| | |
| | print(f"[{i+1}] {q}") |
| | print(f" ๋ฐ๋ณต๋ฅ : {rep_rate*100:.1f}% | EOS: {stopped_eos} | ํ ํฐ: {len(gen_tokens)}") |
| | print(f" ๋ต๋ณ: {full_text[:100]}...") |
| | print() |
| | |
| | avg_rep = sum(r["repetition_rate"] for r in results) / len(results) * 100 |
| | eos_rate = sum(1 for r in results if r["stopped_eos"]) / len(results) * 100 |
| | |
| | print(f"\n=== ์์ฝ ===") |
| | print(f"ํ๊ท ๋ฐ๋ณต๋ฅ : {avg_rep:.1f}%") |
| | print(f"์์ฐ ์ข
๋ฃ์จ: {eos_rate:.1f}%") |
| | |
| | |
| | print("\n=== Val Loss ๊ณ์ฐ ===") |
| | val_path = Path("/PROJECT/0325120031_A/ghong/taketimes/llm-bang/data/sft/val.jsonl") |
| | if val_path.exists(): |
| | val_data = [] |
| | with open(val_path) as f: |
| | for i, line in enumerate(f): |
| | if i >= 100: |
| | break |
| | val_data.append(json.loads(line)) |
| | |
| | total_loss = 0.0 |
| | count = 0 |
| | model.eval() |
| | with torch.no_grad(): |
| | for item in val_data: |
| | |
| | if "conversations" in item: |
| | convs = item["conversations"] |
| | text = "" |
| | for c in convs: |
| | role = c.get("role", c.get("from", "")) |
| | content = c.get("content", c.get("value", "")) |
| | if role in ("user", "human"): |
| | text += f"<|user|>\n{content}\n" |
| | elif role in ("assistant", "gpt"): |
| | text += f"<|assistant|>\n{content}\n" |
| | elif "instruction" in item and "output" in item: |
| | text = f"<|user|>\n{item['instruction']}\n<|assistant|>\n{item['output']}\n" |
| | elif "text" in item: |
| | text = item["text"] |
| | else: |
| | continue |
| | |
| | ids = tokenizer.encode(text).ids |
| | if len(ids) < 2: |
| | continue |
| | ids = ids[:512] |
| | |
| | input_t = torch.tensor([ids], dtype=torch.long, device=DEVICE) |
| | logits, _ = model(input_t) |
| | |
| | |
| | loss = F.cross_entropy( |
| | logits[0, :-1].float().contiguous().view(-1, logits.shape[-1]), |
| | input_t[0, 1:].contiguous().view(-1), |
| | reduction="mean" |
| | ) |
| | total_loss += loss.item() |
| | count += 1 |
| | |
| | avg_loss = total_loss / max(count, 1) |
| | print(f"Val loss (100 samples): {avg_loss:.4f}") |
| | else: |
| | avg_loss = 2.2062 |
| | print(f"val.jsonl not found, using training val_loss: {avg_loss}") |
| | |
| | |
| | report = f"""# SFT v2 ์ฒดํฌํฌ์ธํธ ์ข
ํฉ ํ๊ฐ ๋ณด๊ณ ์ |
| | |
| | ์ฒดํฌํฌ์ธํธ: `checkpoints/korean_1b_sft/checkpoint-best` |
| | ํ๊ฐ์ผ์: 2026-02-27 |
| | |
| | ## ํต์ฌ ์งํ |
| | |
| | | ํญ๋ชฉ | Pretrain | SFT v1 (buggy ํฌ๋งท) | SFT v1 (์ฌ๋ฐ๋ฅธ ํฌ๋งท) | **SFT v2 (์ด๋ฒ)** | |
| | |------|----------|--------------------|--------------------|-----------------| |
| | | ๋ฐ๋ณต๋ฅ | 69.4% | 57.1% | 17.7% | **{avg_rep:.1f}%** | |
| | | val_loss | - | 2.69 | - | **{avg_loss:.4f}** | |
| | | ์์ฐ ์ข
๋ฃ์จ | - | - | - | **{eos_rate:.1f}%** | |
| | |
| | ## ๋ชฉํ ๋ฌ์ฑ ์ฌ๋ถ |
| | |
| | - ๋ฐ๋ณต๋ฅ <5%: {"โ
๋ฌ์ฑ" if avg_rep < 5 else "โ ๋ฏธ๋ฌ์ฑ"} ({avg_rep:.1f}%) |
| | - val_loss <2.2: {"โ
๋ฌ์ฑ" if avg_loss < 2.2 else "โ ๋ฏธ๋ฌ์ฑ"} ({avg_loss:.4f}) |
| | |
| | ## ์์ฑ ํ๋ผ๋ฏธํฐ |
| | |
| | - temperature=0.7, top_p=0.9 |
| | - repetition_penalty=1.1, no_repeat_ngram_size=3 |
| | - max_new_tokens=200 |
| | - ํ๋กฌํํธ ํฌ๋งท: `<|user|>\\n{{์ง๋ฌธ}}\\n<|assistant|>\\n` |
| | |
| | ## ์์ฑ ์ํ ์ ๋ฌธ |
| | |
| | """ |
| | for i, r in enumerate(results): |
| | report += f"""### [{i+1}] {r['question']} |
| | - ๋ฐ๋ณต๋ฅ : {r['repetition_rate']*100:.1f}% |
| | - ์ข
๋ฃ: {"EOS" if r['stopped_eos'] else "max_tokens"} |
| | - ํ ํฐ ์: {r['num_tokens']} |
| | |
| | ``` |
| | {r['answer']} |
| | ``` |
| | |
| | """ |
| | |
| | report += f"""## ๊ฐ์ ๋ ๋ถ์ |
| | |
| | - Pretrain โ SFT v2: {69.4 - avg_rep:.1f}%p ๊ฐ์ |
| | - SFT v1 (buggy) โ SFT v2: {57.1 - avg_rep:.1f}%p ๊ฐ์ |
| | - SFT v1 (์ฌ๋ฐ๋ฅธ ํฌ๋งท) โ SFT v2: {17.7 - avg_rep:.1f}%p ๊ฐ์ |
| | |
| | ## ๊ถ์ฅ ๋ค์ ๋จ๊ณ |
| | |
| | """ |
| | if avg_rep < 5: |
| | report += """- โ
๋ฐ๋ณต๋ฅ ๋ชฉํ ๋ฌ์ฑ - 1B SFT ๊ธฐ๋ณธ ์๋ฃ |
| | - ORPO/DPO ์ ํธ๋ ํ์ต์ผ๋ก ์๋ต ํ์ง ํฅ์ |
| | - 3B ๋ชจ๋ธ๋ก ์ค์ผ์ผ์
๊ณ ๋ ค |
| | - ๋ ๋ค์ํ ๋ฒค์น๋งํฌ (KoBEST, KLUE ๋ฑ) ํ๊ฐ |
| | """ |
| | elif avg_rep < 15: |
| | report += """- ๋ฐ๋ณต๋ฅ ์ด ์์ง ๋ชฉํ ๋ฏธ๋ฌ์ด์ง๋ง ์๋นํ ๊ฐ์ ๋จ |
| | - ๋ฐ์ดํฐ ๋ค์์ฑ ์ฆ๊ฐ (๋ ๋ง์ SFT ๋ฐ์ดํฐ) |
| | - repetition_penalty ์กฐ์ ์คํ |
| | - ORPO๋ก ๋ฐ๋ณต ํจํด ์ถ๊ฐ ๊ต์ ๊ฐ๋ฅ |
| | """ |
| | else: |
| | report += """- ๋ฐ๋ณต๋ฅ ์ด ์ฌ์ ํ ๋์ - ์ถ๊ฐ SFT ํ์ |
| | - ํ์ต ๋ฐ์ดํฐ ํ์ง ์ ๊ฒ |
| | - ํ์ต๋ฅ /์ํฌํฌ ์กฐ์ |
| | - ๋ฐ์ดํฐ ์ฆ๊ฐ ๊ณ ๋ ค |
| | """ |
| | |
| | with open(f"{OUTPUT_DIR}/report.md", "w") as f: |
| | f.write(report) |
| | |
| | |
| | with open(f"{OUTPUT_DIR}/results.json", "w") as f: |
| | json.dump({"results": results, "avg_rep": avg_rep, "eos_rate": eos_rate, "val_loss": avg_loss}, f, ensure_ascii=False, indent=2) |
| | |
| | print(f"\n๋ณด๊ณ ์ ์ ์ฅ: {OUTPUT_DIR}/report.md") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|