namednil commited on
Commit
59d55b8
·
verified ·
1 Parent(s): b0e8c81
Files changed (1) hide show
  1. step_finetune.py +3 -3
step_finetune.py CHANGED
@@ -6,7 +6,7 @@ from typing import Optional, List, Callable, Mapping, Any, Union
6
  import os
7
 
8
  class STEPFinetuningModelConfig(T5Config):
9
- model_type = "STEP_finetune"
10
 
11
  def __init__(self,
12
  num_examples: int = 512,
@@ -40,7 +40,7 @@ class STEPFinetuningModel(PreTrainedModel):
40
  # There are two cases: (1) we initialize the model after STEP-pretraining, i.e. the tunable prefix is not set
41
  # and (2) the model has been fine-tuned on downstream data, and hence there is meaningful data in the tunable prefix
42
 
43
- # Initialize the prefix with NaNs. If we initialize from STEP-pretraining, this will not be overwritten by a custom version of from_pretrained
44
  # if we initialize after fine-tuning, the NaNs will be overwritten anyway.
45
 
46
  self.prefix_embedding = torch.nn.Parameter(torch.nan + torch.zeros((1, self.config.prefix_length, self.config.d_model)))
@@ -49,7 +49,7 @@ class STEPFinetuningModel(PreTrainedModel):
49
  def _initialize_prefix(self):
50
  prefix_init_tensor = self.prefix_init_tensor
51
  if self.config.random_selection:
52
- # randomize selection of FSTs to average for initialization the prefix.
53
  prefix_init_tensor = prefix_init_tensor[torch.randperm(prefix_init_tensor.shape[0]), :, :]
54
 
55
  prefix_init_tensor = prefix_init_tensor[:self.config.num_examples, :self.config.prefix_length,
 
6
  import os
7
 
8
  class STEPFinetuningModelConfig(T5Config):
9
+ model_type = "STEP_finetuning"
10
 
11
  def __init__(self,
12
  num_examples: int = 512,
 
40
  # There are two cases: (1) we initialize the model after STEP-pretraining, i.e. the tunable prefix is not set
41
  # and (2) the model has been fine-tuned on downstream data, and hence there is meaningful data in the tunable prefix
42
 
43
+ # Initialize the prefix with NaNs. If we initialize from STEP-pretraining, this will be overwritten by a custom version of from_pretrained
44
  # if we initialize after fine-tuning, the NaNs will be overwritten anyway.
45
 
46
  self.prefix_embedding = torch.nn.Parameter(torch.nan + torch.zeros((1, self.config.prefix_length, self.config.d_model)))
 
49
  def _initialize_prefix(self):
50
  prefix_init_tensor = self.prefix_init_tensor
51
  if self.config.random_selection:
52
+ # randomize selection of edgewise tranformations to average for initialization the prefix.
53
  prefix_init_tensor = prefix_init_tensor[torch.randperm(prefix_init_tensor.shape[0]), :, :]
54
 
55
  prefix_init_tensor = prefix_init_tensor[:self.config.num_examples, :self.config.prefix_length,