m97j commited on
Commit
54f8424
·
1 Parent(s): 0ce5a27

Initial commit

Browse files
Files changed (1) hide show
  1. models/model_loader.py +3 -3
models/model_loader.py CHANGED
@@ -8,7 +8,7 @@ from sentence_transformers import SentenceTransformer
8
 
9
 
10
  def load_emotion_model(model_name: str, model_dir: Path, token: str = None):
11
- if not model_dir.exists() or not any(model_dir.iterdir()):
12
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
13
  model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
14
  tokenizer.save_pretrained(model_dir)
@@ -20,7 +20,7 @@ def load_emotion_model(model_name: str, model_dir: Path, token: str = None):
20
 
21
 
22
  def load_fallback_model(model_name: str, model_dir: Path, token: str = None):
23
- if not model_dir.exists() or not any(model_dir.iterdir()):
24
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
25
  model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
26
  tokenizer.save_pretrained(model_dir)
@@ -32,7 +32,7 @@ def load_fallback_model(model_name: str, model_dir: Path, token: str = None):
32
 
33
 
34
  def load_embedder(model_name: str, model_dir: Path, token: str = None):
35
- if not model_dir.exists() or not any(model_dir.iterdir()):
36
  embedder = SentenceTransformer(model_name, use_auth_token=token)
37
  embedder.save(str(model_dir))
38
 
 
8
 
9
 
10
  def load_emotion_model(model_name: str, model_dir: Path, token: str = None):
11
+ if not (model_dir / "config.json").exists():
12
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
13
  model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
14
  tokenizer.save_pretrained(model_dir)
 
20
 
21
 
22
  def load_fallback_model(model_name: str, model_dir: Path, token: str = None):
23
+ if not (model_dir / "config.json").exists():
24
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
25
  model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
26
  tokenizer.save_pretrained(model_dir)
 
32
 
33
 
34
  def load_embedder(model_name: str, model_dir: Path, token: str = None):
35
+ if not (model_dir / "config.json").exists():
36
  embedder = SentenceTransformer(model_name, use_auth_token=token)
37
  embedder.save(str(model_dir))
38