|
|
|
|
|
from functools import partial |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
import torch |
|
|
from megatron.core import mpu |
|
|
from megatron.core.packed_seq_params import PackedSeqParams |
|
|
from megatron.core.utils import StragglerDetector |
|
|
from megatron.training import get_args, get_timers |
|
|
from megatron.training.training import cyclic_iter |
|
|
|
|
|
from swift.llm import DataLoaderDispatcher |
|
|
|
|
|
stimer = StragglerDetector() |
|
|
|
|
|
|
|
|
def get_swift_datasets_provider(train_dataset, val_dataset): |
|
|
|
|
|
def swift_datasets_provider(train_val_test_num_samples): |
|
|
return train_dataset, val_dataset, None |
|
|
|
|
|
return swift_datasets_provider |
|
|
|
|
|
|
|
|
class MegatronDataLoaderDispatcher(DataLoaderDispatcher): |
|
|
|
|
|
@property |
|
|
def group(self): |
|
|
return mpu.get_data_parallel_group() |
|
|
|
|
|
|
|
|
def build_streaming_dataloader(args, dataset, collate_fn): |
|
|
base_dataloader = torch.utils.data.DataLoader( |
|
|
dataset, |
|
|
num_workers=args.num_workers, |
|
|
pin_memory=True, |
|
|
collate_fn=collate_fn, |
|
|
batch_size=args.micro_batch_size, |
|
|
prefetch_factor=args.dataloader_prefetch_factor, |
|
|
persistent_workers=args.dataloader_persistent_workers, |
|
|
) |
|
|
return iter(cyclic_iter(MegatronDataLoaderDispatcher(base_dataloader))) |
|
|
|
|
|
|
|
|
def get_batch_on_this_tp_rank(data_iterator): |
|
|
|
|
|
|
|
|
args = get_args() |
|
|
|
|
|
def _broadcast(item): |
|
|
if item is not None: |
|
|
torch.distributed.broadcast( |
|
|
item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) |
|
|
|
|
|
if mpu.get_tensor_model_parallel_rank() == 0: |
|
|
|
|
|
try: |
|
|
data = next(data_iterator) |
|
|
except StopIteration: |
|
|
seq_length = -1 |
|
|
else: |
|
|
tokens = data['input_ids'] |
|
|
seq_length = tokens.shape[1] |
|
|
batch = { |
|
|
'tokens': tokens.cuda(non_blocking=True), |
|
|
'labels': data['labels'].cuda(non_blocking=True), |
|
|
'attention_mask': |
|
|
None if 'attention_mask' not in data else data['attention_mask'].cuda(non_blocking=True), |
|
|
'position_ids': data['position_ids'].cuda(non_blocking=True) |
|
|
} |
|
|
seq_length = torch.tensor(seq_length).cuda(non_blocking=True) |
|
|
_broadcast(seq_length) |
|
|
if seq_length.item() == -1: |
|
|
return {} |
|
|
if args.pipeline_model_parallel_size == 1: |
|
|
_broadcast(batch['tokens']) |
|
|
_broadcast(batch['labels']) |
|
|
_broadcast(batch['attention_mask']) |
|
|
_broadcast(batch['position_ids']) |
|
|
|
|
|
elif mpu.is_pipeline_first_stage(): |
|
|
_broadcast(batch['tokens']) |
|
|
_broadcast(batch['attention_mask']) |
|
|
_broadcast(batch['position_ids']) |
|
|
|
|
|
elif mpu.is_pipeline_last_stage(): |
|
|
_broadcast(batch['labels']) |
|
|
_broadcast(batch['attention_mask']) |
|
|
_broadcast(batch['position_ids']) |
|
|
|
|
|
else: |
|
|
seq_length = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device()) |
|
|
_broadcast(seq_length) |
|
|
if seq_length.item() == -1: |
|
|
return {} |
|
|
micro_batch_size = 1 |
|
|
tokens = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device()) |
|
|
labels = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device()) |
|
|
if args.create_attention_mask_in_dataloader: |
|
|
attention_mask = torch.empty((micro_batch_size, 1, seq_length, seq_length), |
|
|
dtype=torch.bool, |
|
|
device=torch.cuda.current_device()) |
|
|
else: |
|
|
attention_mask = None |
|
|
position_ids = torch.empty((micro_batch_size, seq_length), |
|
|
dtype=torch.int64, |
|
|
device=torch.cuda.current_device()) |
|
|
|
|
|
if args.pipeline_model_parallel_size == 1: |
|
|
_broadcast(tokens) |
|
|
_broadcast(labels) |
|
|
_broadcast(attention_mask) |
|
|
_broadcast(position_ids) |
|
|
|
|
|
elif mpu.is_pipeline_first_stage(): |
|
|
labels = None |
|
|
|
|
|
_broadcast(tokens) |
|
|
_broadcast(attention_mask) |
|
|
_broadcast(position_ids) |
|
|
|
|
|
elif mpu.is_pipeline_last_stage(): |
|
|
tokens = None |
|
|
|
|
|
_broadcast(labels) |
|
|
_broadcast(attention_mask) |
|
|
_broadcast(position_ids) |
|
|
|
|
|
batch = {'tokens': tokens, 'labels': labels, 'attention_mask': attention_mask, 'position_ids': position_ids} |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
def get_packed_seq_params(position_ids: torch.Tensor) -> Optional[PackedSeqParams]: |
|
|
position_ids_f = position_ids.flatten() |
|
|
indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32) |
|
|
|
|
|
cu_seqlens = torch.cat([ |
|
|
indices_q[position_ids_f == 0], |
|
|
torch.tensor(position_ids_f.shape, device=position_ids_f.device, dtype=torch.int32), |
|
|
]) |
|
|
|
|
|
max_length = position_ids_f.max() + 1 |
|
|
return PackedSeqParams( |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_kv=cu_seqlens, |
|
|
max_seqlen_q=max_length, |
|
|
max_seqlen_kv=max_length, |
|
|
qkv_format='thd') |
|
|
|
|
|
|
|
|
def _split_tokens(tokens, cu_seqlens): |
|
|
assert tokens.shape[0] == 1, f'tokens.shape: {tokens.shape}' |
|
|
new_tokens = [] |
|
|
cp_size = mpu.get_context_parallel_world_size() |
|
|
cp_rank = mpu.get_context_parallel_rank() |
|
|
for i in range(cu_seqlens.shape[0] - 1): |
|
|
val = tokens[:, cu_seqlens[i]:cu_seqlens[i + 1]] |
|
|
val = val.view( |
|
|
tokens.shape[0], |
|
|
2 * cp_size, |
|
|
val.shape[1] // (2 * cp_size), |
|
|
) |
|
|
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', |
|
|
pin_memory=True).cuda(non_blocking=True) |
|
|
val = val.index_select(1, index) |
|
|
new_tokens.append(val.view(tokens.shape[0], -1)) |
|
|
return torch.cat(new_tokens, dim=1) |
|
|
|
|
|
|
|
|
def get_batch_on_this_cp_rank(batch: Dict[str, Any]): |
|
|
"""Slice batch input along sequence dimension into multiple chunks, |
|
|
which are parallelized across GPUs in a context parallel group. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cp_size = mpu.get_context_parallel_world_size() |
|
|
if cp_size > 1: |
|
|
packed_seq_params = batch['packed_seq_params'] |
|
|
for key, val in batch.items(): |
|
|
if key == 'packed_seq_params': |
|
|
continue |
|
|
if val is not None: |
|
|
batch[key] = _split_tokens(val, packed_seq_params.cu_seqlens_q) |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
def get_batch(data_iterator): |
|
|
"""Generate a batch.""" |
|
|
|
|
|
|
|
|
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): |
|
|
return None, None, None, None, None |
|
|
|
|
|
|
|
|
batch = get_batch_on_this_tp_rank(data_iterator) |
|
|
if not batch: |
|
|
return batch |
|
|
batch['packed_seq_params'] = get_packed_seq_params(batch['position_ids']) |
|
|
|
|
|
batch = get_batch_on_this_cp_rank(batch) |
|
|
return batch.values() |
|
|
|
|
|
|
|
|
def forward_step(data_iterator, model): |
|
|
from pretrain_gpt import loss_func |
|
|
|
|
|
timers = get_timers() |
|
|
|
|
|
|
|
|
timers('batch-generator', log_level=2).start() |
|
|
global stimer |
|
|
with stimer(bdata=True): |
|
|
data = get_batch(data_iterator) |
|
|
if not data: |
|
|
raise StopIteration |
|
|
tokens, labels, attention_mask, position_ids, packed_seq_params = data |
|
|
timers('batch-generator').stop() |
|
|
|
|
|
with stimer: |
|
|
output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params) |
|
|
loss_mask = None if labels is None else (labels != -100).float() |
|
|
return output_tensor, partial(loss_func, loss_mask) |
|
|
|