vijusudhi commited on
Commit
1835a4f
·
verified ·
1 Parent(s): baa7748

Update embedding_model.py

Browse files
Files changed (1) hide show
  1. embedding_model.py +6 -6
embedding_model.py CHANGED
@@ -52,11 +52,7 @@ class EmbeddingModel(models.Transformer):
52
  # this is specific to OpenGPT-X model
53
  output.pop("token_type_ids", None)
54
 
55
- return output
56
-
57
- def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None:
58
- """Loads the transformer model"""
59
- self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, trust_remote_code=True, **model_args)
60
 
61
  def get_config_dict(self) -> dict[str, str]:
62
  return {"model_name_or_path": self.model_name_or_path}
@@ -79,4 +75,8 @@ class EmbeddingModel(models.Transformer):
79
  def load(load_dir: str, **kwargs) -> "EmbeddingModel":
80
  with open(os.path.join(load_dir, "config.json")) as fIn:
81
  config = json.load(fIn)
82
- return EmbeddingModel(**config)
 
 
 
 
 
52
  # this is specific to OpenGPT-X model
53
  output.pop("token_type_ids", None)
54
 
55
+ return output
 
 
 
 
56
 
57
  def get_config_dict(self) -> dict[str, str]:
58
  return {"model_name_or_path": self.model_name_or_path}
 
75
  def load(load_dir: str, **kwargs) -> "EmbeddingModel":
76
  with open(os.path.join(load_dir, "config.json")) as fIn:
77
  config = json.load(fIn)
78
+ return EmbeddingModel(
79
+ **config,
80
+ model_args={"trust_remote_code": True},
81
+ tokenizer_args={"trust_remote_code": True}
82
+ )