""" Dataset preparation script - streams ORD dataset, applies transformations, and saves splits. """ import sys import os from pathlib import Path from datasets import load_dataset, DatasetDict, Dataset from huggingface_hub import HfApi, create_repo from config import ( HF_DATASET, CACHE_DIR, SEED, TRAIN_SPLIT, MAX_SAMPLES, FORWARD_DATASET_NAME, RETRO_DATASET_NAME ) from utils import forward_example, retro_example, split_dataset_indices, deduplicate_examples from rdkit import RDLogger # Silence RDKit warnings (e.g., "not removing hydrogen atom without neighbors") RDLogger.DisableLog("rdApp.warning") UPLOAD_DATASETS = os.environ.get("ORD_UPLOAD_DATASETS", "1").lower() not in {"0", "false", "no"} HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HF_MODEL_TOKEN") HF_API = HfApi(token=HF_TOKEN) # Create cache directory CACHE_DIR.mkdir(parents=True, exist_ok=True) def push_dataset(dataset: DatasetDict, repo_id: str, label: str): if not UPLOAD_DATASETS: print(f"Skipping upload of {label} dataset (ORD_UPLOAD_DATASETS disabled).") return if not HF_TOKEN: print(f"⚠️ Cannot upload {label} dataset: HF_TOKEN not set.") return try: print(f"Uploading {label} dataset to {repo_id} via push_to_hub...") create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, token=HF_TOKEN) dataset.push_to_hub( repo_id=repo_id, token=HF_TOKEN, max_shard_size="2GB", private=False, commit_message=f"Update {label} dataset", ) print(f"✅ Uploaded {label} dataset to Hugging Face Hub.") except Exception as exc: print(f"⚠️ Could not upload {label} dataset: {exc}") def stream_transform(map_fn, max_samples=None): """Stream dataset and apply transformation function.""" print(f"Loading dataset from {HF_DATASET}...") ds = load_dataset(HF_DATASET, streaming=True)["train"] count = 0 for sample in ds: if max_samples and count >= max_samples: break mapped = map_fn(sample) if mapped: yield mapped count += 1 if count % 50000 == 0: print(f"Processed {count} samples...") def build_dataset(map_fn, name: str, max_samples=None): """Build and save dataset with given transformation function.""" print(f"\n{'='*60}") print(f"Building {name} dataset...") print(f"{'='*60}") split_iter = split_dataset_indices(SEED, TRAIN_SPLIT) buckets = {"train": [], "validation": [], "test": []} # Stream and collect data for idx, ex in enumerate(stream_transform(map_fn, max_samples)): split = next(split_iter) buckets[split].append(ex) # Deduplicate print(f"Deduplicating {name} splits...") for split_name in buckets: before = len(buckets[split_name]) buckets[split_name] = deduplicate_examples(buckets[split_name]) after = len(buckets[split_name]) removed = before - after print(f" {split_name}: {before} -> {after} (removed {removed} duplicates)") # Create and save dataset datasets = {k: Dataset.from_list(v) for k, v in buckets.items() if v} dsd = DatasetDict(datasets) save_path = CACHE_DIR / name dsd.save_to_disk(str(save_path)) if name == "forward": push_dataset(dsd, FORWARD_DATASET_NAME, "forward") elif name == "retro": push_dataset(dsd, RETRO_DATASET_NAME, "retro") print(f"\n{name} dataset statistics:") for split_name, ds in dsd.items(): print(f" {split_name}: {len(ds)} samples") print(f"Saved to {save_path}") def main(): """Main dataset preparation pipeline.""" print("ORD Reaction Dataset Preparation Pipeline") print("=" * 60) # Build forward dataset build_dataset(forward_example, "forward", max_samples=MAX_SAMPLES) # Build retro dataset build_dataset(retro_example, "retro", max_samples=MAX_SAMPLES) print(f"\n{'='*60}") print("Dataset preparation complete!") print(f"{'='*60}") if __name__ == "__main__": main()