File size: 5,671 Bytes
2742ed8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any
import datasets
import torch
import torch.distributed as dist
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers.trainer_utils import seed_worker
from .base import SequenceParallel
class XTuner(SequenceParallel):
@staticmethod
def assert_xtuner_runtime_condition():
from swift.utils import is_xtuner_available
assert is_xtuner_available(), \
('Please install XTuner first to pack dataset to `max_length`.'
'`pip install -U \'xtuner[deepspeed]\'`')
assert dist.is_initialized(), 'pack_to_max_length is only available with distributed training.'
def pack_dataset_xtuner(self, dataset: Dataset, args: Any) -> Any:
self.assert_xtuner_runtime_condition()
if dist.get_rank() == 0:
ds = [i[0] for i in dataset.data]
train_dataset = Dataset.from_list(ds)
from xtuner.dataset.huggingface import pack_dataset
train_dataset = pack_dataset(
train_dataset,
max_length=args.max_length,
use_varlen_attn=False,
shuffle_before_pack=True,
map_num_proc=16)
objects = [train_dataset]
train_dataset.save_to_disk('alpaca_pack')
else:
objects = [None]
dist.broadcast_object_list(objects, src=0)
train_dataset = objects[0]
return train_dataset
@property
def sp_group(self):
from xtuner.parallel.sequence import get_sequence_parallel_group
return get_sequence_parallel_group()
def init_sequence_parallel(self, size):
self.assert_xtuner_runtime_condition()
from xtuner.parallel.sequence import init_sequence_parallel
init_sequence_parallel(size)
def prepare_model(self, model, tokenizer, split_in_forward):
self.assert_xtuner_runtime_condition()
from xtuner.model.modules.dispatch import dispatch_modules
dispatch_modules(model)
def pad_and_split_inputs(self,
tokenizer,
input_ids,
input_embeds,
labels,
position_ids,
attention_mask,
loss_scale,
embed_tokens=None):
self.assert_xtuner_runtime_condition()
from xtuner.parallel.sequence import (pad_for_sequence_parallel, split_for_sequence_parallel,
get_sequence_parallel_group)
input_ids = pad_for_sequence_parallel(input_ids, padding_value=tokenizer.pad_token_id, dim=-1)
labels = pad_for_sequence_parallel(labels, padding_value=-100, dim=-1)
position_ids = pad_for_sequence_parallel(position_ids, padding_value=0, dim=-1)
if attention_mask is not None:
attention_mask = pad_for_sequence_parallel(attention_mask, padding_value=0, dim=-1)
sp_group = get_sequence_parallel_group()
input_ids = split_for_sequence_parallel(input_ids, dim=1, sp_group=sp_group)
labels = split_for_sequence_parallel(labels, dim=1, sp_group=sp_group)
position_ids = split_for_sequence_parallel(position_ids, dim=1, sp_group=sp_group)
if attention_mask is not None:
attention_mask = split_for_sequence_parallel(attention_mask, dim=-1, sp_group=sp_group)
if loss_scale is not None:
loss_scale = pad_for_sequence_parallel(loss_scale, padding_value=0., dim=-1)
loss_scale = split_for_sequence_parallel(loss_scale, dim=1, sp_group=sp_group)
return input_ids, None, labels, position_ids, attention_mask, loss_scale
def reduce_outputs(self, loss, labels):
from xtuner.parallel.sequence import (reduce_sequence_parallel_loss, get_sequence_parallel_group)
# reduce loss for logging correctly
num_tokens = (labels != -100).sum()
return reduce_sequence_parallel_loss(loss, num_tokens, get_sequence_parallel_group())
def world_size(self):
self.assert_xtuner_runtime_condition()
from xtuner.parallel.sequence import get_sequence_parallel_world_size
return get_sequence_parallel_world_size()
def prepare_trainer(self, trainer):
pass
def get_dataloader(self, trainer, dataset, batch_size):
# modified from HFTrainer.get_train_dataloader
# RandomSampler -> SequenceParallelSampler
self.assert_xtuner_runtime_condition()
data_collator = trainer.data_collator
if isinstance(dataset, datasets.Dataset):
dataset = trainer._remove_unused_columns(dataset, description='training')
else:
data_collator = trainer._get_collator_with_removed_columns(data_collator, description='training')
dataloader_params = {
'batch_size': batch_size,
'collate_fn': data_collator,
'num_workers': trainer.args.dataloader_num_workers,
'pin_memory': trainer.args.dataloader_pin_memory,
'persistent_workers': trainer.args.dataloader_persistent_workers,
}
if not isinstance(dataset, torch.utils.data.IterableDataset):
from xtuner.parallel import SequenceParallelSampler
dataloader_params['sampler'] = SequenceParallelSampler(dataset, seed=1024)
dataloader_params['drop_last'] = trainer.args.dataloader_drop_last
dataloader_params['worker_init_fn'] = seed_worker
return DataLoader(dataset, **dataloader_params)
|