Chandler May
commited on
Commit
·
e95c2b6
1
Parent(s):
858fcb1
Pin upstream transformer model revision in config.
Browse files- config.json +1 -0
- config.py +3 -1
- model.py +4 -3
config.json
CHANGED
|
@@ -8,6 +8,7 @@
|
|
| 8 |
},
|
| 9 |
"embedding_size": 512,
|
| 10 |
"k_bucket_size": 1024,
|
|
|
|
| 11 |
"model_type": "LUAR",
|
| 12 |
"q_bucket_size": 512,
|
| 13 |
"torch_dtype": "float32",
|
|
|
|
| 8 |
},
|
| 9 |
"embedding_size": 512,
|
| 10 |
"k_bucket_size": 1024,
|
| 11 |
+
"upstream_transformer_revision": "48bffbbd27bf028ecdd0cd55abb51236ec12ef1b",
|
| 12 |
"model_type": "LUAR",
|
| 13 |
"q_bucket_size": 512,
|
| 14 |
"torch_dtype": "float32",
|
config.py
CHANGED
|
@@ -9,10 +9,12 @@ class LUARConfig(PretrainedConfig):
|
|
| 9 |
use_memory_efficient_attention=False,
|
| 10 |
q_bucket_size=512,
|
| 11 |
k_bucket_size=1024,
|
|
|
|
| 12 |
**kwargs,
|
| 13 |
):
|
| 14 |
self.embedding_size = embedding_size
|
| 15 |
self.use_memory_efficient_attention = use_memory_efficient_attention
|
| 16 |
self.q_bucket_size = q_bucket_size
|
| 17 |
self.k_bucket_size = k_bucket_size
|
| 18 |
-
|
|
|
|
|
|
| 9 |
use_memory_efficient_attention=False,
|
| 10 |
q_bucket_size=512,
|
| 11 |
k_bucket_size=1024,
|
| 12 |
+
upstream_transformer_revision=None,
|
| 13 |
**kwargs,
|
| 14 |
):
|
| 15 |
self.embedding_size = embedding_size
|
| 16 |
self.use_memory_efficient_attention = use_memory_efficient_attention
|
| 17 |
self.q_bucket_size = q_bucket_size
|
| 18 |
self.k_bucket_size = k_bucket_size
|
| 19 |
+
self.upstream_transformer_revision = upstream_transformer_revision
|
| 20 |
+
super().__init__(**kwargs)
|
model.py
CHANGED
|
@@ -139,7 +139,7 @@ class LUAR(PreTrainedModel):
|
|
| 139 |
|
| 140 |
def __init__(self, config):
|
| 141 |
super().__init__(config)
|
| 142 |
-
self.create_transformer()
|
| 143 |
self.attn_fn = SelfAttention(
|
| 144 |
config.use_memory_efficient_attention,
|
| 145 |
config.q_bucket_size,
|
|
@@ -147,10 +147,11 @@ class LUAR(PreTrainedModel):
|
|
| 147 |
)
|
| 148 |
self.linear = nn.Linear(self.hidden_size, config.embedding_size)
|
| 149 |
|
| 150 |
-
def create_transformer(self):
|
| 151 |
"""Creates the Transformer backbone.
|
| 152 |
"""
|
| 153 |
-
|
|
|
|
| 154 |
self.hidden_size = self.transformer.config.hidden_size
|
| 155 |
self.num_attention_heads = self.transformer.config.num_attention_heads
|
| 156 |
self.dim_head = self.hidden_size // self.num_attention_heads
|
|
|
|
| 139 |
|
| 140 |
def __init__(self, config):
|
| 141 |
super().__init__(config)
|
| 142 |
+
self.create_transformer(revision=config.upstream_transformer_revision)
|
| 143 |
self.attn_fn = SelfAttention(
|
| 144 |
config.use_memory_efficient_attention,
|
| 145 |
config.q_bucket_size,
|
|
|
|
| 147 |
)
|
| 148 |
self.linear = nn.Linear(self.hidden_size, config.embedding_size)
|
| 149 |
|
| 150 |
+
def create_transformer(self, revision=None):
|
| 151 |
"""Creates the Transformer backbone.
|
| 152 |
"""
|
| 153 |
+
kwargs = {"revision": revision} if revision else {}
|
| 154 |
+
self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1", **kwargs)
|
| 155 |
self.hidden_size = self.transformer.config.hidden_size
|
| 156 |
self.num_attention_heads = self.transformer.config.num_attention_heads
|
| 157 |
self.dim_head = self.hidden_size // self.num_attention_heads
|