File size: 3,715 Bytes
24e2849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""RS sampling from C18-2 model (48.5% GSM8K)"""
import json, re, asyncio, random
from openai import AsyncOpenAI

SP = "์ฃผ์–ด์ง„ ์ˆ˜ํ•™ ๋ฌธ์ œ๋ฅผ ๋‹จ๊ณ„๋ณ„๋กœ ํ’€๊ณ  ๋‹ต๋ณ€์„ ์ž‘์„ฑํ•˜์„ธ์š”.\n๋ฐ˜๋“œ์‹œ ์ตœ์ข… ๋‹ต๋ณ€์„ \\boxed{์ •์ˆ˜} ํ˜•์‹์œผ๋กœ ๋งˆ์ง€๋ง‰ ์ค„์— ์ถœ๋ ฅํ•˜์„ธ์š”.\n์˜ˆ์‹œ: \\boxed{42}"

def extract_boxed(text):
    m = re.findall(r'\\boxed\{([^}]+)\}', text)
    return m[-1].strip() if m else None

def normalize(a):
    if a is None: return None
    s = str(a).replace(",","").replace(" ","").strip()
    try:
        n = float(s)
        return str(int(n)) if n == int(n) else str(n)
    except: return s

# Load GSM8K train questions (the ones we have gold answers for)
with open("data/GSM8K_full_qwen3_30b.json") as f:
    data = json.load(f)

# Get unique questions with their gold answers
q_to_gold = {}
for d in data:
    q = d["question"]
    if q not in q_to_gold:
        # Extract gold from the existing correct solutions
        gold = extract_boxed(d["answer"])
        if gold:
            q_to_gold[q] = normalize(gold)

questions = list(q_to_gold.keys())
random.seed(42)
random.shuffle(questions)
# Sample a subset for RS (use all unique questions)
print(f"Total unique questions: {len(questions)}")

client = AsyncOpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123")

async def sample_one(question, n=16):
    messages = [{"role": "user", "content": SP + "\n\n" + question}]
    try:
        resp = await client.chat.completions.create(
            model="outputs/models/c18-2-combined-rs",
            messages=messages, temperature=0.8, max_tokens=2048, n=n
        )
        return [c.message.content for c in resp.choices]
    except Exception as e:
        print(f"  Error: {e}")
        return []

async def main():
    sem = asyncio.Semaphore(100)
    sft_data = []
    dpo_data = []
    batch_size = 200

    for batch_start in range(0, len(questions), batch_size):
        batch = questions[batch_start:batch_start+batch_size]

        async def process(q):
            async with sem:
                return q, await sample_one(q)

        results = await asyncio.gather(*[process(q) for q in batch])

        for q, answers in results:
            gold = q_to_gold[q]
            correct = []
            incorrect = []
            for a in answers:
                pred = normalize(extract_boxed(a))
                if pred and pred == gold:
                    correct.append(a)
                else:
                    incorrect.append(a)

            if correct:
                best = min(correct, key=len)  # shortest correct
                sft_data.append({
                    "question": q, "answer": best,
                    "n_correct": len(correct), "n_total": len(answers)
                })

            if correct and incorrect:
                dpo_data.append({
                    "question": q,
                    "answer": min(correct, key=len),
                    "bad_answer": max(incorrect, key=len)
                })

        print(f"  Batch {batch_start//batch_size + 1}: {len(sft_data)} sft, {len(dpo_data)} dpo")

    import os
    os.makedirs("outputs/c18_rs", exist_ok=True)
    with open("outputs/c18_rs/sft_dataset.json", "w") as f:
        json.dump(sft_data, f, ensure_ascii=False, indent=2)
    with open("outputs/c18_rs/dpo_dataset.json", "w") as f:
        json.dump(dpo_data, f, ensure_ascii=False, indent=2)

    n4 = sum(1 for d in sft_data if d["n_correct"] >= 4)
    print(f"\nRS Summary:")
    print(f"  SFT: {len(sft_data)} (4+/16 filter: {n4})")
    print(f"  DPO: {len(dpo_data)} pairs")
    print(f"  Avg correct: {sum(d['n_correct'] for d in sft_data)/len(sft_data):.1f}/16")

asyncio.run(main())