| """RS sampling from C18-2 model (48.5% GSM8K)""" |
| import json, re, asyncio, random |
| from openai import AsyncOpenAI |
|
|
| SP = "์ฃผ์ด์ง ์ํ ๋ฌธ์ ๋ฅผ ๋จ๊ณ๋ณ๋ก ํ๊ณ ๋ต๋ณ์ ์์ฑํ์ธ์.\n๋ฐ๋์ ์ต์ข
๋ต๋ณ์ \\boxed{์ ์} ํ์์ผ๋ก ๋ง์ง๋ง ์ค์ ์ถ๋ ฅํ์ธ์.\n์์: \\boxed{42}" |
|
|
| def extract_boxed(text): |
| m = re.findall(r'\\boxed\{([^}]+)\}', text) |
| return m[-1].strip() if m else None |
|
|
| def normalize(a): |
| if a is None: return None |
| s = str(a).replace(",","").replace(" ","").strip() |
| try: |
| n = float(s) |
| return str(int(n)) if n == int(n) else str(n) |
| except: return s |
|
|
| |
| with open("data/GSM8K_full_qwen3_30b.json") as f: |
| data = json.load(f) |
|
|
| |
| q_to_gold = {} |
| for d in data: |
| q = d["question"] |
| if q not in q_to_gold: |
| |
| gold = extract_boxed(d["answer"]) |
| if gold: |
| q_to_gold[q] = normalize(gold) |
|
|
| questions = list(q_to_gold.keys()) |
| random.seed(42) |
| random.shuffle(questions) |
| |
| print(f"Total unique questions: {len(questions)}") |
|
|
| client = AsyncOpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123") |
|
|
| async def sample_one(question, n=16): |
| messages = [{"role": "user", "content": SP + "\n\n" + question}] |
| try: |
| resp = await client.chat.completions.create( |
| model="outputs/models/c18-2-combined-rs", |
| messages=messages, temperature=0.8, max_tokens=2048, n=n |
| ) |
| return [c.message.content for c in resp.choices] |
| except Exception as e: |
| print(f" Error: {e}") |
| return [] |
|
|
| async def main(): |
| sem = asyncio.Semaphore(100) |
| sft_data = [] |
| dpo_data = [] |
| batch_size = 200 |
|
|
| for batch_start in range(0, len(questions), batch_size): |
| batch = questions[batch_start:batch_start+batch_size] |
|
|
| async def process(q): |
| async with sem: |
| return q, await sample_one(q) |
|
|
| results = await asyncio.gather(*[process(q) for q in batch]) |
|
|
| for q, answers in results: |
| gold = q_to_gold[q] |
| correct = [] |
| incorrect = [] |
| for a in answers: |
| pred = normalize(extract_boxed(a)) |
| if pred and pred == gold: |
| correct.append(a) |
| else: |
| incorrect.append(a) |
|
|
| if correct: |
| best = min(correct, key=len) |
| sft_data.append({ |
| "question": q, "answer": best, |
| "n_correct": len(correct), "n_total": len(answers) |
| }) |
|
|
| if correct and incorrect: |
| dpo_data.append({ |
| "question": q, |
| "answer": min(correct, key=len), |
| "bad_answer": max(incorrect, key=len) |
| }) |
|
|
| print(f" Batch {batch_start//batch_size + 1}: {len(sft_data)} sft, {len(dpo_data)} dpo") |
|
|
| import os |
| os.makedirs("outputs/c18_rs", exist_ok=True) |
| with open("outputs/c18_rs/sft_dataset.json", "w") as f: |
| json.dump(sft_data, f, ensure_ascii=False, indent=2) |
| with open("outputs/c18_rs/dpo_dataset.json", "w") as f: |
| json.dump(dpo_data, f, ensure_ascii=False, indent=2) |
|
|
| n4 = sum(1 for d in sft_data if d["n_correct"] >= 4) |
| print(f"\nRS Summary:") |
| print(f" SFT: {len(sft_data)} (4+/16 filter: {n4})") |
| print(f" DPO: {len(dpo_data)} pairs") |
| print(f" Avg correct: {sum(d['n_correct'] for d in sft_data)/len(sft_data):.1f}/16") |
|
|
| asyncio.run(main()) |
|
|