|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy
|
|
|
import heapq
|
|
|
from typing import List, Tuple
|
|
|
|
|
|
import torch
|
|
|
from torch import distributed as dist
|
|
|
|
|
|
|
|
|
def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool):
|
|
|
|
|
|
class Set:
|
|
|
def __init__(self) -> None:
|
|
|
self.sum = 0
|
|
|
self.items = []
|
|
|
|
|
|
def add(self, idx: int, val: int):
|
|
|
self.items.append((idx, val))
|
|
|
self.sum += val
|
|
|
|
|
|
def merge(self, other):
|
|
|
for idx, val in other.items:
|
|
|
self.items.append((idx, val))
|
|
|
self.sum += val
|
|
|
|
|
|
def __lt__(self, other):
|
|
|
if self.sum != other.sum:
|
|
|
return self.sum < other.sum
|
|
|
if len(self.items) != len(other.items):
|
|
|
return len(self.items) < len(other.items)
|
|
|
return self.items < other.items
|
|
|
|
|
|
class State:
|
|
|
def __init__(self, items: List[Tuple[int, int]], k: int) -> None:
|
|
|
self.k = k
|
|
|
|
|
|
self.sets = [Set() for _ in range(k)]
|
|
|
assert len(items) in [1, k], f"{len(items)} not in [1, {k}]"
|
|
|
for i, (idx, seqlen) in enumerate(items):
|
|
|
self.sets[i].add(idx=idx, val=seqlen)
|
|
|
self.sets = sorted(self.sets, reverse=True)
|
|
|
|
|
|
def get_partitions(self):
|
|
|
partitions = []
|
|
|
for i in range(len(self.sets)):
|
|
|
cur_partition = []
|
|
|
for idx, _ in self.sets[i].items:
|
|
|
cur_partition.append(idx)
|
|
|
partitions.append(cur_partition)
|
|
|
return partitions
|
|
|
|
|
|
def merge(self, other):
|
|
|
for i in range(self.k):
|
|
|
self.sets[i].merge(other.sets[self.k - 1 - i])
|
|
|
self.sets = sorted(self.sets, reverse=True)
|
|
|
|
|
|
@property
|
|
|
def spread(self) -> int:
|
|
|
return self.sets[0].sum - self.sets[-1].sum
|
|
|
|
|
|
def __lt__(self, other):
|
|
|
|
|
|
|
|
|
|
|
|
if self.spread != other.spread:
|
|
|
return self.spread > other.spread
|
|
|
return self.sets[0] > other.sets[0]
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
repr_str = "["
|
|
|
for i in range(self.k):
|
|
|
if i > 0:
|
|
|
repr_str += ","
|
|
|
repr_str += "{"
|
|
|
for j, (_, seqlen) in enumerate(self.sets[i].items):
|
|
|
if j > 0:
|
|
|
repr_str += ","
|
|
|
repr_str += str(seqlen)
|
|
|
repr_str += "}"
|
|
|
repr_str += "]"
|
|
|
return repr_str
|
|
|
|
|
|
sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)])
|
|
|
states_pq = []
|
|
|
if equal_size:
|
|
|
assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0"
|
|
|
for offset in range(0, len(sorted_seqlen_list), k_partitions):
|
|
|
items = []
|
|
|
for i in range(k_partitions):
|
|
|
seqlen, idx = sorted_seqlen_list[offset + i]
|
|
|
items.append((idx, seqlen))
|
|
|
heapq.heappush(states_pq, State(items=items, k=k_partitions))
|
|
|
else:
|
|
|
for seqlen, idx in sorted_seqlen_list:
|
|
|
heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions))
|
|
|
|
|
|
while len(states_pq) > 1:
|
|
|
state0 = heapq.heappop(states_pq)
|
|
|
state1 = heapq.heappop(states_pq)
|
|
|
|
|
|
state0.merge(state1)
|
|
|
heapq.heappush(states_pq, state0)
|
|
|
|
|
|
final_state = states_pq[0]
|
|
|
partitions = final_state.get_partitions()
|
|
|
if equal_size:
|
|
|
for i, partition in enumerate(partitions):
|
|
|
assert len(partition) * k_partitions == len(seqlen_list), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}"
|
|
|
return partitions
|
|
|
|
|
|
|
|
|
def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool):
|
|
|
bias = sum(seqlen_list) + 1 if equal_size else 0
|
|
|
sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)]
|
|
|
partitions = [[] for _ in range(k_partitions)]
|
|
|
partition_sums = [0 for _ in range(k_partitions)]
|
|
|
for seqlen, i in sorted_seqlen:
|
|
|
min_idx = None
|
|
|
for j in range(k_partitions):
|
|
|
if min_idx is None or partition_sums[j] < partition_sums[min_idx]:
|
|
|
min_idx = j
|
|
|
partitions[min_idx].append(i)
|
|
|
partition_sums[min_idx] += seqlen
|
|
|
if equal_size:
|
|
|
for i, partition in enumerate(partitions):
|
|
|
assert len(partition) * k_partitions == len(seqlen_list), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}"
|
|
|
return partitions
|
|
|
|
|
|
|
|
|
def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool):
|
|
|
"""get order of seq lengths to make partitions balanced, this is
|
|
|
used in balacing sum of seqlength across dp ranks and microbatches
|
|
|
Parameters:
|
|
|
seqlen_list (List[int]):
|
|
|
seq lengths of each items
|
|
|
k_partitions (int):
|
|
|
resulting number of partitions
|
|
|
equal_size (bool):
|
|
|
if True, number of items in each partitions must be equal.
|
|
|
if False, only consider balancing the sum, each partition can have
|
|
|
variable number of items
|
|
|
Returns:
|
|
|
partitions (List[List[int]]):
|
|
|
return k_partitions list containing the index of items.
|
|
|
"""
|
|
|
assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"
|
|
|
|
|
|
def _check_and_sort_partitions(partitions):
|
|
|
assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}"
|
|
|
seen_idx = set()
|
|
|
sorted_partitions = [None] * k_partitions
|
|
|
for i, partition in enumerate(partitions):
|
|
|
assert len(partition) > 0, f"the {i}-th partition is empty"
|
|
|
for idx in partition:
|
|
|
seen_idx.add(idx)
|
|
|
sorted_partitions[i] = sorted(partition)
|
|
|
assert seen_idx == set(range(len(seqlen_list)))
|
|
|
return sorted_partitions
|
|
|
|
|
|
partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)
|
|
|
return _check_and_sort_partitions(partitions)
|
|
|
|
|
|
|
|
|
def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], prefix):
|
|
|
|
|
|
k_partition = len(partitions)
|
|
|
|
|
|
batch_size = len(seqlen_list) // k_partition
|
|
|
min_sum_seqlen = None
|
|
|
max_sum_seqlen = None
|
|
|
total_sum_seqlen = 0
|
|
|
for offset in range(0, len(seqlen_list), batch_size):
|
|
|
cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size])
|
|
|
if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen:
|
|
|
min_sum_seqlen = cur_sum_seqlen
|
|
|
if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen:
|
|
|
max_sum_seqlen = cur_sum_seqlen
|
|
|
total_sum_seqlen += cur_sum_seqlen
|
|
|
|
|
|
balanced_sum_seqlen_list = []
|
|
|
for partition in partitions:
|
|
|
cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition])
|
|
|
balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced)
|
|
|
|
|
|
min_sum_seqlen_balanced = min(balanced_sum_seqlen_list)
|
|
|
max_sum_seqlen_balanced = max(balanced_sum_seqlen_list)
|
|
|
|
|
|
return {
|
|
|
f"{prefix}/min": min_sum_seqlen,
|
|
|
f"{prefix}/max": max_sum_seqlen,
|
|
|
f"{prefix}/minmax_diff": max_sum_seqlen - min_sum_seqlen,
|
|
|
f"{prefix}/balanced_min": min_sum_seqlen_balanced,
|
|
|
f"{prefix}/balanced_max": max_sum_seqlen_balanced,
|
|
|
f"{prefix}/mean": total_sum_seqlen / len(partitions),
|
|
|
}
|
|
|
|
|
|
|
|
|
def ceildiv(a, b):
|
|
|
return -(a // -b)
|
|
|
|
|
|
|
|
|
def rearrange_micro_batches(batch, max_token_len, dp_group=None, same_micro_num_in_dp=True, min_num_micro_batch=None):
|
|
|
"""
|
|
|
Split a batch into micro-batches by total token count, with optional DP sync and padding.
|
|
|
|
|
|
Args:
|
|
|
batch (TensorDict): must include "attention_mask" (B*S); other fields are sliced similarly.
|
|
|
max_token_len (int): max sum of attention_mask per micro-batch.
|
|
|
dp_group (optional): torch.distributed group for data-parallel sync.
|
|
|
same_micro_num_in_dp (bool): if True and dp_group set, pad all ranks to the same count.
|
|
|
min_num_micro_batch (int, optional): force at least this many splits (pads empty ones).
|
|
|
|
|
|
Returns:
|
|
|
List[TensorDict]: the micro-batches.
|
|
|
List[List[int]]: index lists mapping each micro-batch back to original positions.
|
|
|
"""
|
|
|
|
|
|
max_seq_len = batch["attention_mask"].shape[-1]
|
|
|
assert max_token_len >= max_seq_len, f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}"
|
|
|
seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
|
|
|
total_seqlen = seq_len_effective.sum().item()
|
|
|
|
|
|
num_micro_batches = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))
|
|
|
if min_num_micro_batch is not None:
|
|
|
|
|
|
num_micro_batches = max(min_num_micro_batch, num_micro_batches)
|
|
|
if dist.is_initialized() and same_micro_num_in_dp:
|
|
|
num_micro_batches = torch.tensor([num_micro_batches], device="cuda")
|
|
|
dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)
|
|
|
num_micro_batches = num_micro_batches.cpu().item()
|
|
|
|
|
|
seq_len_effective = seq_len_effective.tolist()
|
|
|
assert num_micro_batches <= len(seq_len_effective)
|
|
|
|
|
|
micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False)
|
|
|
|
|
|
micro_batches = []
|
|
|
|
|
|
for partition in micro_bsz_idx:
|
|
|
curr_micro_batch = []
|
|
|
for idx in partition:
|
|
|
curr_micro_batch.append(batch[idx : idx + 1])
|
|
|
curr_micro_batch = torch.cat(curr_micro_batch)
|
|
|
|
|
|
micro_batches.append(curr_micro_batch)
|
|
|
|
|
|
return micro_batches, micro_bsz_idx
|
|
|
|
|
|
|
|
|
def get_reverse_idx(idx_map):
|
|
|
reverse_idx_map = copy.deepcopy(idx_map)
|
|
|
|
|
|
for i, idx in enumerate(idx_map):
|
|
|
reverse_idx_map[idx] = i
|
|
|
|
|
|
return reverse_idx_map
|
|
|
|