rl4phyx-backup / scripts /variance_select_vllm.py
YUNTA88's picture
Upload scripts/variance_select_vllm.py with huggingface_hub
15fa009 verified
"""
vLLM Batch Inference for Variance-based One-Shot Question Selection
====================================================================
用 vLLM 做批量推理,比 HuggingFace generate 快 5-10x。
对 SFT 模型在 1533 道开放题上各跑 10 次推理,
用精确字符串匹配判断对错,选出方差最大的题目。
用法 (在 Docker 容器中):
CUDA_VISIBLE_DEVICES=0 python3 -u variance_select_vllm.py
"""
import json
import re
import os
import numpy as np
from pathlib import Path
# ============ 配置 ============
MODEL_PATH = "/workspace/rl4phyx/RL4Phyx/SFT/checkpoints/sft_qwen25vl_3b/merged"
TEST_FILE = "/workspace/rl4phyx/RL4Phyx/SFT/sft_test.jsonl"
IMAGE_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/images"
OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT"
NUM_RUNS = 10 # 每题推理次数
MAX_NEW_TOKENS = 3072 # 最大生成长度
TEMPERATURE = 0.7 # 采样温度
TOP_P = 0.9
RLVR_COPIES = 128 # one-shot 复制次数
def extract_boxed(text):
"""从模型输出中提取 \\boxed{} 内的答案"""
matches = re.findall(r'\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}', text)
if matches:
return matches[-1].strip()
return None
def normalize_answer(ans):
"""轻度归一化:去空格、去末尾句号"""
if ans is None:
return None
ans = ans.strip()
ans = ans.rstrip('.')
ans = re.sub(r'\\text\{([^}]*)\}', r'\1', ans)
ans = re.sub(r'\s+', ' ', ans)
return ans
def exact_match(pred_answer, gt_answer):
"""精确字符串匹配"""
if pred_answer is None:
return False
pred_norm = normalize_answer(pred_answer)
gt_norm = normalize_answer(gt_answer)
if pred_norm is None or gt_norm is None:
return False
return pred_norm == gt_norm
def load_test_data(test_file):
"""加载测试数据"""
data = []
with open(test_file, 'r', encoding='utf-8') as f:
for line in f:
data.append(json.loads(line.strip()))
print(f"Loaded {len(data)} test samples")
return data
def build_vllm_inputs(test_data, processor):
"""构建 vLLM 批量推理的输入"""
from PIL import Image
prompts = []
multi_modal_data_list = []
for i, item in enumerate(test_data):
prompt_text = item['prompt']
image_path = os.path.join(IMAGE_DIR, f"{item['index']}.png")
# 构建聊天消息
messages = [{"role": "user", "content": []}]
has_image = os.path.exists(image_path)
if has_image:
messages[0]["content"].append({
"type": "image",
"image": f"file://{image_path}"
})
messages[0]["content"].append({
"type": "text",
"text": prompt_text
})
# 用 processor 生成 prompt 文本
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# 加载图片
mm_data = {}
if has_image:
img = Image.open(image_path).convert("RGB")
mm_data["image"] = img
prompts.append(text)
multi_modal_data_list.append(mm_data if mm_data else None)
if (i + 1) % 200 == 0:
print(f" Prepared {i+1}/{len(test_data)} inputs...")
return prompts, multi_modal_data_list
def run_vllm_batch(llm, prompts, multi_modal_data_list, sampling_params, run_id):
"""用 vLLM 跑一次批量推理"""
from vllm import TextPrompt
print(f"\n{'='*60}")
print(f" Run {run_id + 1}/{NUM_RUNS}")
print(f"{'='*60}")
# 构建 vLLM 请求
requests = []
for i, (prompt, mm_data) in enumerate(zip(prompts, multi_modal_data_list)):
req = {
"prompt": prompt,
}
if mm_data:
req["multi_modal_data"] = mm_data
requests.append(req)
# 批量推理
print(f" Submitting {len(requests)} requests to vLLM...")
outputs = llm.generate(requests, sampling_params)
print(f" Got {len(outputs)} outputs")
return outputs
def main():
print("=" * 60)
print(" vLLM Batch Variance Selection")
print("=" * 60)
# 1. 加载 vLLM 模型
print("\n[1/5] Loading vLLM model...")
from vllm import LLM, SamplingParams
from transformers import AutoProcessor
llm = LLM(
model=MODEL_PATH,
dtype="bfloat16",
trust_remote_code=True,
max_model_len=4096,
gpu_memory_utilization=0.85,
limit_mm_per_prompt={"image": 1},
)
sampling_params = SamplingParams(
temperature=TEMPERATURE,
top_p=TOP_P,
max_tokens=MAX_NEW_TOKENS,
)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
print(" Model loaded!")
# 2. 加载测试数据
print("\n[2/5] Loading test data...")
test_data = load_test_data(TEST_FILE)
# 3. 预处理输入
print("\n[3/5] Building vLLM inputs...")
prompts, multi_modal_data_list = build_vllm_inputs(test_data, processor)
print(f" Built {len(prompts)} prompts")
# 4. 跑 10 次推理
print(f"\n[4/5] Running {NUM_RUNS} inference passes...")
all_runs = []
for run_id in range(NUM_RUNS):
outputs = run_vllm_batch(llm, prompts, multi_modal_data_list, sampling_params, run_id)
# 解析结果
run_results = []
correct_count = 0
for i, output in enumerate(outputs):
prediction = output.outputs[0].text
pred_answer = extract_boxed(prediction)
gt_answer = test_data[i]['ground_truth']
is_correct = exact_match(pred_answer, gt_answer)
if is_correct:
correct_count += 1
run_results.append({
'index': test_data[i]['index'],
'pred_answer': pred_answer,
'gt_answer': gt_answer,
'correct': is_correct,
})
all_runs.append(run_results)
print(f" Run {run_id+1} accuracy: {correct_count}/{len(test_data)} "
f"({100*correct_count/len(test_data):.1f}%)")
# 保存中间结果
interim_path = os.path.join(OUTPUT_DIR, f"variance_run_{run_id}.json")
with open(interim_path, 'w', encoding='utf-8') as f:
json.dump(run_results, f, ensure_ascii=False, indent=2)
print(f" Saved to {interim_path}")
# 5. 计算方差并选题
print(f"\n[5/5] Computing variance...")
n_questions = len(test_data)
stats = []
for qi in range(n_questions):
correct_flags = [all_runs[r][qi]['correct'] for r in range(NUM_RUNS)]
n_correct = sum(correct_flags)
p = n_correct / NUM_RUNS
variance = p * (1 - p)
stats.append({
'index': test_data[qi]['index'],
'category': test_data[qi].get('category', ''),
'ground_truth': test_data[qi]['ground_truth'],
'n_correct': n_correct,
'accuracy': p,
'variance': variance,
'correct_flags': correct_flags,
'pred_answers': [all_runs[r][qi]['pred_answer'] for r in range(NUM_RUNS)],
})
stats.sort(key=lambda x: x['variance'], reverse=True)
# 保存完整统计
stats_path = os.path.join(OUTPUT_DIR, "variance_results.json")
with open(stats_path, 'w', encoding='utf-8') as f:
json.dump(stats, f, ensure_ascii=False, indent=2)
print(f" Saved variance stats to {stats_path}")
# Top 20
print(f"\n{'='*60}")
print(f" TOP 20 HIGHEST VARIANCE QUESTIONS")
print(f"{'='*60}")
for i, s in enumerate(stats[:20]):
print(f" #{i+1}: idx={s['index']} | gt={s['ground_truth'][:30]:30s} | "
f"correct={s['n_correct']}/{NUM_RUNS} | var={s['variance']:.4f} | "
f"cat={s['category']}")
# 选方差最大的题
best = stats[0]
print(f"\n{'='*60}")
print(f" SELECTED QUESTION FOR ONE-SHOT RLVR")
print(f"{'='*60}")
print(f" Index: {best['index']}")
print(f" Category: {best['category']}")
print(f" Ground Truth: {best['ground_truth']}")
print(f" Accuracy: {best['n_correct']}/{NUM_RUNS} ({best['accuracy']*100:.0f}%)")
print(f" Variance: {best['variance']:.4f}")
# 保存选中题目
best_idx = int(best['index'])
best_item = next(item for item in test_data if int(item['index']) == best_idx)
best_path = os.path.join(OUTPUT_DIR, "best_question_for_rlvr.json")
with open(best_path, 'w', encoding='utf-8') as f:
json.dump({'selected_question': best_item, 'stats': best}, f, ensure_ascii=False, indent=2)
print(f" Saved to {best_path}")
# 转训练 parquet
import pandas as pd
prompt_messages = [{"role": "user", "content": best_item['prompt']}]
records = [{
'prompt': prompt_messages,
'answer': best_item['ground_truth'],
'image_path': f"{best_item['index']}.png",
'data_source': 'deepscaler',
'category': best_item.get('category', 'Physics'),
'index': best_item['index'],
} for _ in range(RLVR_COPIES)]
df = pd.DataFrame(records)
parquet_path = os.path.join(OUTPUT_DIR, "rlvr_train.parquet")
df.to_parquet(parquet_path, index=False)
print(f" Saved training parquet ({len(df)} rows) to {parquet_path}")
print(f"\n{'='*60}")
print(f" DONE!")
print(f"{'='*60}")
if __name__ == "__main__":
main()