update pretrainer to not use distributed sampler (Trainer uses accelerate)
Browse files- geneformer/pretrainer.py +5 -171
geneformer/pretrainer.py
CHANGED
|
@@ -607,7 +607,7 @@ class GeneformerPretrainer(Trainer):
|
|
| 607 |
)
|
| 608 |
super().__init__(*args, **kwargs)
|
| 609 |
|
| 610 |
-
#
|
| 611 |
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
| 612 |
if not isinstance(self.train_dataset, collections.abc.Sized):
|
| 613 |
return None
|
|
@@ -630,181 +630,15 @@ class GeneformerPretrainer(Trainer):
|
|
| 630 |
if self.tokenizer is not None
|
| 631 |
else None
|
| 632 |
)
|
| 633 |
-
|
| 634 |
-
return LengthGroupedSampler(
|
| 635 |
dataset=self.train_dataset,
|
| 636 |
batch_size=self.args.train_batch_size,
|
| 637 |
lengths=lengths,
|
| 638 |
model_input_name=model_input_name,
|
| 639 |
generator=generator,
|
| 640 |
-
)
|
| 641 |
-
else:
|
| 642 |
-
return CustomDistributedLengthGroupedSampler(
|
| 643 |
-
dataset=self.train_dataset,
|
| 644 |
-
batch_size=self.args.train_batch_size,
|
| 645 |
-
num_replicas=self.args.world_size,
|
| 646 |
-
rank=self.args.process_index,
|
| 647 |
-
lengths=lengths,
|
| 648 |
-
model_input_name=model_input_name,
|
| 649 |
-
seed=self.args.seed,
|
| 650 |
-
)
|
| 651 |
-
|
| 652 |
-
else:
|
| 653 |
-
if self.args.world_size <= 1:
|
| 654 |
-
if _is_torch_generator_available:
|
| 655 |
-
return RandomSampler(self.train_dataset, generator=generator)
|
| 656 |
-
return RandomSampler(self.train_dataset)
|
| 657 |
-
elif (
|
| 658 |
-
self.args.parallel_mode
|
| 659 |
-
in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
| 660 |
-
and not self.args.dataloader_drop_last
|
| 661 |
-
):
|
| 662 |
-
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
|
| 663 |
-
return DistributedSamplerWithLoop(
|
| 664 |
-
self.train_dataset,
|
| 665 |
-
batch_size=self.args.per_device_train_batch_size,
|
| 666 |
-
num_replicas=self.args.world_size,
|
| 667 |
-
rank=self.args.process_index,
|
| 668 |
-
seed=self.args.seed,
|
| 669 |
-
)
|
| 670 |
-
else:
|
| 671 |
-
return DistributedSampler(
|
| 672 |
-
self.train_dataset,
|
| 673 |
-
num_replicas=self.args.world_size,
|
| 674 |
-
rank=self.args.process_index,
|
| 675 |
-
seed=self.args.seed,
|
| 676 |
-
)
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
|
| 680 |
-
r"""
|
| 681 |
-
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
| 682 |
-
length while keeping a bit of randomness.
|
| 683 |
-
"""
|
| 684 |
-
|
| 685 |
-
# Copied and adapted from PyTorch DistributedSampler.
|
| 686 |
-
def __init__(
|
| 687 |
-
self,
|
| 688 |
-
dataset: Dataset,
|
| 689 |
-
batch_size: int,
|
| 690 |
-
num_replicas: Optional[int] = None,
|
| 691 |
-
rank: Optional[int] = None,
|
| 692 |
-
seed: int = 0,
|
| 693 |
-
drop_last: bool = False,
|
| 694 |
-
lengths: Optional[List[int]] = None,
|
| 695 |
-
model_input_name: Optional[str] = None,
|
| 696 |
-
):
|
| 697 |
-
if num_replicas is None:
|
| 698 |
-
if not dist.is_available():
|
| 699 |
-
raise RuntimeError("Requires distributed package to be available")
|
| 700 |
-
num_replicas = dist.get_world_size()
|
| 701 |
-
if rank is None:
|
| 702 |
-
if not dist.is_available():
|
| 703 |
-
raise RuntimeError("Requires distributed package to be available")
|
| 704 |
-
rank = dist.get_rank()
|
| 705 |
-
self.dataset = dataset
|
| 706 |
-
self.batch_size = batch_size
|
| 707 |
-
self.num_replicas = num_replicas
|
| 708 |
-
self.rank = rank
|
| 709 |
-
self.epoch = 0
|
| 710 |
-
self.drop_last = drop_last
|
| 711 |
-
# If the dataset length is evenly divisible by # of replicas, then there
|
| 712 |
-
# is no need to drop any data, since the dataset will be split equally.
|
| 713 |
-
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
| 714 |
-
# Split to nearest available length that is evenly divisible.
|
| 715 |
-
# This is to ensure each rank receives the same amount of data when
|
| 716 |
-
# using this Sampler.
|
| 717 |
-
self.num_samples = math.ceil(
|
| 718 |
-
(len(self.dataset) - self.num_replicas) / self.num_replicas
|
| 719 |
)
|
| 720 |
-
else:
|
| 721 |
-
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
| 722 |
-
self.total_size = self.num_samples * self.num_replicas
|
| 723 |
-
self.seed = seed
|
| 724 |
-
self.model_input_name = (
|
| 725 |
-
model_input_name if model_input_name is not None else "input_ids"
|
| 726 |
-
)
|
| 727 |
-
|
| 728 |
-
if lengths is None:
|
| 729 |
-
print("Lengths is none - calculating lengths.")
|
| 730 |
-
if (
|
| 731 |
-
not (
|
| 732 |
-
isinstance(dataset[0], dict)
|
| 733 |
-
or isinstance(dataset[0], BatchEncoding)
|
| 734 |
-
)
|
| 735 |
-
or self.model_input_name not in dataset[0]
|
| 736 |
-
):
|
| 737 |
-
raise ValueError(
|
| 738 |
-
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
| 739 |
-
f"'{self.model_input_name}' key."
|
| 740 |
-
)
|
| 741 |
-
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
| 742 |
-
self.lengths = lengths
|
| 743 |
-
|
| 744 |
-
def __iter__(self) -> Iterator:
|
| 745 |
-
# Deterministically shuffle based on epoch and seed
|
| 746 |
-
g = torch.Generator()
|
| 747 |
-
g.manual_seed(self.seed + self.epoch)
|
| 748 |
-
|
| 749 |
-
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
| 750 |
|
| 751 |
-
if not self.drop_last:
|
| 752 |
-
# add extra samples to make it evenly divisible
|
| 753 |
-
indices += indices[: (self.total_size - len(indices))]
|
| 754 |
else:
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
# subsample
|
| 760 |
-
indices = indices[self.rank : self.total_size : self.num_replicas]
|
| 761 |
-
assert len(indices) == self.num_samples
|
| 762 |
-
|
| 763 |
-
return iter(indices)
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
def get_length_grouped_indices(
|
| 767 |
-
lengths, batch_size, mega_batch_mult=None, generator=None
|
| 768 |
-
):
|
| 769 |
-
"""
|
| 770 |
-
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
|
| 771 |
-
similar lengths. To do this, the indices are:
|
| 772 |
-
|
| 773 |
-
- randomly permuted
|
| 774 |
-
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
|
| 775 |
-
- sorted by length in each mega-batch
|
| 776 |
-
|
| 777 |
-
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
|
| 778 |
-
maximum length placed first, so that an OOM happens sooner rather than later.
|
| 779 |
-
"""
|
| 780 |
-
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
|
| 781 |
-
if mega_batch_mult is None:
|
| 782 |
-
# mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
|
| 783 |
-
mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
|
| 784 |
-
# Just in case, for tiny datasets
|
| 785 |
-
if mega_batch_mult == 0:
|
| 786 |
-
mega_batch_mult = 1
|
| 787 |
-
|
| 788 |
-
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
| 789 |
-
indices = torch.randperm(len(lengths), generator=generator)
|
| 790 |
-
megabatch_size = mega_batch_mult * batch_size
|
| 791 |
-
megabatches = [
|
| 792 |
-
indices[i : i + megabatch_size].tolist()
|
| 793 |
-
for i in range(0, len(lengths), megabatch_size)
|
| 794 |
-
]
|
| 795 |
-
megabatches = [
|
| 796 |
-
list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
|
| 797 |
-
for megabatch in megabatches
|
| 798 |
-
]
|
| 799 |
-
|
| 800 |
-
# The rest is to get the biggest batch first.
|
| 801 |
-
# Since each megabatch is sorted by descending length, the longest element is the first
|
| 802 |
-
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
|
| 803 |
-
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
|
| 804 |
-
# Switch to put the longest element in first position
|
| 805 |
-
megabatches[0][0], megabatches[max_idx][0] = (
|
| 806 |
-
megabatches[max_idx][0],
|
| 807 |
-
megabatches[0][0],
|
| 808 |
-
)
|
| 809 |
-
|
| 810 |
-
return [item for sublist in megabatches for item in sublist]
|
|
|
|
| 607 |
)
|
| 608 |
super().__init__(*args, **kwargs)
|
| 609 |
|
| 610 |
+
# updated to not use distributed sampler since Trainer now distributes with accelerate
|
| 611 |
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
| 612 |
if not isinstance(self.train_dataset, collections.abc.Sized):
|
| 613 |
return None
|
|
|
|
| 630 |
if self.tokenizer is not None
|
| 631 |
else None
|
| 632 |
)
|
| 633 |
+
return LengthGroupedSampler(
|
|
|
|
| 634 |
dataset=self.train_dataset,
|
| 635 |
batch_size=self.args.train_batch_size,
|
| 636 |
lengths=lengths,
|
| 637 |
model_input_name=model_input_name,
|
| 638 |
generator=generator,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
|
|
|
|
|
|
|
|
|
|
| 641 |
else:
|
| 642 |
+
if _is_torch_generator_available:
|
| 643 |
+
return RandomSampler(self.train_dataset, generator=generator)
|
| 644 |
+
return RandomSampler(self.train_dataset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|