#!/usr/bin/env python3 """ Phase 2: Score open-ended inference results. Two-stage scoring (adapted from eval_dual.py + MetaPhyX DeepSeek judge): Stage 1: Rule-based (boxed extraction + normalization + numeric tolerance) - If CORRECT → done, count as correct - If WRONG or UNCERTAIN → go to Stage 2 Stage 2: Gemini 2.5 Flash LLM-as-Judge - Sends model's full response + ground truth to Gemini - Gemini determines [[YES]] or [[NO]] equivalence Usage: python eval_openended_judge.py [--results_dir PATH] [--api_key KEY] Inputs: inference_results_base.jsonl inference_results_sft.jsonl Outputs: scored_results_base.jsonl scored_results_sft.jsonl comparison_report.json """ import json, os, re, time, sys, argparse from collections import defaultdict, Counter # ===================== CONFIG ===================== GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "AIzaSyCXQ9gjVmRhoB1OVSqElnTB6p83GLX4W4w") GEMINI_MODEL = "gemini-2.5-flash" MAX_RETRIES = 3 RATE_LIMIT_DELAY = 0.5 # seconds between Gemini calls # ===================== RULE-BASED SCORING ===================== # Adapted from eval_dual.py (verl/utils/reward_score/utils/utils.py approach) def _strip_string(string): """Normalize math string: remove LaTeX formatting, units, whitespace.""" string = string.replace("\n", "") string = string.replace("\\!", "") string = string.replace("\\\\", "\\") string = string.replace("tfrac", "frac") string = string.replace("dfrac", "frac") string = string.replace("\\left", "") string = string.replace("\\right", "") string = string.replace("^{\\circ}", "") string = string.replace("^\\circ", "") string = string.replace("\\$", "") if "\\text{ " in string: splits = string.split("\\text{ ") if len(splits) == 2: string = splits[0] string = string.replace("\\%", "") string = string.replace(" .", " 0.") string = string.replace("{.", "{0.") if len(string) == 0: return string if string[0] == ".": string = "0" + string if len(string.split("=")) == 2: if len(string.split("=")[0]) <= 2: string = string.split("=")[1] string = string.replace(" ", "") return string def _normalize(expr): """Normalize answer expression for comparison.""" if expr is None: return None m = re.search("^\\\\text\\{(?P.+?)\\}$", expr) if m is not None: expr = m.group("text") expr = expr.replace("\\%", "%") expr = expr.replace("\\$", "$") expr = expr.replace("$", "") expr = expr.replace("%", "") expr = expr.replace(" or ", " , ") expr = expr.replace(" and ", " , ") for unit in ["degree", "cm", "centimeter", "meter", "mile", "second", "minute", "hour", "day", "week", "month", "year", "foot", "feet", "inch", "yard", "newton", "joule", "watt", "ampere", "volt", "ohm", "hertz", "kilogram", "gram", "liter", "mole", "kelvin", "pascal", "m/s", "km/h", "rad/s", "N", "J", "W", "A", "V", "Hz", "Pa", "kg", "mol"]: expr = re.sub(f"\\s*{re.escape(unit)}(es)?(s)?\\s*(\\^[0-9]+)?", "", expr, flags=re.IGNORECASE) if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": expr = expr[1:-1] try: if "." in expr: val = float(expr) if abs(val - int(round(val))) <= 1e-7: expr = str(int(round(val))) except: pass expr = re.sub("- *", "-", expr) expr = expr.replace(" ", "") expr = expr.replace("{", "") expr = expr.replace("}", "") expr = expr.lower() return expr def extract_boxed_answer(text): """Extract the last \\boxed{} content from text.""" idx = text.rfind("\\boxed") if idx < 0: idx = text.rfind("\\fbox") if idx < 0: return None i = idx num_left = 0 right_idx = None while i < len(text): if text[i] == "{": num_left += 1 if text[i] == "}": num_left -= 1 if num_left == 0: right_idx = i break i += 1 if right_idx is None: return None boxed = text[idx:right_idx + 1] left = "\\boxed{" if boxed.startswith(left) and boxed.endswith("}"): return boxed[len(left):-1] return None def extract_answer_from_text(text): """Try to extract answer: first from \\boxed{}, then from common patterns.""" # Handle ... if '' in text and '' in text: text = text.split('')[-1] # Priority 1: \boxed{} boxed = extract_boxed_answer(text) if boxed: return boxed # Priority 2: Common answer patterns patterns = [ r'(?:the answer is|answer is|答案是|答案为)[:\s]*(.+?)(?:\.|$)', r'(?:therefore|thus|so|hence)[,\s]+(?:the answer is\s+)?(.+?)(?:\.|$)', ] for p in patterns: m = re.search(p, text, re.IGNORECASE) if m: ans = m.group(1).strip() if len(ans) < 100: return ans return None def rule_based_score(prediction, ground_truth): """ Rule-based scoring: extract answer + normalize + compare. Returns: (is_correct: bool, reason: str) """ model_answer = extract_answer_from_text(prediction) if model_answer is None: return False, "no_answer_extracted" gt_norm = _normalize(ground_truth) pred_norm = _normalize(model_answer) if gt_norm is None or pred_norm is None: return False, "normalize_failed" # Direct match after normalization if gt_norm == pred_norm: return True, "exact_match" # Numeric comparison (1% tolerance) try: gt_float = float(gt_norm.replace(",", "")) pred_float = float(pred_norm.replace(",", "")) if abs(gt_float - pred_float) < 1e-6: return True, "numeric_match" if gt_float != 0 and abs((gt_float - pred_float) / gt_float) < 0.01: return True, "numeric_close" except: pass # Short answer containment (e.g., "III", "decreasing") if len(ground_truth.strip()) <= 10: gt_clean = ground_truth.strip() if re.search(r'\b' + re.escape(gt_clean) + r'\b', prediction, re.IGNORECASE): return True, "containment_match" return False, f"no_match(pred={pred_norm[:30]},gt={gt_norm[:30]})" # ===================== GEMINI LLM-AS-JUDGE ===================== # Adapted from eval_dual.py + MetaPhyX deepscaler ORM prompt ORM_PROMPT = """You are an expert in verifying if two physics answers are the same. Your input is a physics question prompt and two answers: - Answer 1: the model's prediction - Answer 2: the ground truth answer Determine if they are equivalent. Guidelines for equivalence: - Different forms of the same number (0.5 = 1/2 = 50%) - Same physical quantity with different units or notation (7.55N = 7.55 N = 7.55 newtons) - Semantically equivalent descriptions ("point III" and "III", "decreasing" and "the velocity is decreasing") - Algebraically equivalent expressions (x+1)^2 = x^2+2x+1 - Same choice letter or option name - Correct numerical value even if formatting differs - Minor rounding differences within 2% are acceptable Your output must follow this format: 1) Brief explanation for why the answers are equivalent or not. 2) Final answer: [[YES]] or [[NO]] """ def call_gemini(prompt, api_key): """Call Gemini API using urllib (no external deps).""" import urllib.request, urllib.error url = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={api_key}" payload = json.dumps({ "contents": [{"parts": [{"text": prompt}]}], "generationConfig": { "temperature": 0.0, "maxOutputTokens": 512, } }).encode('utf-8') req = urllib.request.Request( url, data=payload, headers={"Content-Type": "application/json"}, method="POST", ) for attempt in range(MAX_RETRIES): try: with urllib.request.urlopen(req, timeout=30) as resp: result = json.loads(resp.read().decode('utf-8')) text = result['candidates'][0]['content']['parts'][0]['text'] return text.strip() except urllib.error.HTTPError as e: if e.code == 429: wait = (attempt + 1) * 5 print(f" Rate limited, waiting {wait}s...") time.sleep(wait) else: print(f" HTTP error {e.code}") if attempt == MAX_RETRIES - 1: return None time.sleep(2) except Exception as e: print(f" Error: {e}") if attempt == MAX_RETRIES - 1: return None time.sleep(2) return None def gemini_judge(prediction, ground_truth, api_key): """Use Gemini to judge if model's prediction matches ground truth.""" user_msg = f""" Model's full response (contains reasoning and answer): {prediction[:2000]} Ground truth answer: {ground_truth} """ response = call_gemini(ORM_PROMPT + "\n\n" + user_msg, api_key) if response is None: return False, "api_error" if "[[YES]]" in response: return True, response[:200] elif "[[NO]]" in response: return False, response[:200] else: lower = response.lower() if "yes" in lower and "no" not in lower: return True, response[:200] return False, response[:200] # ===================== MAIN EVALUATION ===================== def score_model(results, model_name, api_key, output_file): """ Score all results using two-stage approach: 1. Rule-based first → if correct, DONE 2. If rule-based says wrong/uncertain → Gemini fallback """ print(f"\n{'='*60}") print(f" Scoring: {model_name} ({len(results)} samples)") print(f"{'='*60}") rule_correct = 0 rule_wrong_gemini_correct = 0 rule_wrong_gemini_wrong = 0 gemini_errors = 0 total = len(results) cat_stats = defaultdict(lambda: {'total': 0, 'rule_correct': 0, 'gemini_correct': 0, 'final_correct': 0}) for i, r in enumerate(results): cat = r.get('category', 'Unknown') pred = r.get('model_output', '') gt = r.get('ground_truth_value', '') cat_stats[cat]['total'] += 1 # === Stage 1: Rule-based === rule_match, rule_reason = rule_based_score(pred, gt) r['rule_match'] = rule_match r['rule_reason'] = rule_reason if rule_match: # Rule says CORRECT → done rule_correct += 1 cat_stats[cat]['rule_correct'] += 1 cat_stats[cat]['final_correct'] += 1 r['final_correct'] = True r['final_method'] = f"rule:{rule_reason}" r['gemini_called'] = False else: # Rule says WRONG → Gemini fallback r['gemini_called'] = True gemini_match, gemini_reason = gemini_judge(pred, gt, api_key) r['gemini_match'] = gemini_match r['gemini_reason'] = gemini_reason if gemini_match: rule_wrong_gemini_correct += 1 cat_stats[cat]['gemini_correct'] += 1 cat_stats[cat]['final_correct'] += 1 r['final_correct'] = True r['final_method'] = "gemini_override" else: rule_wrong_gemini_wrong += 1 r['final_correct'] = False r['final_method'] = f"wrong:{rule_reason}" time.sleep(RATE_LIMIT_DELAY) # Progress final_correct_so_far = rule_correct + rule_wrong_gemini_correct if (i + 1) % 10 == 0 or (i + 1) == total: acc_so_far = final_correct_so_far / (i + 1) print(f" [{i+1}/{total}] acc={acc_so_far:.1%} " f"(rule✓={rule_correct} gemini✓={rule_wrong_gemini_correct} ✗={rule_wrong_gemini_wrong})", flush=True) # Save scored results with open(output_file, 'w', encoding='utf-8') as f: for r in results: f.write(json.dumps(r, ensure_ascii=False) + '\n') # Summary final_correct = rule_correct + rule_wrong_gemini_correct final_acc = final_correct / total if total > 0 else 0 print(f"\n{'─'*60}") print(f" {model_name} — RESULTS") print(f"{'─'*60}") print(f" Rule-based correct : {rule_correct}/{total} ({100*rule_correct/total:.1f}%)") print(f" Gemini rescued : {rule_wrong_gemini_correct} (rule wrong → Gemini correct)") print(f" Final accuracy : {final_correct}/{total} ({100*final_acc:.1f}%)") print(f" Gemini calls made : {rule_wrong_gemini_correct + rule_wrong_gemini_wrong}") print(f"\n Per-category:") for cat, s in sorted(cat_stats.items()): acc = s['final_correct'] / s['total'] if s['total'] > 0 else 0 print(f" {cat:25s}: {s['final_correct']}/{s['total']} ({acc:.1%})" f" [rule={s['rule_correct']}, gemini+={s['gemini_correct']}]") return { 'model': model_name, 'total': total, 'rule_correct': rule_correct, 'gemini_rescued': rule_wrong_gemini_correct, 'final_correct': final_correct, 'final_acc': round(100 * final_acc, 2), 'category_stats': {cat: dict(s) for cat, s in cat_stats.items()}, } def main(): parser = argparse.ArgumentParser() parser.add_argument('--results_dir', type=str, default=None) parser.add_argument('--api_key', type=str, default=None) args = parser.parse_args() api_key = args.api_key or GEMINI_API_KEY # Find results directory results_dir = args.results_dir if results_dir is None: for d in [os.path.dirname(os.path.abspath(__file__)), '/workspace/rl4phyx/RL4Phyx/SFT/sft_eval_footprint/']: if os.path.exists(os.path.join(d, 'inference_results_base.jsonl')): results_dir = d break if results_dir is None: print("ERROR: Cannot find inference results. Use --results_dir") sys.exit(1) print("=" * 60) print(" OPEN-ENDED EVAL: Rule-based + Gemini 2.5 Flash Judge") print(f" Results dir: {results_dir}") print("=" * 60) # Load test data for context test_file = os.path.join(results_dir, 'test_1533_openended.jsonl') if os.path.exists(test_file): with open(test_file, 'r') as f: test_data = {json.loads(l)['index']: json.loads(l) for l in f if l.strip()} print(f"Test data loaded: {len(test_data)} samples") # Load and score base model base_file = os.path.join(results_dir, 'inference_results_base.jsonl') with open(base_file, 'r') as f: base_results = [json.loads(l) for l in f if l.strip()] base_scored_file = os.path.join(results_dir, 'scored_results_base.jsonl') base_stats = score_model(base_results, "Qwen2.5-VL-3B (Base)", api_key, base_scored_file) # Load and score SFT model sft_file = os.path.join(results_dir, 'inference_results_sft.jsonl') with open(sft_file, 'r') as f: sft_results = [json.loads(l) for l in f if l.strip()] sft_scored_file = os.path.join(results_dir, 'scored_results_sft.jsonl') sft_stats = score_model(sft_results, "Qwen2.5-VL-3B (SFT)", api_key, sft_scored_file) # Comparison delta = sft_stats['final_acc'] - base_stats['final_acc'] report = { 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), 'scoring_method': 'rule-based + Gemini 2.5 Flash judge (fallback)', 'base': base_stats, 'sft': sft_stats, 'improvement': f"{delta:+.2f}%", } report_file = os.path.join(results_dir, 'comparison_report.json') with open(report_file, 'w', encoding='utf-8') as f: json.dump(report, f, indent=2, ensure_ascii=False) print(f"\n{'='*60}") print(f" FINAL COMPARISON") print(f"{'='*60}") print(f" Base accuracy: {base_stats['final_acc']}% ({base_stats['final_correct']}/{base_stats['total']})") print(f" SFT accuracy: {sft_stats['final_acc']}% ({sft_stats['final_correct']}/{sft_stats['total']})") print(f" Improvement: {delta:+.2f}%") print(f"\n Per-category:") all_cats = sorted(set(list(base_stats['category_stats'].keys()) + list(sft_stats['category_stats'].keys()))) for cat in all_cats: b = base_stats['category_stats'].get(cat, {'final_correct': 0, 'total': 0}) s = sft_stats['category_stats'].get(cat, {'final_correct': 0, 'total': 0}) b_acc = b['final_correct'] / b['total'] if b['total'] > 0 else 0 s_acc = s['final_correct'] / s['total'] if s['total'] > 0 else 0 print(f" {cat:25s} Base: {b_acc:.1%} SFT: {s_acc:.1%} Δ: {(s_acc-b_acc)*100:+.1f}%") print(f"\n Report: {report_file}") print(f"{'='*60}") if __name__ == '__main__': main()