Chandler May commited on
Commit
e95c2b6
·
1 Parent(s): 858fcb1

Pin upstream transformer model revision in config.

Browse files
Files changed (3) hide show
  1. config.json +1 -0
  2. config.py +3 -1
  3. model.py +4 -3
config.json CHANGED
@@ -8,6 +8,7 @@
8
  },
9
  "embedding_size": 512,
10
  "k_bucket_size": 1024,
 
11
  "model_type": "LUAR",
12
  "q_bucket_size": 512,
13
  "torch_dtype": "float32",
 
8
  },
9
  "embedding_size": 512,
10
  "k_bucket_size": 1024,
11
+ "upstream_transformer_revision": "48bffbbd27bf028ecdd0cd55abb51236ec12ef1b",
12
  "model_type": "LUAR",
13
  "q_bucket_size": 512,
14
  "torch_dtype": "float32",
config.py CHANGED
@@ -9,10 +9,12 @@ class LUARConfig(PretrainedConfig):
9
  use_memory_efficient_attention=False,
10
  q_bucket_size=512,
11
  k_bucket_size=1024,
 
12
  **kwargs,
13
  ):
14
  self.embedding_size = embedding_size
15
  self.use_memory_efficient_attention = use_memory_efficient_attention
16
  self.q_bucket_size = q_bucket_size
17
  self.k_bucket_size = k_bucket_size
18
- super().__init__(**kwargs)
 
 
9
  use_memory_efficient_attention=False,
10
  q_bucket_size=512,
11
  k_bucket_size=1024,
12
+ upstream_transformer_revision=None,
13
  **kwargs,
14
  ):
15
  self.embedding_size = embedding_size
16
  self.use_memory_efficient_attention = use_memory_efficient_attention
17
  self.q_bucket_size = q_bucket_size
18
  self.k_bucket_size = k_bucket_size
19
+ self.upstream_transformer_revision = upstream_transformer_revision
20
+ super().__init__(**kwargs)
model.py CHANGED
@@ -139,7 +139,7 @@ class LUAR(PreTrainedModel):
139
 
140
  def __init__(self, config):
141
  super().__init__(config)
142
- self.create_transformer()
143
  self.attn_fn = SelfAttention(
144
  config.use_memory_efficient_attention,
145
  config.q_bucket_size,
@@ -147,10 +147,11 @@ class LUAR(PreTrainedModel):
147
  )
148
  self.linear = nn.Linear(self.hidden_size, config.embedding_size)
149
 
150
- def create_transformer(self):
151
  """Creates the Transformer backbone.
152
  """
153
- self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1")
 
154
  self.hidden_size = self.transformer.config.hidden_size
155
  self.num_attention_heads = self.transformer.config.num_attention_heads
156
  self.dim_head = self.hidden_size // self.num_attention_heads
 
139
 
140
  def __init__(self, config):
141
  super().__init__(config)
142
+ self.create_transformer(revision=config.upstream_transformer_revision)
143
  self.attn_fn = SelfAttention(
144
  config.use_memory_efficient_attention,
145
  config.q_bucket_size,
 
147
  )
148
  self.linear = nn.Linear(self.hidden_size, config.embedding_size)
149
 
150
+ def create_transformer(self, revision=None):
151
  """Creates the Transformer backbone.
152
  """
153
+ kwargs = {"revision": revision} if revision else {}
154
+ self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1", **kwargs)
155
  self.hidden_size = self.transformer.config.hidden_size
156
  self.num_attention_heads = self.transformer.config.num_attention_heads
157
  self.dim_head = self.hidden_size // self.num_attention_heads