Integrate with Sentence Transformers v5.4

#13
by tomaarsen HF Staff - opened
Files changed (1) hide show
  1. sentence_transformers_impl.py +7 -3
sentence_transformers_impl.py CHANGED
@@ -54,17 +54,21 @@ class Transformer(nn.Module):
54
  if config_args is None:
55
  config_args = {}
56
 
 
 
 
 
 
57
  if not model_args.get("trust_remote_code", False):
58
  raise ValueError(
59
  "You need to set `trust_remote_code=True` to load this model."
60
  )
61
 
62
- self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
63
- self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
64
 
65
  self.tokenizer = AutoTokenizer.from_pretrained(
66
  model_name_or_path,
67
- cache_dir=cache_dir,
68
  **tokenizer_args,
69
  )
70
 
 
54
  if config_args is None:
55
  config_args = {}
56
 
57
+ if cache_dir is not None:
58
+ config_args["cache_dir"] = cache_dir
59
+ model_args["cache_dir"] = cache_dir
60
+ tokenizer_args["cache_dir"] = cache_dir
61
+
62
  if not model_args.get("trust_remote_code", False):
63
  raise ValueError(
64
  "You need to set `trust_remote_code=True` to load this model."
65
  )
66
 
67
+ self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args)
68
+ self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, **model_args)
69
 
70
  self.tokenizer = AutoTokenizer.from_pretrained(
71
  model_name_or_path,
 
72
  **tokenizer_args,
73
  )
74