ord-training-simple / src /dataset_prepare.py
Vaishnav14220
Push datasets via push_to_hub and load from Hub on resume
fdbfba8
"""
Dataset preparation script - streams ORD dataset, applies transformations, and saves splits.
"""
import sys
import os
from pathlib import Path
from datasets import load_dataset, DatasetDict, Dataset
from huggingface_hub import HfApi, create_repo
from config import (
HF_DATASET, CACHE_DIR, SEED, TRAIN_SPLIT, MAX_SAMPLES,
FORWARD_DATASET_NAME, RETRO_DATASET_NAME
)
from utils import forward_example, retro_example, split_dataset_indices, deduplicate_examples
from rdkit import RDLogger
# Silence RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
RDLogger.DisableLog("rdApp.warning")
UPLOAD_DATASETS = os.environ.get("ORD_UPLOAD_DATASETS", "1").lower() not in {"0", "false", "no"}
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HF_MODEL_TOKEN")
HF_API = HfApi(token=HF_TOKEN)
# Create cache directory
CACHE_DIR.mkdir(parents=True, exist_ok=True)
def push_dataset(dataset: DatasetDict, repo_id: str, label: str):
if not UPLOAD_DATASETS:
print(f"Skipping upload of {label} dataset (ORD_UPLOAD_DATASETS disabled).")
return
if not HF_TOKEN:
print(f"⚠️ Cannot upload {label} dataset: HF_TOKEN not set.")
return
try:
print(f"Uploading {label} dataset to {repo_id} via push_to_hub...")
create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, token=HF_TOKEN)
dataset.push_to_hub(
repo_id=repo_id,
token=HF_TOKEN,
max_shard_size="2GB",
private=False,
commit_message=f"Update {label} dataset",
)
print(f"✅ Uploaded {label} dataset to Hugging Face Hub.")
except Exception as exc:
print(f"⚠️ Could not upload {label} dataset: {exc}")
def stream_transform(map_fn, max_samples=None):
"""Stream dataset and apply transformation function."""
print(f"Loading dataset from {HF_DATASET}...")
ds = load_dataset(HF_DATASET, streaming=True)["train"]
count = 0
for sample in ds:
if max_samples and count >= max_samples:
break
mapped = map_fn(sample)
if mapped:
yield mapped
count += 1
if count % 50000 == 0:
print(f"Processed {count} samples...")
def build_dataset(map_fn, name: str, max_samples=None):
"""Build and save dataset with given transformation function."""
print(f"\n{'='*60}")
print(f"Building {name} dataset...")
print(f"{'='*60}")
split_iter = split_dataset_indices(SEED, TRAIN_SPLIT)
buckets = {"train": [], "validation": [], "test": []}
# Stream and collect data
for idx, ex in enumerate(stream_transform(map_fn, max_samples)):
split = next(split_iter)
buckets[split].append(ex)
# Deduplicate
print(f"Deduplicating {name} splits...")
for split_name in buckets:
before = len(buckets[split_name])
buckets[split_name] = deduplicate_examples(buckets[split_name])
after = len(buckets[split_name])
removed = before - after
print(f" {split_name}: {before} -> {after} (removed {removed} duplicates)")
# Create and save dataset
datasets = {k: Dataset.from_list(v) for k, v in buckets.items() if v}
dsd = DatasetDict(datasets)
save_path = CACHE_DIR / name
dsd.save_to_disk(str(save_path))
if name == "forward":
push_dataset(dsd, FORWARD_DATASET_NAME, "forward")
elif name == "retro":
push_dataset(dsd, RETRO_DATASET_NAME, "retro")
print(f"\n{name} dataset statistics:")
for split_name, ds in dsd.items():
print(f" {split_name}: {len(ds)} samples")
print(f"Saved to {save_path}")
def main():
"""Main dataset preparation pipeline."""
print("ORD Reaction Dataset Preparation Pipeline")
print("=" * 60)
# Build forward dataset
build_dataset(forward_example, "forward", max_samples=MAX_SAMPLES)
# Build retro dataset
build_dataset(retro_example, "retro", max_samples=MAX_SAMPLES)
print(f"\n{'='*60}")
print("Dataset preparation complete!")
print(f"{'='*60}")
if __name__ == "__main__":
main()