Vortex-13b-V1 / data /dataset_loader.py
Zandy-Wandy's picture
Upload Vortex model
5c43f61 verified
"""
DatasetLoader: Loads and processes open scientific datasets.
Supports streaming from HuggingFace datasets with sharding.
"""
import os
import json
from typing import List, Dict, Optional, Iterator
from pathlib import Path
try:
from datasets import load_dataset, Dataset, IterableDataset
import pyarrow.parquet as pq
except ImportError:
print("Please install datasets and pyarrow: pip install datasets pyarrow")
raise
class VortexDatasetLoader:
"""
Loads and processes open scientific datasets.
Supports streaming with sharding to Parquet files.
"""
# Open datasets configuration
DATASETS = {
"pile_scientific": {
"path": "EleutherAI/pile",
"subset": "pubmed_central",
"split": "train",
"text_field": "text",
"domain": "biology", # approximate
},
"s2orc": {
"path": "allenai/s2orc",
"subset": None,
"split": "train",
"text_field": "text",
"domain": "multidisciplinary",
},
"pes2o": {
"path": "allenai/peS2o",
"subset": None,
"split": "train",
"text_field": "text",
"domain": "multidisciplinary",
},
"automath": {
"path": "math-ai/AutoMathText",
"subset": None,
"split": "train",
"text_field": "text",
"domain": "math",
},
"deepmind_math": {
"path": "deepmind/math_dataset",
"subset": "algebra__linear_1d",
"split": "train",
"text_field": "text",
"domain": "math",
},
"pubmed_qa": {
"path": "bigbio/pubmed_qa",
"subset": "pubmed_qa_labeled_fold0_source",
"split": "train",
"text_field": "question",
"domain": "biology",
},
}
def __init__(
self,
cache_dir: str = "./data/cache",
output_dir: str = "./data/processed",
num_proc: int = 4,
):
"""
Initialize dataset loader.
Args:
cache_dir: Directory for caching downloaded datasets
output_dir: Directory for processed shards
num_proc: Number of processes for data processing
"""
self.cache_dir = Path(cache_dir)
self.output_dir = Path(output_dir)
self.num_proc = num_proc
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.output_dir.mkdir(parents=True, exist_ok=True)
def load_dataset(
self,
dataset_name: str,
streaming: bool = True,
max_samples: Optional[int] = None,
) -> Iterator[Dict]:
"""
Load a dataset as an iterator.
Args:
dataset_name: Name from DATASETS config
streaming: Use streaming mode for large datasets
max_samples: Maximum number of samples to yield
Yields:
Dictionary with text and metadata
"""
if dataset_name not in self.DATASETS:
raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(self.DATASETS.keys())}")
config = self.DATASETS[dataset_name]
print(f"Loading dataset: {dataset_name}")
print(f" Path: {config['path']}")
print(f" Streaming: {streaming}")
try:
dataset = load_dataset(
config["path"],
name=config["subset"],
split=config["split"],
streaming=streaming,
cache_dir=str(self.cache_dir),
)
count = 0
for sample in dataset:
text = sample.get(config["text_field"], "")
if not text or not isinstance(text, str):
continue
yield {
"text": text,
"dataset": dataset_name,
"domain": config["domain"],
"source": config["path"],
}
count += 1
if max_samples and count >= max_samples:
break
print(f"Loaded {count} samples from {dataset_name}")
except Exception as e:
print(f"Error loading dataset {dataset_name}: {e}")
# Return empty iterator
return
def load_multiple_datasets(
self,
dataset_names: List[str],
streaming: bool = True,
max_per_dataset: Optional[int] = None,
) -> Iterator[Dict]:
"""
Load multiple datasets and yield samples interleaved.
Args:
dataset_names: List of dataset names
streaming: Use streaming mode
max_per_dataset: Max samples per dataset
Yields:
Dictionary with text and metadata
"""
iterators = []
for name in dataset_names:
it = self.load_dataset(name, streaming=streaming, max_samples=max_per_dataset)
iterators.append(it)
# Simple round-robin interleaving
active = len(iterators)
while active > 0:
for i, it in enumerate(iterators):
if it is None:
continue
try:
yield next(it)
except StopIteration:
iterators[i] = None
active -= 1
break
def shard_to_parquet(
self,
samples: Iterator[Dict],
output_prefix: str,
samples_per_shard: int = 10000,
):
"""
Write samples to sharded Parquet files.
Args:
samples: Iterator of sample dictionaries
output_prefix: Prefix for output files (e.g., "train")
samples_per_shard: Number of samples per shard
"""
shard_index = 0
buffer = []
for sample in samples:
buffer.append(sample)
if len(buffer) >= samples_per_shard:
self._write_shard(buffer, output_prefix, shard_index)
shard_index += 1
buffer = []
# Write remaining
if buffer:
self._write_shard(buffer, output_prefix, shard_index)
print(f"Wrote {shard_index + 1} shards to {self.output_dir}")
def _write_shard(
self,
buffer: List[Dict],
output_prefix: str,
shard_index: int,
):
"""Write a single shard to Parquet."""
import pandas as pd
df = pd.DataFrame(buffer)
output_path = self.output_dir / f"{output_prefix}_{shard_index:05d}.parquet"
df.to_parquet(output_path, index=False)
def get_shard_list(
self,
prefix: str,
) -> List[Path]:
"""Get list of shard files matching prefix."""
return sorted(self.output_dir.glob(f"{prefix}_*.parquet"))
def read_shard(
self,
shard_path: Path,
) -> List[Dict]:
"""Read a single shard."""
import pandas as pd
df = pd.read_parquet(shard_path)
return df.to_dict('records')
def test_dataset_loader():
"""Test the dataset loader."""
loader = VortexDatasetLoader()
# Test loading a small dataset
print("Testing dataset loader...")
count = 0
for sample in loader.load_dataset("pubmed_qa", streaming=True, max_samples=10):
print(f"Sample {count}: {sample['text'][:100]}...")
count += 1
print(f"Loaded {count} samples")
print("DatasetLoader test passed!")
if __name__ == "__main__":
test_dataset_loader()