Christina Theodoris
commited on
Commit
·
b925dcc
1
Parent(s):
8ce598f
Update pretrainer for transformers==4.28.0
Browse files
examples/pretrain_geneformer_w_deepspeed.py
CHANGED
|
@@ -137,9 +137,8 @@ training_args = {
|
|
| 137 |
"weight_decay": weight_decay,
|
| 138 |
"per_device_train_batch_size": geneformer_batch_size,
|
| 139 |
"num_train_epochs": epochs,
|
| 140 |
-
"load_best_model_at_end": True,
|
| 141 |
"save_strategy": "steps",
|
| 142 |
-
"save_steps": num_examples / geneformer_batch_size / 8, # 8 saves per epoch
|
| 143 |
"logging_steps": 1000,
|
| 144 |
"output_dir": training_output_dir,
|
| 145 |
"logging_dir": logging_dir,
|
|
|
|
| 137 |
"weight_decay": weight_decay,
|
| 138 |
"per_device_train_batch_size": geneformer_batch_size,
|
| 139 |
"num_train_epochs": epochs,
|
|
|
|
| 140 |
"save_strategy": "steps",
|
| 141 |
+
"save_steps": np.floor(num_examples / geneformer_batch_size / 8), # 8 saves per epoch
|
| 142 |
"logging_steps": 1000,
|
| 143 |
"output_dir": training_output_dir,
|
| 144 |
"logging_dir": logging_dir,
|
geneformer/pretrainer.py
CHANGED
|
@@ -106,19 +106,23 @@ class TensorType(ExplicitEnum):
|
|
| 106 |
|
| 107 |
class GeneformerPreCollator(SpecialTokensMixin):
|
| 108 |
def __init__(self, *args, **kwargs) -> None:
|
|
|
|
|
|
|
|
|
|
| 109 |
self.token_dictionary = kwargs.get("token_dictionary")
|
| 110 |
-
self.mask_token = "<mask>"
|
| 111 |
-
self.mask_token_id = self.token_dictionary.get("<mask>")
|
| 112 |
-
self.pad_token = "<pad>"
|
| 113 |
-
self.pad_token_id = self.token_dictionary.get("<pad>")
|
| 114 |
self.padding_side = "right"
|
| 115 |
-
self.all_special_ids = [
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
]
|
| 119 |
self.model_input_names = ["input_ids"]
|
| 120 |
-
|
| 121 |
-
|
|
|
|
| 122 |
|
| 123 |
def _get_padding_truncation_strategies(
|
| 124 |
self,
|
|
@@ -592,8 +596,8 @@ class GeneformerPreCollator(SpecialTokensMixin):
|
|
| 592 |
|
| 593 |
class GeneformerPretrainer(Trainer):
|
| 594 |
def __init__(self, *args, **kwargs):
|
| 595 |
-
data_collator = kwargs.get("data_collator")
|
| 596 |
-
token_dictionary = kwargs.
|
| 597 |
|
| 598 |
if data_collator is None:
|
| 599 |
precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
|
|
@@ -604,17 +608,17 @@ class GeneformerPretrainer(Trainer):
|
|
| 604 |
)
|
| 605 |
kwargs["data_collator"] = data_collator
|
| 606 |
|
| 607 |
-
super().__init__(*args, **kwargs)
|
| 608 |
-
|
| 609 |
# load previously saved length vector for dataset to speed up LengthGroupedSampler
|
| 610 |
# pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
|
| 611 |
-
|
| 612 |
-
|
|
|
|
| 613 |
self.example_lengths = pickle.load(f)
|
| 614 |
else:
|
| 615 |
raise Exception(
|
| 616 |
"example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
|
| 617 |
)
|
|
|
|
| 618 |
|
| 619 |
# modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
|
| 620 |
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
|
@@ -634,7 +638,6 @@ class GeneformerPretrainer(Trainer):
|
|
| 634 |
lengths = self.example_lengths
|
| 635 |
else:
|
| 636 |
lengths = None
|
| 637 |
-
print(f"Lengths: {len(lengths)}")
|
| 638 |
model_input_name = (
|
| 639 |
self.tokenizer.model_input_names[0]
|
| 640 |
if self.tokenizer is not None
|
|
@@ -642,16 +645,16 @@ class GeneformerPretrainer(Trainer):
|
|
| 642 |
)
|
| 643 |
if self.args.world_size <= 1:
|
| 644 |
return LengthGroupedSampler(
|
| 645 |
-
self.train_dataset,
|
| 646 |
-
self.args.train_batch_size,
|
| 647 |
lengths=lengths,
|
| 648 |
model_input_name=model_input_name,
|
| 649 |
generator=generator,
|
| 650 |
)
|
| 651 |
else:
|
| 652 |
return CustomDistributedLengthGroupedSampler(
|
| 653 |
-
self.train_dataset,
|
| 654 |
-
self.args.train_batch_size,
|
| 655 |
num_replicas=self.args.world_size,
|
| 656 |
rank=self.args.process_index,
|
| 657 |
lengths=lengths,
|
|
@@ -754,7 +757,7 @@ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
|
|
| 754 |
# Deterministically shuffle based on epoch and seed
|
| 755 |
g = torch.Generator()
|
| 756 |
g.manual_seed(self.seed + self.epoch)
|
| 757 |
-
|
| 758 |
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
| 759 |
|
| 760 |
if not self.drop_last:
|
|
|
|
| 106 |
|
| 107 |
class GeneformerPreCollator(SpecialTokensMixin):
|
| 108 |
def __init__(self, *args, **kwargs) -> None:
|
| 109 |
+
|
| 110 |
+
super().__init__(mask_token = "<mask>", pad_token = "<pad>")
|
| 111 |
+
|
| 112 |
self.token_dictionary = kwargs.get("token_dictionary")
|
| 113 |
+
# self.mask_token = "<mask>"
|
| 114 |
+
# self.mask_token_id = self.token_dictionary.get("<mask>")
|
| 115 |
+
# self.pad_token = "<pad>"
|
| 116 |
+
# self.pad_token_id = self.token_dictionary.get("<pad>")
|
| 117 |
self.padding_side = "right"
|
| 118 |
+
# self.all_special_ids = [
|
| 119 |
+
# self.token_dictionary.get("<mask>"),
|
| 120 |
+
# self.token_dictionary.get("<pad>"),
|
| 121 |
+
# ]
|
| 122 |
self.model_input_names = ["input_ids"]
|
| 123 |
+
|
| 124 |
+
def convert_ids_to_tokens(self,value):
|
| 125 |
+
return self.token_dictionary.get(value)
|
| 126 |
|
| 127 |
def _get_padding_truncation_strategies(
|
| 128 |
self,
|
|
|
|
| 596 |
|
| 597 |
class GeneformerPretrainer(Trainer):
|
| 598 |
def __init__(self, *args, **kwargs):
|
| 599 |
+
data_collator = kwargs.get("data_collator",None)
|
| 600 |
+
token_dictionary = kwargs.pop("token_dictionary")
|
| 601 |
|
| 602 |
if data_collator is None:
|
| 603 |
precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
|
|
|
|
| 608 |
)
|
| 609 |
kwargs["data_collator"] = data_collator
|
| 610 |
|
|
|
|
|
|
|
| 611 |
# load previously saved length vector for dataset to speed up LengthGroupedSampler
|
| 612 |
# pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
|
| 613 |
+
example_lengths_file = kwargs.pop("example_lengths_file")
|
| 614 |
+
if example_lengths_file:
|
| 615 |
+
with open(example_lengths_file, "rb") as f:
|
| 616 |
self.example_lengths = pickle.load(f)
|
| 617 |
else:
|
| 618 |
raise Exception(
|
| 619 |
"example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
|
| 620 |
)
|
| 621 |
+
super().__init__(*args, **kwargs)
|
| 622 |
|
| 623 |
# modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
|
| 624 |
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
|
|
|
| 638 |
lengths = self.example_lengths
|
| 639 |
else:
|
| 640 |
lengths = None
|
|
|
|
| 641 |
model_input_name = (
|
| 642 |
self.tokenizer.model_input_names[0]
|
| 643 |
if self.tokenizer is not None
|
|
|
|
| 645 |
)
|
| 646 |
if self.args.world_size <= 1:
|
| 647 |
return LengthGroupedSampler(
|
| 648 |
+
dataset=self.train_dataset,
|
| 649 |
+
batch_size=self.args.train_batch_size,
|
| 650 |
lengths=lengths,
|
| 651 |
model_input_name=model_input_name,
|
| 652 |
generator=generator,
|
| 653 |
)
|
| 654 |
else:
|
| 655 |
return CustomDistributedLengthGroupedSampler(
|
| 656 |
+
dataset=self.train_dataset,
|
| 657 |
+
batch_size=self.args.train_batch_size,
|
| 658 |
num_replicas=self.args.world_size,
|
| 659 |
rank=self.args.process_index,
|
| 660 |
lengths=lengths,
|
|
|
|
| 757 |
# Deterministically shuffle based on epoch and seed
|
| 758 |
g = torch.Generator()
|
| 759 |
g.manual_seed(self.seed + self.epoch)
|
| 760 |
+
|
| 761 |
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
| 762 |
|
| 763 |
if not self.drop_last:
|