typos
Browse files- step_finetune.py +3 -3
step_finetune.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import Optional, List, Callable, Mapping, Any, Union
|
|
| 6 |
import os
|
| 7 |
|
| 8 |
class STEPFinetuningModelConfig(T5Config):
|
| 9 |
-
model_type = "
|
| 10 |
|
| 11 |
def __init__(self,
|
| 12 |
num_examples: int = 512,
|
|
@@ -40,7 +40,7 @@ class STEPFinetuningModel(PreTrainedModel):
|
|
| 40 |
# There are two cases: (1) we initialize the model after STEP-pretraining, i.e. the tunable prefix is not set
|
| 41 |
# and (2) the model has been fine-tuned on downstream data, and hence there is meaningful data in the tunable prefix
|
| 42 |
|
| 43 |
-
# Initialize the prefix with NaNs. If we initialize from STEP-pretraining, this will
|
| 44 |
# if we initialize after fine-tuning, the NaNs will be overwritten anyway.
|
| 45 |
|
| 46 |
self.prefix_embedding = torch.nn.Parameter(torch.nan + torch.zeros((1, self.config.prefix_length, self.config.d_model)))
|
|
@@ -49,7 +49,7 @@ class STEPFinetuningModel(PreTrainedModel):
|
|
| 49 |
def _initialize_prefix(self):
|
| 50 |
prefix_init_tensor = self.prefix_init_tensor
|
| 51 |
if self.config.random_selection:
|
| 52 |
-
# randomize selection of
|
| 53 |
prefix_init_tensor = prefix_init_tensor[torch.randperm(prefix_init_tensor.shape[0]), :, :]
|
| 54 |
|
| 55 |
prefix_init_tensor = prefix_init_tensor[:self.config.num_examples, :self.config.prefix_length,
|
|
|
|
| 6 |
import os
|
| 7 |
|
| 8 |
class STEPFinetuningModelConfig(T5Config):
|
| 9 |
+
model_type = "STEP_finetuning"
|
| 10 |
|
| 11 |
def __init__(self,
|
| 12 |
num_examples: int = 512,
|
|
|
|
| 40 |
# There are two cases: (1) we initialize the model after STEP-pretraining, i.e. the tunable prefix is not set
|
| 41 |
# and (2) the model has been fine-tuned on downstream data, and hence there is meaningful data in the tunable prefix
|
| 42 |
|
| 43 |
+
# Initialize the prefix with NaNs. If we initialize from STEP-pretraining, this will be overwritten by a custom version of from_pretrained
|
| 44 |
# if we initialize after fine-tuning, the NaNs will be overwritten anyway.
|
| 45 |
|
| 46 |
self.prefix_embedding = torch.nn.Parameter(torch.nan + torch.zeros((1, self.config.prefix_length, self.config.d_model)))
|
|
|
|
| 49 |
def _initialize_prefix(self):
|
| 50 |
prefix_init_tensor = self.prefix_init_tensor
|
| 51 |
if self.config.random_selection:
|
| 52 |
+
# randomize selection of edgewise tranformations to average for initialization the prefix.
|
| 53 |
prefix_init_tensor = prefix_init_tensor[torch.randperm(prefix_init_tensor.shape[0]), :, :]
|
| 54 |
|
| 55 |
prefix_init_tensor = prefix_init_tensor[:self.config.num_examples, :self.config.prefix_length,
|