Spaces:
Build error
Build error
| import os | |
| import numpy as np | |
| import tiktoken | |
| from datasets import load_dataset, concatenate_datasets, interleave_datasets | |
| from torch.utils.data import IterableDataset | |
| import torch | |
| class StreamingDataset(IterableDataset): | |
| """Streaming dataset that loads and processes data on the fly""" | |
| def __init__(self, dataset_configs, block_size=2048, batch_size=12): | |
| self.dataset_configs = dataset_configs | |
| self.block_size = block_size | |
| self.batch_size = batch_size | |
| self.enc = tiktoken.get_encoding("gpt2") | |
| def load_and_process_chunk(self, dataset_name, split="train"): | |
| # Load datasets with appropriate configs | |
| if dataset_name == "openwebtext": | |
| dataset = load_dataset(dataset_name, split=split, streaming=True, trust_remote_code=True) | |
| elif dataset_name == "the_pile": | |
| dataset = load_dataset("the_pile", split=split, streaming=True) | |
| elif dataset_name == "red_pajama": | |
| dataset = load_dataset("togethercomputer/RedPajama-Data-1T", split=split, streaming=True) | |
| for example in dataset: | |
| ids = self.enc.encode_ordinary(example['text']) | |
| ids.append(self.enc.eot_token) | |
| if len(ids) >= self.block_size: | |
| # Return chunks of block_size | |
| for i in range(0, len(ids) - self.block_size + 1, self.block_size): | |
| yield torch.tensor(ids[i:i + self.block_size]) | |
| def __iter__(self): | |
| # Interleave datasets with specified weights | |
| iterators = [] | |
| weights = [] | |
| for config in self.dataset_configs: | |
| iterators.append(self.load_and_process_chunk(config['name'])) | |
| weights.append(config['weight']) | |
| # Normalize weights | |
| weights = np.array(weights) / sum(weights) | |
| while True: | |
| # Randomly select a dataset based on weights | |
| dataset_idx = np.random.choice(len(iterators), p=weights) | |
| try: | |
| batch = [] | |
| for _ in range(self.batch_size): | |
| batch.append(next(iterators[dataset_idx])) | |
| yield torch.stack(batch) | |
| except StopIteration: | |
| # Restart iterator if it's exhausted | |
| iterators[dataset_idx] = self.load_and_process_chunk(self.dataset_configs[dataset_idx]['name']) | |
| continue | |
| # Example usage: | |
| dataset_configs = [ | |
| {'name': 'openwebtext', 'weight': 0.4}, | |
| {'name': 'the_pile', 'weight': 0.3}, | |
| {'name': 'red_pajama', 'weight': 0.3} | |
| ] | |