rl4phyx-backup / root_scripts /analyze_gt.py
YUNTA88's picture
Upload root_scripts/analyze_gt.py with huggingface_hub
0c7080e verified
import json, re, sys
# Reuse parse logic from physics_reward
UNIT_PREFIXES = {
'T': 1e12, 'G': 1e9, 'M': 1e6, 'k': 1e3, 'h': 1e2,
'c': 1e-2, 'm': 1e-3, 'u': 1e-6, 'n': 1e-9, 'p': 1e-12,
}
BASE_UNITS = {'N','Pa','J','W','Hz','m','s','g','kg','m/s','m/s^2','V','A','C','F','H','T','Wb','K','mol','cd','rad','sr','dB','eV'}
COMPOUND_UNITS = {'km/h','km/s','cm/s','mm/s','kg/m^3','g/cm^3','N/m','N/m^2','J/s','W/m^2','rad/s','rpm','m/s^2'}
def parse_gt(text):
text = text.strip()
# Try to match: number + optional unit
m = re.match(r'^([+-]?\d+\.?\d*(?:[eE][+-]?\d+)?)\s*(.*)$', text)
if not m:
return None, None, "no_number"
value = float(m.group(1))
unit_str = m.group(2).strip()
if not unit_str:
return value, None, "number_only"
if unit_str in COMPOUND_UNITS or unit_str in BASE_UNITS:
return value, unit_str, "ok"
if len(unit_str) >= 2 and unit_str[0] in UNIT_PREFIXES:
rest = unit_str[1:]
if rest in BASE_UNITS or rest in COMPOUND_UNITS:
return value * UNIT_PREFIXES[unit_str[0]], rest, "ok_prefixed"
return value, unit_str, "unknown_unit"
# Load the 1533 test questions
f = "/workspace/rl4phyx/RL4Phyx/SFT/sft_eval_footprint/inference_results_base.jsonl"
with open(f) as fh:
lines = [json.loads(l) for l in fh if l.strip()]
categories = {}
hard_cases = []
for r in lines:
gt = str(r.get("ground_truth_value", "")).strip()
cat = r.get("category", "unknown")
val, unit, status = parse_gt(gt)
if status not in categories:
categories[status] = []
categories[status].append(gt)
if status in ("no_number", "unknown_unit"):
hard_cases.append({"gt": gt, "category": cat, "status": status})
print("=== GT Format Analysis (1533 questions) ===\n")
for status, items in sorted(categories.items(), key=lambda x: -len(x[1])):
print(f" {status:20s}: {len(items):4d} ({len(items)/len(lines)*100:.1f}%)")
print(f"\n=== Hard Cases ({len(hard_cases)} total) ===\n")
# Group hard cases by status
for status_type in ["no_number", "unknown_unit"]:
cases = [c for c in hard_cases if c["status"] == status_type]
if not cases:
continue
print(f"--- {status_type} ({len(cases)}) ---")
# Show unique GT values (deduplicated)
seen = set()
for c in cases:
gt = c["gt"]
if gt not in seen and len(seen) < 30:
seen.add(gt)
print(f" [{c['category']:20s}] {gt[:80]}")
if len(cases) > 30:
print(f" ... and {len(cases)-30} more")
print()