|
|
| """
|
| Phase 2: Score open-ended inference results.
|
|
|
| Two-stage scoring (adapted from eval_dual.py + MetaPhyX DeepSeek judge):
|
| Stage 1: Rule-based (boxed extraction + normalization + numeric tolerance)
|
| - If CORRECT → done, count as correct
|
| - If WRONG or UNCERTAIN → go to Stage 2
|
| Stage 2: Gemini 2.5 Flash LLM-as-Judge
|
| - Sends model's full response + ground truth to Gemini
|
| - Gemini determines [[YES]] or [[NO]] equivalence
|
|
|
| Usage:
|
| python eval_openended_judge.py [--results_dir PATH] [--api_key KEY]
|
|
|
| Inputs:
|
| inference_results_base.jsonl
|
| inference_results_sft.jsonl
|
|
|
| Outputs:
|
| scored_results_base.jsonl
|
| scored_results_sft.jsonl
|
| comparison_report.json
|
| """
|
| import json, os, re, time, sys, argparse
|
| from collections import defaultdict, Counter
|
|
|
|
|
| GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "AIzaSyCXQ9gjVmRhoB1OVSqElnTB6p83GLX4W4w")
|
| GEMINI_MODEL = "gemini-2.5-flash"
|
| MAX_RETRIES = 3
|
| RATE_LIMIT_DELAY = 0.5
|
|
|
|
|
|
|
|
|
| def _strip_string(string):
|
| """Normalize math string: remove LaTeX formatting, units, whitespace."""
|
| string = string.replace("\n", "")
|
| string = string.replace("\\!", "")
|
| string = string.replace("\\\\", "\\")
|
| string = string.replace("tfrac", "frac")
|
| string = string.replace("dfrac", "frac")
|
| string = string.replace("\\left", "")
|
| string = string.replace("\\right", "")
|
| string = string.replace("^{\\circ}", "")
|
| string = string.replace("^\\circ", "")
|
| string = string.replace("\\$", "")
|
| if "\\text{ " in string:
|
| splits = string.split("\\text{ ")
|
| if len(splits) == 2:
|
| string = splits[0]
|
| string = string.replace("\\%", "")
|
| string = string.replace(" .", " 0.")
|
| string = string.replace("{.", "{0.")
|
| if len(string) == 0:
|
| return string
|
| if string[0] == ".":
|
| string = "0" + string
|
| if len(string.split("=")) == 2:
|
| if len(string.split("=")[0]) <= 2:
|
| string = string.split("=")[1]
|
| string = string.replace(" ", "")
|
| return string
|
|
|
|
|
| def _normalize(expr):
|
| """Normalize answer expression for comparison."""
|
| if expr is None:
|
| return None
|
| m = re.search("^\\\\text\\{(?P<text>.+?)\\}$", expr)
|
| if m is not None:
|
| expr = m.group("text")
|
| expr = expr.replace("\\%", "%")
|
| expr = expr.replace("\\$", "$")
|
| expr = expr.replace("$", "")
|
| expr = expr.replace("%", "")
|
| expr = expr.replace(" or ", " , ")
|
| expr = expr.replace(" and ", " , ")
|
| for unit in ["degree", "cm", "centimeter", "meter", "mile", "second", "minute",
|
| "hour", "day", "week", "month", "year", "foot", "feet", "inch", "yard",
|
| "newton", "joule", "watt", "ampere", "volt", "ohm", "hertz",
|
| "kilogram", "gram", "liter", "mole", "kelvin", "pascal",
|
| "m/s", "km/h", "rad/s", "N", "J", "W", "A", "V", "Hz", "Pa", "kg", "mol"]:
|
| expr = re.sub(f"\\s*{re.escape(unit)}(es)?(s)?\\s*(\\^[0-9]+)?", "", expr, flags=re.IGNORECASE)
|
| if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
|
| expr = expr[1:-1]
|
| try:
|
| if "." in expr:
|
| val = float(expr)
|
| if abs(val - int(round(val))) <= 1e-7:
|
| expr = str(int(round(val)))
|
| except:
|
| pass
|
| expr = re.sub("- *", "-", expr)
|
| expr = expr.replace(" ", "")
|
| expr = expr.replace("{", "")
|
| expr = expr.replace("}", "")
|
| expr = expr.lower()
|
| return expr
|
|
|
|
|
| def extract_boxed_answer(text):
|
| """Extract the last \\boxed{} content from text."""
|
| idx = text.rfind("\\boxed")
|
| if idx < 0:
|
| idx = text.rfind("\\fbox")
|
| if idx < 0:
|
| return None
|
| i = idx
|
| num_left = 0
|
| right_idx = None
|
| while i < len(text):
|
| if text[i] == "{":
|
| num_left += 1
|
| if text[i] == "}":
|
| num_left -= 1
|
| if num_left == 0:
|
| right_idx = i
|
| break
|
| i += 1
|
| if right_idx is None:
|
| return None
|
| boxed = text[idx:right_idx + 1]
|
| left = "\\boxed{"
|
| if boxed.startswith(left) and boxed.endswith("}"):
|
| return boxed[len(left):-1]
|
| return None
|
|
|
|
|
| def extract_answer_from_text(text):
|
| """Try to extract answer: first from \\boxed{}, then from common patterns."""
|
|
|
| if '<think>' in text and '</think>' in text:
|
| text = text.split('</think>')[-1]
|
|
|
|
|
| boxed = extract_boxed_answer(text)
|
| if boxed:
|
| return boxed
|
|
|
|
|
| patterns = [
|
| r'(?:the answer is|answer is|答案是|答案为)[:\s]*(.+?)(?:\.|$)',
|
| r'(?:therefore|thus|so|hence)[,\s]+(?:the answer is\s+)?(.+?)(?:\.|$)',
|
| ]
|
| for p in patterns:
|
| m = re.search(p, text, re.IGNORECASE)
|
| if m:
|
| ans = m.group(1).strip()
|
| if len(ans) < 100:
|
| return ans
|
|
|
| return None
|
|
|
|
|
| def rule_based_score(prediction, ground_truth):
|
| """
|
| Rule-based scoring: extract answer + normalize + compare.
|
| Returns: (is_correct: bool, reason: str)
|
| """
|
| model_answer = extract_answer_from_text(prediction)
|
| if model_answer is None:
|
| return False, "no_answer_extracted"
|
|
|
| gt_norm = _normalize(ground_truth)
|
| pred_norm = _normalize(model_answer)
|
|
|
| if gt_norm is None or pred_norm is None:
|
| return False, "normalize_failed"
|
|
|
|
|
| if gt_norm == pred_norm:
|
| return True, "exact_match"
|
|
|
|
|
| try:
|
| gt_float = float(gt_norm.replace(",", ""))
|
| pred_float = float(pred_norm.replace(",", ""))
|
| if abs(gt_float - pred_float) < 1e-6:
|
| return True, "numeric_match"
|
| if gt_float != 0 and abs((gt_float - pred_float) / gt_float) < 0.01:
|
| return True, "numeric_close"
|
| except:
|
| pass
|
|
|
|
|
| if len(ground_truth.strip()) <= 10:
|
| gt_clean = ground_truth.strip()
|
| if re.search(r'\b' + re.escape(gt_clean) + r'\b', prediction, re.IGNORECASE):
|
| return True, "containment_match"
|
|
|
| return False, f"no_match(pred={pred_norm[:30]},gt={gt_norm[:30]})"
|
|
|
|
|
|
|
|
|
|
|
| ORM_PROMPT = """You are an expert in verifying if two physics answers are the same.
|
| Your input is a physics question prompt and two answers:
|
| - Answer 1: the model's prediction
|
| - Answer 2: the ground truth answer
|
|
|
| Determine if they are equivalent.
|
|
|
| Guidelines for equivalence:
|
| - Different forms of the same number (0.5 = 1/2 = 50%)
|
| - Same physical quantity with different units or notation (7.55N = 7.55 N = 7.55 newtons)
|
| - Semantically equivalent descriptions ("point III" and "III", "decreasing" and "the velocity is decreasing")
|
| - Algebraically equivalent expressions (x+1)^2 = x^2+2x+1
|
| - Same choice letter or option name
|
| - Correct numerical value even if formatting differs
|
| - Minor rounding differences within 2% are acceptable
|
|
|
| Your output must follow this format:
|
| 1) Brief explanation for why the answers are equivalent or not.
|
| 2) Final answer: [[YES]] or [[NO]]
|
| """
|
|
|
|
|
| def call_gemini(prompt, api_key):
|
| """Call Gemini API using urllib (no external deps)."""
|
| import urllib.request, urllib.error
|
|
|
| url = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={api_key}"
|
| payload = json.dumps({
|
| "contents": [{"parts": [{"text": prompt}]}],
|
| "generationConfig": {
|
| "temperature": 0.0,
|
| "maxOutputTokens": 512,
|
| }
|
| }).encode('utf-8')
|
|
|
| req = urllib.request.Request(
|
| url, data=payload,
|
| headers={"Content-Type": "application/json"},
|
| method="POST",
|
| )
|
|
|
| for attempt in range(MAX_RETRIES):
|
| try:
|
| with urllib.request.urlopen(req, timeout=30) as resp:
|
| result = json.loads(resp.read().decode('utf-8'))
|
| text = result['candidates'][0]['content']['parts'][0]['text']
|
| return text.strip()
|
| except urllib.error.HTTPError as e:
|
| if e.code == 429:
|
| wait = (attempt + 1) * 5
|
| print(f" Rate limited, waiting {wait}s...")
|
| time.sleep(wait)
|
| else:
|
| print(f" HTTP error {e.code}")
|
| if attempt == MAX_RETRIES - 1:
|
| return None
|
| time.sleep(2)
|
| except Exception as e:
|
| print(f" Error: {e}")
|
| if attempt == MAX_RETRIES - 1:
|
| return None
|
| time.sleep(2)
|
| return None
|
|
|
|
|
| def gemini_judge(prediction, ground_truth, api_key):
|
| """Use Gemini to judge if model's prediction matches ground truth."""
|
| user_msg = f"""
|
| Model's full response (contains reasoning and answer):
|
| {prediction[:2000]}
|
|
|
| Ground truth answer: {ground_truth}
|
| """
|
| response = call_gemini(ORM_PROMPT + "\n\n" + user_msg, api_key)
|
|
|
| if response is None:
|
| return False, "api_error"
|
|
|
| if "[[YES]]" in response:
|
| return True, response[:200]
|
| elif "[[NO]]" in response:
|
| return False, response[:200]
|
| else:
|
| lower = response.lower()
|
| if "yes" in lower and "no" not in lower:
|
| return True, response[:200]
|
| return False, response[:200]
|
|
|
|
|
|
|
|
|
| def score_model(results, model_name, api_key, output_file):
|
| """
|
| Score all results using two-stage approach:
|
| 1. Rule-based first → if correct, DONE
|
| 2. If rule-based says wrong/uncertain → Gemini fallback
|
| """
|
| print(f"\n{'='*60}")
|
| print(f" Scoring: {model_name} ({len(results)} samples)")
|
| print(f"{'='*60}")
|
|
|
| rule_correct = 0
|
| rule_wrong_gemini_correct = 0
|
| rule_wrong_gemini_wrong = 0
|
| gemini_errors = 0
|
| total = len(results)
|
|
|
| cat_stats = defaultdict(lambda: {'total': 0, 'rule_correct': 0, 'gemini_correct': 0, 'final_correct': 0})
|
|
|
| for i, r in enumerate(results):
|
| cat = r.get('category', 'Unknown')
|
| pred = r.get('model_output', '')
|
| gt = r.get('ground_truth_value', '')
|
| cat_stats[cat]['total'] += 1
|
|
|
|
|
| rule_match, rule_reason = rule_based_score(pred, gt)
|
| r['rule_match'] = rule_match
|
| r['rule_reason'] = rule_reason
|
|
|
| if rule_match:
|
|
|
| rule_correct += 1
|
| cat_stats[cat]['rule_correct'] += 1
|
| cat_stats[cat]['final_correct'] += 1
|
| r['final_correct'] = True
|
| r['final_method'] = f"rule:{rule_reason}"
|
| r['gemini_called'] = False
|
| else:
|
|
|
| r['gemini_called'] = True
|
| gemini_match, gemini_reason = gemini_judge(pred, gt, api_key)
|
| r['gemini_match'] = gemini_match
|
| r['gemini_reason'] = gemini_reason
|
|
|
| if gemini_match:
|
| rule_wrong_gemini_correct += 1
|
| cat_stats[cat]['gemini_correct'] += 1
|
| cat_stats[cat]['final_correct'] += 1
|
| r['final_correct'] = True
|
| r['final_method'] = "gemini_override"
|
| else:
|
| rule_wrong_gemini_wrong += 1
|
| r['final_correct'] = False
|
| r['final_method'] = f"wrong:{rule_reason}"
|
|
|
| time.sleep(RATE_LIMIT_DELAY)
|
|
|
|
|
| final_correct_so_far = rule_correct + rule_wrong_gemini_correct
|
| if (i + 1) % 10 == 0 or (i + 1) == total:
|
| acc_so_far = final_correct_so_far / (i + 1)
|
| print(f" [{i+1}/{total}] acc={acc_so_far:.1%} "
|
| f"(rule✓={rule_correct} gemini✓={rule_wrong_gemini_correct} ✗={rule_wrong_gemini_wrong})",
|
| flush=True)
|
|
|
|
|
| with open(output_file, 'w', encoding='utf-8') as f:
|
| for r in results:
|
| f.write(json.dumps(r, ensure_ascii=False) + '\n')
|
|
|
|
|
| final_correct = rule_correct + rule_wrong_gemini_correct
|
| final_acc = final_correct / total if total > 0 else 0
|
|
|
| print(f"\n{'─'*60}")
|
| print(f" {model_name} — RESULTS")
|
| print(f"{'─'*60}")
|
| print(f" Rule-based correct : {rule_correct}/{total} ({100*rule_correct/total:.1f}%)")
|
| print(f" Gemini rescued : {rule_wrong_gemini_correct} (rule wrong → Gemini correct)")
|
| print(f" Final accuracy : {final_correct}/{total} ({100*final_acc:.1f}%)")
|
| print(f" Gemini calls made : {rule_wrong_gemini_correct + rule_wrong_gemini_wrong}")
|
| print(f"\n Per-category:")
|
| for cat, s in sorted(cat_stats.items()):
|
| acc = s['final_correct'] / s['total'] if s['total'] > 0 else 0
|
| print(f" {cat:25s}: {s['final_correct']}/{s['total']} ({acc:.1%})"
|
| f" [rule={s['rule_correct']}, gemini+={s['gemini_correct']}]")
|
|
|
| return {
|
| 'model': model_name,
|
| 'total': total,
|
| 'rule_correct': rule_correct,
|
| 'gemini_rescued': rule_wrong_gemini_correct,
|
| 'final_correct': final_correct,
|
| 'final_acc': round(100 * final_acc, 2),
|
| 'category_stats': {cat: dict(s) for cat, s in cat_stats.items()},
|
| }
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument('--results_dir', type=str, default=None)
|
| parser.add_argument('--api_key', type=str, default=None)
|
| args = parser.parse_args()
|
|
|
| api_key = args.api_key or GEMINI_API_KEY
|
|
|
|
|
| results_dir = args.results_dir
|
| if results_dir is None:
|
| for d in [os.path.dirname(os.path.abspath(__file__)),
|
| '/workspace/rl4phyx/RL4Phyx/SFT/sft_eval_footprint/']:
|
| if os.path.exists(os.path.join(d, 'inference_results_base.jsonl')):
|
| results_dir = d
|
| break
|
| if results_dir is None:
|
| print("ERROR: Cannot find inference results. Use --results_dir")
|
| sys.exit(1)
|
|
|
| print("=" * 60)
|
| print(" OPEN-ENDED EVAL: Rule-based + Gemini 2.5 Flash Judge")
|
| print(f" Results dir: {results_dir}")
|
| print("=" * 60)
|
|
|
|
|
| test_file = os.path.join(results_dir, 'test_1533_openended.jsonl')
|
| if os.path.exists(test_file):
|
| with open(test_file, 'r') as f:
|
| test_data = {json.loads(l)['index']: json.loads(l) for l in f if l.strip()}
|
| print(f"Test data loaded: {len(test_data)} samples")
|
|
|
|
|
| base_file = os.path.join(results_dir, 'inference_results_base.jsonl')
|
| with open(base_file, 'r') as f:
|
| base_results = [json.loads(l) for l in f if l.strip()]
|
| base_scored_file = os.path.join(results_dir, 'scored_results_base.jsonl')
|
| base_stats = score_model(base_results, "Qwen2.5-VL-3B (Base)", api_key, base_scored_file)
|
|
|
|
|
| sft_file = os.path.join(results_dir, 'inference_results_sft.jsonl')
|
| with open(sft_file, 'r') as f:
|
| sft_results = [json.loads(l) for l in f if l.strip()]
|
| sft_scored_file = os.path.join(results_dir, 'scored_results_sft.jsonl')
|
| sft_stats = score_model(sft_results, "Qwen2.5-VL-3B (SFT)", api_key, sft_scored_file)
|
|
|
|
|
| delta = sft_stats['final_acc'] - base_stats['final_acc']
|
| report = {
|
| 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
|
| 'scoring_method': 'rule-based + Gemini 2.5 Flash judge (fallback)',
|
| 'base': base_stats,
|
| 'sft': sft_stats,
|
| 'improvement': f"{delta:+.2f}%",
|
| }
|
|
|
| report_file = os.path.join(results_dir, 'comparison_report.json')
|
| with open(report_file, 'w', encoding='utf-8') as f:
|
| json.dump(report, f, indent=2, ensure_ascii=False)
|
|
|
| print(f"\n{'='*60}")
|
| print(f" FINAL COMPARISON")
|
| print(f"{'='*60}")
|
| print(f" Base accuracy: {base_stats['final_acc']}% ({base_stats['final_correct']}/{base_stats['total']})")
|
| print(f" SFT accuracy: {sft_stats['final_acc']}% ({sft_stats['final_correct']}/{sft_stats['total']})")
|
| print(f" Improvement: {delta:+.2f}%")
|
| print(f"\n Per-category:")
|
| all_cats = sorted(set(list(base_stats['category_stats'].keys()) + list(sft_stats['category_stats'].keys())))
|
| for cat in all_cats:
|
| b = base_stats['category_stats'].get(cat, {'final_correct': 0, 'total': 0})
|
| s = sft_stats['category_stats'].get(cat, {'final_correct': 0, 'total': 0})
|
| b_acc = b['final_correct'] / b['total'] if b['total'] > 0 else 0
|
| s_acc = s['final_correct'] / s['total'] if s['total'] > 0 else 0
|
| print(f" {cat:25s} Base: {b_acc:.1%} SFT: {s_acc:.1%} Δ: {(s_acc-b_acc)*100:+.1f}%")
|
| print(f"\n Report: {report_file}")
|
| print(f"{'='*60}")
|
|
|
|
|
| if __name__ == '__main__':
|
| main()
|
|
|