Support specifying upstream sentence transformers model revision

#2
by ccmaymay - opened
Files changed (3) hide show
  1. config.json +1 -0
  2. config.py +5 -1
  3. model.py +5 -3
config.json CHANGED
@@ -8,6 +8,7 @@
8
  },
9
  "embedding_size": 512,
10
  "k_bucket_size": 1024,
 
11
  "model_type": "LUAR",
12
  "q_bucket_size": 512,
13
  "torch_dtype": "float32",
 
8
  },
9
  "embedding_size": 512,
10
  "k_bucket_size": 1024,
11
+ "upstream_transformer_revision": null,
12
  "model_type": "LUAR",
13
  "q_bucket_size": 512,
14
  "torch_dtype": "float32",
config.py CHANGED
@@ -1,4 +1,6 @@
1
 
 
 
2
  from transformers import PretrainedConfig
3
 
4
  class LUARConfig(PretrainedConfig):
@@ -9,10 +11,12 @@ class LUARConfig(PretrainedConfig):
9
  use_memory_efficient_attention=False,
10
  q_bucket_size=512,
11
  k_bucket_size=1024,
 
12
  **kwargs,
13
  ):
14
  self.embedding_size = embedding_size
15
  self.use_memory_efficient_attention = use_memory_efficient_attention
16
  self.q_bucket_size = q_bucket_size
17
  self.k_bucket_size = k_bucket_size
18
- super().__init__(**kwargs)
 
 
1
 
2
+ from typing import Optional
3
+
4
  from transformers import PretrainedConfig
5
 
6
  class LUARConfig(PretrainedConfig):
 
11
  use_memory_efficient_attention=False,
12
  q_bucket_size=512,
13
  k_bucket_size=1024,
14
+ upstream_transformer_revision: Optional[str] = None,
15
  **kwargs,
16
  ):
17
  self.embedding_size = embedding_size
18
  self.use_memory_efficient_attention = use_memory_efficient_attention
19
  self.q_bucket_size = q_bucket_size
20
  self.k_bucket_size = k_bucket_size
21
+ self.upstream_transformer_revision = upstream_transformer_revision
22
+ super().__init__(**kwargs)
model.py CHANGED
@@ -1,6 +1,7 @@
1
 
2
  import math
3
  from functools import partial
 
4
 
5
  import torch
6
  import torch.nn as nn
@@ -139,7 +140,7 @@ class LUAR(PreTrainedModel):
139
 
140
  def __init__(self, config):
141
  super().__init__(config)
142
- self.create_transformer()
143
  self.attn_fn = SelfAttention(
144
  config.use_memory_efficient_attention,
145
  config.q_bucket_size,
@@ -147,10 +148,11 @@ class LUAR(PreTrainedModel):
147
  )
148
  self.linear = nn.Linear(self.hidden_size, config.embedding_size)
149
 
150
- def create_transformer(self):
151
  """Creates the Transformer backbone.
152
  """
153
- self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1")
 
154
  self.hidden_size = self.transformer.config.hidden_size
155
  self.num_attention_heads = self.transformer.config.num_attention_heads
156
  self.dim_head = self.hidden_size // self.num_attention_heads
 
1
 
2
  import math
3
  from functools import partial
4
+ from typing import Optional
5
 
6
  import torch
7
  import torch.nn as nn
 
140
 
141
  def __init__(self, config):
142
  super().__init__(config)
143
+ self.create_transformer(revision=config.upstream_transformer_revision)
144
  self.attn_fn = SelfAttention(
145
  config.use_memory_efficient_attention,
146
  config.q_bucket_size,
 
148
  )
149
  self.linear = nn.Linear(self.hidden_size, config.embedding_size)
150
 
151
+ def create_transformer(self, revision: Optional[str] = None):
152
  """Creates the Transformer backbone.
153
  """
154
+ kwargs = {"revision": revision} if revision else {}
155
+ self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1", **kwargs)
156
  self.hidden_size = self.transformer.config.hidden_size
157
  self.num_attention_heads = self.transformer.config.num_attention_heads
158
  self.dim_head = self.hidden_size // self.num_attention_heads