File size: 9,908 Bytes
15fa009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""

vLLM Batch Inference for Variance-based One-Shot Question Selection

====================================================================

用 vLLM 做批量推理,比 HuggingFace generate 快 5-10x。

对 SFT 模型在 1533 道开放题上各跑 10 次推理,

用精确字符串匹配判断对错,选出方差最大的题目。



用法 (在 Docker 容器中):

  CUDA_VISIBLE_DEVICES=0 python3 -u variance_select_vllm.py

"""

import json
import re
import os
import numpy as np
from pathlib import Path

# ============ 配置 ============
MODEL_PATH = "/workspace/rl4phyx/RL4Phyx/SFT/checkpoints/sft_qwen25vl_3b/merged"
TEST_FILE = "/workspace/rl4phyx/RL4Phyx/SFT/sft_test.jsonl"
IMAGE_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/images"
OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT"

NUM_RUNS = 10          # 每题推理次数
MAX_NEW_TOKENS = 3072  # 最大生成长度
TEMPERATURE = 0.7      # 采样温度
TOP_P = 0.9
RLVR_COPIES = 128      # one-shot 复制次数


def extract_boxed(text):
    """从模型输出中提取 \\boxed{} 内的答案"""
    matches = re.findall(r'\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}', text)
    if matches:
        return matches[-1].strip()
    return None


def normalize_answer(ans):
    """轻度归一化:去空格、去末尾句号"""
    if ans is None:
        return None
    ans = ans.strip()
    ans = ans.rstrip('.')
    ans = re.sub(r'\\text\{([^}]*)\}', r'\1', ans)
    ans = re.sub(r'\s+', ' ', ans)
    return ans


def exact_match(pred_answer, gt_answer):
    """精确字符串匹配"""
    if pred_answer is None:
        return False
    pred_norm = normalize_answer(pred_answer)
    gt_norm = normalize_answer(gt_answer)
    if pred_norm is None or gt_norm is None:
        return False
    return pred_norm == gt_norm


def load_test_data(test_file):
    """加载测试数据"""
    data = []
    with open(test_file, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    print(f"Loaded {len(data)} test samples")
    return data


def build_vllm_inputs(test_data, processor):
    """构建 vLLM 批量推理的输入"""
    from PIL import Image
    
    prompts = []
    multi_modal_data_list = []
    
    for i, item in enumerate(test_data):
        prompt_text = item['prompt']
        image_path = os.path.join(IMAGE_DIR, f"{item['index']}.png")
        
        # 构建聊天消息
        messages = [{"role": "user", "content": []}]
        
        has_image = os.path.exists(image_path)
        if has_image:
            messages[0]["content"].append({
                "type": "image",
                "image": f"file://{image_path}"
            })
        
        messages[0]["content"].append({
            "type": "text",
            "text": prompt_text
        })
        
        # 用 processor 生成 prompt 文本
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        
        # 加载图片
        mm_data = {}
        if has_image:
            img = Image.open(image_path).convert("RGB")
            mm_data["image"] = img
        
        prompts.append(text)
        multi_modal_data_list.append(mm_data if mm_data else None)
        
        if (i + 1) % 200 == 0:
            print(f"  Prepared {i+1}/{len(test_data)} inputs...")
    
    return prompts, multi_modal_data_list


def run_vllm_batch(llm, prompts, multi_modal_data_list, sampling_params, run_id):
    """用 vLLM 跑一次批量推理"""
    from vllm import TextPrompt
    
    print(f"\n{'='*60}")
    print(f"  Run {run_id + 1}/{NUM_RUNS}")
    print(f"{'='*60}")
    
    # 构建 vLLM 请求
    requests = []
    for i, (prompt, mm_data) in enumerate(zip(prompts, multi_modal_data_list)):
        req = {
            "prompt": prompt,
        }
        if mm_data:
            req["multi_modal_data"] = mm_data
        requests.append(req)
    
    # 批量推理
    print(f"  Submitting {len(requests)} requests to vLLM...")
    outputs = llm.generate(requests, sampling_params)
    print(f"  Got {len(outputs)} outputs")
    
    return outputs


def main():
    print("=" * 60)
    print("  vLLM Batch Variance Selection")
    print("=" * 60)
    
    # 1. 加载 vLLM 模型
    print("\n[1/5] Loading vLLM model...")
    from vllm import LLM, SamplingParams
    from transformers import AutoProcessor
    
    llm = LLM(
        model=MODEL_PATH,
        dtype="bfloat16",
        trust_remote_code=True,
        max_model_len=4096,
        gpu_memory_utilization=0.85,
        limit_mm_per_prompt={"image": 1},
    )
    
    sampling_params = SamplingParams(
        temperature=TEMPERATURE,
        top_p=TOP_P,
        max_tokens=MAX_NEW_TOKENS,
    )
    
    processor = AutoProcessor.from_pretrained(MODEL_PATH)
    print("  Model loaded!")
    
    # 2. 加载测试数据
    print("\n[2/5] Loading test data...")
    test_data = load_test_data(TEST_FILE)
    
    # 3. 预处理输入
    print("\n[3/5] Building vLLM inputs...")
    prompts, multi_modal_data_list = build_vllm_inputs(test_data, processor)
    print(f"  Built {len(prompts)} prompts")
    
    # 4. 跑 10 次推理
    print(f"\n[4/5] Running {NUM_RUNS} inference passes...")
    all_runs = []
    
    for run_id in range(NUM_RUNS):
        outputs = run_vllm_batch(llm, prompts, multi_modal_data_list, sampling_params, run_id)
        
        # 解析结果
        run_results = []
        correct_count = 0
        for i, output in enumerate(outputs):
            prediction = output.outputs[0].text
            pred_answer = extract_boxed(prediction)
            gt_answer = test_data[i]['ground_truth']
            is_correct = exact_match(pred_answer, gt_answer)
            if is_correct:
                correct_count += 1
            
            run_results.append({
                'index': test_data[i]['index'],
                'pred_answer': pred_answer,
                'gt_answer': gt_answer,
                'correct': is_correct,
            })
        
        all_runs.append(run_results)
        print(f"  Run {run_id+1} accuracy: {correct_count}/{len(test_data)} "
              f"({100*correct_count/len(test_data):.1f}%)")
        
        # 保存中间结果
        interim_path = os.path.join(OUTPUT_DIR, f"variance_run_{run_id}.json")
        with open(interim_path, 'w', encoding='utf-8') as f:
            json.dump(run_results, f, ensure_ascii=False, indent=2)
        print(f"  Saved to {interim_path}")
    
    # 5. 计算方差并选题
    print(f"\n[5/5] Computing variance...")
    n_questions = len(test_data)
    stats = []
    
    for qi in range(n_questions):
        correct_flags = [all_runs[r][qi]['correct'] for r in range(NUM_RUNS)]
        n_correct = sum(correct_flags)
        p = n_correct / NUM_RUNS
        variance = p * (1 - p)
        
        stats.append({
            'index': test_data[qi]['index'],
            'category': test_data[qi].get('category', ''),
            'ground_truth': test_data[qi]['ground_truth'],
            'n_correct': n_correct,
            'accuracy': p,
            'variance': variance,
            'correct_flags': correct_flags,
            'pred_answers': [all_runs[r][qi]['pred_answer'] for r in range(NUM_RUNS)],
        })
    
    stats.sort(key=lambda x: x['variance'], reverse=True)
    
    # 保存完整统计
    stats_path = os.path.join(OUTPUT_DIR, "variance_results.json")
    with open(stats_path, 'w', encoding='utf-8') as f:
        json.dump(stats, f, ensure_ascii=False, indent=2)
    print(f"  Saved variance stats to {stats_path}")
    
    # Top 20
    print(f"\n{'='*60}")
    print(f"  TOP 20 HIGHEST VARIANCE QUESTIONS")
    print(f"{'='*60}")
    for i, s in enumerate(stats[:20]):
        print(f"  #{i+1}: idx={s['index']} | gt={s['ground_truth'][:30]:30s} | "
              f"correct={s['n_correct']}/{NUM_RUNS} | var={s['variance']:.4f} | "
              f"cat={s['category']}")
    
    # 选方差最大的题
    best = stats[0]
    print(f"\n{'='*60}")
    print(f"  SELECTED QUESTION FOR ONE-SHOT RLVR")
    print(f"{'='*60}")
    print(f"  Index:       {best['index']}")
    print(f"  Category:    {best['category']}")
    print(f"  Ground Truth: {best['ground_truth']}")
    print(f"  Accuracy:    {best['n_correct']}/{NUM_RUNS} ({best['accuracy']*100:.0f}%)")
    print(f"  Variance:    {best['variance']:.4f}")
    
    # 保存选中题目
    best_idx = int(best['index'])
    best_item = next(item for item in test_data if int(item['index']) == best_idx)
    
    best_path = os.path.join(OUTPUT_DIR, "best_question_for_rlvr.json")
    with open(best_path, 'w', encoding='utf-8') as f:
        json.dump({'selected_question': best_item, 'stats': best}, f, ensure_ascii=False, indent=2)
    print(f"  Saved to {best_path}")
    
    # 转训练 parquet
    import pandas as pd
    prompt_messages = [{"role": "user", "content": best_item['prompt']}]
    records = [{
        'prompt': prompt_messages,
        'answer': best_item['ground_truth'],
        'image_path': f"{best_item['index']}.png",
        'data_source': 'deepscaler',
        'category': best_item.get('category', 'Physics'),
        'index': best_item['index'],
    } for _ in range(RLVR_COPIES)]
    
    df = pd.DataFrame(records)
    parquet_path = os.path.join(OUTPUT_DIR, "rlvr_train.parquet")
    df.to_parquet(parquet_path, index=False)
    print(f"  Saved training parquet ({len(df)} rows) to {parquet_path}")
    
    print(f"\n{'='*60}")
    print(f"  DONE!")
    print(f"{'='*60}")


if __name__ == "__main__":
    main()