Update embedding_model.py
Browse files- embedding_model.py +2 -2
embedding_model.py
CHANGED
|
@@ -57,8 +57,8 @@ class EmbeddingModel(models.Transformer):
|
|
| 57 |
return {"model_name_or_path": self.model_name_or_path}
|
| 58 |
|
| 59 |
def save(self, save_dir: str, **kwargs) -> None:
|
| 60 |
-
|
| 61 |
-
|
| 62 |
|
| 63 |
with open(os.path.join(save_dir, "sentence_bert_config.json"), "w+") as f:
|
| 64 |
json.dump(self.get_config_dict(), f, indent=4)
|
|
|
|
| 57 |
return {"model_name_or_path": self.model_name_or_path}
|
| 58 |
|
| 59 |
def save(self, save_dir: str, **kwargs) -> None:
|
| 60 |
+
self.auto_model.save_pretrained(save_dir, safe_serialization=True)
|
| 61 |
+
self.tokenizer.save_pretrained(save_dir)
|
| 62 |
|
| 63 |
with open(os.path.join(save_dir, "sentence_bert_config.json"), "w+") as f:
|
| 64 |
json.dump(self.get_config_dict(), f, indent=4)
|