File size: 1,912 Bytes
661c54a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
# score_results.py
import argparse, json, re
from typing import List, Dict, Any
def normalize(s: str) -> str:
s = s.replace("```", " ")
s = s.strip().lower()
# 把多空白压缩为单个空格,去掉常见对齐缩进影响
s = re.sub(r"\s+", " ", s)
return s
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--pred_path", type=str, required=True, help="eval 的输出 JSON")
ap.add_argument("--out_path", type=str, default="./valid_clean/valid.json", help="评分明细输出 JSON")
args = ap.parse_args()
with open(args.pred_path, "r", encoding="utf-8") as f:
preds: List[Dict[str, Any]] = json.load(f)
rows = []
hit, total = 0, 0
for item in preds:
gt = item.get("ground_truth", "")
pred = item.get("model_output", "")
# 只有有真解的样本才计分
if gt is None or gt == "":
rows.append({
"id": item.get("id"),
"match": None,
"reason": "missing_ground_truth",
"ground_truth": gt,
"model_output": pred
})
continue
total += 1
ngt = normalize(gt)
npred = normalize(pred)
match = (npred in ngt)
if match:
hit += 1
rows.append({
"id": item.get("id"),
"match": bool(match),
"ground_truth": gt,
"model_output": pred
})
summary = {
"total_with_gt": total,
"matched": hit,
"accuracy": (hit / total) if total > 0 else None
}
out = {"summary": summary, "details": rows}
with open(args.out_path, "w", encoding="utf-8") as f:
json.dump(out, f, ensure_ascii=False, indent=2)
print(f"[SUMMARY] matched {hit}/{total} = {summary['accuracy']:.4f}" if total else "[SUMMARY] no GT")
if __name__ == "__main__":
main()
|