Update embedding_model.py
Browse files- embedding_model.py +6 -6
embedding_model.py
CHANGED
|
@@ -52,11 +52,7 @@ class EmbeddingModel(models.Transformer):
|
|
| 52 |
# this is specific to OpenGPT-X model
|
| 53 |
output.pop("token_type_ids", None)
|
| 54 |
|
| 55 |
-
return output
|
| 56 |
-
|
| 57 |
-
def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None:
|
| 58 |
-
"""Loads the transformer model"""
|
| 59 |
-
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, trust_remote_code=True, **model_args)
|
| 60 |
|
| 61 |
def get_config_dict(self) -> dict[str, str]:
|
| 62 |
return {"model_name_or_path": self.model_name_or_path}
|
|
@@ -79,4 +75,8 @@ class EmbeddingModel(models.Transformer):
|
|
| 79 |
def load(load_dir: str, **kwargs) -> "EmbeddingModel":
|
| 80 |
with open(os.path.join(load_dir, "config.json")) as fIn:
|
| 81 |
config = json.load(fIn)
|
| 82 |
-
return EmbeddingModel(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
# this is specific to OpenGPT-X model
|
| 53 |
output.pop("token_type_ids", None)
|
| 54 |
|
| 55 |
+
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
def get_config_dict(self) -> dict[str, str]:
|
| 58 |
return {"model_name_or_path": self.model_name_or_path}
|
|
|
|
| 75 |
def load(load_dir: str, **kwargs) -> "EmbeddingModel":
|
| 76 |
with open(os.path.join(load_dir, "config.json")) as fIn:
|
| 77 |
config = json.load(fIn)
|
| 78 |
+
return EmbeddingModel(
|
| 79 |
+
**config,
|
| 80 |
+
model_args={"trust_remote_code": True},
|
| 81 |
+
tokenizer_args={"trust_remote_code": True}
|
| 82 |
+
)
|