"""RS sampling from C18-2 model (48.5% GSM8K)""" import json, re, asyncio, random from openai import AsyncOpenAI SP = "주어진 수학 문제를 단계별로 풀고 답변을 작성하세요.\n반드시 최종 답변을 \\boxed{정수} 형식으로 마지막 줄에 출력하세요.\n예시: \\boxed{42}" def extract_boxed(text): m = re.findall(r'\\boxed\{([^}]+)\}', text) return m[-1].strip() if m else None def normalize(a): if a is None: return None s = str(a).replace(",","").replace(" ","").strip() try: n = float(s) return str(int(n)) if n == int(n) else str(n) except: return s # Load GSM8K train questions (the ones we have gold answers for) with open("data/GSM8K_full_qwen3_30b.json") as f: data = json.load(f) # Get unique questions with their gold answers q_to_gold = {} for d in data: q = d["question"] if q not in q_to_gold: # Extract gold from the existing correct solutions gold = extract_boxed(d["answer"]) if gold: q_to_gold[q] = normalize(gold) questions = list(q_to_gold.keys()) random.seed(42) random.shuffle(questions) # Sample a subset for RS (use all unique questions) print(f"Total unique questions: {len(questions)}") client = AsyncOpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123") async def sample_one(question, n=16): messages = [{"role": "user", "content": SP + "\n\n" + question}] try: resp = await client.chat.completions.create( model="outputs/models/c18-2-combined-rs", messages=messages, temperature=0.8, max_tokens=2048, n=n ) return [c.message.content for c in resp.choices] except Exception as e: print(f" Error: {e}") return [] async def main(): sem = asyncio.Semaphore(100) sft_data = [] dpo_data = [] batch_size = 200 for batch_start in range(0, len(questions), batch_size): batch = questions[batch_start:batch_start+batch_size] async def process(q): async with sem: return q, await sample_one(q) results = await asyncio.gather(*[process(q) for q in batch]) for q, answers in results: gold = q_to_gold[q] correct = [] incorrect = [] for a in answers: pred = normalize(extract_boxed(a)) if pred and pred == gold: correct.append(a) else: incorrect.append(a) if correct: best = min(correct, key=len) # shortest correct sft_data.append({ "question": q, "answer": best, "n_correct": len(correct), "n_total": len(answers) }) if correct and incorrect: dpo_data.append({ "question": q, "answer": min(correct, key=len), "bad_answer": max(incorrect, key=len) }) print(f" Batch {batch_start//batch_size + 1}: {len(sft_data)} sft, {len(dpo_data)} dpo") import os os.makedirs("outputs/c18_rs", exist_ok=True) with open("outputs/c18_rs/sft_dataset.json", "w") as f: json.dump(sft_data, f, ensure_ascii=False, indent=2) with open("outputs/c18_rs/dpo_dataset.json", "w") as f: json.dump(dpo_data, f, ensure_ascii=False, indent=2) n4 = sum(1 for d in sft_data if d["n_correct"] >= 4) print(f"\nRS Summary:") print(f" SFT: {len(sft_data)} (4+/16 filter: {n4})") print(f" DPO: {len(dpo_data)} pairs") print(f" Avg correct: {sum(d['n_correct'] for d in sft_data)/len(sft_data):.1f}/16") asyncio.run(main())