| |
| """Left-join score columns onto a base JSONL by `id`.""" |
|
|
| import argparse |
| import json |
| import os |
| from typing import Dict |
|
|
|
|
| def load_by_id(path: str) -> Dict[str, dict]: |
| m: Dict[str, dict] = {} |
| with open(path, encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| r = json.loads(line) |
| m[r["id"]] = r |
| return m |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser( |
| epilog="Omit --l2d until you have a real JSONL from L2D (documentation paths like path\\\\to\\\\... are not valid).", |
| ) |
| p.add_argument("--base", required=True, help="Base JSONL (must include id, label, text, ...)") |
| p.add_argument("-o", "--output", required=True) |
| p.add_argument("--ada", help="JSONL with ada_score") |
| p.add_argument("--l2d", help="JSONL with l2d_score") |
| p.add_argument("--sup", help="JSONL with sup_score") |
| args = p.parse_args() |
|
|
| def _require_file(path: str, flag: str) -> None: |
| if not os.path.isfile(path): |
| raise SystemExit( |
| f"{flag} file not found: {path!r}\n" |
| f" Fix the path, or omit {flag} until you have that scores JSONL " |
| f"(e.g. merge only Ada: --base ... --ada ... -o ...)." |
| ) |
|
|
| _require_file(args.base, "--base") |
| if args.ada: |
| _require_file(args.ada, "--ada") |
| if args.l2d: |
| _require_file(args.l2d, "--l2d") |
| if args.sup: |
| _require_file(args.sup, "--sup") |
|
|
| ada_m = load_by_id(args.ada) if args.ada else {} |
| l2d_m = load_by_id(args.l2d) if args.l2d else {} |
| sup_m = load_by_id(args.sup) if args.sup else {} |
|
|
| base_abs = os.path.abspath(args.base) |
| out_abs = os.path.abspath(args.output) |
| if base_abs == out_abs: |
| with open(args.base, encoding="utf-8") as fin: |
| base_lines = [ln for ln in fin if ln.strip()] |
| else: |
| base_lines = None |
|
|
| def iter_nonempty_lines(): |
| if base_lines is not None: |
| for line in base_lines: |
| yield line |
| else: |
| with open(args.base, encoding="utf-8") as fin: |
| for line in fin: |
| if line.strip(): |
| yield line |
|
|
| with open(args.output, "w", encoding="utf-8") as fout: |
| for line in iter_nonempty_lines(): |
| line = line.strip() |
| row = json.loads(line) |
| i = row["id"] |
| if i in ada_m and ada_m[i].get("ada_score") is not None: |
| row["ada_score"] = ada_m[i]["ada_score"] |
| if i in l2d_m and l2d_m[i].get("l2d_score") is not None: |
| row["l2d_score"] = l2d_m[i]["l2d_score"] |
| if i in sup_m and sup_m[i].get("sup_score") is not None: |
| row["sup_score"] = sup_m[i]["sup_score"] |
| fout.write(json.dumps(row, ensure_ascii=False) + "\n") |
|
|
| print(f"Wrote {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|