File size: 1,026 Bytes
2742ed8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import abc
from abc import abstractmethod


class SequenceParallel(abc.ABC):

    @abstractmethod
    def init_sequence_parallel(self, size):
        pass

    @abstractmethod
    def prepare_model(self, model, tokenizer, split_in_forward):
        pass

    @abstractmethod
    def pad_and_split_inputs(self,
                             tokenizer,
                             input_ids,
                             input_embeds,
                             labels,
                             position_ids,
                             attention_mask,
                             loss_scale,
                             embed_tokens=None):
        pass

    @abstractmethod
    def reduce_outputs(self, loss, labels):
        pass

    @property
    def sp_group(self):
        return None

    @abstractmethod
    def world_size(self):
        pass

    @abstractmethod
    def prepare_trainer(self, trainer):
        pass

    @abstractmethod
    def get_dataloader(self, trainer, dataset, batch_size):
        pass