File size: 5,671 Bytes
2742ed8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any

import datasets
import torch
import torch.distributed as dist
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers.trainer_utils import seed_worker

from .base import SequenceParallel


class XTuner(SequenceParallel):

    @staticmethod
    def assert_xtuner_runtime_condition():
        from swift.utils import is_xtuner_available
        assert is_xtuner_available(), \
            ('Please install XTuner first to pack dataset to `max_length`.'
             '`pip install -U \'xtuner[deepspeed]\'`')
        assert dist.is_initialized(), 'pack_to_max_length is only available with distributed training.'

    def pack_dataset_xtuner(self, dataset: Dataset, args: Any) -> Any:
        self.assert_xtuner_runtime_condition()
        if dist.get_rank() == 0:
            ds = [i[0] for i in dataset.data]
            train_dataset = Dataset.from_list(ds)
            from xtuner.dataset.huggingface import pack_dataset
            train_dataset = pack_dataset(
                train_dataset,
                max_length=args.max_length,
                use_varlen_attn=False,
                shuffle_before_pack=True,
                map_num_proc=16)
            objects = [train_dataset]
            train_dataset.save_to_disk('alpaca_pack')
        else:
            objects = [None]
        dist.broadcast_object_list(objects, src=0)
        train_dataset = objects[0]
        return train_dataset

    @property
    def sp_group(self):
        from xtuner.parallel.sequence import get_sequence_parallel_group
        return get_sequence_parallel_group()

    def init_sequence_parallel(self, size):
        self.assert_xtuner_runtime_condition()
        from xtuner.parallel.sequence import init_sequence_parallel
        init_sequence_parallel(size)

    def prepare_model(self, model, tokenizer, split_in_forward):
        self.assert_xtuner_runtime_condition()
        from xtuner.model.modules.dispatch import dispatch_modules
        dispatch_modules(model)

    def pad_and_split_inputs(self,
                             tokenizer,
                             input_ids,
                             input_embeds,
                             labels,
                             position_ids,
                             attention_mask,
                             loss_scale,
                             embed_tokens=None):
        self.assert_xtuner_runtime_condition()
        from xtuner.parallel.sequence import (pad_for_sequence_parallel, split_for_sequence_parallel,
                                              get_sequence_parallel_group)
        input_ids = pad_for_sequence_parallel(input_ids, padding_value=tokenizer.pad_token_id, dim=-1)
        labels = pad_for_sequence_parallel(labels, padding_value=-100, dim=-1)
        position_ids = pad_for_sequence_parallel(position_ids, padding_value=0, dim=-1)
        if attention_mask is not None:
            attention_mask = pad_for_sequence_parallel(attention_mask, padding_value=0, dim=-1)

        sp_group = get_sequence_parallel_group()
        input_ids = split_for_sequence_parallel(input_ids, dim=1, sp_group=sp_group)
        labels = split_for_sequence_parallel(labels, dim=1, sp_group=sp_group)
        position_ids = split_for_sequence_parallel(position_ids, dim=1, sp_group=sp_group)
        if attention_mask is not None:
            attention_mask = split_for_sequence_parallel(attention_mask, dim=-1, sp_group=sp_group)
        if loss_scale is not None:
            loss_scale = pad_for_sequence_parallel(loss_scale, padding_value=0., dim=-1)
            loss_scale = split_for_sequence_parallel(loss_scale, dim=1, sp_group=sp_group)

        return input_ids, None, labels, position_ids, attention_mask, loss_scale

    def reduce_outputs(self, loss, labels):
        from xtuner.parallel.sequence import (reduce_sequence_parallel_loss, get_sequence_parallel_group)
        # reduce loss for logging correctly
        num_tokens = (labels != -100).sum()
        return reduce_sequence_parallel_loss(loss, num_tokens, get_sequence_parallel_group())

    def world_size(self):
        self.assert_xtuner_runtime_condition()
        from xtuner.parallel.sequence import get_sequence_parallel_world_size
        return get_sequence_parallel_world_size()

    def prepare_trainer(self, trainer):
        pass

    def get_dataloader(self, trainer, dataset, batch_size):
        # modified from HFTrainer.get_train_dataloader
        # RandomSampler -> SequenceParallelSampler
        self.assert_xtuner_runtime_condition()
        data_collator = trainer.data_collator
        if isinstance(dataset, datasets.Dataset):
            dataset = trainer._remove_unused_columns(dataset, description='training')
        else:
            data_collator = trainer._get_collator_with_removed_columns(data_collator, description='training')

        dataloader_params = {
            'batch_size': batch_size,
            'collate_fn': data_collator,
            'num_workers': trainer.args.dataloader_num_workers,
            'pin_memory': trainer.args.dataloader_pin_memory,
            'persistent_workers': trainer.args.dataloader_persistent_workers,
        }

        if not isinstance(dataset, torch.utils.data.IterableDataset):
            from xtuner.parallel import SequenceParallelSampler
            dataloader_params['sampler'] = SequenceParallelSampler(dataset, seed=1024)
            dataloader_params['drop_last'] = trainer.args.dataloader_drop_last
            dataloader_params['worker_init_fn'] = seed_worker

        return DataLoader(dataset, **dataloader_params)