rl4phyx-backup / root_scripts /eval_judge.py
YUNTA88's picture
Upload root_scripts/eval_judge.py with huggingface_hub
7078fd1 verified
#!/usr/bin/env python3
"""
Phase 2: Score open-ended inference results.
Two-stage scoring (adapted from eval_dual.py + MetaPhyX DeepSeek judge):
Stage 1: Rule-based (boxed extraction + normalization + numeric tolerance)
- If CORRECT → done, count as correct
- If WRONG or UNCERTAIN → go to Stage 2
Stage 2: Gemini 2.5 Flash LLM-as-Judge
- Sends model's full response + ground truth to Gemini
- Gemini determines [[YES]] or [[NO]] equivalence
Usage:
python eval_openended_judge.py [--results_dir PATH] [--api_key KEY]
Inputs:
inference_results_base.jsonl
inference_results_sft.jsonl
Outputs:
scored_results_base.jsonl
scored_results_sft.jsonl
comparison_report.json
"""
import json, os, re, time, sys, argparse
from collections import defaultdict, Counter
# ===================== CONFIG =====================
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "AIzaSyCXQ9gjVmRhoB1OVSqElnTB6p83GLX4W4w")
GEMINI_MODEL = "gemini-2.5-flash"
MAX_RETRIES = 3
RATE_LIMIT_DELAY = 0.5 # seconds between Gemini calls
# ===================== RULE-BASED SCORING =====================
# Adapted from eval_dual.py (verl/utils/reward_score/utils/utils.py approach)
def _strip_string(string):
"""Normalize math string: remove LaTeX formatting, units, whitespace."""
string = string.replace("\n", "")
string = string.replace("\\!", "")
string = string.replace("\\\\", "\\")
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
string = string.replace("\\left", "")
string = string.replace("\\right", "")
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
string = string.replace("\\$", "")
if "\\text{ " in string:
splits = string.split("\\text{ ")
if len(splits) == 2:
string = splits[0]
string = string.replace("\\%", "")
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
string = string.replace(" ", "")
return string
def _normalize(expr):
"""Normalize answer expression for comparison."""
if expr is None:
return None
m = re.search("^\\\\text\\{(?P<text>.+?)\\}$", expr)
if m is not None:
expr = m.group("text")
expr = expr.replace("\\%", "%")
expr = expr.replace("\\$", "$")
expr = expr.replace("$", "")
expr = expr.replace("%", "")
expr = expr.replace(" or ", " , ")
expr = expr.replace(" and ", " , ")
for unit in ["degree", "cm", "centimeter", "meter", "mile", "second", "minute",
"hour", "day", "week", "month", "year", "foot", "feet", "inch", "yard",
"newton", "joule", "watt", "ampere", "volt", "ohm", "hertz",
"kilogram", "gram", "liter", "mole", "kelvin", "pascal",
"m/s", "km/h", "rad/s", "N", "J", "W", "A", "V", "Hz", "Pa", "kg", "mol"]:
expr = re.sub(f"\\s*{re.escape(unit)}(es)?(s)?\\s*(\\^[0-9]+)?", "", expr, flags=re.IGNORECASE)
if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
expr = expr[1:-1]
try:
if "." in expr:
val = float(expr)
if abs(val - int(round(val))) <= 1e-7:
expr = str(int(round(val)))
except:
pass
expr = re.sub("- *", "-", expr)
expr = expr.replace(" ", "")
expr = expr.replace("{", "")
expr = expr.replace("}", "")
expr = expr.lower()
return expr
def extract_boxed_answer(text):
"""Extract the last \\boxed{} content from text."""
idx = text.rfind("\\boxed")
if idx < 0:
idx = text.rfind("\\fbox")
if idx < 0:
return None
i = idx
num_left = 0
right_idx = None
while i < len(text):
if text[i] == "{":
num_left += 1
if text[i] == "}":
num_left -= 1
if num_left == 0:
right_idx = i
break
i += 1
if right_idx is None:
return None
boxed = text[idx:right_idx + 1]
left = "\\boxed{"
if boxed.startswith(left) and boxed.endswith("}"):
return boxed[len(left):-1]
return None
def extract_answer_from_text(text):
"""Try to extract answer: first from \\boxed{}, then from common patterns."""
# Handle <think>...</think>
if '<think>' in text and '</think>' in text:
text = text.split('</think>')[-1]
# Priority 1: \boxed{}
boxed = extract_boxed_answer(text)
if boxed:
return boxed
# Priority 2: Common answer patterns
patterns = [
r'(?:the answer is|answer is|答案是|答案为)[:\s]*(.+?)(?:\.|$)',
r'(?:therefore|thus|so|hence)[,\s]+(?:the answer is\s+)?(.+?)(?:\.|$)',
]
for p in patterns:
m = re.search(p, text, re.IGNORECASE)
if m:
ans = m.group(1).strip()
if len(ans) < 100:
return ans
return None
def rule_based_score(prediction, ground_truth):
"""
Rule-based scoring: extract answer + normalize + compare.
Returns: (is_correct: bool, reason: str)
"""
model_answer = extract_answer_from_text(prediction)
if model_answer is None:
return False, "no_answer_extracted"
gt_norm = _normalize(ground_truth)
pred_norm = _normalize(model_answer)
if gt_norm is None or pred_norm is None:
return False, "normalize_failed"
# Direct match after normalization
if gt_norm == pred_norm:
return True, "exact_match"
# Numeric comparison (1% tolerance)
try:
gt_float = float(gt_norm.replace(",", ""))
pred_float = float(pred_norm.replace(",", ""))
if abs(gt_float - pred_float) < 1e-6:
return True, "numeric_match"
if gt_float != 0 and abs((gt_float - pred_float) / gt_float) < 0.01:
return True, "numeric_close"
except:
pass
# Short answer containment (e.g., "III", "decreasing")
if len(ground_truth.strip()) <= 10:
gt_clean = ground_truth.strip()
if re.search(r'\b' + re.escape(gt_clean) + r'\b', prediction, re.IGNORECASE):
return True, "containment_match"
return False, f"no_match(pred={pred_norm[:30]},gt={gt_norm[:30]})"
# ===================== GEMINI LLM-AS-JUDGE =====================
# Adapted from eval_dual.py + MetaPhyX deepscaler ORM prompt
ORM_PROMPT = """You are an expert in verifying if two physics answers are the same.
Your input is a physics question prompt and two answers:
- Answer 1: the model's prediction
- Answer 2: the ground truth answer
Determine if they are equivalent.
Guidelines for equivalence:
- Different forms of the same number (0.5 = 1/2 = 50%)
- Same physical quantity with different units or notation (7.55N = 7.55 N = 7.55 newtons)
- Semantically equivalent descriptions ("point III" and "III", "decreasing" and "the velocity is decreasing")
- Algebraically equivalent expressions (x+1)^2 = x^2+2x+1
- Same choice letter or option name
- Correct numerical value even if formatting differs
- Minor rounding differences within 2% are acceptable
Your output must follow this format:
1) Brief explanation for why the answers are equivalent or not.
2) Final answer: [[YES]] or [[NO]]
"""
def call_gemini(prompt, api_key):
"""Call Gemini API using urllib (no external deps)."""
import urllib.request, urllib.error
url = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={api_key}"
payload = json.dumps({
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": 0.0,
"maxOutputTokens": 512,
}
}).encode('utf-8')
req = urllib.request.Request(
url, data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
for attempt in range(MAX_RETRIES):
try:
with urllib.request.urlopen(req, timeout=30) as resp:
result = json.loads(resp.read().decode('utf-8'))
text = result['candidates'][0]['content']['parts'][0]['text']
return text.strip()
except urllib.error.HTTPError as e:
if e.code == 429:
wait = (attempt + 1) * 5
print(f" Rate limited, waiting {wait}s...")
time.sleep(wait)
else:
print(f" HTTP error {e.code}")
if attempt == MAX_RETRIES - 1:
return None
time.sleep(2)
except Exception as e:
print(f" Error: {e}")
if attempt == MAX_RETRIES - 1:
return None
time.sleep(2)
return None
def gemini_judge(prediction, ground_truth, api_key):
"""Use Gemini to judge if model's prediction matches ground truth."""
user_msg = f"""
Model's full response (contains reasoning and answer):
{prediction[:2000]}
Ground truth answer: {ground_truth}
"""
response = call_gemini(ORM_PROMPT + "\n\n" + user_msg, api_key)
if response is None:
return False, "api_error"
if "[[YES]]" in response:
return True, response[:200]
elif "[[NO]]" in response:
return False, response[:200]
else:
lower = response.lower()
if "yes" in lower and "no" not in lower:
return True, response[:200]
return False, response[:200]
# ===================== MAIN EVALUATION =====================
def score_model(results, model_name, api_key, output_file):
"""
Score all results using two-stage approach:
1. Rule-based first → if correct, DONE
2. If rule-based says wrong/uncertain → Gemini fallback
"""
print(f"\n{'='*60}")
print(f" Scoring: {model_name} ({len(results)} samples)")
print(f"{'='*60}")
rule_correct = 0
rule_wrong_gemini_correct = 0
rule_wrong_gemini_wrong = 0
gemini_errors = 0
total = len(results)
cat_stats = defaultdict(lambda: {'total': 0, 'rule_correct': 0, 'gemini_correct': 0, 'final_correct': 0})
for i, r in enumerate(results):
cat = r.get('category', 'Unknown')
pred = r.get('model_output', '')
gt = r.get('ground_truth_value', '')
cat_stats[cat]['total'] += 1
# === Stage 1: Rule-based ===
rule_match, rule_reason = rule_based_score(pred, gt)
r['rule_match'] = rule_match
r['rule_reason'] = rule_reason
if rule_match:
# Rule says CORRECT → done
rule_correct += 1
cat_stats[cat]['rule_correct'] += 1
cat_stats[cat]['final_correct'] += 1
r['final_correct'] = True
r['final_method'] = f"rule:{rule_reason}"
r['gemini_called'] = False
else:
# Rule says WRONG → Gemini fallback
r['gemini_called'] = True
gemini_match, gemini_reason = gemini_judge(pred, gt, api_key)
r['gemini_match'] = gemini_match
r['gemini_reason'] = gemini_reason
if gemini_match:
rule_wrong_gemini_correct += 1
cat_stats[cat]['gemini_correct'] += 1
cat_stats[cat]['final_correct'] += 1
r['final_correct'] = True
r['final_method'] = "gemini_override"
else:
rule_wrong_gemini_wrong += 1
r['final_correct'] = False
r['final_method'] = f"wrong:{rule_reason}"
time.sleep(RATE_LIMIT_DELAY)
# Progress
final_correct_so_far = rule_correct + rule_wrong_gemini_correct
if (i + 1) % 10 == 0 or (i + 1) == total:
acc_so_far = final_correct_so_far / (i + 1)
print(f" [{i+1}/{total}] acc={acc_so_far:.1%} "
f"(rule✓={rule_correct} gemini✓={rule_wrong_gemini_correct} ✗={rule_wrong_gemini_wrong})",
flush=True)
# Save scored results
with open(output_file, 'w', encoding='utf-8') as f:
for r in results:
f.write(json.dumps(r, ensure_ascii=False) + '\n')
# Summary
final_correct = rule_correct + rule_wrong_gemini_correct
final_acc = final_correct / total if total > 0 else 0
print(f"\n{'─'*60}")
print(f" {model_name} — RESULTS")
print(f"{'─'*60}")
print(f" Rule-based correct : {rule_correct}/{total} ({100*rule_correct/total:.1f}%)")
print(f" Gemini rescued : {rule_wrong_gemini_correct} (rule wrong → Gemini correct)")
print(f" Final accuracy : {final_correct}/{total} ({100*final_acc:.1f}%)")
print(f" Gemini calls made : {rule_wrong_gemini_correct + rule_wrong_gemini_wrong}")
print(f"\n Per-category:")
for cat, s in sorted(cat_stats.items()):
acc = s['final_correct'] / s['total'] if s['total'] > 0 else 0
print(f" {cat:25s}: {s['final_correct']}/{s['total']} ({acc:.1%})"
f" [rule={s['rule_correct']}, gemini+={s['gemini_correct']}]")
return {
'model': model_name,
'total': total,
'rule_correct': rule_correct,
'gemini_rescued': rule_wrong_gemini_correct,
'final_correct': final_correct,
'final_acc': round(100 * final_acc, 2),
'category_stats': {cat: dict(s) for cat, s in cat_stats.items()},
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--results_dir', type=str, default=None)
parser.add_argument('--api_key', type=str, default=None)
args = parser.parse_args()
api_key = args.api_key or GEMINI_API_KEY
# Find results directory
results_dir = args.results_dir
if results_dir is None:
for d in [os.path.dirname(os.path.abspath(__file__)),
'/workspace/rl4phyx/RL4Phyx/SFT/sft_eval_footprint/']:
if os.path.exists(os.path.join(d, 'inference_results_base.jsonl')):
results_dir = d
break
if results_dir is None:
print("ERROR: Cannot find inference results. Use --results_dir")
sys.exit(1)
print("=" * 60)
print(" OPEN-ENDED EVAL: Rule-based + Gemini 2.5 Flash Judge")
print(f" Results dir: {results_dir}")
print("=" * 60)
# Load test data for context
test_file = os.path.join(results_dir, 'test_1533_openended.jsonl')
if os.path.exists(test_file):
with open(test_file, 'r') as f:
test_data = {json.loads(l)['index']: json.loads(l) for l in f if l.strip()}
print(f"Test data loaded: {len(test_data)} samples")
# Load and score base model
base_file = os.path.join(results_dir, 'inference_results_base.jsonl')
with open(base_file, 'r') as f:
base_results = [json.loads(l) for l in f if l.strip()]
base_scored_file = os.path.join(results_dir, 'scored_results_base.jsonl')
base_stats = score_model(base_results, "Qwen2.5-VL-3B (Base)", api_key, base_scored_file)
# Load and score SFT model
sft_file = os.path.join(results_dir, 'inference_results_sft.jsonl')
with open(sft_file, 'r') as f:
sft_results = [json.loads(l) for l in f if l.strip()]
sft_scored_file = os.path.join(results_dir, 'scored_results_sft.jsonl')
sft_stats = score_model(sft_results, "Qwen2.5-VL-3B (SFT)", api_key, sft_scored_file)
# Comparison
delta = sft_stats['final_acc'] - base_stats['final_acc']
report = {
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
'scoring_method': 'rule-based + Gemini 2.5 Flash judge (fallback)',
'base': base_stats,
'sft': sft_stats,
'improvement': f"{delta:+.2f}%",
}
report_file = os.path.join(results_dir, 'comparison_report.json')
with open(report_file, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
print(f"\n{'='*60}")
print(f" FINAL COMPARISON")
print(f"{'='*60}")
print(f" Base accuracy: {base_stats['final_acc']}% ({base_stats['final_correct']}/{base_stats['total']})")
print(f" SFT accuracy: {sft_stats['final_acc']}% ({sft_stats['final_correct']}/{sft_stats['total']})")
print(f" Improvement: {delta:+.2f}%")
print(f"\n Per-category:")
all_cats = sorted(set(list(base_stats['category_stats'].keys()) + list(sft_stats['category_stats'].keys())))
for cat in all_cats:
b = base_stats['category_stats'].get(cat, {'final_correct': 0, 'total': 0})
s = sft_stats['category_stats'].get(cat, {'final_correct': 0, 'total': 0})
b_acc = b['final_correct'] / b['total'] if b['total'] > 0 else 0
s_acc = s['final_correct'] / s['total'] if s['total'] > 0 else 0
print(f" {cat:25s} Base: {b_acc:.1%} SFT: {s_acc:.1%} Δ: {(s_acc-b_acc)*100:+.1f}%")
print(f"\n Report: {report_file}")
print(f"{'='*60}")
if __name__ == '__main__':
main()