"""Create a notebook-style local validation split. The official example notebook creates validation data by: 1. sampling 90% of train author-paper edges as training edges; 2. using the remaining 10% known positives as validation positives; 3. sampling the same number of random negatives not present in all known refs. This script materializes that split so later experiments use identical data. """ from __future__ import annotations import argparse from pathlib import Path import numpy as np import pandas as pd def read_txt(path: Path) -> list[list[int]]: rows: list[list[int]] = [] with path.open("r") as f: for line in f: rows.append(list(map(int, line.strip().split()))) return rows def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1]) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--train-frac", type=float, default=0.9) args = parser.parse_args() root = args.package_root data_dir = root / "data_and_docs" split_dir = root / "splits" / f"notebook_seed{args.seed}" split_dir.mkdir(parents=True, exist_ok=True) refs = read_txt(data_dir / "bipartite_train_ann.txt") coauthor = read_txt(data_dir / "author_file_ann.txt") citation = read_txt(data_dir / "paper_file_ann.txt") test_refs = read_txt(data_dir / "bipartite_test_ann.txt") ref_edges = pd.DataFrame(refs, columns=["source", "target"]) ref_edges = ref_edges.set_index("r-" + ref_edges.index.astype(str)) coauthor_edges = pd.DataFrame(coauthor, columns=["source", "target"]) citation_edges = pd.DataFrame(citation, columns=["source", "target"]) test_arr = np.array(test_refs, dtype=np.int64) node_tmp = pd.concat([citation_edges["source"], citation_edges["target"], ref_edges["target"]]) paper_ids = pd.unique(node_tmp).astype(np.int64) node_tmp = pd.concat([ref_edges["source"], coauthor_edges["source"], coauthor_edges["target"]]) author_ids = pd.unique(node_tmp).astype(np.int64) train_refs = ref_edges.sample(frac=args.train_frac, random_state=args.seed, axis=0) val_pos = ref_edges[~ref_edges.index.isin(train_refs.index)].copy() val_pos.loc[:, "label"] = 1 existing_ref_set = set(map(tuple, ref_edges[["source", "target"]].to_numpy().tolist())) neg_pairs: list[tuple[int, int]] = [] rng = np.random.default_rng(args.seed) while len(neg_pairs) < len(val_pos): src = int(rng.choice(author_ids)) dst = int(rng.choice(paper_ids)) if (src, dst) not in existing_ref_set: neg_pairs.append((src, dst)) val_neg = pd.DataFrame(neg_pairs, columns=["source", "target"]) val_neg.loc[:, "label"] = 0 val_pairs = pd.concat([val_pos.reset_index(drop=True), val_neg], ignore_index=True) val_pairs = val_pairs.sample(frac=1, random_state=args.seed, axis=0).reset_index(drop=True) train_refs[["source", "target"]].to_csv(split_dir / "train_refs.csv", index=False) val_pairs[["source", "target", "label"]].to_csv(split_dir / "val_pairs.csv", index=False) np.save(split_dir / "test_refs.npy", test_arr) print(f"wrote {split_dir}") print(f"train positives: {len(train_refs)}") print(f"val positives: {int(val_pairs['label'].sum())}") print(f"val negatives: {int((val_pairs['label'] == 0).sum())}") print(f"val total: {len(val_pairs)}") if __name__ == "__main__": main()