| | """
|
| | DatasetLoader: Loads and processes open scientific datasets.
|
| | Supports streaming from HuggingFace datasets with sharding.
|
| | """
|
| |
|
| | import os
|
| | import json
|
| | from typing import List, Dict, Optional, Iterator
|
| | from pathlib import Path
|
| |
|
| | try:
|
| | from datasets import load_dataset, Dataset, IterableDataset
|
| | import pyarrow.parquet as pq
|
| | except ImportError:
|
| | print("Please install datasets and pyarrow: pip install datasets pyarrow")
|
| | raise
|
| |
|
| |
|
| | class VortexDatasetLoader:
|
| | """
|
| | Loads and processes open scientific datasets.
|
| | Supports streaming with sharding to Parquet files.
|
| | """
|
| |
|
| |
|
| | DATASETS = {
|
| | "pile_scientific": {
|
| | "path": "EleutherAI/pile",
|
| | "subset": "pubmed_central",
|
| | "split": "train",
|
| | "text_field": "text",
|
| | "domain": "biology",
|
| | },
|
| | "s2orc": {
|
| | "path": "allenai/s2orc",
|
| | "subset": None,
|
| | "split": "train",
|
| | "text_field": "text",
|
| | "domain": "multidisciplinary",
|
| | },
|
| | "pes2o": {
|
| | "path": "allenai/peS2o",
|
| | "subset": None,
|
| | "split": "train",
|
| | "text_field": "text",
|
| | "domain": "multidisciplinary",
|
| | },
|
| | "automath": {
|
| | "path": "math-ai/AutoMathText",
|
| | "subset": None,
|
| | "split": "train",
|
| | "text_field": "text",
|
| | "domain": "math",
|
| | },
|
| | "deepmind_math": {
|
| | "path": "deepmind/math_dataset",
|
| | "subset": "algebra__linear_1d",
|
| | "split": "train",
|
| | "text_field": "text",
|
| | "domain": "math",
|
| | },
|
| | "pubmed_qa": {
|
| | "path": "bigbio/pubmed_qa",
|
| | "subset": "pubmed_qa_labeled_fold0_source",
|
| | "split": "train",
|
| | "text_field": "question",
|
| | "domain": "biology",
|
| | },
|
| | }
|
| |
|
| | def __init__(
|
| | self,
|
| | cache_dir: str = "./data/cache",
|
| | output_dir: str = "./data/processed",
|
| | num_proc: int = 4,
|
| | ):
|
| | """
|
| | Initialize dataset loader.
|
| |
|
| | Args:
|
| | cache_dir: Directory for caching downloaded datasets
|
| | output_dir: Directory for processed shards
|
| | num_proc: Number of processes for data processing
|
| | """
|
| | self.cache_dir = Path(cache_dir)
|
| | self.output_dir = Path(output_dir)
|
| | self.num_proc = num_proc
|
| |
|
| | self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| | self.output_dir.mkdir(parents=True, exist_ok=True)
|
| |
|
| | def load_dataset(
|
| | self,
|
| | dataset_name: str,
|
| | streaming: bool = True,
|
| | max_samples: Optional[int] = None,
|
| | ) -> Iterator[Dict]:
|
| | """
|
| | Load a dataset as an iterator.
|
| |
|
| | Args:
|
| | dataset_name: Name from DATASETS config
|
| | streaming: Use streaming mode for large datasets
|
| | max_samples: Maximum number of samples to yield
|
| |
|
| | Yields:
|
| | Dictionary with text and metadata
|
| | """
|
| | if dataset_name not in self.DATASETS:
|
| | raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(self.DATASETS.keys())}")
|
| |
|
| | config = self.DATASETS[dataset_name]
|
| |
|
| | print(f"Loading dataset: {dataset_name}")
|
| | print(f" Path: {config['path']}")
|
| | print(f" Streaming: {streaming}")
|
| |
|
| | try:
|
| | dataset = load_dataset(
|
| | config["path"],
|
| | name=config["subset"],
|
| | split=config["split"],
|
| | streaming=streaming,
|
| | cache_dir=str(self.cache_dir),
|
| | )
|
| |
|
| | count = 0
|
| | for sample in dataset:
|
| | text = sample.get(config["text_field"], "")
|
| | if not text or not isinstance(text, str):
|
| | continue
|
| |
|
| | yield {
|
| | "text": text,
|
| | "dataset": dataset_name,
|
| | "domain": config["domain"],
|
| | "source": config["path"],
|
| | }
|
| |
|
| | count += 1
|
| | if max_samples and count >= max_samples:
|
| | break
|
| |
|
| | print(f"Loaded {count} samples from {dataset_name}")
|
| |
|
| | except Exception as e:
|
| | print(f"Error loading dataset {dataset_name}: {e}")
|
| |
|
| | return
|
| |
|
| | def load_multiple_datasets(
|
| | self,
|
| | dataset_names: List[str],
|
| | streaming: bool = True,
|
| | max_per_dataset: Optional[int] = None,
|
| | ) -> Iterator[Dict]:
|
| | """
|
| | Load multiple datasets and yield samples interleaved.
|
| |
|
| | Args:
|
| | dataset_names: List of dataset names
|
| | streaming: Use streaming mode
|
| | max_per_dataset: Max samples per dataset
|
| |
|
| | Yields:
|
| | Dictionary with text and metadata
|
| | """
|
| | iterators = []
|
| | for name in dataset_names:
|
| | it = self.load_dataset(name, streaming=streaming, max_samples=max_per_dataset)
|
| | iterators.append(it)
|
| |
|
| |
|
| | active = len(iterators)
|
| | while active > 0:
|
| | for i, it in enumerate(iterators):
|
| | if it is None:
|
| | continue
|
| | try:
|
| | yield next(it)
|
| | except StopIteration:
|
| | iterators[i] = None
|
| | active -= 1
|
| | break
|
| |
|
| | def shard_to_parquet(
|
| | self,
|
| | samples: Iterator[Dict],
|
| | output_prefix: str,
|
| | samples_per_shard: int = 10000,
|
| | ):
|
| | """
|
| | Write samples to sharded Parquet files.
|
| |
|
| | Args:
|
| | samples: Iterator of sample dictionaries
|
| | output_prefix: Prefix for output files (e.g., "train")
|
| | samples_per_shard: Number of samples per shard
|
| | """
|
| | shard_index = 0
|
| | buffer = []
|
| |
|
| | for sample in samples:
|
| | buffer.append(sample)
|
| |
|
| | if len(buffer) >= samples_per_shard:
|
| | self._write_shard(buffer, output_prefix, shard_index)
|
| | shard_index += 1
|
| | buffer = []
|
| |
|
| |
|
| | if buffer:
|
| | self._write_shard(buffer, output_prefix, shard_index)
|
| |
|
| | print(f"Wrote {shard_index + 1} shards to {self.output_dir}")
|
| |
|
| | def _write_shard(
|
| | self,
|
| | buffer: List[Dict],
|
| | output_prefix: str,
|
| | shard_index: int,
|
| | ):
|
| | """Write a single shard to Parquet."""
|
| | import pandas as pd
|
| |
|
| | df = pd.DataFrame(buffer)
|
| | output_path = self.output_dir / f"{output_prefix}_{shard_index:05d}.parquet"
|
| | df.to_parquet(output_path, index=False)
|
| |
|
| | def get_shard_list(
|
| | self,
|
| | prefix: str,
|
| | ) -> List[Path]:
|
| | """Get list of shard files matching prefix."""
|
| | return sorted(self.output_dir.glob(f"{prefix}_*.parquet"))
|
| |
|
| | def read_shard(
|
| | self,
|
| | shard_path: Path,
|
| | ) -> List[Dict]:
|
| | """Read a single shard."""
|
| | import pandas as pd
|
| | df = pd.read_parquet(shard_path)
|
| | return df.to_dict('records')
|
| |
|
| |
|
| | def test_dataset_loader():
|
| | """Test the dataset loader."""
|
| | loader = VortexDatasetLoader()
|
| |
|
| |
|
| | print("Testing dataset loader...")
|
| | count = 0
|
| | for sample in loader.load_dataset("pubmed_qa", streaming=True, max_samples=10):
|
| | print(f"Sample {count}: {sample['text'][:100]}...")
|
| | count += 1
|
| |
|
| | print(f"Loaded {count} samples")
|
| | print("DatasetLoader test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_dataset_loader()
|
| |
|