open-wikitable-viewer / scripts /build_dataset.py
timchen0618's picture
Initial: Open-WikiTable test split viewer (6,602 qids)
2da3a94 verified
Raw
History Blame Contribute Delete
4.78 kB
#!/usr/bin/env python3
"""Build per-qid record shards for the Open-WikiTable viewer.
Inputs:
--in-split path to a *.json from Open-WikiTable (columnar pandas format)
--in-tables path to splitted_tables.json (the row-wise chunked corpus)
Outputs:
--out-dir directory; writes <out-dir>/index.json and <out-dir>/records/<qid>.json
Each record bundles the qid metadata with the three buckets of candidate tables
(hard_positive / positive / negative), with the referenced chunks denormalized
(header + rows + page_title + section_title + caption + name).
Indexing note: hard_positive_idx / positive_idx / negative_idx hold 1-based ids
into splitted_tables.id (verified against the reference dataloader at
Open_WikiTable/src/dataloader.py:227 -> `index = [i-1 for i in index]`).
"""
from __future__ import annotations
import argparse
import json
import shutil
from pathlib import Path
def load_columnar(path: Path) -> list[dict]:
"""Read a pandas to_json columnar dump and yield row dicts in original order."""
with path.open() as f:
cols = json.load(f)
keys = list(cols.keys())
row_ids = list(cols[keys[0]].keys())
out = []
for rid in row_ids:
out.append({k: cols[k][rid] for k in keys})
return out
def chunk_payload(rec: dict) -> dict:
"""Project a splitted_tables row to the fields the UI renders."""
return {
"chunk_id": rec["id"],
"name": rec["name"],
"page_title": rec["page_title"],
"section_title": rec["section_title"],
"caption": rec["caption"],
"header": rec["header"],
"rows": rec["rows"],
}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--in-split", required=True, type=Path)
ap.add_argument("--in-tables", required=True, type=Path)
ap.add_argument("--out-dir", required=True, type=Path)
ap.add_argument(
"--max-neg",
type=int,
default=0,
help="If >0, truncate negative_idx to at most N chunks per qid (test ships ~2 each, so default 0 = no cap).",
)
args = ap.parse_args()
print(f"loading split: {args.in_split}")
rows = load_columnar(args.in_split)
print(f" -> {len(rows)} qids")
print(f"loading tables: {args.in_tables}")
table_rows = load_columnar(args.in_tables)
print(f" -> {len(table_rows)} chunks")
# Index by the 1-based `id` field used by the *_idx lists.
by_chunk_id: dict[int, dict] = {int(r["id"]): r for r in table_rows}
out_dir = args.out_dir
records_dir = out_dir / "records"
if records_dir.exists():
shutil.rmtree(records_dir)
records_dir.mkdir(parents=True)
index_entries = []
missing = 0
for r in rows:
qid = r["question_id"]
idx_buckets = {
"hard_positive": list(r.get("hard_positive_idx") or []),
"positive": list(r.get("positive_idx") or []),
"negative": list(r.get("negative_idx") or []),
}
if args.max_neg > 0:
idx_buckets["negative"] = idx_buckets["negative"][: args.max_neg]
tables = {bucket: [] for bucket in idx_buckets}
for bucket, ids in idx_buckets.items():
for cid in ids:
src = by_chunk_id.get(int(cid))
if src is None:
missing += 1
continue
tables[bucket].append(chunk_payload(src))
record = {
"question_id": qid,
"dataset": r["dataset"],
"question": r["question"],
"sql": r["sql"],
"answer": r["answer"],
"original_table_id": r["original_table_id"],
"tables": tables,
}
shard_path = records_dir / f"{qid}.json"
with shard_path.open("w") as f:
json.dump(record, f, ensure_ascii=False, separators=(",", ":"))
index_entries.append(
{
"qid": qid,
"dataset": r["dataset"],
"question": r["question"],
"n_hard": len(tables["hard_positive"]),
"n_pos": len(tables["positive"]),
"n_neg": len(tables["negative"]),
}
)
index_path = out_dir / "index.json"
with index_path.open("w") as f:
json.dump(index_entries, f, ensure_ascii=False, separators=(",", ":"))
total_bytes = sum(p.stat().st_size for p in records_dir.iterdir())
print(
f"wrote {len(index_entries)} shards under {records_dir} "
f"({total_bytes/1024/1024:.1f} MB total)"
)
print(f"wrote index: {index_path} ({index_path.stat().st_size/1024:.1f} KB)")
if missing:
print(f"WARN: {missing} referenced chunk ids were not found in tables file")
if __name__ == "__main__":
main()