Christina Theodoris
commited on
Commit
·
fb130e6
1
Parent(s):
86fe0dd
update kwargs for pretrainer
Browse files- geneformer/pretrainer.py +10 -9
geneformer/pretrainer.py
CHANGED
|
@@ -106,9 +106,8 @@ class TensorType(ExplicitEnum):
|
|
| 106 |
|
| 107 |
class GeneformerPreCollator(SpecialTokensMixin):
|
| 108 |
def __init__(self, *args, **kwargs) -> None:
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
self.token_dictionary = kwargs.get("token_dictionary")
|
| 113 |
# self.mask_token = "<mask>"
|
| 114 |
# self.mask_token_id = self.token_dictionary.get("<mask>")
|
|
@@ -120,8 +119,8 @@ class GeneformerPreCollator(SpecialTokensMixin):
|
|
| 120 |
# self.token_dictionary.get("<pad>"),
|
| 121 |
# ]
|
| 122 |
self.model_input_names = ["input_ids"]
|
| 123 |
-
|
| 124 |
-
def convert_ids_to_tokens(self,value):
|
| 125 |
return self.token_dictionary.get(value)
|
| 126 |
|
| 127 |
def _get_padding_truncation_strategies(
|
|
@@ -391,7 +390,6 @@ class GeneformerPreCollator(SpecialTokensMixin):
|
|
| 391 |
|
| 392 |
for key, value in encoded_inputs.items():
|
| 393 |
encoded_inputs[key] = to_py_obj(value)
|
| 394 |
-
|
| 395 |
|
| 396 |
# Convert padding_strategy in PaddingStrategy
|
| 397 |
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
|
|
@@ -596,15 +594,17 @@ class GeneformerPreCollator(SpecialTokensMixin):
|
|
| 596 |
|
| 597 |
class GeneformerPretrainer(Trainer):
|
| 598 |
def __init__(self, *args, **kwargs):
|
| 599 |
-
data_collator = kwargs.get("data_collator",None)
|
| 600 |
token_dictionary = kwargs.pop("token_dictionary")
|
|
|
|
|
|
|
| 601 |
|
| 602 |
if data_collator is None:
|
| 603 |
precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
|
| 604 |
|
| 605 |
# # Data Collator Functions
|
| 606 |
data_collator = DataCollatorForLanguageModeling(
|
| 607 |
-
tokenizer=precollator, mlm=
|
| 608 |
)
|
| 609 |
kwargs["data_collator"] = data_collator
|
| 610 |
|
|
@@ -694,6 +694,7 @@ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
|
|
| 694 |
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
| 695 |
length while keeping a bit of randomness.
|
| 696 |
"""
|
|
|
|
| 697 |
# Copied and adapted from PyTorch DistributedSampler.
|
| 698 |
def __init__(
|
| 699 |
self,
|
|
@@ -757,7 +758,7 @@ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
|
|
| 757 |
# Deterministically shuffle based on epoch and seed
|
| 758 |
g = torch.Generator()
|
| 759 |
g.manual_seed(self.seed + self.epoch)
|
| 760 |
-
|
| 761 |
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
| 762 |
|
| 763 |
if not self.drop_last:
|
|
|
|
| 106 |
|
| 107 |
class GeneformerPreCollator(SpecialTokensMixin):
|
| 108 |
def __init__(self, *args, **kwargs) -> None:
|
| 109 |
+
super().__init__(mask_token="<mask>", pad_token="<pad>")
|
| 110 |
+
|
|
|
|
| 111 |
self.token_dictionary = kwargs.get("token_dictionary")
|
| 112 |
# self.mask_token = "<mask>"
|
| 113 |
# self.mask_token_id = self.token_dictionary.get("<mask>")
|
|
|
|
| 119 |
# self.token_dictionary.get("<pad>"),
|
| 120 |
# ]
|
| 121 |
self.model_input_names = ["input_ids"]
|
| 122 |
+
|
| 123 |
+
def convert_ids_to_tokens(self, value):
|
| 124 |
return self.token_dictionary.get(value)
|
| 125 |
|
| 126 |
def _get_padding_truncation_strategies(
|
|
|
|
| 390 |
|
| 391 |
for key, value in encoded_inputs.items():
|
| 392 |
encoded_inputs[key] = to_py_obj(value)
|
|
|
|
| 393 |
|
| 394 |
# Convert padding_strategy in PaddingStrategy
|
| 395 |
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
|
|
|
|
| 594 |
|
| 595 |
class GeneformerPretrainer(Trainer):
|
| 596 |
def __init__(self, *args, **kwargs):
|
| 597 |
+
data_collator = kwargs.get("data_collator", None)
|
| 598 |
token_dictionary = kwargs.pop("token_dictionary")
|
| 599 |
+
mlm = kwargs.pop("mlm", True)
|
| 600 |
+
mlm_probability = kwargs.pop("mlm_probability", 0.15)
|
| 601 |
|
| 602 |
if data_collator is None:
|
| 603 |
precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
|
| 604 |
|
| 605 |
# # Data Collator Functions
|
| 606 |
data_collator = DataCollatorForLanguageModeling(
|
| 607 |
+
tokenizer=precollator, mlm=mlm, mlm_probability=mlm_probability
|
| 608 |
)
|
| 609 |
kwargs["data_collator"] = data_collator
|
| 610 |
|
|
|
|
| 694 |
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
| 695 |
length while keeping a bit of randomness.
|
| 696 |
"""
|
| 697 |
+
|
| 698 |
# Copied and adapted from PyTorch DistributedSampler.
|
| 699 |
def __init__(
|
| 700 |
self,
|
|
|
|
| 758 |
# Deterministically shuffle based on epoch and seed
|
| 759 |
g = torch.Generator()
|
| 760 |
g.manual_seed(self.seed + self.epoch)
|
| 761 |
+
|
| 762 |
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
| 763 |
|
| 764 |
if not self.drop_last:
|