pathcosmos's picture
Upload folder using huggingface_hub (#29)
5b1ff4d
"""SFT v2 comprehensive evaluation script."""
import sys, json, os
sys.path.insert(0, "/PROJECT/0325120031_A/ghong/taketimes/llm-bang")
import torch
import torch.nn.functional as F
from pathlib import Path
from eval.generate import load_model_and_tokenizer, generate
CKPT = "/PROJECT/0325120031_A/ghong/taketimes/llm-bang/checkpoints/korean_1b_sft/checkpoint-best"
DEVICE = "cuda:0"
OUTPUT_DIR = "/PROJECT/0325120031_A/ghong/taketimes/llm-bang/eval/sft_v2_eval"
QUESTIONS = [
"ํ•œ๊ตญ์˜ ์ˆ˜๋„๋Š” ์–ด๋””์ธ๊ฐ€์š”?",
"ํŒŒ์ด์ฌ์—์„œ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ •๋ ฌํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”.",
"์ง€๊ตฌ์˜จ๋‚œํ™”์˜ ์ฃผ์š” ์›์ธ์„ ์„ค๋ช…ํ•˜์„ธ์š”.",
"์ข‹์€ ์ˆ˜๋ฉด ์Šต๊ด€์„ ๋งŒ๋“ค๊ธฐ ์œ„ํ•œ ํŒ์„ ์•Œ๋ ค์ฃผ์„ธ์š”.",
"ํ•œ๊ตญ ์ „ํ†ต ์Œ์‹ ์ค‘ ๊น€์น˜์— ๋Œ€ํ•ด ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”.",
"๋จธ์‹ ๋Ÿฌ๋‹๊ณผ ๋”ฅ๋Ÿฌ๋‹์˜ ์ฐจ์ด์ ์€ ๋ฌด์—‡์ธ๊ฐ€์š”?",
"์ŠคํŠธ๋ ˆ์Šค ํ•ด์†Œ ๋ฐฉ๋ฒ•์„ ์•Œ๋ ค์ฃผ์„ธ์š”.",
"ํšจ๊ณผ์ ์ธ ๊ณต๋ถ€ ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”.",
"์ธ๊ณต์ง€๋Šฅ์˜ ๋ฏธ๋ž˜์— ๋Œ€ํ•ด ์–ด๋–ป๊ฒŒ ์ƒ๊ฐํ•˜์‹œ๋‚˜์š”?",
"๊ฑด๊ฐ•ํ•œ ์‹์Šต๊ด€์„ ์œ ์ง€ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ๋ ค์ฃผ์„ธ์š”.",
]
def calc_repetition_rate(text, n=3):
tokens = list(text)
if len(tokens) < n:
return 0.0
ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
if not ngrams:
return 0.0
unique = set(ngrams)
return 1.0 - len(unique) / len(ngrams)
def main():
print("Loading model...")
model, tokenizer = load_model_and_tokenizer(CKPT, DEVICE)
eos_id = tokenizer.token_to_id("</s>")
# Get stop token for <|user|>
user_token_id = tokenizer.token_to_id("<|user|>")
results = []
print("\n=== Generation Evaluation ===\n")
for i, q in enumerate(QUESTIONS):
prompt = f"<|user|>\n{q}\n<|assistant|>\n"
# Collect generated text
gen_tokens = []
full_text = ""
stopped_eos = False
# Use modified generation with repetition penalty and no_repeat_ngram
input_ids = torch.tensor([tokenizer.encode(prompt).ids], dtype=torch.long, device=DEVICE)
generated_ids = input_ids.clone()
for step in range(200):
logits_all, _ = model(generated_ids)
logits = logits_all[:, -1, :].float()
# Temperature
logits = logits / 0.7
# Repetition penalty
for token_id in set(generated_ids[0].tolist()):
if logits[0, token_id] > 0:
logits[0, token_id] /= 1.1
else:
logits[0, token_id] *= 1.1
# No repeat 3-gram
if generated_ids.shape[1] >= 3:
last_2 = tuple(generated_ids[0, -2:].tolist())
for j in range(generated_ids.shape[1] - 2):
if tuple(generated_ids[0, j:j+2].tolist()) == last_2:
blocked = generated_ids[0, j+2].item()
logits[0, blocked] = float('-inf')
# Top-p sampling
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumsum = torch.cumsum(sorted_probs, dim=-1)
mask = cumsum - sorted_probs >= 0.9
sorted_probs[mask] = 0.0
sorted_probs = sorted_probs / sorted_probs.sum()
next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)] # [1]
generated_ids = torch.cat([generated_ids, next_token.view(1, 1)], dim=-1)
tid = next_token.item()
gen_tokens.append(tid)
if tid == eos_id:
stopped_eos = True
break
if user_token_id and tid == user_token_id:
stopped_eos = True
break
full_text = tokenizer.decode(gen_tokens)
# Clean up
if "<|user|>" in full_text:
full_text = full_text[:full_text.index("<|user|>")]
full_text = full_text.replace("</s>", "").strip()
rep_rate = calc_repetition_rate(full_text)
result = {
"question": q,
"answer": full_text,
"repetition_rate": rep_rate,
"stopped_eos": stopped_eos,
"num_tokens": len(gen_tokens),
}
results.append(result)
print(f"[{i+1}] {q}")
print(f" ๋ฐ˜๋ณต๋ฅ : {rep_rate*100:.1f}% | EOS: {stopped_eos} | ํ† ํฐ: {len(gen_tokens)}")
print(f" ๋‹ต๋ณ€: {full_text[:100]}...")
print()
avg_rep = sum(r["repetition_rate"] for r in results) / len(results) * 100
eos_rate = sum(1 for r in results if r["stopped_eos"]) / len(results) * 100
print(f"\n=== ์š”์•ฝ ===")
print(f"ํ‰๊ท  ๋ฐ˜๋ณต๋ฅ : {avg_rep:.1f}%")
print(f"์ž์—ฐ ์ข…๋ฃŒ์œจ: {eos_rate:.1f}%")
# Val loss calculation
print("\n=== Val Loss ๊ณ„์‚ฐ ===")
val_path = Path("/PROJECT/0325120031_A/ghong/taketimes/llm-bang/data/sft/val.jsonl")
if val_path.exists():
val_data = []
with open(val_path) as f:
for i, line in enumerate(f):
if i >= 100:
break
val_data.append(json.loads(line))
total_loss = 0.0
count = 0
model.eval()
with torch.no_grad():
for item in val_data:
# Format as SFT
if "conversations" in item:
convs = item["conversations"]
text = ""
for c in convs:
role = c.get("role", c.get("from", ""))
content = c.get("content", c.get("value", ""))
if role in ("user", "human"):
text += f"<|user|>\n{content}\n"
elif role in ("assistant", "gpt"):
text += f"<|assistant|>\n{content}\n"
elif "instruction" in item and "output" in item:
text = f"<|user|>\n{item['instruction']}\n<|assistant|>\n{item['output']}\n"
elif "text" in item:
text = item["text"]
else:
continue
ids = tokenizer.encode(text).ids
if len(ids) < 2:
continue
ids = ids[:512] # truncate
input_t = torch.tensor([ids], dtype=torch.long, device=DEVICE)
logits, _ = model(input_t)
# Cross entropy on all tokens
loss = F.cross_entropy(
logits[0, :-1].float().contiguous().view(-1, logits.shape[-1]),
input_t[0, 1:].contiguous().view(-1),
reduction="mean"
)
total_loss += loss.item()
count += 1
avg_loss = total_loss / max(count, 1)
print(f"Val loss (100 samples): {avg_loss:.4f}")
else:
avg_loss = 2.2062 # from training
print(f"val.jsonl not found, using training val_loss: {avg_loss}")
# Write report
report = f"""# SFT v2 ์ฒดํฌํฌ์ธํŠธ ์ข…ํ•ฉ ํ‰๊ฐ€ ๋ณด๊ณ ์„œ
์ฒดํฌํฌ์ธํŠธ: `checkpoints/korean_1b_sft/checkpoint-best`
ํ‰๊ฐ€์ผ์‹œ: 2026-02-27
## ํ•ต์‹ฌ ์ง€ํ‘œ
| ํ•ญ๋ชฉ | Pretrain | SFT v1 (buggy ํฌ๋งท) | SFT v1 (์˜ฌ๋ฐ”๋ฅธ ํฌ๋งท) | **SFT v2 (์ด๋ฒˆ)** |
|------|----------|--------------------|--------------------|-----------------|
| ๋ฐ˜๋ณต๋ฅ  | 69.4% | 57.1% | 17.7% | **{avg_rep:.1f}%** |
| val_loss | - | 2.69 | - | **{avg_loss:.4f}** |
| ์ž์—ฐ ์ข…๋ฃŒ์œจ | - | - | - | **{eos_rate:.1f}%** |
## ๋ชฉํ‘œ ๋‹ฌ์„ฑ ์—ฌ๋ถ€
- ๋ฐ˜๋ณต๋ฅ  <5%: {"โœ… ๋‹ฌ์„ฑ" if avg_rep < 5 else "โŒ ๋ฏธ๋‹ฌ์„ฑ"} ({avg_rep:.1f}%)
- val_loss <2.2: {"โœ… ๋‹ฌ์„ฑ" if avg_loss < 2.2 else "โŒ ๋ฏธ๋‹ฌ์„ฑ"} ({avg_loss:.4f})
## ์ƒ์„ฑ ํŒŒ๋ผ๋ฏธํ„ฐ
- temperature=0.7, top_p=0.9
- repetition_penalty=1.1, no_repeat_ngram_size=3
- max_new_tokens=200
- ํ”„๋กฌํ”„ํŠธ ํฌ๋งท: `<|user|>\\n{{์งˆ๋ฌธ}}\\n<|assistant|>\\n`
## ์ƒ์„ฑ ์ƒ˜ํ”Œ ์ „๋ฌธ
"""
for i, r in enumerate(results):
report += f"""### [{i+1}] {r['question']}
- ๋ฐ˜๋ณต๋ฅ : {r['repetition_rate']*100:.1f}%
- ์ข…๋ฃŒ: {"EOS" if r['stopped_eos'] else "max_tokens"}
- ํ† ํฐ ์ˆ˜: {r['num_tokens']}
```
{r['answer']}
```
"""
report += f"""## ๊ฐœ์„ ๋„ ๋ถ„์„
- Pretrain โ†’ SFT v2: {69.4 - avg_rep:.1f}%p ๊ฐœ์„ 
- SFT v1 (buggy) โ†’ SFT v2: {57.1 - avg_rep:.1f}%p ๊ฐœ์„ 
- SFT v1 (์˜ฌ๋ฐ”๋ฅธ ํฌ๋งท) โ†’ SFT v2: {17.7 - avg_rep:.1f}%p ๊ฐœ์„ 
## ๊ถŒ์žฅ ๋‹ค์Œ ๋‹จ๊ณ„
"""
if avg_rep < 5:
report += """- โœ… ๋ฐ˜๋ณต๋ฅ  ๋ชฉํ‘œ ๋‹ฌ์„ฑ - 1B SFT ๊ธฐ๋ณธ ์™„๋ฃŒ
- ORPO/DPO ์„ ํ˜ธ๋„ ํ•™์Šต์œผ๋กœ ์‘๋‹ต ํ’ˆ์งˆ ํ–ฅ์ƒ
- 3B ๋ชจ๋ธ๋กœ ์Šค์ผ€์ผ์—… ๊ณ ๋ ค
- ๋” ๋‹ค์–‘ํ•œ ๋ฒค์น˜๋งˆํฌ (KoBEST, KLUE ๋“ฑ) ํ‰๊ฐ€
"""
elif avg_rep < 15:
report += """- ๋ฐ˜๋ณต๋ฅ ์ด ์•„์ง ๋ชฉํ‘œ ๋ฏธ๋‹ฌ์ด์ง€๋งŒ ์ƒ๋‹นํžˆ ๊ฐœ์„ ๋จ
- ๋ฐ์ดํ„ฐ ๋‹ค์–‘์„ฑ ์ฆ๊ฐ€ (๋” ๋งŽ์€ SFT ๋ฐ์ดํ„ฐ)
- repetition_penalty ์กฐ์ • ์‹คํ—˜
- ORPO๋กœ ๋ฐ˜๋ณต ํŒจํ„ด ์ถ”๊ฐ€ ๊ต์ • ๊ฐ€๋Šฅ
"""
else:
report += """- ๋ฐ˜๋ณต๋ฅ ์ด ์—ฌ์ „ํžˆ ๋†’์Œ - ์ถ”๊ฐ€ SFT ํ•„์š”
- ํ•™์Šต ๋ฐ์ดํ„ฐ ํ’ˆ์งˆ ์ ๊ฒ€
- ํ•™์Šต๋ฅ /์—ํฌํฌ ์กฐ์ •
- ๋ฐ์ดํ„ฐ ์ฆ๊ฐ• ๊ณ ๋ ค
"""
with open(f"{OUTPUT_DIR}/report.md", "w") as f:
f.write(report)
# Save raw results
with open(f"{OUTPUT_DIR}/results.json", "w") as f:
json.dump({"results": results, "avg_rep": avg_rep, "eos_rate": eos_rate, "val_loss": avg_loss}, f, ensure_ascii=False, indent=2)
print(f"\n๋ณด๊ณ ์„œ ์ €์žฅ: {OUTPUT_DIR}/report.md")
if __name__ == "__main__":
main()