File size: 4,085 Bytes
30fb9c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb493c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30fb9c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb493c
30fb9c5
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb493c
 
30fb9c5
 
 
 
 
1eb493c
 
30fb9c5
 
 
 
 
 
1eb493c
30fb9c5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python3
"""
Patch the traj_summary_orig_ext (orig-analysis) HF dataset to add
question/correct_answer/correct by joining with eval result files.

Dataset: timchen0618/browsecomp-plus-selected-tools-orig-analysis-v1  (826 rows)
Eval dir: evals/bcp/Qwen3-Embedding-8B/full/gpt-oss-120b/
          traj_summary_orig_ext_selected_tools_gpt-oss-120b_seed0  (832 eval files)

Python env: /scratch/hc3337/envs/raca-py312/bin/python
"""
from __future__ import annotations
import argparse, json, sys, os
from pathlib import Path

os.environ.setdefault("HF_HOME", "/scratch/hc3337/.cache/huggingface")

REPO = "timchen0618/browsecomp-plus-selected-tools-orig-analysis-v1"
EVAL_DIR = Path("/scratch/hc3337/projects/BrowseComp-Plus/evals/bcp/Qwen3-Embedding-8B/full/gpt-oss-120b/traj_summary_orig_ext_selected_tools_gpt-oss-120b_seed0")
BC_JSONL = Path("/scratch/hc3337/projects/BrowseComp-Plus/data/browsecomp_plus_decrypted_test300.jsonl")


def load_browsecomp_questions(jsonl_path: Path) -> dict:
    qmap: dict = {}
    if not jsonl_path.exists():
        return qmap
    with jsonl_path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                d = json.loads(line)
                qid_raw = str(d.get("query_id", "")).strip()
                qid = int(qid_raw) if qid_raw.isdigit() else qid_raw
                q = d.get("query") or d.get("question") or ""
                if qid and q:
                    qmap[qid] = q
            except Exception as e:
                print(f"warning: skipping line: {e}", file=sys.stderr)
    print(f"Loaded {len(qmap)} questions from {jsonl_path}", file=sys.stderr)
    return qmap


def load_eval_data(eval_dir: Path) -> dict:
    eval_map: dict = {}
    for p in eval_dir.glob("*_eval.json"):
        try:
            d = json.load(p.open("r", encoding="utf-8"))
            qid_raw = str(d.get("query_id", "")).strip()
            qid = int(qid_raw) if qid_raw.isdigit() else qid_raw
            jr = d.get("judge_result") or {}
            correct_val = jr.get("correct")
            eval_map[qid] = {
                "question": str(d.get("question") or ""),
                "correct_answer": str(d.get("correct_answer") or ""),
                "correct": bool(correct_val) if correct_val is not None else None,
            }
        except Exception as e:
            print(f"warning: skipping {p.name}: {e}", file=sys.stderr)
    print(f"Loaded {len(eval_map)} eval entries from {eval_dir}", file=sys.stderr)
    return eval_map


def main():
    from datasets import load_dataset, Dataset

    eval_map = load_eval_data(EVAL_DIR)
    bc_questions = load_browsecomp_questions(BC_JSONL)

    print(f"Loading {REPO}...", file=sys.stderr)
    ds = load_dataset(REPO, split="train")
    print(f"Loaded {len(ds)} rows. Columns: {ds.column_names}", file=sys.stderr)

    rows = []
    matched = 0
    for row in ds:
        qid_raw = str(row["query_id"]).strip()
        qid = int(qid_raw) if qid_raw.isdigit() else qid_raw
        ev = eval_map.get(qid, {})
        if ev:
            matched += 1
        r = dict(row)
        question = ev.get("question", "") or bc_questions.get(qid, "")
        r["question"] = question
        r["correct_answer"] = ev.get("correct_answer", "")
        r["correct"] = ev.get("correct", None)
        rows.append(r)

    print(f"Matched {matched}/{len(rows)} rows with eval data.", file=sys.stderr)
    no_question = sum(1 for r in rows if not r.get("question"))
    print(f"Rows missing question: {no_question}", file=sys.stderr)
    correct_count = sum(1 for r in rows if r.get("correct") is True)
    if matched:
        print(f"Accuracy: {correct_count}/{matched} ({100*correct_count//matched}%)", file=sys.stderr)

    ds_new = Dataset.from_list(rows)
    ds_new.push_to_hub(REPO, split="train",
                       commit_message="Fix missing questions via BrowseComp JSONL fallback")
    print(f"Pushed {len(rows)} rows to {REPO}.")


if __name__ == "__main__":
    main()