Spaces:
Running
Running
| """Build a chronological train / val / test holdout split for photon-route's | |
| relevance set. | |
| Why this exists: the live `relevance.json` is six hand-curated multi-positive | |
| queries. `expand_titles.py` adds 20 single-positive title-as-query entries. | |
| Together that's 26 queries, but the trainer (`space/train.py`) and the eval | |
| (`eval/run.py`) both consume them as a single pool. So the train-on-everything | |
| + test-on-everything cycle silently overfits β the photon-route memory note | |
| captures the symptom exactly: | |
| "6qΓ20doc eval too small; train==test gives nDCG 0.747, holdout collapses | |
| to 0.071" | |
| This script splits the relevance set chronologically by the arXiv year of the | |
| target paper, writing three sibling JSON files: | |
| relevance_train.json arXiv year β€ TRAIN_CUTOFF | |
| relevance_val.json TRAIN_CUTOFF < year β€ VAL_CUTOFF | |
| relevance_test.json year > VAL_CUTOFF | |
| A chronological split is meaningful here: photonic ML / quantum NLP | |
| terminology drifted between e.g. 2010 and 2024, so training on older work and | |
| evaluating on newer work captures generalisation, not memorisation. | |
| Workflow: | |
| # 1. expand the relevance set if you haven't (one-shot, requires net) | |
| python -m photon_route.eval.expand_titles \ | |
| --out eval/relevance_expanded.json | |
| # 2. split chronologically (offline, just JSON math) | |
| python -m photon_route.eval.split_holdout \ | |
| --in eval/relevance_expanded.json \ | |
| --train-cutoff 2018 --val-cutoff 2020 | |
| # 3. retrain on train + early-stop on val | |
| python -m space.train \ | |
| --relevance eval/relevance_train.json \ | |
| --val-relevance eval/relevance_val.json \ | |
| --out weights_holdout.npz | |
| # 4. evaluate ON TEST ONLY β this is the number you report | |
| python -m eval.run \ | |
| --weights weights_holdout.npz \ | |
| --relevance eval/relevance_test.json | |
| Reusable: the same script works for whatever expanded relevance set you | |
| build next β just feed it in via --in. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import re | |
| from pathlib import Path | |
| # arXiv IDs come in two formats: | |
| # pre-2007: math.GT/0512345, hep-th/0501123 β no usable year here | |
| # 2007+: 0701.1234, 1306.5358, 2304.12717 β YYMM prefix, β₯ 2007 | |
| ARXIV_NEW_RE = re.compile(r"^(\d{2})(\d{2})\.\d{4,5}$") | |
| def year_of(arxiv_id: str) -> int | None: | |
| m = ARXIV_NEW_RE.match(arxiv_id.strip()) | |
| if not m: | |
| return None | |
| yy = int(m.group(1)) | |
| # IDs starting "07β99" β 2007β2099, "00β06" β 2100β2106 (won't happen). | |
| # arXiv started this format Apr 2007, so any "00β06" prefix is a typo. | |
| return 2000 + yy if yy >= 7 else None | |
| def split_query_by_year(q: dict, train_cutoff: int, val_cutoff: int): | |
| """Return ('train' | 'val' | 'test', q) based on the youngest relevant doc. | |
| Why youngest, not oldest: a query whose RELEVANT papers reach into 2023 | |
| can't be in the training set if our cutoff is 2018 β we'd be leaking | |
| future labels into training. Use the youngest year as the date-of-knowledge. | |
| """ | |
| years = [year_of(rid) for rid in q.get("relevant_ids", [])] | |
| years = [y for y in years if y is not None] | |
| if not years: | |
| # Old-format IDs only β put in train (oldest bucket). | |
| return "train", q | |
| y_max = max(years) | |
| if y_max <= train_cutoff: | |
| return "train", q | |
| if y_max <= val_cutoff: | |
| return "val", q | |
| return "test", q | |
| def main() -> int: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--in", dest="inp", type=Path, | |
| default=Path(__file__).parent / "relevance_expanded.json", | |
| help="source relevance file (default: relevance_expanded.json)") | |
| ap.add_argument("--out-dir", type=Path, default=Path(__file__).parent, | |
| help="where to write the three split files") | |
| ap.add_argument("--train-cutoff", type=int, default=2018, | |
| help="queries whose youngest relevant doc is β€ this year β train") | |
| ap.add_argument("--val-cutoff", type=int, default=2020, | |
| help="train-cutoff < year β€ this β val; year > this β test") | |
| args = ap.parse_args() | |
| if not args.inp.exists(): | |
| print(f"missing {args.inp} β run expand_titles.py first?") | |
| return 1 | |
| src = json.loads(args.inp.read_text("utf-8")) | |
| queries = src.get("queries", []) | |
| buckets = {"train": [], "val": [], "test": []} | |
| for q in queries: | |
| bucket, q_keep = split_query_by_year(q, args.train_cutoff, args.val_cutoff) | |
| buckets[bucket].append(q_keep) | |
| schema_v = src.get("schema_version", 1) | |
| base_desc = src.get("description", "Chronological holdout split") | |
| for name, qs in buckets.items(): | |
| path = args.out_dir / f"relevance_{name}.json" | |
| path.write_text(json.dumps({ | |
| "schema_version": schema_v, | |
| "description": f"{base_desc} β {name} split " | |
| f"(trainβ€{args.train_cutoff} < valβ€{args.val_cutoff} < test).", | |
| "queries": qs, | |
| }, indent=2) + "\n", encoding="utf-8") | |
| print(f" {name:<6} {len(qs):3d} queries β {path.relative_to(args.out_dir.parent)}") | |
| print(f"\n[split-holdout] total={len(queries)} " | |
| f"train={len(buckets['train'])} val={len(buckets['val'])} test={len(buckets['test'])}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |