|
|
""" |
|
|
Dataset preparation script - streams ORD dataset, applies transformations, and saves splits. |
|
|
""" |
|
|
import sys |
|
|
import os |
|
|
from pathlib import Path |
|
|
from datasets import load_dataset, DatasetDict, Dataset |
|
|
from huggingface_hub import HfApi, create_repo |
|
|
from config import ( |
|
|
HF_DATASET, CACHE_DIR, SEED, TRAIN_SPLIT, MAX_SAMPLES, |
|
|
FORWARD_DATASET_NAME, RETRO_DATASET_NAME |
|
|
) |
|
|
from utils import forward_example, retro_example, split_dataset_indices, deduplicate_examples |
|
|
from rdkit import RDLogger |
|
|
|
|
|
|
|
|
RDLogger.DisableLog("rdApp.warning") |
|
|
|
|
|
UPLOAD_DATASETS = os.environ.get("ORD_UPLOAD_DATASETS", "1").lower() not in {"0", "false", "no"} |
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HF_MODEL_TOKEN") |
|
|
HF_API = HfApi(token=HF_TOKEN) |
|
|
|
|
|
|
|
|
CACHE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
def push_dataset(dataset: DatasetDict, repo_id: str, label: str): |
|
|
if not UPLOAD_DATASETS: |
|
|
print(f"Skipping upload of {label} dataset (ORD_UPLOAD_DATASETS disabled).") |
|
|
return |
|
|
|
|
|
if not HF_TOKEN: |
|
|
print(f"⚠️ Cannot upload {label} dataset: HF_TOKEN not set.") |
|
|
return |
|
|
|
|
|
try: |
|
|
print(f"Uploading {label} dataset to {repo_id} via push_to_hub...") |
|
|
create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, token=HF_TOKEN) |
|
|
dataset.push_to_hub( |
|
|
repo_id=repo_id, |
|
|
token=HF_TOKEN, |
|
|
max_shard_size="2GB", |
|
|
private=False, |
|
|
commit_message=f"Update {label} dataset", |
|
|
) |
|
|
print(f"✅ Uploaded {label} dataset to Hugging Face Hub.") |
|
|
except Exception as exc: |
|
|
print(f"⚠️ Could not upload {label} dataset: {exc}") |
|
|
|
|
|
|
|
|
def stream_transform(map_fn, max_samples=None): |
|
|
"""Stream dataset and apply transformation function.""" |
|
|
print(f"Loading dataset from {HF_DATASET}...") |
|
|
ds = load_dataset(HF_DATASET, streaming=True)["train"] |
|
|
|
|
|
count = 0 |
|
|
for sample in ds: |
|
|
if max_samples and count >= max_samples: |
|
|
break |
|
|
|
|
|
mapped = map_fn(sample) |
|
|
if mapped: |
|
|
yield mapped |
|
|
count += 1 |
|
|
|
|
|
if count % 50000 == 0: |
|
|
print(f"Processed {count} samples...") |
|
|
|
|
|
|
|
|
def build_dataset(map_fn, name: str, max_samples=None): |
|
|
"""Build and save dataset with given transformation function.""" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Building {name} dataset...") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
split_iter = split_dataset_indices(SEED, TRAIN_SPLIT) |
|
|
buckets = {"train": [], "validation": [], "test": []} |
|
|
|
|
|
|
|
|
for idx, ex in enumerate(stream_transform(map_fn, max_samples)): |
|
|
split = next(split_iter) |
|
|
buckets[split].append(ex) |
|
|
|
|
|
|
|
|
print(f"Deduplicating {name} splits...") |
|
|
for split_name in buckets: |
|
|
before = len(buckets[split_name]) |
|
|
buckets[split_name] = deduplicate_examples(buckets[split_name]) |
|
|
after = len(buckets[split_name]) |
|
|
removed = before - after |
|
|
print(f" {split_name}: {before} -> {after} (removed {removed} duplicates)") |
|
|
|
|
|
|
|
|
datasets = {k: Dataset.from_list(v) for k, v in buckets.items() if v} |
|
|
dsd = DatasetDict(datasets) |
|
|
|
|
|
save_path = CACHE_DIR / name |
|
|
dsd.save_to_disk(str(save_path)) |
|
|
|
|
|
if name == "forward": |
|
|
push_dataset(dsd, FORWARD_DATASET_NAME, "forward") |
|
|
elif name == "retro": |
|
|
push_dataset(dsd, RETRO_DATASET_NAME, "retro") |
|
|
|
|
|
print(f"\n{name} dataset statistics:") |
|
|
for split_name, ds in dsd.items(): |
|
|
print(f" {split_name}: {len(ds)} samples") |
|
|
|
|
|
print(f"Saved to {save_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main dataset preparation pipeline.""" |
|
|
print("ORD Reaction Dataset Preparation Pipeline") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
build_dataset(forward_example, "forward", max_samples=MAX_SAMPLES) |
|
|
|
|
|
|
|
|
build_dataset(retro_example, "retro", max_samples=MAX_SAMPLES) |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("Dataset preparation complete!") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|