Support specifying upstream sentence transformers model revision
#2
by
ccmaymay
- opened
- config.json +1 -0
- config.py +5 -1
- model.py +5 -3
config.json
CHANGED
|
@@ -8,6 +8,7 @@
|
|
| 8 |
},
|
| 9 |
"embedding_size": 512,
|
| 10 |
"k_bucket_size": 1024,
|
|
|
|
| 11 |
"model_type": "LUAR",
|
| 12 |
"q_bucket_size": 512,
|
| 13 |
"torch_dtype": "float32",
|
|
|
|
| 8 |
},
|
| 9 |
"embedding_size": 512,
|
| 10 |
"k_bucket_size": 1024,
|
| 11 |
+
"upstream_transformer_revision": null,
|
| 12 |
"model_type": "LUAR",
|
| 13 |
"q_bucket_size": 512,
|
| 14 |
"torch_dtype": "float32",
|
config.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
|
|
|
|
|
|
|
| 2 |
from transformers import PretrainedConfig
|
| 3 |
|
| 4 |
class LUARConfig(PretrainedConfig):
|
|
@@ -9,10 +11,12 @@ class LUARConfig(PretrainedConfig):
|
|
| 9 |
use_memory_efficient_attention=False,
|
| 10 |
q_bucket_size=512,
|
| 11 |
k_bucket_size=1024,
|
|
|
|
| 12 |
**kwargs,
|
| 13 |
):
|
| 14 |
self.embedding_size = embedding_size
|
| 15 |
self.use_memory_efficient_attention = use_memory_efficient_attention
|
| 16 |
self.q_bucket_size = q_bucket_size
|
| 17 |
self.k_bucket_size = k_bucket_size
|
| 18 |
-
|
|
|
|
|
|
| 1 |
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
from transformers import PretrainedConfig
|
| 5 |
|
| 6 |
class LUARConfig(PretrainedConfig):
|
|
|
|
| 11 |
use_memory_efficient_attention=False,
|
| 12 |
q_bucket_size=512,
|
| 13 |
k_bucket_size=1024,
|
| 14 |
+
upstream_transformer_revision: Optional[str] = None,
|
| 15 |
**kwargs,
|
| 16 |
):
|
| 17 |
self.embedding_size = embedding_size
|
| 18 |
self.use_memory_efficient_attention = use_memory_efficient_attention
|
| 19 |
self.q_bucket_size = q_bucket_size
|
| 20 |
self.k_bucket_size = k_bucket_size
|
| 21 |
+
self.upstream_transformer_revision = upstream_transformer_revision
|
| 22 |
+
super().__init__(**kwargs)
|
model.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
|
| 2 |
import math
|
| 3 |
from functools import partial
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
|
@@ -139,7 +140,7 @@ class LUAR(PreTrainedModel):
|
|
| 139 |
|
| 140 |
def __init__(self, config):
|
| 141 |
super().__init__(config)
|
| 142 |
-
self.create_transformer()
|
| 143 |
self.attn_fn = SelfAttention(
|
| 144 |
config.use_memory_efficient_attention,
|
| 145 |
config.q_bucket_size,
|
|
@@ -147,10 +148,11 @@ class LUAR(PreTrainedModel):
|
|
| 147 |
)
|
| 148 |
self.linear = nn.Linear(self.hidden_size, config.embedding_size)
|
| 149 |
|
| 150 |
-
def create_transformer(self):
|
| 151 |
"""Creates the Transformer backbone.
|
| 152 |
"""
|
| 153 |
-
|
|
|
|
| 154 |
self.hidden_size = self.transformer.config.hidden_size
|
| 155 |
self.num_attention_heads = self.transformer.config.num_attention_heads
|
| 156 |
self.dim_head = self.hidden_size // self.num_attention_heads
|
|
|
|
| 1 |
|
| 2 |
import math
|
| 3 |
from functools import partial
|
| 4 |
+
from typing import Optional
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
|
|
|
| 140 |
|
| 141 |
def __init__(self, config):
|
| 142 |
super().__init__(config)
|
| 143 |
+
self.create_transformer(revision=config.upstream_transformer_revision)
|
| 144 |
self.attn_fn = SelfAttention(
|
| 145 |
config.use_memory_efficient_attention,
|
| 146 |
config.q_bucket_size,
|
|
|
|
| 148 |
)
|
| 149 |
self.linear = nn.Linear(self.hidden_size, config.embedding_size)
|
| 150 |
|
| 151 |
+
def create_transformer(self, revision: Optional[str] = None):
|
| 152 |
"""Creates the Transformer backbone.
|
| 153 |
"""
|
| 154 |
+
kwargs = {"revision": revision} if revision else {}
|
| 155 |
+
self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1", **kwargs)
|
| 156 |
self.hidden_size = self.transformer.config.hidden_size
|
| 157 |
self.num_attention_heads = self.transformer.config.num_attention_heads
|
| 158 |
self.dim_head = self.hidden_size // self.num_attention_heads
|