File size: 1,026 Bytes
2742ed8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import abc
from abc import abstractmethod
class SequenceParallel(abc.ABC):
@abstractmethod
def init_sequence_parallel(self, size):
pass
@abstractmethod
def prepare_model(self, model, tokenizer, split_in_forward):
pass
@abstractmethod
def pad_and_split_inputs(self,
tokenizer,
input_ids,
input_embeds,
labels,
position_ids,
attention_mask,
loss_scale,
embed_tokens=None):
pass
@abstractmethod
def reduce_outputs(self, loss, labels):
pass
@property
def sp_group(self):
return None
@abstractmethod
def world_size(self):
pass
@abstractmethod
def prepare_trainer(self, trainer):
pass
@abstractmethod
def get_dataloader(self, trainer, dataset, batch_size):
pass
|