YUNTA88 commited on
Commit
d491800
·
verified ·
1 Parent(s): 3eee49d

Upload scripts/variance_select.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/variance_select.py +312 -0
scripts/variance_select.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Variance-based One-Shot Question Selection
3
+ ===========================================
4
+ 对 SFT 模型在 1533 道开放题上各跑 10 次推理,
5
+ 用精确字符串匹配判断对错,选出方差最大的题目。
6
+
7
+ 用法:
8
+ docker exec rl4phyx_env python3 /workspace/rl4phyx/RL4Phyx/SFT/variance_select.py
9
+
10
+ 输出:
11
+ - variance_results.json: 每道题的正确率和方差
12
+ - best_question_for_rlvr.json: 方差最大的题目信息
13
+ - rlvr_train.parquet: 转好的训练数据 (1题 × 128行)
14
+ """
15
+
16
+ import json
17
+ import re
18
+ import os
19
+ import torch
20
+ import numpy as np
21
+ from pathlib import Path
22
+
23
+ # ============ 配置 ============
24
+ MODEL_PATH = "/workspace/rl4phyx/RL4Phyx/SFT/checkpoints/sft_qwen25vl_3b/merged"
25
+ TEST_FILE = "/workspace/rl4phyx/RL4Phyx/SFT/sft_test.jsonl"
26
+ IMAGE_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/images"
27
+ OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT"
28
+
29
+ NUM_RUNS = 10 # 每题推理次数
30
+ MAX_NEW_TOKENS = 1024 # 最大生成长度 (缩短加速,只需提取boxed答案)
31
+ TEMPERATURE = 0.7 # 采样温度
32
+ BATCH_SIZE = 4 # 推理 batch size (根据显存调)
33
+ RLVR_COPIES = 128 # one-shot 复制次数
34
+
35
+
36
+ def extract_boxed(text):
37
+ """从模型输出中提取 \\boxed{} 内的答案"""
38
+ # 找最后一个 \boxed{}
39
+ matches = re.findall(r'\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}', text)
40
+ if matches:
41
+ return matches[-1].strip()
42
+ return None
43
+
44
+
45
+ def normalize_answer(ans):
46
+ """轻度归一化:去空格、去末尾句号"""
47
+ if ans is None:
48
+ return None
49
+ ans = ans.strip()
50
+ ans = ans.rstrip('.')
51
+ # 去掉 \text{} 包裹
52
+ ans = re.sub(r'\\text\{([^}]*)\}', r'\1', ans)
53
+ # 去多余空格
54
+ ans = re.sub(r'\s+', ' ', ans)
55
+ return ans
56
+
57
+
58
+ def exact_match(pred_answer, gt_answer):
59
+ """精确字符串匹配"""
60
+ if pred_answer is None:
61
+ return False
62
+ pred_norm = normalize_answer(pred_answer)
63
+ gt_norm = normalize_answer(gt_answer)
64
+ if pred_norm is None or gt_norm is None:
65
+ return False
66
+ return pred_norm == gt_norm
67
+
68
+
69
+ def load_test_data(test_file):
70
+ """加载测试数据"""
71
+ data = []
72
+ with open(test_file, 'r', encoding='utf-8') as f:
73
+ for line in f:
74
+ r = json.loads(line.strip())
75
+ data.append(r)
76
+ print(f"Loaded {len(data)} test samples")
77
+ return data
78
+
79
+
80
+ def run_inference(model, processor, test_data, run_id):
81
+ """对所有题目跑一次推理"""
82
+ print(f"\n{'='*60}")
83
+ print(f" Run {run_id + 1}/{NUM_RUNS}")
84
+ print(f"{'='*60}")
85
+
86
+ results = []
87
+
88
+ for i, item in enumerate(test_data):
89
+ if i % 10 == 0:
90
+ print(f" Processing {i}/{len(test_data)}...", flush=True)
91
+
92
+ prompt_text = item['prompt']
93
+ image_path = os.path.join(IMAGE_DIR, f"{item['index']}.png")
94
+
95
+ # 构建消息
96
+ messages = [{"role": "user", "content": []}]
97
+
98
+ # 添加图片(如果存在)
99
+ if os.path.exists(image_path):
100
+ messages[0]["content"].append({
101
+ "type": "image",
102
+ "image": f"file://{image_path}"
103
+ })
104
+
105
+ messages[0]["content"].append({
106
+ "type": "text",
107
+ "text": prompt_text
108
+ })
109
+
110
+ # 处理输入
111
+ text = processor.apply_chat_template(
112
+ messages, tokenize=False, add_generation_prompt=True
113
+ )
114
+
115
+ from qwen_vl_utils import process_vision_info
116
+ image_inputs, video_inputs = process_vision_info(messages)
117
+
118
+ inputs = processor(
119
+ text=[text],
120
+ images=image_inputs,
121
+ videos=video_inputs,
122
+ padding=True,
123
+ return_tensors="pt"
124
+ ).to(model.device)
125
+
126
+ # 生成
127
+ with torch.no_grad():
128
+ output_ids = model.generate(
129
+ **inputs,
130
+ max_new_tokens=MAX_NEW_TOKENS,
131
+ temperature=TEMPERATURE,
132
+ do_sample=True,
133
+ top_p=0.9,
134
+ )
135
+
136
+ # 解码 (只取生成部分)
137
+ input_len = inputs['input_ids'].shape[1]
138
+ generated = output_ids[0][input_len:]
139
+ prediction = processor.decode(generated, skip_special_tokens=True)
140
+
141
+ # 提取答案
142
+ pred_answer = extract_boxed(prediction)
143
+ gt_answer = item['ground_truth']
144
+ is_correct = exact_match(pred_answer, gt_answer)
145
+
146
+ results.append({
147
+ 'index': item['index'],
148
+ 'pred_answer': pred_answer,
149
+ 'gt_answer': gt_answer,
150
+ 'correct': is_correct,
151
+ })
152
+
153
+ correct_count = sum(1 for r in results if r['correct'])
154
+ print(f" Run {run_id + 1} accuracy: {correct_count}/{len(results)} "
155
+ f"({100*correct_count/len(results):.1f}%)")
156
+
157
+ return results
158
+
159
+
160
+ def compute_variance(all_runs, test_data):
161
+ """计算每道题的正确率方差"""
162
+ n_questions = len(test_data)
163
+ stats = []
164
+
165
+ for qi in range(n_questions):
166
+ # 收集这道题在 10 次 run 中的对错情况
167
+ correct_flags = [all_runs[run_id][qi]['correct'] for run_id in range(NUM_RUNS)]
168
+ n_correct = sum(correct_flags)
169
+ p = n_correct / NUM_RUNS # 正确率
170
+ variance = p * (1 - p) # 伯努利方差
171
+
172
+ stats.append({
173
+ 'index': test_data[qi]['index'],
174
+ 'category': test_data[qi].get('category', ''),
175
+ 'ground_truth': test_data[qi]['ground_truth'],
176
+ 'n_correct': n_correct,
177
+ 'accuracy': p,
178
+ 'variance': variance,
179
+ 'correct_flags': correct_flags,
180
+ 'pred_answers': [all_runs[run_id][qi]['pred_answer'] for run_id in range(NUM_RUNS)],
181
+ })
182
+
183
+ # 按方差降序排列
184
+ stats.sort(key=lambda x: x['variance'], reverse=True)
185
+
186
+ return stats
187
+
188
+
189
+ def convert_to_training_format(question_item, copies=RLVR_COPIES):
190
+ """将选中的题目转成 RLVR 训练 parquet 格式"""
191
+ import pandas as pd
192
+
193
+ prompt_text = question_item['prompt']
194
+ image_path = f"{question_item['index']}.png"
195
+
196
+ # 构建 veRL 格式的 prompt
197
+ prompt_messages = [{"role": "user", "content": prompt_text}]
198
+
199
+ records = []
200
+ for _ in range(copies):
201
+ records.append({
202
+ 'prompt': prompt_messages,
203
+ 'answer': question_item['ground_truth'],
204
+ 'image_path': image_path,
205
+ 'data_source': 'deepscaler',
206
+ 'category': question_item.get('category', 'Physics'),
207
+ 'index': question_item['index'],
208
+ })
209
+
210
+ df = pd.DataFrame(records)
211
+ return df
212
+
213
+
214
+ def main():
215
+ print("=" * 60)
216
+ print(" Variance-based One-Shot Question Selection")
217
+ print("=" * 60)
218
+
219
+ # 1. 加载模型
220
+ print("\n[1/4] Loading SFT model...")
221
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
222
+
223
+ import os
224
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
225
+
226
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
227
+ MODEL_PATH,
228
+ torch_dtype=torch.bfloat16,
229
+ ).to("cuda")
230
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
231
+ model.eval()
232
+ print(f" Model loaded: {sum(p.numel() for p in model.parameters())/1e6:.0f}M params")
233
+
234
+ # 2. 加载测试数据
235
+ print("\n[2/4] Loading test data...")
236
+ test_data = load_test_data(TEST_FILE)
237
+
238
+ # 3. 跑 10 次推理
239
+ print(f"\n[3/4] Running {NUM_RUNS} inference passes on {len(test_data)} questions...")
240
+ all_runs = []
241
+ for run_id in range(NUM_RUNS):
242
+ run_results = run_inference(model, processor, test_data, run_id)
243
+ all_runs.append(run_results)
244
+
245
+ # 每次 run 后保存中间结果
246
+ interim_path = os.path.join(OUTPUT_DIR, f"variance_run_{run_id}.json")
247
+ with open(interim_path, 'w', encoding='utf-8') as f:
248
+ json.dump(run_results, f, ensure_ascii=False, indent=2)
249
+ print(f" Saved interim results to {interim_path}")
250
+
251
+ # 4. 计算方差并选题
252
+ print(f"\n[4/4] Computing variance and selecting best question...")
253
+ stats = compute_variance(all_runs, test_data)
254
+
255
+ # 保存完整统计
256
+ stats_path = os.path.join(OUTPUT_DIR, "variance_results.json")
257
+ with open(stats_path, 'w', encoding='utf-8') as f:
258
+ json.dump(stats, f, ensure_ascii=False, indent=2)
259
+ print(f" Saved all variance stats to {stats_path}")
260
+
261
+ # 打印 Top 20 最高方差题目
262
+ print(f"\n{'='*60}")
263
+ print(f" TOP 20 HIGHEST VARIANCE QUESTIONS")
264
+ print(f"{'='*60}")
265
+ for i, s in enumerate(stats[:20]):
266
+ print(f" #{i+1}: idx={s['index']} | gt={s['ground_truth'][:30]:30s} | "
267
+ f"correct={s['n_correct']}/{NUM_RUNS} | var={s['variance']:.4f} | "
268
+ f"cat={s['category']}")
269
+ print(f" preds: {s['pred_answers'][:5]}")
270
+
271
+ # 选方差最大的题
272
+ best = stats[0]
273
+ print(f"\n{'='*60}")
274
+ print(f" SELECTED QUESTION FOR ONE-SHOT RLVR")
275
+ print(f"{'='*60}")
276
+ print(f" Index: {best['index']}")
277
+ print(f" Category: {best['category']}")
278
+ print(f" Ground Truth: {best['ground_truth']}")
279
+ print(f" Accuracy: {best['n_correct']}/{NUM_RUNS} ({best['accuracy']*100:.0f}%)")
280
+ print(f" Variance: {best['variance']:.4f}")
281
+ print(f" Pred Answers: {best['pred_answers']}")
282
+
283
+ # 保存选中题目
284
+ best_idx = int(best['index'])
285
+ best_item = None
286
+ for item in test_data:
287
+ if int(item['index']) == best_idx:
288
+ best_item = item
289
+ break
290
+
291
+ best_path = os.path.join(OUTPUT_DIR, "best_question_for_rlvr.json")
292
+ with open(best_path, 'w', encoding='utf-8') as f:
293
+ json.dump({
294
+ 'selected_question': best_item,
295
+ 'stats': best,
296
+ }, f, ensure_ascii=False, indent=2)
297
+ print(f" Saved best question to {best_path}")
298
+
299
+ # 转成训练 parquet
300
+ if best_item:
301
+ df = convert_to_training_format(best_item)
302
+ parquet_path = os.path.join(OUTPUT_DIR, "rlvr_train.parquet")
303
+ df.to_parquet(parquet_path, index=False)
304
+ print(f" Saved training parquet ({len(df)} rows) to {parquet_path}")
305
+
306
+ print(f"\n{'='*60}")
307
+ print(f" DONE!")
308
+ print(f"{'='*60}")
309
+
310
+
311
+ if __name__ == "__main__":
312
+ main()