File size: 8,394 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from functools import partial
from typing import Any, Dict, Optional
import torch
from megatron.core import mpu
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import StragglerDetector
from megatron.training import get_args, get_timers
from megatron.training.training import cyclic_iter
from swift.llm import DataLoaderDispatcher
stimer = StragglerDetector()
def get_swift_datasets_provider(train_dataset, val_dataset):
def swift_datasets_provider(train_val_test_num_samples):
return train_dataset, val_dataset, None
return swift_datasets_provider
class MegatronDataLoaderDispatcher(DataLoaderDispatcher):
@property
def group(self):
return mpu.get_data_parallel_group()
def build_streaming_dataloader(args, dataset, collate_fn):
base_dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_fn,
batch_size=args.micro_batch_size,
prefetch_factor=args.dataloader_prefetch_factor,
persistent_workers=args.dataloader_persistent_workers,
)
return iter(cyclic_iter(MegatronDataLoaderDispatcher(base_dataloader)))
def get_batch_on_this_tp_rank(data_iterator):
# copy from megatron-lm
args = get_args()
def _broadcast(item):
if item is not None:
torch.distributed.broadcast(
item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
if mpu.get_tensor_model_parallel_rank() == 0:
try:
data = next(data_iterator)
except StopIteration:
seq_length = -1
else:
tokens = data['input_ids']
seq_length = tokens.shape[1]
batch = {
'tokens': tokens.cuda(non_blocking=True),
'labels': data['labels'].cuda(non_blocking=True),
'attention_mask':
None if 'attention_mask' not in data else data['attention_mask'].cuda(non_blocking=True),
'position_ids': data['position_ids'].cuda(non_blocking=True)
}
seq_length = torch.tensor(seq_length).cuda(non_blocking=True)
_broadcast(seq_length)
if seq_length.item() == -1:
return {}
if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_last_stage():
_broadcast(batch['labels'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
else:
seq_length = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device())
_broadcast(seq_length)
if seq_length.item() == -1:
return {}
micro_batch_size = 1 # use qkv_format 'thd'
tokens = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device())
labels = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device())
if args.create_attention_mask_in_dataloader:
attention_mask = torch.empty((micro_batch_size, 1, seq_length, seq_length),
dtype=torch.bool,
device=torch.cuda.current_device())
else:
attention_mask = None
position_ids = torch.empty((micro_batch_size, seq_length),
dtype=torch.int64,
device=torch.cuda.current_device())
if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_first_stage():
labels = None
_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_last_stage():
tokens = None
_broadcast(labels)
_broadcast(attention_mask)
_broadcast(position_ids) # compat packing & cp
batch = {'tokens': tokens, 'labels': labels, 'attention_mask': attention_mask, 'position_ids': position_ids}
return batch
def get_packed_seq_params(position_ids: torch.Tensor) -> Optional[PackedSeqParams]:
position_ids_f = position_ids.flatten()
indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32)
cu_seqlens = torch.cat([
indices_q[position_ids_f == 0],
torch.tensor(position_ids_f.shape, device=position_ids_f.device, dtype=torch.int32),
])
max_length = position_ids_f.max() + 1
return PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
max_seqlen_q=max_length,
max_seqlen_kv=max_length,
qkv_format='thd')
def _split_tokens(tokens, cu_seqlens):
assert tokens.shape[0] == 1, f'tokens.shape: {tokens.shape}'
new_tokens = []
cp_size = mpu.get_context_parallel_world_size()
cp_rank = mpu.get_context_parallel_rank()
for i in range(cu_seqlens.shape[0] - 1):
val = tokens[:, cu_seqlens[i]:cu_seqlens[i + 1]]
val = val.view(
tokens.shape[0],
2 * cp_size,
val.shape[1] // (2 * cp_size),
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu',
pin_memory=True).cuda(non_blocking=True)
val = val.index_select(1, index)
new_tokens.append(val.view(tokens.shape[0], -1))
return torch.cat(new_tokens, dim=1)
def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
"""Slice batch input along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
"""
# With causal masking, each token only attends to its prior tokens. Simply split
# sequence into CP chunks can result in severe load imbalance. That's to say, chunks
# at the end of sequence have bigger workload than others. To address this issue,
# we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
# and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
# that we can get balanced workload among GPUs in a context parallel group.
cp_size = mpu.get_context_parallel_world_size()
if cp_size > 1:
packed_seq_params = batch['packed_seq_params']
for key, val in batch.items():
if key == 'packed_seq_params':
continue
if val is not None:
batch[key] = _split_tokens(val, packed_seq_params.cu_seqlens_q)
return batch
def get_batch(data_iterator):
"""Generate a batch."""
# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
# get batches based on the TP rank you are on
batch = get_batch_on_this_tp_rank(data_iterator)
if not batch:
return batch
batch['packed_seq_params'] = get_packed_seq_params(batch['position_ids'])
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)
return batch.values()
def forward_step(data_iterator, model):
from pretrain_gpt import loss_func
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
global stimer
with stimer(bdata=True):
data = get_batch(data_iterator)
if not data:
raise StopIteration
tokens, labels, attention_mask, position_ids, packed_seq_params = data
timers('batch-generator').stop()
with stimer:
output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params)
loss_mask = None if labels is None else (labels != -100).float()
return output_tensor, partial(loss_func, loss_mask)
|