#!/usr/bin/env python3 """Build per-qid record shards for the Open-WikiTable viewer. Inputs: --in-split path to a *.json from Open-WikiTable (columnar pandas format) --in-tables path to splitted_tables.json (the row-wise chunked corpus) Outputs: --out-dir directory; writes /index.json and /records/.json Each record bundles the qid metadata with the three buckets of candidate tables (hard_positive / positive / negative), with the referenced chunks denormalized (header + rows + page_title + section_title + caption + name). Indexing note: hard_positive_idx / positive_idx / negative_idx hold 1-based ids into splitted_tables.id (verified against the reference dataloader at Open_WikiTable/src/dataloader.py:227 -> `index = [i-1 for i in index]`). """ from __future__ import annotations import argparse import json import shutil from pathlib import Path def load_columnar(path: Path) -> list[dict]: """Read a pandas to_json columnar dump and yield row dicts in original order.""" with path.open() as f: cols = json.load(f) keys = list(cols.keys()) row_ids = list(cols[keys[0]].keys()) out = [] for rid in row_ids: out.append({k: cols[k][rid] for k in keys}) return out def chunk_payload(rec: dict) -> dict: """Project a splitted_tables row to the fields the UI renders.""" return { "chunk_id": rec["id"], "name": rec["name"], "page_title": rec["page_title"], "section_title": rec["section_title"], "caption": rec["caption"], "header": rec["header"], "rows": rec["rows"], } def main(): ap = argparse.ArgumentParser() ap.add_argument("--in-split", required=True, type=Path) ap.add_argument("--in-tables", required=True, type=Path) ap.add_argument("--out-dir", required=True, type=Path) ap.add_argument( "--max-neg", type=int, default=0, help="If >0, truncate negative_idx to at most N chunks per qid (test ships ~2 each, so default 0 = no cap).", ) args = ap.parse_args() print(f"loading split: {args.in_split}") rows = load_columnar(args.in_split) print(f" -> {len(rows)} qids") print(f"loading tables: {args.in_tables}") table_rows = load_columnar(args.in_tables) print(f" -> {len(table_rows)} chunks") # Index by the 1-based `id` field used by the *_idx lists. by_chunk_id: dict[int, dict] = {int(r["id"]): r for r in table_rows} out_dir = args.out_dir records_dir = out_dir / "records" if records_dir.exists(): shutil.rmtree(records_dir) records_dir.mkdir(parents=True) index_entries = [] missing = 0 for r in rows: qid = r["question_id"] idx_buckets = { "hard_positive": list(r.get("hard_positive_idx") or []), "positive": list(r.get("positive_idx") or []), "negative": list(r.get("negative_idx") or []), } if args.max_neg > 0: idx_buckets["negative"] = idx_buckets["negative"][: args.max_neg] tables = {bucket: [] for bucket in idx_buckets} for bucket, ids in idx_buckets.items(): for cid in ids: src = by_chunk_id.get(int(cid)) if src is None: missing += 1 continue tables[bucket].append(chunk_payload(src)) record = { "question_id": qid, "dataset": r["dataset"], "question": r["question"], "sql": r["sql"], "answer": r["answer"], "original_table_id": r["original_table_id"], "tables": tables, } shard_path = records_dir / f"{qid}.json" with shard_path.open("w") as f: json.dump(record, f, ensure_ascii=False, separators=(",", ":")) index_entries.append( { "qid": qid, "dataset": r["dataset"], "question": r["question"], "n_hard": len(tables["hard_positive"]), "n_pos": len(tables["positive"]), "n_neg": len(tables["negative"]), } ) index_path = out_dir / "index.json" with index_path.open("w") as f: json.dump(index_entries, f, ensure_ascii=False, separators=(",", ":")) total_bytes = sum(p.stat().st_size for p in records_dir.iterdir()) print( f"wrote {len(index_entries)} shards under {records_dir} " f"({total_bytes/1024/1024:.1f} MB total)" ) print(f"wrote index: {index_path} ({index_path.stat().st_size/1024:.1f} KB)") if missing: print(f"WARN: {missing} referenced chunk ids were not found in tables file") if __name__ == "__main__": main()