|
|
"""
|
|
|
Dataset configurations and loaders for ULTRATHINK training
|
|
|
Supports multiple popular datasets with easy switching
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
from transformers import AutoTokenizer
|
|
|
from datasets import load_dataset
|
|
|
import json
|
|
|
import os
|
|
|
from typing import Dict, List, Optional, Union, Any
|
|
|
from dataclasses import dataclass
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
|
class DatasetConfig:
|
|
|
"""Configuration for dataset loading"""
|
|
|
name: str = "wikitext"
|
|
|
subset: Optional[str] = "wikitext-2-raw-v1"
|
|
|
split_train: str = "train"
|
|
|
split_val: str = "validation"
|
|
|
split_test: str = "test"
|
|
|
text_column: str = "text"
|
|
|
max_length: int = 512
|
|
|
tokenizer_name: str = "gpt2"
|
|
|
streaming: bool = False
|
|
|
cache_dir: Optional[str] = None
|
|
|
num_proc: int = 4
|
|
|
|
|
|
seed: int = 42
|
|
|
buffer_size: int = 10000
|
|
|
shard_rank: Optional[int] = None
|
|
|
shard_num_shards: Optional[int] = None
|
|
|
|
|
|
|
|
|
local_path: Optional[str] = None
|
|
|
file_type: str = "json"
|
|
|
|
|
|
|
|
|
min_length: int = 10
|
|
|
max_samples: Optional[int] = None
|
|
|
|
|
|
|
|
|
mixing_weights: Optional[Dict[str, float]] = None
|
|
|
|
|
|
|
|
|
DATASET_CONFIGS = {
|
|
|
"wikitext": DatasetConfig(
|
|
|
name="wikitext",
|
|
|
subset="wikitext-2-raw-v1",
|
|
|
text_column="text",
|
|
|
max_length=512,
|
|
|
streaming=False
|
|
|
),
|
|
|
"wikitext-103": DatasetConfig(
|
|
|
name="wikitext",
|
|
|
subset="wikitext-103-raw-v1",
|
|
|
text_column="text",
|
|
|
max_length=1024,
|
|
|
streaming=True
|
|
|
),
|
|
|
|
|
|
"openwebtext": DatasetConfig(
|
|
|
name="Skylion007/openwebtext",
|
|
|
subset=None,
|
|
|
text_column="text",
|
|
|
max_length=1024,
|
|
|
streaming=True
|
|
|
),
|
|
|
"slim-pajama": DatasetConfig(
|
|
|
name="cerebras/SlimPajama-627B",
|
|
|
subset=None,
|
|
|
text_column="text",
|
|
|
max_length=2048,
|
|
|
streaming=True
|
|
|
),
|
|
|
"pile": DatasetConfig(
|
|
|
name="EleutherAI/pile",
|
|
|
subset=None,
|
|
|
text_column="text",
|
|
|
max_length=2048,
|
|
|
streaming=True
|
|
|
),
|
|
|
"pile-unc": DatasetConfig(
|
|
|
name="monology/pile-uncopyrighted",
|
|
|
subset=None,
|
|
|
text_column="text",
|
|
|
max_length=2048,
|
|
|
streaming=True
|
|
|
),
|
|
|
"c4": DatasetConfig(
|
|
|
name="allenai/c4",
|
|
|
subset="en",
|
|
|
text_column="text",
|
|
|
max_length=512,
|
|
|
streaming=True
|
|
|
),
|
|
|
|
|
|
"bookcorpus": DatasetConfig(
|
|
|
name="bookcorpusopen",
|
|
|
subset=None,
|
|
|
text_column="text",
|
|
|
max_length=1024,
|
|
|
streaming=True
|
|
|
),
|
|
|
"oscar": DatasetConfig(
|
|
|
name="oscar",
|
|
|
subset="unshuffled_deduplicated_en",
|
|
|
text_column="text",
|
|
|
max_length=512,
|
|
|
streaming=True
|
|
|
),
|
|
|
|
|
|
"wikipedia": DatasetConfig(
|
|
|
name="wikimedia/wikipedia",
|
|
|
subset="20231101.en",
|
|
|
text_column="text",
|
|
|
max_length=1024,
|
|
|
streaming=True
|
|
|
),
|
|
|
"dummy": DatasetConfig(
|
|
|
name="dummy",
|
|
|
subset=None,
|
|
|
text_column="text",
|
|
|
max_length=512,
|
|
|
streaming=False
|
|
|
)
|
|
|
}
|
|
|
|
|
|
class TextDataset(Dataset):
|
|
|
"""Generic text dataset for language modeling"""
|
|
|
|
|
|
def __init__(self, config: DatasetConfig, split: str = "train"):
|
|
|
self.config = config
|
|
|
self.split = split
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
|
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token is None:
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
|
self.data = self._load_data()
|
|
|
|
|
|
def _load_data(self):
|
|
|
"""Load and preprocess data"""
|
|
|
if self.config.name == "dummy":
|
|
|
return self._create_dummy_data()
|
|
|
elif self.config.local_path:
|
|
|
return self._load_local_data()
|
|
|
else:
|
|
|
return self._load_hf_data()
|
|
|
|
|
|
def _create_dummy_data(self):
|
|
|
"""Create dummy data for testing"""
|
|
|
logger.info("Creating dummy dataset for testing...")
|
|
|
dummy_texts = [
|
|
|
"The quick brown fox jumps over the lazy dog.",
|
|
|
"Machine learning is a subset of artificial intelligence.",
|
|
|
"Natural language processing enables computers to understand human language.",
|
|
|
"Deep learning models can learn complex patterns from data.",
|
|
|
"Transformers have revolutionized the field of NLP.",
|
|
|
] * 2000
|
|
|
|
|
|
return [{"text": text} for text in dummy_texts]
|
|
|
|
|
|
def _load_local_data(self):
|
|
|
"""Load data from local files or remote URLs.
|
|
|
Supports:
|
|
|
- Local JSONL/TXT small files (read directly)
|
|
|
- HTTP/HTTPS URLs or local globs via datasets.load_dataset with streaming
|
|
|
- Multiple files via comma-separated list
|
|
|
"""
|
|
|
path = self.config.local_path
|
|
|
logger.info(f"Loading local/remote data from {path}")
|
|
|
|
|
|
|
|
|
if "," in path:
|
|
|
paths = [p.strip() for p in path.split(",") if p.strip()]
|
|
|
else:
|
|
|
paths = [path]
|
|
|
|
|
|
def is_remote(p: str) -> bool:
|
|
|
return p.startswith("http://") or p.startswith("https://")
|
|
|
|
|
|
|
|
|
if any(is_remote(p) or ("*" in p) for p in paths):
|
|
|
|
|
|
sample = paths[0]
|
|
|
lower = sample.lower()
|
|
|
if lower.endswith(".parquet"):
|
|
|
builder = "parquet"
|
|
|
elif lower.endswith(".jsonl") or lower.endswith(".json") or lower.endswith(".jsonl.zst") or lower.endswith(".jsonl.gz"):
|
|
|
builder = "json"
|
|
|
else:
|
|
|
|
|
|
builder = "json"
|
|
|
|
|
|
logger.info(f"Using datasets.load_dataset builder='{builder}' with streaming for data_files={paths}")
|
|
|
dataset = load_dataset(builder, data_files=paths, split=self.split, streaming=True)
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
data = []
|
|
|
for p in paths:
|
|
|
if self.config.file_type == "json":
|
|
|
with open(p, 'r', encoding='utf-8') as f:
|
|
|
for line in f:
|
|
|
item = json.loads(line)
|
|
|
if self.config.text_column in item:
|
|
|
data.append({self.config.text_column: item[self.config.text_column]})
|
|
|
elif self.config.file_type == "txt":
|
|
|
with open(p, 'r', encoding='utf-8') as f:
|
|
|
text = f.read()
|
|
|
chunks = text.split('\n\n')
|
|
|
data.extend([{self.config.text_column: chunk.strip()} for chunk in chunks if len(chunk.strip()) > self.config.min_length])
|
|
|
|
|
|
logger.info(f"Loaded {len(data)} samples from local files")
|
|
|
return data
|
|
|
|
|
|
def _load_hf_data(self):
|
|
|
"""Load data from Hugging Face datasets"""
|
|
|
logger.info(f"Loading {self.config.name} dataset from Hugging Face...")
|
|
|
|
|
|
try:
|
|
|
target_name = self.config.name
|
|
|
target_subset = self.config.subset
|
|
|
|
|
|
legacy_map = {
|
|
|
"openwebtext": "Skylion007/openwebtext",
|
|
|
"bookcorpus": "bookcorpusopen",
|
|
|
|
|
|
"c4": "allenai/c4",
|
|
|
}
|
|
|
if target_name in legacy_map:
|
|
|
target_name = legacy_map[target_name]
|
|
|
|
|
|
|
|
|
kwargs = {
|
|
|
"split": self.split,
|
|
|
"streaming": self.config.streaming,
|
|
|
"cache_dir": self.config.cache_dir,
|
|
|
}
|
|
|
|
|
|
def try_load():
|
|
|
if target_subset:
|
|
|
return load_dataset(target_name, target_subset, **kwargs)
|
|
|
else:
|
|
|
return load_dataset(target_name, **kwargs)
|
|
|
|
|
|
try:
|
|
|
dataset = try_load()
|
|
|
except Exception as e:
|
|
|
logger.error(f"HF load_dataset failed for {target_name} ({target_subset}): {e}")
|
|
|
raise
|
|
|
|
|
|
|
|
|
if not self.config.streaming:
|
|
|
data = []
|
|
|
for item in dataset:
|
|
|
if self.config.text_column in item and item[self.config.text_column]:
|
|
|
text = item[self.config.text_column].strip()
|
|
|
if len(text) >= self.config.min_length:
|
|
|
data.append({self.config.text_column: text})
|
|
|
|
|
|
if self.config.max_samples and len(data) >= self.config.max_samples:
|
|
|
break
|
|
|
|
|
|
logger.info(f"Loaded {len(data)} samples from {self.config.name}")
|
|
|
return data
|
|
|
else:
|
|
|
|
|
|
if self.config.shard_num_shards is not None and self.config.shard_rank is not None:
|
|
|
try:
|
|
|
dataset = dataset.shard(self.config.shard_num_shards, self.config.shard_rank)
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Streaming shard not applied: {e}")
|
|
|
try:
|
|
|
dataset = dataset.shuffle(seed=self.config.seed, buffer_size=self.config.buffer_size)
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Streaming shuffle not applied: {e}")
|
|
|
return dataset
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load {self.config.name}: {e}")
|
|
|
logger.info("Falling back to dummy dataset...")
|
|
|
return self._create_dummy_data()
|
|
|
|
|
|
def __len__(self):
|
|
|
if hasattr(self.data, '__len__'):
|
|
|
return len(self.data)
|
|
|
else:
|
|
|
|
|
|
return self.config.max_samples or 100000
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
if isinstance(self.data, list):
|
|
|
item = self.data[idx % len(self.data)]
|
|
|
else:
|
|
|
|
|
|
stream = self.data
|
|
|
try:
|
|
|
stream = stream.shuffle(seed=self.config.seed, buffer_size=self.config.buffer_size)
|
|
|
except Exception:
|
|
|
pass
|
|
|
item = next(iter(stream.skip(idx).take(1)))
|
|
|
|
|
|
text = item[self.config.text_column]
|
|
|
|
|
|
|
|
|
encoding = self.tokenizer(
|
|
|
text,
|
|
|
truncation=True,
|
|
|
padding='max_length',
|
|
|
max_length=self.config.max_length,
|
|
|
return_tensors='pt'
|
|
|
)
|
|
|
|
|
|
input_ids = encoding['input_ids'].squeeze()
|
|
|
attention_mask = encoding['attention_mask'].squeeze()
|
|
|
|
|
|
|
|
|
labels = input_ids.clone()
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
labels = labels.masked_fill(attention_mask == 0, -100)
|
|
|
|
|
|
return {
|
|
|
'input_ids': input_ids,
|
|
|
'attention_mask': attention_mask,
|
|
|
'labels': labels
|
|
|
}
|
|
|
|
|
|
class MixedDataset(Dataset):
|
|
|
"""Dataset that mixes multiple datasets with specified weights"""
|
|
|
|
|
|
def __init__(self, datasets: Dict[str, Dataset], weights: Dict[str, float]):
|
|
|
self.datasets = datasets
|
|
|
self.weights = weights
|
|
|
self.dataset_names = list(datasets.keys())
|
|
|
|
|
|
|
|
|
total_weight = sum(weights.values())
|
|
|
self.cumulative_weights = []
|
|
|
cumsum = 0
|
|
|
for name in self.dataset_names:
|
|
|
cumsum += weights[name] / total_weight
|
|
|
self.cumulative_weights.append(cumsum)
|
|
|
|
|
|
|
|
|
self.total_length = sum(len(ds) for ds in datasets.values())
|
|
|
|
|
|
def __len__(self):
|
|
|
return self.total_length
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
|
import random
|
|
|
rand = random.random()
|
|
|
|
|
|
for i, cum_weight in enumerate(self.cumulative_weights):
|
|
|
if rand <= cum_weight:
|
|
|
dataset_name = self.dataset_names[i]
|
|
|
dataset = self.datasets[dataset_name]
|
|
|
|
|
|
dataset_idx = idx % len(dataset)
|
|
|
return dataset[dataset_idx]
|
|
|
|
|
|
|
|
|
return self.datasets[self.dataset_names[0]][idx % len(self.datasets[self.dataset_names[0]])]
|
|
|
|
|
|
def create_dataset(config: Union[str, DatasetConfig], split: str = "train") -> Dataset:
|
|
|
"""Create a dataset from config"""
|
|
|
if isinstance(config, str):
|
|
|
if config in DATASET_CONFIGS:
|
|
|
config = DATASET_CONFIGS[config]
|
|
|
else:
|
|
|
raise ValueError(f"Unknown dataset config: {config}")
|
|
|
|
|
|
return TextDataset(config, split)
|
|
|
|
|
|
def create_mixed_dataset(configs: Dict[str, Union[str, DatasetConfig]],
|
|
|
weights: Dict[str, float],
|
|
|
split: str = "train") -> MixedDataset:
|
|
|
"""Create a mixed dataset from multiple configs"""
|
|
|
datasets = {}
|
|
|
for name, config in configs.items():
|
|
|
datasets[name] = create_dataset(config, split)
|
|
|
|
|
|
return MixedDataset(datasets, weights)
|
|
|
|
|
|
def create_dataloader(dataset: Dataset,
|
|
|
batch_size: int = 8,
|
|
|
shuffle: bool = True,
|
|
|
num_workers: int = 0,
|
|
|
pin_memory: bool = False) -> DataLoader:
|
|
|
"""Create a DataLoader with optimized settings"""
|
|
|
|
|
|
optimal_workers = min(num_workers * 2, 6)
|
|
|
return DataLoader(
|
|
|
dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=shuffle,
|
|
|
num_workers=optimal_workers,
|
|
|
pin_memory=pin_memory,
|
|
|
drop_last=True,
|
|
|
persistent_workers=True if optimal_workers > 0 else False,
|
|
|
prefetch_factor=4 if optimal_workers > 0 else None
|
|
|
)
|
|
|
|
|
|
|
|
|
DATASET_INFO = {
|
|
|
"wikitext": {
|
|
|
"size": "~100MB (wikitext-2), ~500MB (wikitext-103)",
|
|
|
"language": "English",
|
|
|
"domain": "Wikipedia articles",
|
|
|
"license": "Creative Commons"
|
|
|
},
|
|
|
"openwebtext": {
|
|
|
"description": "Open source recreation of WebText",
|
|
|
"size": "~40GB",
|
|
|
"language": "English",
|
|
|
"domain": "Web pages",
|
|
|
"license": "Public domain"
|
|
|
},
|
|
|
"pile": {
|
|
|
"description": "Large-scale curated text dataset",
|
|
|
"size": "~800GB",
|
|
|
"language": "English",
|
|
|
"domain": "Books, web, academic papers, code",
|
|
|
"license": "MIT"
|
|
|
},
|
|
|
"c4": {
|
|
|
"description": "Colossal Clean Crawled Corpus",
|
|
|
"size": "~750GB",
|
|
|
"language": "Multiple (English subset available)",
|
|
|
"domain": "Web crawl data",
|
|
|
"license": "ODC-BY"
|
|
|
},
|
|
|
"bookcorpus": {
|
|
|
"description": "Collection of over 11,000 books",
|
|
|
"size": "~5GB",
|
|
|
"language": "English",
|
|
|
"domain": "Books and novels",
|
|
|
"license": "Research use"
|
|
|
}
|
|
|
}
|
|
|
|
|
|
def print_dataset_info():
|
|
|
"""Print information about available datasets"""
|
|
|
print("\n📚 Available Datasets for ULTRATHINK Training:\n")
|
|
|
|
|
|
for name, info in DATASET_INFO.items():
|
|
|
print(f"🔹 {name.upper()}")
|
|
|
print(f" Description: {info['description']}")
|
|
|
print(f" Size: {info['size']}")
|
|
|
print(f" Language: {info['language']}")
|
|
|
print(f" Domain: {info['domain']}")
|
|
|
print(f" License: {info['license']}")
|
|
|
print()
|
|
|
|
|
|
print("💡 Usage Examples:")
|
|
|
print(" --dataset wikitext # Small, fast download")
|
|
|
print(" --dataset openwebtext # Medium size, diverse")
|
|
|
print(" --dataset pile # Large, comprehensive")
|
|
|
print(" --dataset custom --data_path /path/to/data.json")
|
|
|
print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print_dataset_info()
|
|
|
|