Update embedding_model.py
Browse files- embedding_model.py +4 -13
embedding_model.py
CHANGED
|
@@ -6,11 +6,11 @@ import torch
|
|
| 6 |
from sentence_transformers import models
|
| 7 |
from transformers import AutoModel
|
| 8 |
|
| 9 |
-
MODEL_PATH_IN_MODULE = "0_Model"
|
| 10 |
-
|
| 11 |
|
| 12 |
class EmbeddingModel(models.Transformer):
|
| 13 |
def __init__(self, *args, **kwargs):
|
|
|
|
|
|
|
| 14 |
super().__init__(*args, **kwargs)
|
| 15 |
|
| 16 |
def tokenize(
|
|
@@ -63,15 +63,6 @@ class EmbeddingModel(models.Transformer):
|
|
| 63 |
with open(os.path.join(save_dir, "sentence_bert_config.json"), "w+") as f:
|
| 64 |
json.dump(self.get_config_dict(), f, indent=4)
|
| 65 |
|
| 66 |
-
model_path = os.path.join(save_dir, MODEL_PATH_IN_MODULE)
|
| 67 |
-
if not os.path.exists(model_path):
|
| 68 |
-
os.makedirs(model_path)
|
| 69 |
-
|
| 70 |
-
with open(f"{model_path}/config.json", "w+") as f:
|
| 71 |
-
json.dump(self.get_config_dict(), f, indent=4)
|
| 72 |
-
|
| 73 |
@staticmethod
|
| 74 |
-
def load(
|
| 75 |
-
|
| 76 |
-
config = json.load(fIn)
|
| 77 |
-
return EmbeddingModel(**config, **kwargs)
|
|
|
|
| 6 |
from sentence_transformers import models
|
| 7 |
from transformers import AutoModel
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class EmbeddingModel(models.Transformer):
|
| 11 |
def __init__(self, *args, **kwargs):
|
| 12 |
+
self.model_name_or_path = "lamarr-llm-development/elbedding"
|
| 13 |
+
kwargs.pop("model_name_or_path", None)
|
| 14 |
super().__init__(*args, **kwargs)
|
| 15 |
|
| 16 |
def tokenize(
|
|
|
|
| 63 |
with open(os.path.join(save_dir, "sentence_bert_config.json"), "w+") as f:
|
| 64 |
json.dump(self.get_config_dict(), f, indent=4)
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
@staticmethod
|
| 67 |
+
def load(**kwargs) -> "EmbeddingModel":
|
| 68 |
+
return EmbeddingModel(**kwargs)
|
|
|
|
|
|