| """SFT v2 comprehensive evaluation script.""" |
| import sys, json, os |
| sys.path.insert(0, "/PROJECT/0325120031_A/ghong/taketimes/llm-bang") |
|
|
| import torch |
| import torch.nn.functional as F |
| from pathlib import Path |
| from eval.generate import load_model_and_tokenizer, generate |
|
|
| CKPT = "/PROJECT/0325120031_A/ghong/taketimes/llm-bang/checkpoints/korean_1b_sft/checkpoint-best" |
| DEVICE = "cuda:0" |
| OUTPUT_DIR = "/PROJECT/0325120031_A/ghong/taketimes/llm-bang/eval/sft_v2_eval" |
|
|
| QUESTIONS = [ |
| "ํ๊ตญ์ ์๋๋ ์ด๋์ธ๊ฐ์?", |
| "ํ์ด์ฌ์์ ๋ฆฌ์คํธ๋ฅผ ์ ๋ ฌํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช
ํด์ฃผ์ธ์.", |
| "์ง๊ตฌ์จ๋ํ์ ์ฃผ์ ์์ธ์ ์ค๋ช
ํ์ธ์.", |
| "์ข์ ์๋ฉด ์ต๊ด์ ๋ง๋ค๊ธฐ ์ํ ํ์ ์๋ ค์ฃผ์ธ์.", |
| "ํ๊ตญ ์ ํต ์์ ์ค ๊น์น์ ๋ํด ์ค๋ช
ํด์ฃผ์ธ์.", |
| "๋จธ์ ๋ฌ๋๊ณผ ๋ฅ๋ฌ๋์ ์ฐจ์ด์ ์ ๋ฌด์์ธ๊ฐ์?", |
| "์คํธ๋ ์ค ํด์ ๋ฐฉ๋ฒ์ ์๋ ค์ฃผ์ธ์.", |
| "ํจ๊ณผ์ ์ธ ๊ณต๋ถ ๋ฐฉ๋ฒ์ ์ค๋ช
ํด์ฃผ์ธ์.", |
| "์ธ๊ณต์ง๋ฅ์ ๋ฏธ๋์ ๋ํด ์ด๋ป๊ฒ ์๊ฐํ์๋์?", |
| "๊ฑด๊ฐํ ์์ต๊ด์ ์ ์งํ๋ ๋ฐฉ๋ฒ์ ์๋ ค์ฃผ์ธ์.", |
| ] |
|
|
| def calc_repetition_rate(text, n=3): |
| tokens = list(text) |
| if len(tokens) < n: |
| return 0.0 |
| ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)] |
| if not ngrams: |
| return 0.0 |
| unique = set(ngrams) |
| return 1.0 - len(unique) / len(ngrams) |
|
|
| def main(): |
| print("Loading model...") |
| model, tokenizer = load_model_and_tokenizer(CKPT, DEVICE) |
| eos_id = tokenizer.token_to_id("</s>") |
| |
| |
| user_token_id = tokenizer.token_to_id("<|user|>") |
| |
| results = [] |
| |
| print("\n=== Generation Evaluation ===\n") |
| for i, q in enumerate(QUESTIONS): |
| prompt = f"<|user|>\n{q}\n<|assistant|>\n" |
| |
| |
| gen_tokens = [] |
| full_text = "" |
| stopped_eos = False |
| |
| |
| input_ids = torch.tensor([tokenizer.encode(prompt).ids], dtype=torch.long, device=DEVICE) |
| generated_ids = input_ids.clone() |
| |
| for step in range(200): |
| logits_all, _ = model(generated_ids) |
| logits = logits_all[:, -1, :].float() |
| |
| |
| logits = logits / 0.7 |
| |
| |
| for token_id in set(generated_ids[0].tolist()): |
| if logits[0, token_id] > 0: |
| logits[0, token_id] /= 1.1 |
| else: |
| logits[0, token_id] *= 1.1 |
| |
| |
| if generated_ids.shape[1] >= 3: |
| last_2 = tuple(generated_ids[0, -2:].tolist()) |
| for j in range(generated_ids.shape[1] - 2): |
| if tuple(generated_ids[0, j:j+2].tolist()) == last_2: |
| blocked = generated_ids[0, j+2].item() |
| logits[0, blocked] = float('-inf') |
| |
| |
| probs = F.softmax(logits, dim=-1) |
| sorted_probs, sorted_idx = torch.sort(probs, descending=True) |
| cumsum = torch.cumsum(sorted_probs, dim=-1) |
| mask = cumsum - sorted_probs >= 0.9 |
| sorted_probs[mask] = 0.0 |
| sorted_probs = sorted_probs / sorted_probs.sum() |
| |
| next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)] |
| generated_ids = torch.cat([generated_ids, next_token.view(1, 1)], dim=-1) |
| |
| tid = next_token.item() |
| gen_tokens.append(tid) |
| |
| if tid == eos_id: |
| stopped_eos = True |
| break |
| if user_token_id and tid == user_token_id: |
| stopped_eos = True |
| break |
| |
| full_text = tokenizer.decode(gen_tokens) |
| |
| if "<|user|>" in full_text: |
| full_text = full_text[:full_text.index("<|user|>")] |
| full_text = full_text.replace("</s>", "").strip() |
| |
| rep_rate = calc_repetition_rate(full_text) |
| |
| result = { |
| "question": q, |
| "answer": full_text, |
| "repetition_rate": rep_rate, |
| "stopped_eos": stopped_eos, |
| "num_tokens": len(gen_tokens), |
| } |
| results.append(result) |
| |
| print(f"[{i+1}] {q}") |
| print(f" ๋ฐ๋ณต๋ฅ : {rep_rate*100:.1f}% | EOS: {stopped_eos} | ํ ํฐ: {len(gen_tokens)}") |
| print(f" ๋ต๋ณ: {full_text[:100]}...") |
| print() |
| |
| avg_rep = sum(r["repetition_rate"] for r in results) / len(results) * 100 |
| eos_rate = sum(1 for r in results if r["stopped_eos"]) / len(results) * 100 |
| |
| print(f"\n=== ์์ฝ ===") |
| print(f"ํ๊ท ๋ฐ๋ณต๋ฅ : {avg_rep:.1f}%") |
| print(f"์์ฐ ์ข
๋ฃ์จ: {eos_rate:.1f}%") |
| |
| |
| print("\n=== Val Loss ๊ณ์ฐ ===") |
| val_path = Path("/PROJECT/0325120031_A/ghong/taketimes/llm-bang/data/sft/val.jsonl") |
| if val_path.exists(): |
| val_data = [] |
| with open(val_path) as f: |
| for i, line in enumerate(f): |
| if i >= 100: |
| break |
| val_data.append(json.loads(line)) |
| |
| total_loss = 0.0 |
| count = 0 |
| model.eval() |
| with torch.no_grad(): |
| for item in val_data: |
| |
| if "conversations" in item: |
| convs = item["conversations"] |
| text = "" |
| for c in convs: |
| role = c.get("role", c.get("from", "")) |
| content = c.get("content", c.get("value", "")) |
| if role in ("user", "human"): |
| text += f"<|user|>\n{content}\n" |
| elif role in ("assistant", "gpt"): |
| text += f"<|assistant|>\n{content}\n" |
| elif "instruction" in item and "output" in item: |
| text = f"<|user|>\n{item['instruction']}\n<|assistant|>\n{item['output']}\n" |
| elif "text" in item: |
| text = item["text"] |
| else: |
| continue |
| |
| ids = tokenizer.encode(text).ids |
| if len(ids) < 2: |
| continue |
| ids = ids[:512] |
| |
| input_t = torch.tensor([ids], dtype=torch.long, device=DEVICE) |
| logits, _ = model(input_t) |
| |
| |
| loss = F.cross_entropy( |
| logits[0, :-1].float().contiguous().view(-1, logits.shape[-1]), |
| input_t[0, 1:].contiguous().view(-1), |
| reduction="mean" |
| ) |
| total_loss += loss.item() |
| count += 1 |
| |
| avg_loss = total_loss / max(count, 1) |
| print(f"Val loss (100 samples): {avg_loss:.4f}") |
| else: |
| avg_loss = 2.2062 |
| print(f"val.jsonl not found, using training val_loss: {avg_loss}") |
| |
| |
| report = f"""# SFT v2 ์ฒดํฌํฌ์ธํธ ์ข
ํฉ ํ๊ฐ ๋ณด๊ณ ์ |
| |
| ์ฒดํฌํฌ์ธํธ: `checkpoints/korean_1b_sft/checkpoint-best` |
| ํ๊ฐ์ผ์: 2026-02-27 |
| |
| ## ํต์ฌ ์งํ |
| |
| | ํญ๋ชฉ | Pretrain | SFT v1 (buggy ํฌ๋งท) | SFT v1 (์ฌ๋ฐ๋ฅธ ํฌ๋งท) | **SFT v2 (์ด๋ฒ)** | |
| |------|----------|--------------------|--------------------|-----------------| |
| | ๋ฐ๋ณต๋ฅ | 69.4% | 57.1% | 17.7% | **{avg_rep:.1f}%** | |
| | val_loss | - | 2.69 | - | **{avg_loss:.4f}** | |
| | ์์ฐ ์ข
๋ฃ์จ | - | - | - | **{eos_rate:.1f}%** | |
| |
| ## ๋ชฉํ ๋ฌ์ฑ ์ฌ๋ถ |
| |
| - ๋ฐ๋ณต๋ฅ <5%: {"โ
๋ฌ์ฑ" if avg_rep < 5 else "โ ๋ฏธ๋ฌ์ฑ"} ({avg_rep:.1f}%) |
| - val_loss <2.2: {"โ
๋ฌ์ฑ" if avg_loss < 2.2 else "โ ๋ฏธ๋ฌ์ฑ"} ({avg_loss:.4f}) |
| |
| ## ์์ฑ ํ๋ผ๋ฏธํฐ |
| |
| - temperature=0.7, top_p=0.9 |
| - repetition_penalty=1.1, no_repeat_ngram_size=3 |
| - max_new_tokens=200 |
| - ํ๋กฌํํธ ํฌ๋งท: `<|user|>\\n{{์ง๋ฌธ}}\\n<|assistant|>\\n` |
| |
| ## ์์ฑ ์ํ ์ ๋ฌธ |
| |
| """ |
| for i, r in enumerate(results): |
| report += f"""### [{i+1}] {r['question']} |
| - ๋ฐ๋ณต๋ฅ : {r['repetition_rate']*100:.1f}% |
| - ์ข
๋ฃ: {"EOS" if r['stopped_eos'] else "max_tokens"} |
| - ํ ํฐ ์: {r['num_tokens']} |
| |
| ``` |
| {r['answer']} |
| ``` |
| |
| """ |
| |
| report += f"""## ๊ฐ์ ๋ ๋ถ์ |
| |
| - Pretrain โ SFT v2: {69.4 - avg_rep:.1f}%p ๊ฐ์ |
| - SFT v1 (buggy) โ SFT v2: {57.1 - avg_rep:.1f}%p ๊ฐ์ |
| - SFT v1 (์ฌ๋ฐ๋ฅธ ํฌ๋งท) โ SFT v2: {17.7 - avg_rep:.1f}%p ๊ฐ์ |
| |
| ## ๊ถ์ฅ ๋ค์ ๋จ๊ณ |
| |
| """ |
| if avg_rep < 5: |
| report += """- โ
๋ฐ๋ณต๋ฅ ๋ชฉํ ๋ฌ์ฑ - 1B SFT ๊ธฐ๋ณธ ์๋ฃ |
| - ORPO/DPO ์ ํธ๋ ํ์ต์ผ๋ก ์๋ต ํ์ง ํฅ์ |
| - 3B ๋ชจ๋ธ๋ก ์ค์ผ์ผ์
๊ณ ๋ ค |
| - ๋ ๋ค์ํ ๋ฒค์น๋งํฌ (KoBEST, KLUE ๋ฑ) ํ๊ฐ |
| """ |
| elif avg_rep < 15: |
| report += """- ๋ฐ๋ณต๋ฅ ์ด ์์ง ๋ชฉํ ๋ฏธ๋ฌ์ด์ง๋ง ์๋นํ ๊ฐ์ ๋จ |
| - ๋ฐ์ดํฐ ๋ค์์ฑ ์ฆ๊ฐ (๋ ๋ง์ SFT ๋ฐ์ดํฐ) |
| - repetition_penalty ์กฐ์ ์คํ |
| - ORPO๋ก ๋ฐ๋ณต ํจํด ์ถ๊ฐ ๊ต์ ๊ฐ๋ฅ |
| """ |
| else: |
| report += """- ๋ฐ๋ณต๋ฅ ์ด ์ฌ์ ํ ๋์ - ์ถ๊ฐ SFT ํ์ |
| - ํ์ต ๋ฐ์ดํฐ ํ์ง ์ ๊ฒ |
| - ํ์ต๋ฅ /์ํฌํฌ ์กฐ์ |
| - ๋ฐ์ดํฐ ์ฆ๊ฐ ๊ณ ๋ ค |
| """ |
| |
| with open(f"{OUTPUT_DIR}/report.md", "w") as f: |
| f.write(report) |
| |
| |
| with open(f"{OUTPUT_DIR}/results.json", "w") as f: |
| json.dump({"results": results, "avg_rep": avg_rep, "eos_rate": eos_rate, "val_loss": avg_loss}, f, ensure_ascii=False, indent=2) |
| |
| print(f"\n๋ณด๊ณ ์ ์ ์ฅ: {OUTPUT_DIR}/report.md") |
|
|
| if __name__ == "__main__": |
| main() |
|
|