""" DatasetLoader: Loads and processes open scientific datasets. Supports streaming from HuggingFace datasets with sharding. """ import os import json from typing import List, Dict, Optional, Iterator from pathlib import Path try: from datasets import load_dataset, Dataset, IterableDataset import pyarrow.parquet as pq except ImportError: print("Please install datasets and pyarrow: pip install datasets pyarrow") raise class VortexDatasetLoader: """ Loads and processes open scientific datasets. Supports streaming with sharding to Parquet files. """ # Open datasets configuration DATASETS = { "pile_scientific": { "path": "EleutherAI/pile", "subset": "pubmed_central", "split": "train", "text_field": "text", "domain": "biology", # approximate }, "s2orc": { "path": "allenai/s2orc", "subset": None, "split": "train", "text_field": "text", "domain": "multidisciplinary", }, "pes2o": { "path": "allenai/peS2o", "subset": None, "split": "train", "text_field": "text", "domain": "multidisciplinary", }, "automath": { "path": "math-ai/AutoMathText", "subset": None, "split": "train", "text_field": "text", "domain": "math", }, "deepmind_math": { "path": "deepmind/math_dataset", "subset": "algebra__linear_1d", "split": "train", "text_field": "text", "domain": "math", }, "pubmed_qa": { "path": "bigbio/pubmed_qa", "subset": "pubmed_qa_labeled_fold0_source", "split": "train", "text_field": "question", "domain": "biology", }, } def __init__( self, cache_dir: str = "./data/cache", output_dir: str = "./data/processed", num_proc: int = 4, ): """ Initialize dataset loader. Args: cache_dir: Directory for caching downloaded datasets output_dir: Directory for processed shards num_proc: Number of processes for data processing """ self.cache_dir = Path(cache_dir) self.output_dir = Path(output_dir) self.num_proc = num_proc self.cache_dir.mkdir(parents=True, exist_ok=True) self.output_dir.mkdir(parents=True, exist_ok=True) def load_dataset( self, dataset_name: str, streaming: bool = True, max_samples: Optional[int] = None, ) -> Iterator[Dict]: """ Load a dataset as an iterator. Args: dataset_name: Name from DATASETS config streaming: Use streaming mode for large datasets max_samples: Maximum number of samples to yield Yields: Dictionary with text and metadata """ if dataset_name not in self.DATASETS: raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(self.DATASETS.keys())}") config = self.DATASETS[dataset_name] print(f"Loading dataset: {dataset_name}") print(f" Path: {config['path']}") print(f" Streaming: {streaming}") try: dataset = load_dataset( config["path"], name=config["subset"], split=config["split"], streaming=streaming, cache_dir=str(self.cache_dir), ) count = 0 for sample in dataset: text = sample.get(config["text_field"], "") if not text or not isinstance(text, str): continue yield { "text": text, "dataset": dataset_name, "domain": config["domain"], "source": config["path"], } count += 1 if max_samples and count >= max_samples: break print(f"Loaded {count} samples from {dataset_name}") except Exception as e: print(f"Error loading dataset {dataset_name}: {e}") # Return empty iterator return def load_multiple_datasets( self, dataset_names: List[str], streaming: bool = True, max_per_dataset: Optional[int] = None, ) -> Iterator[Dict]: """ Load multiple datasets and yield samples interleaved. Args: dataset_names: List of dataset names streaming: Use streaming mode max_per_dataset: Max samples per dataset Yields: Dictionary with text and metadata """ iterators = [] for name in dataset_names: it = self.load_dataset(name, streaming=streaming, max_samples=max_per_dataset) iterators.append(it) # Simple round-robin interleaving active = len(iterators) while active > 0: for i, it in enumerate(iterators): if it is None: continue try: yield next(it) except StopIteration: iterators[i] = None active -= 1 break def shard_to_parquet( self, samples: Iterator[Dict], output_prefix: str, samples_per_shard: int = 10000, ): """ Write samples to sharded Parquet files. Args: samples: Iterator of sample dictionaries output_prefix: Prefix for output files (e.g., "train") samples_per_shard: Number of samples per shard """ shard_index = 0 buffer = [] for sample in samples: buffer.append(sample) if len(buffer) >= samples_per_shard: self._write_shard(buffer, output_prefix, shard_index) shard_index += 1 buffer = [] # Write remaining if buffer: self._write_shard(buffer, output_prefix, shard_index) print(f"Wrote {shard_index + 1} shards to {self.output_dir}") def _write_shard( self, buffer: List[Dict], output_prefix: str, shard_index: int, ): """Write a single shard to Parquet.""" import pandas as pd df = pd.DataFrame(buffer) output_path = self.output_dir / f"{output_prefix}_{shard_index:05d}.parquet" df.to_parquet(output_path, index=False) def get_shard_list( self, prefix: str, ) -> List[Path]: """Get list of shard files matching prefix.""" return sorted(self.output_dir.glob(f"{prefix}_*.parquet")) def read_shard( self, shard_path: Path, ) -> List[Dict]: """Read a single shard.""" import pandas as pd df = pd.read_parquet(shard_path) return df.to_dict('records') def test_dataset_loader(): """Test the dataset loader.""" loader = VortexDatasetLoader() # Test loading a small dataset print("Testing dataset loader...") count = 0 for sample in loader.load_dataset("pubmed_qa", streaming=True, max_samples=10): print(f"Sample {count}: {sample['text'][:100]}...") count += 1 print(f"Loaded {count} samples") print("DatasetLoader test passed!") if __name__ == "__main__": test_dataset_loader()