|
|
|
|
|
|
|
|
import contextlib |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
from copy import deepcopy |
|
|
from dataclasses import dataclass, field |
|
|
from functools import partial |
|
|
from multiprocessing import Event, Process, Queue |
|
|
from multiprocessing.synchronize import Event as EventClass |
|
|
from pathlib import Path |
|
|
from queue import Empty, Full |
|
|
from typing import Any, Dict, Iterator, Optional, TypedDict |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from core.tokenizer import ChatFormat, TokenizerArgs, build_tokenizer |
|
|
|
|
|
logger = logging.getLogger() |
|
|
|
|
|
""" |
|
|
This file contains all code necessary for text data loading from preshuffled jsonl chunks. |
|
|
For example if given the following files with a world size of 8 |
|
|
|
|
|
/path/to/arxiv: |
|
|
arxiv.chunk.00.jsonl (Contains many lines of {"text":...} or {"content":...}) |
|
|
arxiv.chunk.01.jsonl |
|
|
arxiv.chunk.02.jsonl |
|
|
arxiv.chunk.03.jsonl |
|
|
|
|
|
/path/to/wikipedia: |
|
|
wikipedia.chunk.00.jsonl |
|
|
wikipedia.chunk.01.jsonl |
|
|
wikipedia.chunk.02.jsonl |
|
|
wikipedia.chunk.03.jsonl |
|
|
|
|
|
Step (1) => infinite_block_jsonl_iterator |
|
|
2 workers will read each jsonl chunk (world_size = 8 distributed over 4 workers) from each source. |
|
|
Each worker will read 1 line and skip the next, therefore workers on the same file read in an interleaved manner. |
|
|
|
|
|
Step (2) => multi_choice_iterator |
|
|
At every iteration, a source is sampled randomly given some weights |
|
|
|
|
|
Step (3) => tokenizer and pack_tokens |
|
|
Reads sequences until reaching seq_len tokens and yields a numpy array of shape (seq_len, n_views) |
|
|
|
|
|
Step (4) => prefetch_data_loader |
|
|
Prefetches batches in advance and shuffles them to reduce correlation, yields a numpy array of shape (batch_size, seq_len, n_views) |
|
|
|
|
|
This create a nested iterator structure where each iterator is responsible for a specific task: |
|
|
[ [ [ [ [ (1) read document ] -> (2) sample source ] -> (3) tokenize ] -> (4) tokenize and build sequence of fixed seq_len ] -> (5) prefetch batches ] |
|
|
|
|
|
Each iterator returns a tuple (output, state) where state contains all the info necessary to resume from the last output. |
|
|
|
|
|
build_mixed_token_packing_dataloader creates the states and return an iterator that does everything above |
|
|
|
|
|
build_seperate_token_packing_dataloader does the same thing but swaps step 2 and 3 |
|
|
|
|
|
Both can be called with a resume_state to resume from any given position deterministically |
|
|
""" |
|
|
|
|
|
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" |
|
|
|
|
|
|
|
|
class JSONLState(TypedDict): |
|
|
"""Represents the current state of a JSON line reader. |
|
|
|
|
|
Attributes: |
|
|
content (Dict): The JSON content of the line. |
|
|
file_path (str): The path to the JSONL file. |
|
|
position (int): The file position after reading the line (in bytes). |
|
|
window (int): The window size used for iteration. |
|
|
offset (int): The offset used for iteration. |
|
|
current_iter (Optional[int]): Number of iterations over the jsonl file (for infinite iteration). |
|
|
""" |
|
|
|
|
|
file_path: str |
|
|
position: int |
|
|
block_size: int |
|
|
offset: int |
|
|
current_iter: int |
|
|
|
|
|
|
|
|
class MultiChoiceState(TypedDict): |
|
|
"""Represents the current state of a Multi choice iterator. |
|
|
|
|
|
Attributes: |
|
|
root_dir: path to dataset root directory |
|
|
sources Dict[str, float]: Dict from subdirectory to the weight used for sampling |
|
|
source_states: Dict[str, Any] Dict from source to iterator state |
|
|
rng_state: dict numpy bit generator state used to resume rng |
|
|
""" |
|
|
|
|
|
root_dir: str |
|
|
sources: Dict[str, float] |
|
|
source_to_state: Dict[str, Any] |
|
|
rng_state: Dict[str, Any] |
|
|
|
|
|
|
|
|
class TokenizerState(TypedDict): |
|
|
it_state: Any |
|
|
name: str |
|
|
add_bos: bool |
|
|
add_eos: bool |
|
|
path: Optional[str] |
|
|
|
|
|
|
|
|
class PackTokensState(TypedDict): |
|
|
"""Represents the current state of a packing iterator. |
|
|
|
|
|
Attributes: |
|
|
start_token: int index to start reading from in the current sequence |
|
|
output_seq_len: int Length of sequences to output |
|
|
n_views: dict int Number of views to output. Each view is the same sequence but shifted by 1 from the previous |
|
|
""" |
|
|
|
|
|
start_token: int |
|
|
it_state: Any |
|
|
output_seq_len: int |
|
|
n_views: int |
|
|
seq_len: int |
|
|
|
|
|
|
|
|
class PrefetchState(TypedDict): |
|
|
"""Represents the current state of a prefetching iterator. |
|
|
|
|
|
Attributes: |
|
|
prefetch_buffer: numpy array to store prefetched data |
|
|
seq_idx: int index of the current sequence to resume from |
|
|
rng_state: dict numpy bit generator state used to resume rng |
|
|
""" |
|
|
|
|
|
it_state: Any |
|
|
seq_idx: int |
|
|
rng_state: Dict[str, Any] |
|
|
prefetch_size: int |
|
|
batch_size: int |
|
|
|
|
|
|
|
|
def read_jsonl( |
|
|
file_path: str, |
|
|
position: int, |
|
|
block_size: int, |
|
|
offset: int, |
|
|
current_iter: int, |
|
|
): |
|
|
"""Iterates over a JSON Lines file, yielding a line every `block_size` lines with an offset |
|
|
|
|
|
Example : If block_size = 3, offset = 1, iterator will yield lines 1 4 7 10 ... |
|
|
Example : If block_size = 2, offset = 0, iterator will yield lines 0 2 4 6 ... |
|
|
|
|
|
Args: |
|
|
file_path (str): Path to the JSONL file. |
|
|
position (int): The file position (in bytes) from which to start reading. |
|
|
block_size (int): The number of lines to skip between yields |
|
|
offset (int): The initial number of lines skipped |
|
|
|
|
|
Yields: |
|
|
JSONLState: Represents the state of each line read according to window and offset. |
|
|
""" |
|
|
if (offset < 0) or (offset >= block_size): |
|
|
raise RuntimeError(f"JSONL iterator offset value is invalid") |
|
|
|
|
|
|
|
|
current_line = offset + 1 if position > 0 else 0 |
|
|
|
|
|
state = JSONLState( |
|
|
file_path=file_path, |
|
|
position=position, |
|
|
block_size=block_size, |
|
|
offset=offset, |
|
|
current_iter=current_iter, |
|
|
) |
|
|
with open(file_path, "r") as file: |
|
|
file.seek(position) |
|
|
while line := file.readline(): |
|
|
current_line += 1 |
|
|
if (current_line - 1) % block_size == offset: |
|
|
|
|
|
|
|
|
state = JSONLState( |
|
|
file_path=file_path, |
|
|
position=file.tell(), |
|
|
block_size=block_size, |
|
|
offset=offset, |
|
|
current_iter=current_iter, |
|
|
) |
|
|
yield json.loads(line), state |
|
|
|
|
|
|
|
|
def loop_on_jsonl( |
|
|
file_path: str, |
|
|
position: int, |
|
|
block_size: int, |
|
|
offset: int, |
|
|
current_iter: int, |
|
|
): |
|
|
"""Makes the block jsonl iterator infinite and updates n_iter counter""" |
|
|
try: |
|
|
while True: |
|
|
it = read_jsonl(file_path, position, block_size, offset, current_iter) |
|
|
for content, jsonl_state in it: |
|
|
yield content, jsonl_state |
|
|
current_iter += 1 |
|
|
position = 0 |
|
|
finally: |
|
|
it.close() |
|
|
|
|
|
|
|
|
def tokenize( |
|
|
iterator: Iterator, |
|
|
add_bos: bool, |
|
|
add_eos: bool, |
|
|
tokenizer_type: str, |
|
|
tokenizer_path: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Tokenizes text from an iterator of content-state pairs using a specified tokenizer. |
|
|
|
|
|
Parameters: |
|
|
- iterator: An iterable of (content, state) pairs where content is a dict with a 'text' or 'content' key. |
|
|
- tokenizer: Tokenizer object with an `encode` method to convert text to tokens, supporting `add_bos` and `add_eos`. |
|
|
- add_bos (bool): Flag to add a beginning-of-sequence token. |
|
|
- add_eos (bool): Flag to add an end-of-sequence token. |
|
|
|
|
|
Yields: |
|
|
- (tokens, state) pairs, where `tokens` is a list of tokenized text, and `state` is the original state from the iterator. |
|
|
""" |
|
|
tokenizer = build_tokenizer(name=tokenizer_type, path=tokenizer_path) |
|
|
if tokenizer_type == "llama3": |
|
|
chat_format = ChatFormat(tokenizer) |
|
|
for content, state in iterator: |
|
|
if "conversations" in content: |
|
|
assert ( |
|
|
tokenizer_type == "llama3" |
|
|
), "conversations should be tokenized with llama3 tokenizer" |
|
|
dialog = [] |
|
|
first_user_media = None |
|
|
if "image" in content: |
|
|
first_user_media = content["image"] |
|
|
for conversation in content["conversations"]: |
|
|
role = "human" if conversation["from"] == "human" else "assistant" |
|
|
text = conversation["value"] |
|
|
if role == "human" and first_user_media: |
|
|
dialog.append( |
|
|
{ |
|
|
"role": role, |
|
|
"content": text, |
|
|
"type": "image", |
|
|
"num_chunks": first_user_media["num_chunks"], |
|
|
} |
|
|
) |
|
|
else: |
|
|
dialog.append({"role": role, "content": text}) |
|
|
tokens = chat_format.encode_dialog_prompt(dialog) |
|
|
else: |
|
|
assert ( |
|
|
"text" in content or "content" in content |
|
|
), "JSON line must contain either text or content key" |
|
|
content_key = "text" if ("text" in content) else "content" |
|
|
text = content[content_key] |
|
|
tokens = tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos) |
|
|
yield tokens, TokenizerState( |
|
|
it_state=state, |
|
|
add_bos=add_bos, |
|
|
add_eos=add_eos, |
|
|
name=tokenizer_type, |
|
|
path=tokenizer_path, |
|
|
) |
|
|
|
|
|
|
|
|
def choose_source( |
|
|
source_to_iterator: Dict[str, Iterator], |
|
|
source_to_state: Dict[str, Any], |
|
|
root_dir: str, |
|
|
sources: Dict[str, float], |
|
|
rng_state: Dict[str, Any], |
|
|
): |
|
|
""" |
|
|
Iterates over multiple data sources, selecting sequences based on weighted random choice. |
|
|
|
|
|
Parameters: |
|
|
- source_to_iterator (Dict[str, Iterator]): Dict from source paths to their iterators. |
|
|
- source_to_state (Dict[str, State]): Initial state for each source, allowing state tracking. |
|
|
- root_dir str: Root dir of data sources |
|
|
- sources Dict[str, float]: Dict from subdirectory to the weight used for sampling |
|
|
- rng_state (dict): State of the random number generator for reproducibility. |
|
|
|
|
|
Yields: |
|
|
- Tuple of (seq, multi_choice_state) where `seq` is the next sequence from the chosen source, |
|
|
and `multi_choice_state` includes the current state of all sources and the RNG. |
|
|
|
|
|
This function ensures that sequences are chosen from the provided sources based on the specified weights, |
|
|
maintaining state information for each source and the RNG to allow for reproducible iteration. |
|
|
""" |
|
|
n_sources = len(sources) |
|
|
possible_sources = list(sources.keys()) |
|
|
weights = list(sources.values()) |
|
|
|
|
|
rng = np.random.default_rng() |
|
|
rng.bit_generator.state = rng_state |
|
|
while True: |
|
|
|
|
|
norm_weights = np.array(weights) / np.array(weights).sum() |
|
|
source_choice = possible_sources[rng.choice(n_sources, p=norm_weights)] |
|
|
seq, state = next(source_to_iterator[source_choice]) |
|
|
source_to_state = {**source_to_state, source_choice: state} |
|
|
|
|
|
multi_choice_state = MultiChoiceState( |
|
|
root_dir=root_dir, |
|
|
sources=sources, |
|
|
source_to_state=source_to_state, |
|
|
rng_state=rng.bit_generator.state, |
|
|
) |
|
|
yield seq, multi_choice_state |
|
|
|
|
|
|
|
|
def get_empty_buffer_state( |
|
|
start_token, |
|
|
states, |
|
|
): |
|
|
""" |
|
|
Calculates the state to resume iteration after the buffer is cleared. |
|
|
|
|
|
This function determines the starting point for resuming iteration by rewinding `n_views` from the `end_token`. |
|
|
It handles cases where the rewind goes beyond the current sequence, adjusting the starting sequence and token index accordingly. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq_to_resume_from = -1 |
|
|
while start_token < 0: |
|
|
seq_to_resume_from -= 1 |
|
|
start_token += states[seq_to_resume_from]["seq_len"] |
|
|
resume_state = deepcopy(states[seq_to_resume_from]) |
|
|
resume_state["start_token"] = start_token |
|
|
|
|
|
del states[:seq_to_resume_from] |
|
|
if "seq_len" in resume_state: |
|
|
del resume_state["seq_len"] |
|
|
|
|
|
return resume_state |
|
|
|
|
|
|
|
|
def pack_tokens( |
|
|
iterator: Iterator, |
|
|
empty_buffer_state: PackTokensState, |
|
|
): |
|
|
""" |
|
|
Iterates over tokens, packing them into chunks. |
|
|
|
|
|
This function aggregates tokens into a buffer and yields fixed-size chunks with dimensions `(output_seq_len, n_views)`, |
|
|
where each column represents shifted sequences of tokens. It ensures continuity in token sequences across chunks, |
|
|
preventing boundary effects and maintaining consistency regardless of `n_views`. |
|
|
|
|
|
Parameters: |
|
|
- iterator: An iterator that yields pairs of (tokens, state), where tokens is a 1D sequence of tokens and state contains all necessary information to resume iterator from current position. |
|
|
- it_state: State of the iterator currently. |
|
|
- start_token (int): The index of the first token to start reading from for the first sequence. |
|
|
- output_seq_len (int): The length of the output sequences to be generated. |
|
|
- n_views (int): The number of shifted views to include in each output chunk. |
|
|
|
|
|
Yields: |
|
|
- numpy.ndarray: An array of shape `(output_seq_len, n_views)` containing the packed tokens. |
|
|
- PackTokensState: The state required to resume packing tokens from where the last returned chunk. |
|
|
|
|
|
The function handles the complexity of determining the correct state for resuming iteration after the buffer is cleared, ensuring seamless continuation of token sequences. |
|
|
""" |
|
|
buffer = [] |
|
|
states = [] |
|
|
output_seq_len = empty_buffer_state["output_seq_len"] |
|
|
n_views = empty_buffer_state["n_views"] |
|
|
start_token = empty_buffer_state["start_token"] |
|
|
previous_state = empty_buffer_state["it_state"] |
|
|
buffer_size = output_seq_len + n_views - 1 |
|
|
for i, (tokens, state) in enumerate(iterator): |
|
|
end_token = start_token |
|
|
sample_is_read = False |
|
|
while not sample_is_read: |
|
|
assert start_token < len( |
|
|
tokens |
|
|
), f"Start token index {start_token} bigger than sequence {len(tokens)}" |
|
|
free_space = buffer_size - len(buffer) |
|
|
seq_len = min(free_space, len(tokens) - start_token) |
|
|
end_token = start_token + seq_len |
|
|
buffer.extend(tokens[start_token:end_token]) |
|
|
start_token = end_token |
|
|
|
|
|
states.append( |
|
|
PackTokensState( |
|
|
start_token=start_token, |
|
|
seq_len=seq_len, |
|
|
it_state=previous_state, |
|
|
output_seq_len=output_seq_len, |
|
|
n_views=n_views, |
|
|
) |
|
|
) |
|
|
assert len(buffer) <= buffer_size, "Buffer overflow" |
|
|
|
|
|
if len(buffer) == buffer_size: |
|
|
out = np.array(buffer) |
|
|
assert out.ndim == 1, "Iterator should return 1D sequences" |
|
|
out = np.lib.stride_tricks.sliding_window_view( |
|
|
out, n_views, axis=0 |
|
|
) |
|
|
|
|
|
|
|
|
rewinded_idx = start_token - (n_views - 1) |
|
|
empty_buffer_state = get_empty_buffer_state(rewinded_idx, states) |
|
|
buffer = buffer[output_seq_len:] |
|
|
assert len(buffer) == (n_views - 1) |
|
|
|
|
|
yield out, empty_buffer_state |
|
|
|
|
|
if start_token == len(tokens): |
|
|
start_token = 0 |
|
|
sample_is_read = True |
|
|
previous_state = state |
|
|
|
|
|
|
|
|
def batch_and_shuffle_prefetched_sequences( |
|
|
data_loader: Iterator, |
|
|
batch_size: int, |
|
|
prefetch_size: int, |
|
|
seq_len: int, |
|
|
n_views: int, |
|
|
state: PrefetchState, |
|
|
): |
|
|
""" |
|
|
Prepare batch in advance and shuffle them to reduce correlation inside batches (for ex when very long document is encountered). |
|
|
|
|
|
This function aggregates batches into a buffer and yields fixed-size batch size and seqlen with dimensions `(batch_size, seqlen, n_views)`, |
|
|
|
|
|
It uses a prefetch buffer to store batches in advance and shuffles them, the prefetch buffer is similar to `reservoir sampling`, |
|
|
but by block to preserve a smooth, easy and deterministic reloading. To ensure more uniform sequence sampling -> prefetch_size * batch_size * seq_len >> max_document_seqlength. |
|
|
|
|
|
Parameters: |
|
|
- iterator: An iterator that yields pairs of (sequence, state), where is a random sequence sampled from a corpus (as done by pack_tokens for example). |
|
|
- batch_size: The desired batch size. |
|
|
- prefetch_size: The number of batches to prefetch in advance. |
|
|
- seq_len (int): The length of the output sequences to be generated. |
|
|
- n_views (int): The number of shifted views to include in each output chunk. |
|
|
|
|
|
Yields: |
|
|
- numpy.ndarray: An array of shape `(batch_size, seq_len, n_views)` containing the packed tokens. |
|
|
- PrefetchState: The state required to resume prefetched batch. Contains also the internal of iterator. |
|
|
""" |
|
|
prefetch_buffer = -1 * np.ones( |
|
|
(prefetch_size * batch_size, seq_len, n_views), dtype=int |
|
|
) |
|
|
rng = np.random.default_rng() |
|
|
rng.bit_generator.state = state["rng_state"] |
|
|
|
|
|
|
|
|
seq_idx = state["seq_idx"] |
|
|
assert ( |
|
|
seq_idx >= 0 and seq_idx < prefetch_size |
|
|
), "Prefetch state seq_idx should be in 0 <= seq_idx < prefetch_size." |
|
|
|
|
|
_rng_state = state["rng_state"] |
|
|
_it_state = state["it_state"] |
|
|
|
|
|
for i in range(prefetch_size * batch_size): |
|
|
prefetch_buffer[i], next_it_state = next(data_loader) |
|
|
rng.shuffle(prefetch_buffer, axis=0) |
|
|
for i in range(seq_idx * batch_size): |
|
|
prefetch_buffer[i], _ = next(data_loader) |
|
|
|
|
|
idx = seq_idx |
|
|
while True: |
|
|
if idx == prefetch_size - 1: |
|
|
_it_state = next_it_state |
|
|
_rng_state = rng.bit_generator.state |
|
|
|
|
|
state = PrefetchState( |
|
|
it_state=_it_state, |
|
|
seq_idx=(idx + 1) % prefetch_size, |
|
|
rng_state=_rng_state, |
|
|
batch_size=batch_size, |
|
|
prefetch_size=prefetch_size, |
|
|
) |
|
|
|
|
|
yield prefetch_buffer[idx * batch_size : (idx + 1) * batch_size].copy(), state |
|
|
|
|
|
for i in range(batch_size): |
|
|
prefetch_buffer[idx * batch_size + i], pack_state = next(data_loader) |
|
|
|
|
|
if idx == prefetch_size - 1: |
|
|
next_it_state = pack_state |
|
|
rng.shuffle(prefetch_buffer, axis=0) |
|
|
|
|
|
idx = (idx + 1) % prefetch_size |
|
|
|
|
|
|
|
|
def find_and_sanitize_chunks( |
|
|
dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN |
|
|
): |
|
|
dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] |
|
|
n_chunks = len(dataset_chunks) |
|
|
|
|
|
if n_chunks == 0: |
|
|
logger.fatal(f"No valid chunks of pattern {file_pattern} in {dataset_path}") |
|
|
|
|
|
if n_chunks > world_size: |
|
|
n_discard = n_chunks - world_size |
|
|
dataset_chunks = dataset_chunks[:world_size] |
|
|
else: |
|
|
assert ( |
|
|
world_size % n_chunks == 0 |
|
|
), "World size should be a multiple of number of chunks" |
|
|
|
|
|
assert n_chunks > 0, f"No valid chunks in {dataset_path}" |
|
|
|
|
|
return dataset_chunks |
|
|
|
|
|
|
|
|
def distribute_data_to_rank( |
|
|
dataset_path: str, rank: int, world_size: int, file_pattern: str |
|
|
): |
|
|
""" |
|
|
Distributes the chunk files in a dataset path to each worker. |
|
|
If world_size is smaller than the number of chunks, the extra chunks are discarded. |
|
|
Otherwise, world_size is assumed to be a multiple of number of chunks. |
|
|
In that case there are world_size//nb_chunks workers on each chunk file, reading with different offsets. |
|
|
""" |
|
|
dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size, file_pattern) |
|
|
n_ranks_per_chunk = world_size // len(dataset_chunks) |
|
|
rank_to_jsonl_iterator_params = [] |
|
|
for chunk_path in dataset_chunks: |
|
|
for i in range(n_ranks_per_chunk): |
|
|
rank_to_jsonl_iterator_params.append( |
|
|
JSONLState( |
|
|
file_path=chunk_path, |
|
|
position=0, |
|
|
block_size=n_ranks_per_chunk, |
|
|
offset=i, |
|
|
current_iter=0, |
|
|
) |
|
|
) |
|
|
|
|
|
return rank_to_jsonl_iterator_params[rank] |
|
|
|
|
|
|
|
|
def init_choice_state( |
|
|
root_dir: str, |
|
|
sources: Dict[str, float], |
|
|
seed: int, |
|
|
rank: int, |
|
|
world_size: int, |
|
|
file_pattern: str, |
|
|
): |
|
|
data_path_to_jsonl_state = dict() |
|
|
for dataset_path in sources: |
|
|
logger.info( |
|
|
f"Distributing data to rank {rank} for dataset {dataset_path} from root {root_dir}" |
|
|
) |
|
|
jsonl_state = distribute_data_to_rank( |
|
|
os.path.join(root_dir, dataset_path), rank, world_size, file_pattern |
|
|
) |
|
|
data_path_to_jsonl_state[dataset_path] = jsonl_state |
|
|
|
|
|
multi_rng_state = np.random.default_rng( |
|
|
(seed, rank, world_size) |
|
|
).bit_generator.state |
|
|
|
|
|
multi_choice_state = MultiChoiceState( |
|
|
root_dir=root_dir, |
|
|
sources=sources, |
|
|
source_to_state=data_path_to_jsonl_state, |
|
|
rng_state=multi_rng_state, |
|
|
) |
|
|
return multi_choice_state |
|
|
|
|
|
|
|
|
def init_state( |
|
|
root_dir: str, |
|
|
sources: Dict[str, float], |
|
|
batch_size: int, |
|
|
prefetch_size: int, |
|
|
seq_len: int, |
|
|
n_views: int, |
|
|
seed: int, |
|
|
rank: int, |
|
|
world_size: int, |
|
|
add_bos: bool, |
|
|
add_eos: bool, |
|
|
tokenizer_name: str, |
|
|
tokenizer_path: Optional[str] = None, |
|
|
file_pattern: str = TRAIN_DATA_FILE_PATTERN, |
|
|
image_size: Optional[int] = None, |
|
|
patch_size: Optional[int] = None, |
|
|
max_num_tiles: Optional[int] = None, |
|
|
vision_input_type: Optional[str] = None, |
|
|
): |
|
|
multi_choice_state = init_choice_state( |
|
|
root_dir=root_dir, |
|
|
sources=sources, |
|
|
seed=seed, |
|
|
rank=rank, |
|
|
world_size=world_size, |
|
|
file_pattern=file_pattern, |
|
|
) |
|
|
tokenizer_state = TokenizerState( |
|
|
it_state=multi_choice_state, |
|
|
add_bos=add_bos, |
|
|
add_eos=add_eos, |
|
|
name=tokenizer_name, |
|
|
path=tokenizer_path, |
|
|
) |
|
|
pack_state = PackTokensState( |
|
|
start_token=0, |
|
|
it_state=tokenizer_state, |
|
|
output_seq_len=seq_len, |
|
|
n_views=n_views, |
|
|
seq_len=0, |
|
|
) |
|
|
|
|
|
prefetch_rng_state = np.random.default_rng( |
|
|
(seed + 1, rank, world_size) |
|
|
).bit_generator.state |
|
|
|
|
|
return PrefetchState( |
|
|
it_state=pack_state, |
|
|
seq_idx=0, |
|
|
rng_state=prefetch_rng_state, |
|
|
batch_size=batch_size, |
|
|
prefetch_size=prefetch_size, |
|
|
) |
|
|
|
|
|
|
|
|
def setup_sources(multi_state): |
|
|
path_to_iter = dict() |
|
|
for source in multi_state["sources"]: |
|
|
jsonl_state = multi_state["source_to_state"][source] |
|
|
path_to_iter[source] = loop_on_jsonl( |
|
|
jsonl_state["file_path"], |
|
|
jsonl_state["position"], |
|
|
jsonl_state["block_size"], |
|
|
jsonl_state["offset"], |
|
|
jsonl_state["current_iter"], |
|
|
) |
|
|
|
|
|
return path_to_iter |
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def build_dataloader( |
|
|
state: PrefetchState, |
|
|
): |
|
|
pack_state = state["it_state"] |
|
|
tokenizer_state = pack_state["it_state"] |
|
|
multi_state = tokenizer_state["it_state"] |
|
|
|
|
|
path_to_iter = setup_sources(multi_state) |
|
|
data_it = choose_source( |
|
|
source_to_iterator=path_to_iter, |
|
|
source_to_state=multi_state["source_to_state"], |
|
|
root_dir=multi_state["root_dir"], |
|
|
sources=multi_state["sources"], |
|
|
rng_state=multi_state["rng_state"], |
|
|
) |
|
|
data_it = tokenize( |
|
|
data_it, |
|
|
tokenizer_state["add_bos"], |
|
|
tokenizer_state["add_eos"], |
|
|
tokenizer_state["name"], |
|
|
tokenizer_state["path"], |
|
|
) |
|
|
|
|
|
data_it = pack_tokens( |
|
|
data_it, |
|
|
pack_state, |
|
|
) |
|
|
|
|
|
data_it = batch_and_shuffle_prefetched_sequences( |
|
|
data_loader=data_it, |
|
|
seq_len=pack_state["output_seq_len"], |
|
|
n_views=pack_state["n_views"], |
|
|
batch_size=state["batch_size"], |
|
|
prefetch_size=state["prefetch_size"], |
|
|
state=state, |
|
|
) |
|
|
yield data_it |
|
|
for it in path_to_iter.values(): |
|
|
it.close() |
|
|
data_it.close() |
|
|
|
|
|
|
|
|
def feed_buffer(queue: Queue, stop_event: EventClass, iterator_builder): |
|
|
""" |
|
|
Producer function to fetch data from an iterable dataset and put it into a queue. |
|
|
Incorporates timeout management to avoid hanging on queue.put() when the queue is full. |
|
|
""" |
|
|
with iterator_builder() as iterator: |
|
|
for item in iterator: |
|
|
while not stop_event.is_set(): |
|
|
try: |
|
|
queue.put( |
|
|
item, timeout=0.1 |
|
|
) |
|
|
break |
|
|
except Full: |
|
|
pass |
|
|
if stop_event.is_set(): |
|
|
break |
|
|
|
|
|
|
|
|
def consume_buffer(producer: Process, queue: Queue): |
|
|
""" |
|
|
Consumer function to process items from the queue. |
|
|
Handles cases where the queue might be empty by implementing timeouts on queue.get(). |
|
|
""" |
|
|
while producer.exitcode is None: |
|
|
try: |
|
|
item = queue.get( |
|
|
timeout=0.1 |
|
|
) |
|
|
yield item |
|
|
except Empty: |
|
|
pass |
|
|
|
|
|
raise RuntimeError( |
|
|
"Data loader quit unexpectedly, real error has been raised previously" |
|
|
) |
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def async_iterator(buffer_size: int, iterator_builder): |
|
|
""" |
|
|
Context manager to setup and manage asynchronous iteration with producer-consumer model. |
|
|
""" |
|
|
queue = Queue(maxsize=buffer_size) |
|
|
stop_event = Event() |
|
|
producer = Process(target=feed_buffer, args=(queue, stop_event, iterator_builder)) |
|
|
logger.info("Async dataloader started") |
|
|
producer.start() |
|
|
|
|
|
consumer = consume_buffer(producer, queue) |
|
|
try: |
|
|
yield consumer |
|
|
finally: |
|
|
stop_event.set() |
|
|
consumer.close() |
|
|
producer.join(timeout=0.2) |
|
|
if producer.exitcode is None: |
|
|
logger.info(f"Killing async data process {producer.pid} ...") |
|
|
producer.kill() |
|
|
else: |
|
|
logger.info( |
|
|
f"Async data process {producer.pid} exited with code {producer.exitcode}" |
|
|
) |
|
|
logger.info("Async dataloader cleaned up") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataArgs: |
|
|
root_dir: Optional[str] = None |
|
|
sources: Dict[str, float] = field(default_factory=dict) |
|
|
batch_size: int = 2 |
|
|
seq_len: int = 2048 |
|
|
n_views: int = 2 |
|
|
seed: int = 42 |
|
|
add_bos: bool = True |
|
|
add_eos: bool = True |
|
|
load_async: bool = True |
|
|
prefetch_size: int = 64 |
|
|
image_size: Optional[int] = None |
|
|
patch_size: Optional[int] = None |
|
|
max_num_tiles: Optional[int] = None |
|
|
vision_input_type: Optional[str] = None |
|
|
tokenizer: TokenizerArgs = field(default_factory=TokenizerArgs) |
|
|
|
|
|
|
|
|
def init_dataloader_state_from_args( |
|
|
args: DataArgs, |
|
|
rank: int, |
|
|
world_size: int, |
|
|
): |
|
|
return init_state( |
|
|
root_dir=args.root_dir, |
|
|
sources=args.sources, |
|
|
seq_len=args.seq_len, |
|
|
batch_size=args.batch_size, |
|
|
prefetch_size=args.prefetch_size, |
|
|
n_views=args.n_views, |
|
|
seed=args.seed, |
|
|
rank=rank, |
|
|
world_size=world_size, |
|
|
tokenizer_name=args.tokenizer.name, |
|
|
tokenizer_path=args.tokenizer.path, |
|
|
add_bos=args.add_bos, |
|
|
add_eos=args.add_eos, |
|
|
image_size=args.image_size, |
|
|
patch_size=args.patch_size, |
|
|
max_num_tiles=args.max_num_tiles, |
|
|
vision_input_type=args.vision_input_type, |
|
|
) |
|
|
|
|
|
|
|
|
def build_dataloader_from_args( |
|
|
args: DataArgs, |
|
|
state: Optional[PrefetchState] = None, |
|
|
): |
|
|
data_builder = partial(build_dataloader, state) |
|
|
if args.load_async: |
|
|
return async_iterator(args.prefetch_size, data_builder) |
|
|
else: |
|
|
return data_builder() |
|
|
|