|
|
""" |
|
|
Dataset classes for LLM training. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import random |
|
|
import numpy as np |
|
|
import time |
|
|
import threading |
|
|
import queue |
|
|
import logging |
|
|
from typing import Dict, List, Optional, Tuple, Union, Callable, Iterator, Any |
|
|
import jax.numpy as jnp |
|
|
from data.tokenizer import Tokenizer |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
try: |
|
|
import datasets |
|
|
from datasets import load_dataset, Dataset as HFDataset |
|
|
DATASETS_AVAILABLE = True |
|
|
except ImportError: |
|
|
logger.warning("HuggingFace datasets library not available. Streaming datasets will be disabled.") |
|
|
DATASETS_AVAILABLE = False |
|
|
|
|
|
|
|
|
class Dataset: |
|
|
""" |
|
|
Base dataset class. |
|
|
""" |
|
|
|
|
|
def __init__(self, tokenizer: Tokenizer): |
|
|
""" |
|
|
Initialize dataset. |
|
|
|
|
|
Args: |
|
|
tokenizer: Tokenizer for encoding/decoding text |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
def __len__(self) -> int: |
|
|
""" |
|
|
Get dataset length. |
|
|
|
|
|
Returns: |
|
|
Dataset length |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Get dataset item. |
|
|
|
|
|
Args: |
|
|
idx: Item index |
|
|
|
|
|
Returns: |
|
|
Dictionary of tensors |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class StreamingDataset(Dataset): |
|
|
""" |
|
|
Streaming dataset for efficient training with large datasets. |
|
|
|
|
|
This dataset streams data from disk or remote sources, minimizing memory usage. |
|
|
It supports HuggingFace datasets streaming mode for efficient processing. |
|
|
|
|
|
Attributes: |
|
|
tokenizer: Tokenizer for encoding/decoding text |
|
|
dataset_path: Path to dataset file or HuggingFace dataset name |
|
|
max_seq_length: Maximum sequence length |
|
|
streaming: Whether to use streaming mode |
|
|
buffer_size: Size of buffer for streaming |
|
|
seed: Random seed for shuffling |
|
|
hf_dataset: HuggingFace dataset object |
|
|
text_column: Name of text column in dataset |
|
|
buffer: Buffer of processed examples |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tokenizer: Tokenizer, |
|
|
dataset_path: str, |
|
|
max_seq_length: int = 131072, |
|
|
streaming: bool = True, |
|
|
buffer_size: int = 1000, |
|
|
seed: int = 42, |
|
|
text_column: str = "text", |
|
|
preprocessing_num_workers: int = 16, |
|
|
use_auth_token: Optional[str] = None |
|
|
): |
|
|
""" |
|
|
Initialize streaming dataset. |
|
|
|
|
|
Args: |
|
|
tokenizer: Tokenizer for encoding/decoding text |
|
|
dataset_path: Path to dataset file or HuggingFace dataset name |
|
|
max_seq_length: Maximum sequence length |
|
|
streaming: Whether to use streaming mode |
|
|
buffer_size: Size of buffer for streaming |
|
|
seed: Random seed for shuffling |
|
|
text_column: Name of text column in dataset |
|
|
preprocessing_num_workers: Number of workers for preprocessing |
|
|
use_auth_token: HuggingFace auth token for private datasets |
|
|
""" |
|
|
super().__init__(tokenizer) |
|
|
|
|
|
self.dataset_path = dataset_path |
|
|
self.max_seq_length = max_seq_length |
|
|
self.streaming = streaming and DATASETS_AVAILABLE |
|
|
self.buffer_size = buffer_size |
|
|
self.seed = seed |
|
|
self.text_column = text_column |
|
|
self.preprocessing_num_workers = preprocessing_num_workers |
|
|
|
|
|
|
|
|
self.buffer = [] |
|
|
self.buffer_lock = threading.Lock() |
|
|
self.buffer_ready = threading.Event() |
|
|
self.buffer_idx = 0 |
|
|
self.dataset_exhausted = False |
|
|
|
|
|
|
|
|
self._load_dataset(use_auth_token) |
|
|
|
|
|
|
|
|
if self.streaming: |
|
|
self.buffer_thread = threading.Thread(target=self._fill_buffer) |
|
|
self.buffer_thread.daemon = True |
|
|
self.buffer_thread.start() |
|
|
|
|
|
def _load_dataset(self, use_auth_token: Optional[str] = None): |
|
|
""" |
|
|
Load dataset from file or HuggingFace. |
|
|
|
|
|
Args: |
|
|
use_auth_token: HuggingFace auth token for private datasets |
|
|
""" |
|
|
if not DATASETS_AVAILABLE: |
|
|
raise ImportError("HuggingFace datasets library is required for streaming datasets") |
|
|
|
|
|
logger.info(f"Loading dataset from {self.dataset_path}") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if os.path.exists(self.dataset_path): |
|
|
|
|
|
file_extension = os.path.splitext(self.dataset_path)[1] |
|
|
if file_extension == ".jsonl" or file_extension == ".json": |
|
|
self.hf_dataset = load_dataset( |
|
|
"json", |
|
|
data_files=self.dataset_path, |
|
|
streaming=self.streaming, |
|
|
use_auth_token=use_auth_token |
|
|
)["train"] |
|
|
elif file_extension == ".txt": |
|
|
self.hf_dataset = load_dataset( |
|
|
"text", |
|
|
data_files=self.dataset_path, |
|
|
streaming=self.streaming, |
|
|
use_auth_token=use_auth_token |
|
|
)["train"] |
|
|
else: |
|
|
raise ValueError(f"Unsupported file extension: {file_extension}") |
|
|
else: |
|
|
|
|
|
self.hf_dataset = load_dataset( |
|
|
self.dataset_path, |
|
|
streaming=self.streaming, |
|
|
use_auth_token=use_auth_token |
|
|
)["train"] |
|
|
|
|
|
|
|
|
if self.streaming: |
|
|
self.hf_dataset = self.hf_dataset.shuffle(seed=self.seed, buffer_size=self.buffer_size) |
|
|
|
|
|
logger.info(f"Dataset loaded in {time.time() - start_time:.2f} seconds") |
|
|
|
|
|
|
|
|
if not self.streaming: |
|
|
self.dataset_length = len(self.hf_dataset) |
|
|
logger.info(f"Dataset length: {self.dataset_length}") |
|
|
|
|
|
def _fill_buffer(self): |
|
|
""" |
|
|
Fill buffer with processed examples in background thread. |
|
|
""" |
|
|
try: |
|
|
|
|
|
dataset_iter = iter(self.hf_dataset) |
|
|
|
|
|
while True: |
|
|
|
|
|
with self.buffer_lock: |
|
|
if len(self.buffer) >= self.buffer_size: |
|
|
|
|
|
self.buffer_ready.set() |
|
|
time.sleep(0.1) |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
example = next(dataset_iter) |
|
|
except StopIteration: |
|
|
|
|
|
self.dataset_exhausted = True |
|
|
self.buffer_ready.set() |
|
|
break |
|
|
|
|
|
|
|
|
processed = self._process_example(example) |
|
|
|
|
|
|
|
|
with self.buffer_lock: |
|
|
self.buffer.append(processed) |
|
|
|
|
|
|
|
|
if len(self.buffer) > 0: |
|
|
self.buffer_ready.set() |
|
|
except Exception as e: |
|
|
logger.error(f"Error in buffer filling thread: {e}") |
|
|
self.dataset_exhausted = True |
|
|
self.buffer_ready.set() |
|
|
|
|
|
def _process_example(self, example: Dict[str, Any]) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Process example from dataset. |
|
|
|
|
|
Args: |
|
|
example: Example from dataset |
|
|
|
|
|
Returns: |
|
|
Processed example |
|
|
""" |
|
|
|
|
|
if self.text_column in example: |
|
|
text = example[self.text_column] |
|
|
else: |
|
|
|
|
|
text_columns = ["text", "content", "document", "input_text"] |
|
|
for col in text_columns: |
|
|
if col in example: |
|
|
text = example[col] |
|
|
break |
|
|
else: |
|
|
|
|
|
for key, value in example.items(): |
|
|
if isinstance(value, str): |
|
|
text = value |
|
|
break |
|
|
else: |
|
|
raise ValueError(f"No text column found in example: {example}") |
|
|
|
|
|
|
|
|
input_ids = self.tokenizer.encode(text) |
|
|
|
|
|
|
|
|
if len(input_ids) > self.max_seq_length: |
|
|
input_ids = input_ids[:self.max_seq_length] |
|
|
|
|
|
|
|
|
input_ids = np.array(input_ids, dtype=np.int32) |
|
|
|
|
|
return {"input_ids": input_ids} |
|
|
|
|
|
def __len__(self) -> int: |
|
|
""" |
|
|
Get dataset length. |
|
|
|
|
|
Returns: |
|
|
Dataset length |
|
|
""" |
|
|
if self.streaming: |
|
|
|
|
|
return 1_000_000_000 |
|
|
else: |
|
|
return self.dataset_length |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Get dataset item. |
|
|
|
|
|
Args: |
|
|
idx: Item index (ignored in streaming mode) |
|
|
|
|
|
Returns: |
|
|
Dictionary of tensors |
|
|
""" |
|
|
if self.streaming: |
|
|
|
|
|
self.buffer_ready.wait() |
|
|
|
|
|
with self.buffer_lock: |
|
|
if len(self.buffer) == 0: |
|
|
if self.dataset_exhausted: |
|
|
|
|
|
self.buffer_idx = 0 |
|
|
raise StopIteration("Dataset exhausted") |
|
|
else: |
|
|
|
|
|
self.buffer_ready.clear() |
|
|
return self.__getitem__(idx) |
|
|
|
|
|
|
|
|
item = self.buffer[self.buffer_idx] |
|
|
self.buffer_idx += 1 |
|
|
|
|
|
|
|
|
if self.buffer_idx >= len(self.buffer): |
|
|
self.buffer = [] |
|
|
self.buffer_idx = 0 |
|
|
self.buffer_ready.clear() |
|
|
|
|
|
return item |
|
|
else: |
|
|
|
|
|
example = self.hf_dataset[idx] |
|
|
return self._process_example(example) |
|
|
|
|
|
|
|
|
class TextDataset(Dataset): |
|
|
""" |
|
|
Text dataset. |
|
|
|
|
|
Attributes: |
|
|
texts: List of texts |
|
|
tokenizer: Tokenizer for encoding/decoding text |
|
|
max_length: Maximum sequence length |
|
|
add_bos: Whether to add beginning of sequence token |
|
|
add_eos: Whether to add end of sequence token |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
texts: List[str], |
|
|
tokenizer: Tokenizer, |
|
|
max_length: int = 1024, |
|
|
add_bos: bool = True, |
|
|
add_eos: bool = False |
|
|
): |
|
|
""" |
|
|
Initialize dataset. |
|
|
|
|
|
Args: |
|
|
texts: List of texts |
|
|
tokenizer: Tokenizer for encoding/decoding text |
|
|
max_length: Maximum sequence length |
|
|
add_bos: Whether to add beginning of sequence token |
|
|
add_eos: Whether to add end of sequence token |
|
|
""" |
|
|
super().__init__(tokenizer) |
|
|
self.texts = texts |
|
|
self.max_length = max_length |
|
|
self.add_bos = add_bos |
|
|
self.add_eos = add_eos |
|
|
|
|
|
def __len__(self) -> int: |
|
|
""" |
|
|
Get dataset length. |
|
|
|
|
|
Returns: |
|
|
Dataset length |
|
|
""" |
|
|
return len(self.texts) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Get dataset item. |
|
|
|
|
|
Args: |
|
|
idx: Item index |
|
|
|
|
|
Returns: |
|
|
Dictionary of tensors |
|
|
""" |
|
|
|
|
|
text = self.texts[idx] |
|
|
|
|
|
|
|
|
token_ids = self.tokenizer.encode( |
|
|
text, |
|
|
add_bos=self.add_bos, |
|
|
add_eos=self.add_eos |
|
|
) |
|
|
|
|
|
|
|
|
if len(token_ids) > self.max_length: |
|
|
token_ids = token_ids[:self.max_length] |
|
|
|
|
|
|
|
|
attention_mask = np.ones(len(token_ids), dtype=np.int32) |
|
|
|
|
|
|
|
|
position_ids = np.arange(len(token_ids), dtype=np.int32) |
|
|
|
|
|
return { |
|
|
"input_ids": np.array(token_ids, dtype=np.int32), |
|
|
"attention_mask": attention_mask, |
|
|
"position_ids": position_ids |
|
|
} |
|
|
|
|
|
|
|
|
class TokenizedDataset(Dataset): |
|
|
""" |
|
|
Pre-tokenized dataset. |
|
|
|
|
|
Attributes: |
|
|
token_ids: List of token ID sequences |
|
|
tokenizer: Tokenizer for encoding/decoding text |
|
|
max_length: Maximum sequence length |
|
|
add_bos: Whether to add beginning of sequence token |
|
|
add_eos: Whether to add end of sequence token |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
token_ids: List[List[int]], |
|
|
tokenizer: Tokenizer, |
|
|
max_length: int = 1024, |
|
|
add_bos: bool = True, |
|
|
add_eos: bool = False |
|
|
): |
|
|
""" |
|
|
Initialize dataset. |
|
|
|
|
|
Args: |
|
|
token_ids: List of token ID sequences |
|
|
tokenizer: Tokenizer for encoding/decoding text |
|
|
max_length: Maximum sequence length |
|
|
add_bos: Whether to add beginning of sequence token |
|
|
add_eos: Whether to add end of sequence token |
|
|
""" |
|
|
super().__init__(tokenizer) |
|
|
self.token_ids = token_ids |
|
|
self.max_length = max_length |
|
|
self.add_bos = add_bos |
|
|
self.add_eos = add_eos |
|
|
|
|
|
def __len__(self) -> int: |
|
|
""" |
|
|
Get dataset length. |
|
|
|
|
|
Returns: |
|
|
Dataset length |
|
|
""" |
|
|
return len(self.token_ids) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Get dataset item. |
|
|
|
|
|
Args: |
|
|
idx: Item index |
|
|
|
|
|
Returns: |
|
|
Dictionary of tensors |
|
|
""" |
|
|
|
|
|
ids = self.token_ids[idx].copy() |
|
|
|
|
|
|
|
|
if self.add_bos: |
|
|
ids = [self.tokenizer.bos_token_id] + ids |
|
|
|
|
|
if self.add_eos: |
|
|
ids = ids + [self.tokenizer.eos_token_id] |
|
|
|
|
|
|
|
|
if len(ids) > self.max_length: |
|
|
ids = ids[:self.max_length] |
|
|
|
|
|
|
|
|
attention_mask = np.ones(len(ids), dtype=np.int32) |
|
|
|
|
|
|
|
|
position_ids = np.arange(len(ids), dtype=np.int32) |
|
|
|
|
|
return { |
|
|
"input_ids": np.array(ids, dtype=np.int32), |
|
|
"attention_mask": attention_mask, |
|
|
"position_ids": position_ids |
|
|
} |
|
|
|
|
|
|
|
|
class ConcatDataset(Dataset): |
|
|
""" |
|
|
Concatenated dataset. |
|
|
|
|
|
Attributes: |
|
|
datasets: List of datasets |
|
|
tokenizer: Tokenizer for encoding/decoding text |
|
|
weights: Weights for sampling from datasets |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
datasets: List[Dataset], |
|
|
tokenizer: Tokenizer, |
|
|
weights: Optional[List[float]] = None |
|
|
): |
|
|
""" |
|
|
Initialize dataset. |
|
|
|
|
|
Args: |
|
|
datasets: List of datasets |
|
|
tokenizer: Tokenizer for encoding/decoding text |
|
|
weights: Weights for sampling from datasets |
|
|
""" |
|
|
super().__init__(tokenizer) |
|
|
self.datasets = datasets |
|
|
|
|
|
|
|
|
if weights is None: |
|
|
weights = [1.0] * len(datasets) |
|
|
|
|
|
|
|
|
total = sum(weights) |
|
|
self.weights = [w / total for w in weights] |
|
|
|
|
|
|
|
|
self.cumulative_weights = np.cumsum(self.weights) |
|
|
|
|
|
|
|
|
self.lengths = [len(dataset) for dataset in datasets] |
|
|
self.total_length = sum(self.lengths) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
""" |
|
|
Get dataset length. |
|
|
|
|
|
Returns: |
|
|
Dataset length |
|
|
""" |
|
|
return self.total_length |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Get dataset item. |
|
|
|
|
|
Args: |
|
|
idx: Item index |
|
|
|
|
|
Returns: |
|
|
Dictionary of tensors |
|
|
""" |
|
|
|
|
|
r = random.random() |
|
|
dataset_idx = 0 |
|
|
for i, cw in enumerate(self.cumulative_weights): |
|
|
if r <= cw: |
|
|
dataset_idx = i |
|
|
break |
|
|
|
|
|
|
|
|
item_idx = random.randint(0, self.lengths[dataset_idx] - 1) |
|
|
|
|
|
return self.datasets[dataset_idx][item_idx] |
|
|
|
|
|
|
|
|
def load_jsonl_dataset( |
|
|
file_path: str, |
|
|
tokenizer: Tokenizer, |
|
|
text_key: str = "text", |
|
|
max_length: int = 1024, |
|
|
add_bos: bool = True, |
|
|
add_eos: bool = False, |
|
|
max_samples: Optional[int] = None |
|
|
) -> TextDataset: |
|
|
""" |
|
|
Load dataset from JSONL file. |
|
|
|
|
|
Args: |
|
|
file_path: Path to JSONL file |
|
|
tokenizer: Tokenizer for encoding/decoding text |
|
|
text_key: Key for text field in JSON objects |
|
|
max_length: Maximum sequence length |
|
|
add_bos: Whether to add beginning of sequence token |
|
|
add_eos: Whether to add end of sequence token |
|
|
max_samples: Maximum number of samples to load |
|
|
|
|
|
Returns: |
|
|
Text dataset |
|
|
""" |
|
|
|
|
|
texts = [] |
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
|
for i, line in enumerate(f): |
|
|
if max_samples is not None and i >= max_samples: |
|
|
break |
|
|
|
|
|
data = json.loads(line) |
|
|
texts.append(data[text_key]) |
|
|
|
|
|
|
|
|
return TextDataset( |
|
|
texts=texts, |
|
|
tokenizer=tokenizer, |
|
|
max_length=max_length, |
|
|
add_bos=add_bos, |
|
|
add_eos=add_eos |
|
|
) |
|
|
|
|
|
|
|
|
def load_text_dataset( |
|
|
file_path: str, |
|
|
tokenizer: Tokenizer, |
|
|
max_length: int = 1024, |
|
|
add_bos: bool = True, |
|
|
add_eos: bool = False, |
|
|
max_samples: Optional[int] = None |
|
|
) -> TextDataset: |
|
|
""" |
|
|
Load dataset from text file. |
|
|
|
|
|
Args: |
|
|
file_path: Path to text file |
|
|
tokenizer: Tokenizer for encoding/decoding text |
|
|
max_length: Maximum sequence length |
|
|
add_bos: Whether to add beginning of sequence token |
|
|
add_eos: Whether to add end of sequence token |
|
|
max_samples: Maximum number of samples to load |
|
|
|
|
|
Returns: |
|
|
Text dataset |
|
|
""" |
|
|
|
|
|
texts = [] |
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
|
for i, line in enumerate(f): |
|
|
if max_samples is not None and i >= max_samples: |
|
|
break |
|
|
|
|
|
texts.append(line.strip()) |
|
|
|
|
|
|
|
|
return TextDataset( |
|
|
texts=texts, |
|
|
tokenizer=tokenizer, |
|
|
max_length=max_length, |
|
|
add_bos=add_bos, |
|
|
add_eos=add_eos |
|
|
) |
|
|
|