Update embedding_model.py
Browse files- embedding_model.py +4 -0
embedding_model.py
CHANGED
|
@@ -53,6 +53,10 @@ class EmbeddingModel(models.Transformer):
|
|
| 53 |
|
| 54 |
return output
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def get_config_dict(self) -> dict[str, str]:
|
| 57 |
return {"model_name_or_path": self.model_name_or_path}
|
| 58 |
|
|
|
|
| 53 |
|
| 54 |
return output
|
| 55 |
|
| 56 |
+
def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None:
|
| 57 |
+
"""Loads the transformer model"""
|
| 58 |
+
return self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, trust_remote_code=True, **model_args)
|
| 59 |
+
|
| 60 |
def get_config_dict(self) -> dict[str, str]:
|
| 61 |
return {"model_name_or_path": self.model_name_or_path}
|
| 62 |
|