|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import collections |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
|
|
|
from nemo.utils import logging |
|
|
|
|
|
PACKING_ALGOS = ["first_fit_decreasing", "first_fit_shuffle"] |
|
|
|
|
|
|
|
|
def find_first_bin_that_fits(bins: List[List[int]], s: int, bin_size: int) -> int: |
|
|
""" |
|
|
Finds the first bin in a list of bins that has enough space to fit a sequence of size 's'. |
|
|
|
|
|
Args: |
|
|
bins: A list of lists, where each inner list represents a bin and contains the current elements in that bin. |
|
|
s: The size of the sequence to be placed in a bin. |
|
|
bin_size: The maximum capacity of each bin. |
|
|
|
|
|
Returns: |
|
|
The index of the first bin that can fit the sequence 's', or -1 if no such bin exists. |
|
|
""" |
|
|
for i, abin in enumerate(bins): |
|
|
if sum(abin) + s <= bin_size: |
|
|
return i |
|
|
return -1 |
|
|
|
|
|
|
|
|
def first_fit(seqlens: List[int], pack_size: int) -> List[List[int]]: |
|
|
""" |
|
|
Packs sequences of varying lengths into bins using the First-Fit algorithm. |
|
|
|
|
|
Args: |
|
|
seqlens: A list of integers, representing the lengths of the sequences to be packed. |
|
|
pack_size: The maximum capacity of each bin. |
|
|
|
|
|
Returns: |
|
|
A list of lists, where each inner list represents a bin and contains the indices |
|
|
of the sequences assigned to that bin. |
|
|
""" |
|
|
res = [] |
|
|
for s in seqlens: |
|
|
first_bin = find_first_bin_that_fits(res, s, pack_size) |
|
|
if first_bin == -1: |
|
|
res.append([s]) |
|
|
else: |
|
|
res[first_bin].append(s) |
|
|
return res |
|
|
|
|
|
|
|
|
def first_fit_decreasing(seqlens: List[int], pack_size: int) -> List[List[int]]: |
|
|
""" |
|
|
Packs sequences of varying lengths into bins using the First-Fit Decreasing algorithm. |
|
|
|
|
|
This is a variation of the First-Fit algorithm where the sequences are sorted by decreasing length before packing. |
|
|
|
|
|
Args: |
|
|
seqlens: A list of integers, representing the lengths of the sequences to be packed. |
|
|
pack_size: The maximum capacity of each bin. |
|
|
|
|
|
Returns: |
|
|
A list of lists, similar to the output of the 'first_fit' function. |
|
|
""" |
|
|
sorted_seqlens = sorted(seqlens, reverse=True) |
|
|
return first_fit(sorted_seqlens, pack_size) |
|
|
|
|
|
|
|
|
def first_fit_shuffle(seqlens: List[int], pack_size: int) -> List[List[int]]: |
|
|
""" |
|
|
Packs sequences of varying lengths into bins using the First-Fit with Shuffling algorithm. |
|
|
|
|
|
This variation shuffles the order of the sequences before applying the First-Fit algorithm. |
|
|
|
|
|
Args: |
|
|
seqlens: A list of integers, representing the lengths of the sequences to be packed. |
|
|
pack_size: The maximum capacity of each bin. |
|
|
|
|
|
Returns: |
|
|
A list of lists, similar to the output of the 'first_fit' function. |
|
|
""" |
|
|
shuffled_seqlens = seqlens[:] |
|
|
np.random.shuffle(shuffled_seqlens) |
|
|
return first_fit(shuffled_seqlens, pack_size) |
|
|
|
|
|
|
|
|
def create_hist(dataset: np.array, truncate_seq_len: int): |
|
|
""" |
|
|
Creates a histogram of sequence lengths from a tokenized dataset. |
|
|
|
|
|
This function analyzes the tokenized dataset and creates a histogram showing the distribution of sequence lengths. |
|
|
|
|
|
Args: |
|
|
dataset: A NumPy array containing the tokenized sequences. Each element is a dictionary that contains at minimum |
|
|
the key `input_ids`. |
|
|
truncate_seq_len: The maximum sequence length to consider in the histogram. |
|
|
|
|
|
Returns: |
|
|
sequences: A dictionary where keys are sequence lengths and values are lists |
|
|
of corresponding sequences from the dataset. |
|
|
histogram: A list representing the histogram data (number of sequences for each length). |
|
|
""" |
|
|
logging.info("Creating histogram from tokenized dataset...") |
|
|
|
|
|
sequences = collections.defaultdict(list) |
|
|
counts = [0] * (truncate_seq_len + 1) |
|
|
|
|
|
for item_dict in dataset: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq_len = len(item_dict["input_ids"]) - 1 |
|
|
sequences[seq_len].append(item_dict) |
|
|
counts[seq_len] += 1 |
|
|
|
|
|
logging.debug("Histogram of sequence lengths") |
|
|
logging.debug(counts) |
|
|
|
|
|
histogram = [] |
|
|
for seq_len in range(truncate_seq_len + 1): |
|
|
histogram.append(len(sequences[seq_len])) |
|
|
|
|
|
return sequences, histogram |
|
|
|
|
|
|
|
|
def create_packing_strategy( |
|
|
histogram: List[int], pack_size: int, packing_algorithm: str = "first_fit" |
|
|
) -> Tuple[List[List[int]], dict]: |
|
|
""" |
|
|
Packs sequences into bins using the specified packing algorithm. |
|
|
|
|
|
This function takes the histogram of sequence lengths, desired pack size, and a string representing the packing |
|
|
algorithm to use. It then calls the corresponding function (e.g., 'first_fit_decreasing') and performs the |
|
|
packing process using only sequence lengths as input (without the actual sequences). |
|
|
|
|
|
Args: |
|
|
histogram: A list representing the histogram data (number of sequences for each length). |
|
|
pack_size: The maximum capacity of each bin. |
|
|
packing_algorithm: One of the supported packing algorithms from ['first_fit_decreasing', 'first_fit_shuffle'] |
|
|
|
|
|
Returns: |
|
|
assignments: A list of lists, where each inner list represents a bin and contains the indices of the |
|
|
sequence lengths assigned to that bin. |
|
|
pack_metadata: A dict that records packing metadata, for instance the max number of samples per bin. |
|
|
""" |
|
|
|
|
|
logging.info(f"Packing sequences to length {pack_size}...") |
|
|
|
|
|
all_seq_lens = [] |
|
|
for i, count in enumerate(histogram): |
|
|
all_seq_lens.extend([i] * count) |
|
|
|
|
|
packing_fn = globals()[packing_algorithm] |
|
|
assignments: List[List[int]] = packing_fn(all_seq_lens, pack_size) |
|
|
packed_seq_lens = [sum(x) for x in assignments] |
|
|
packing_factor = len(all_seq_lens) / len(packed_seq_lens) |
|
|
|
|
|
max_seqlen = max(all_seq_lens) |
|
|
max_samples_per_bin = max([len(b) for b in assignments]) |
|
|
min_packed_seqlen = min(packed_seq_lens) |
|
|
packing_metadata = { |
|
|
"dataset_max_seqlen": max_seqlen, |
|
|
"max_samples_per_bin": max_samples_per_bin, |
|
|
"packing_factor": round(packing_factor, 2), |
|
|
"packing_efficiency": round(sum(packed_seq_lens) / len(packed_seq_lens) / pack_size * 100, 2), |
|
|
"pack_size": pack_size, |
|
|
'min_packed_seqlen': min_packed_seqlen, |
|
|
} |
|
|
logging.debug("Packed sequence lengths:") |
|
|
logging.debug(packed_seq_lens) |
|
|
logging.info(f"Packing is {sum(packed_seq_lens) / len(packed_seq_lens) / pack_size * 100:.2f}% efficient") |
|
|
logging.info( |
|
|
f">>>>> For pack size {pack_size}, average number of sequences per pack is n = {packing_factor:.3f} <<<<<" |
|
|
) |
|
|
return assignments, packing_metadata |
|
|
|
|
|
|
|
|
def fill_packing_strategy( |
|
|
assignments: List[List[int]], |
|
|
sequences: Dict[int, List[Dict]], |
|
|
pack_size: int, |
|
|
pad_id: int, |
|
|
) -> List[Dict]: |
|
|
""" |
|
|
Fills the packing strategy with actual sequence data based on assignments and sequence information. |
|
|
|
|
|
This function takes the assignments generated by the packing algorithm (containing sequence length indices), |
|
|
the original sequences data, and the pack size. It iterates through the assignments, retrieves the corresponding |
|
|
sequences from the sequences dictionary, and constructs the final output data structure with input IDs, loss masks |
|
|
(if available), and starting indices for each sequence in a packed sequence. |
|
|
|
|
|
Args: |
|
|
assignments: A list of lists, where each inner list represents a bin and contains the indices of the |
|
|
sequence lengths assigned to that bin (output of 'create_packing_strategy'). |
|
|
sequences: A dictionary where keys are sequence lengths and values are lists of corresponding sequences |
|
|
from the dataset (output of 'create_hist'). |
|
|
pack_size: The maximum capacity of each bin. |
|
|
pad_id: The tokenizer's padding token. |
|
|
|
|
|
Returns: |
|
|
output_data: A list of dictionaries, where each dictionary represents a packed sequence with its input IDs, |
|
|
loss mask (if available), and starting indices. |
|
|
""" |
|
|
ifile_handles = dict() |
|
|
for seq_len in tqdm(range(pack_size + 1)): |
|
|
per_seq_data = sequences[seq_len] |
|
|
if len(per_seq_data) > 0: |
|
|
perm = np.random.permutation(len(per_seq_data)) |
|
|
input_ids = np.array([x["input_ids"] for x in per_seq_data])[perm].tolist() |
|
|
try: |
|
|
loss_mask = np.array([x["loss_mask"] for x in per_seq_data])[perm].tolist() |
|
|
except KeyError: |
|
|
try: |
|
|
loss_mask = np.array( |
|
|
[ |
|
|
[ |
|
|
|
|
|
|
|
|
idx >= (x["answer_start_idx"] - 1) and x["input_ids"][idx] != pad_id |
|
|
for idx in range(len(x["input_ids"])) |
|
|
] |
|
|
for x in per_seq_data |
|
|
] |
|
|
)[perm].tolist() |
|
|
except KeyError as err: |
|
|
err_msg = "Key errors loss_mask and answer_start_idx missing in example - " |
|
|
err_msg += f"{err} {per_seq_data[0]}" |
|
|
logging.error(err_msg) |
|
|
raise ValueError(err_msg) |
|
|
|
|
|
ifile_handles[seq_len] = (input_ids, loss_mask) |
|
|
|
|
|
input_ids, loss_mask, seq_start_id = {}, {}, {} |
|
|
|
|
|
for oindex, assignment in tqdm(enumerate(assignments), total=len(assignments)): |
|
|
_input_ids, _loss_mask, _seq_start_id = [], [], [0] |
|
|
|
|
|
for seq_length in assignment: |
|
|
_input_ids.extend(ifile_handles[seq_length][0].pop()) |
|
|
_loss_mask.extend(ifile_handles[seq_length][1].pop()) |
|
|
_seq_start_id.append(len(_input_ids)) |
|
|
|
|
|
input_ids[oindex] = _input_ids |
|
|
loss_mask[oindex] = _loss_mask |
|
|
seq_start_id[oindex] = _seq_start_id[:-1] |
|
|
|
|
|
output_data = [] |
|
|
for i in range(len(input_ids)): |
|
|
item_dict = { |
|
|
"input_ids": input_ids[i], |
|
|
"loss_mask": loss_mask[i], |
|
|
"seq_start_id": seq_start_id[i], |
|
|
} |
|
|
output_data.append(item_dict) |
|
|
|
|
|
assert all(not seq[0] for seq in ifile_handles.values()), "Error: There are items left over from the assignment" |
|
|
assert all(not seq[1] for seq in ifile_handles.values()), "Error: There are items left over from the assignment" |
|
|
return output_data |
|
|
|