| |
| """Build per-qid record shards for the Open-WikiTable viewer. |
| |
| Inputs: |
| --in-split path to a *.json from Open-WikiTable (columnar pandas format) |
| --in-tables path to splitted_tables.json (the row-wise chunked corpus) |
| |
| Outputs: |
| --out-dir directory; writes <out-dir>/index.json and <out-dir>/records/<qid>.json |
| |
| Each record bundles the qid metadata with the three buckets of candidate tables |
| (hard_positive / positive / negative), with the referenced chunks denormalized |
| (header + rows + page_title + section_title + caption + name). |
| |
| Indexing note: hard_positive_idx / positive_idx / negative_idx hold 1-based ids |
| into splitted_tables.id (verified against the reference dataloader at |
| Open_WikiTable/src/dataloader.py:227 -> `index = [i-1 for i in index]`). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import shutil |
| from pathlib import Path |
|
|
|
|
| def load_columnar(path: Path) -> list[dict]: |
| """Read a pandas to_json columnar dump and yield row dicts in original order.""" |
| with path.open() as f: |
| cols = json.load(f) |
| keys = list(cols.keys()) |
| row_ids = list(cols[keys[0]].keys()) |
| out = [] |
| for rid in row_ids: |
| out.append({k: cols[k][rid] for k in keys}) |
| return out |
|
|
|
|
| def chunk_payload(rec: dict) -> dict: |
| """Project a splitted_tables row to the fields the UI renders.""" |
| return { |
| "chunk_id": rec["id"], |
| "name": rec["name"], |
| "page_title": rec["page_title"], |
| "section_title": rec["section_title"], |
| "caption": rec["caption"], |
| "header": rec["header"], |
| "rows": rec["rows"], |
| } |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--in-split", required=True, type=Path) |
| ap.add_argument("--in-tables", required=True, type=Path) |
| ap.add_argument("--out-dir", required=True, type=Path) |
| ap.add_argument( |
| "--max-neg", |
| type=int, |
| default=0, |
| help="If >0, truncate negative_idx to at most N chunks per qid (test ships ~2 each, so default 0 = no cap).", |
| ) |
| args = ap.parse_args() |
|
|
| print(f"loading split: {args.in_split}") |
| rows = load_columnar(args.in_split) |
| print(f" -> {len(rows)} qids") |
|
|
| print(f"loading tables: {args.in_tables}") |
| table_rows = load_columnar(args.in_tables) |
| print(f" -> {len(table_rows)} chunks") |
| |
| by_chunk_id: dict[int, dict] = {int(r["id"]): r for r in table_rows} |
|
|
| out_dir = args.out_dir |
| records_dir = out_dir / "records" |
| if records_dir.exists(): |
| shutil.rmtree(records_dir) |
| records_dir.mkdir(parents=True) |
|
|
| index_entries = [] |
| missing = 0 |
| for r in rows: |
| qid = r["question_id"] |
| idx_buckets = { |
| "hard_positive": list(r.get("hard_positive_idx") or []), |
| "positive": list(r.get("positive_idx") or []), |
| "negative": list(r.get("negative_idx") or []), |
| } |
| if args.max_neg > 0: |
| idx_buckets["negative"] = idx_buckets["negative"][: args.max_neg] |
|
|
| tables = {bucket: [] for bucket in idx_buckets} |
| for bucket, ids in idx_buckets.items(): |
| for cid in ids: |
| src = by_chunk_id.get(int(cid)) |
| if src is None: |
| missing += 1 |
| continue |
| tables[bucket].append(chunk_payload(src)) |
|
|
| record = { |
| "question_id": qid, |
| "dataset": r["dataset"], |
| "question": r["question"], |
| "sql": r["sql"], |
| "answer": r["answer"], |
| "original_table_id": r["original_table_id"], |
| "tables": tables, |
| } |
|
|
| shard_path = records_dir / f"{qid}.json" |
| with shard_path.open("w") as f: |
| json.dump(record, f, ensure_ascii=False, separators=(",", ":")) |
|
|
| index_entries.append( |
| { |
| "qid": qid, |
| "dataset": r["dataset"], |
| "question": r["question"], |
| "n_hard": len(tables["hard_positive"]), |
| "n_pos": len(tables["positive"]), |
| "n_neg": len(tables["negative"]), |
| } |
| ) |
|
|
| index_path = out_dir / "index.json" |
| with index_path.open("w") as f: |
| json.dump(index_entries, f, ensure_ascii=False, separators=(",", ":")) |
|
|
| total_bytes = sum(p.stat().st_size for p in records_dir.iterdir()) |
| print( |
| f"wrote {len(index_entries)} shards under {records_dir} " |
| f"({total_bytes/1024/1024:.1f} MB total)" |
| ) |
| print(f"wrote index: {index_path} ({index_path.stat().st_size/1024:.1f} KB)") |
| if missing: |
| print(f"WARN: {missing} referenced chunk ids were not found in tables file") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|