Pachinee commited on
Commit
da10199
·
verified ·
1 Parent(s): bfab49a

Update utils/predict.py

Browse files
Files changed (1) hide show
  1. utils/predict.py +6 -13
utils/predict.py CHANGED
@@ -1,24 +1,17 @@
1
- from pathlib import Path
2
  import joblib
3
  from sentence_transformers import SentenceTransformer
4
 
5
- BASE_DIR = Path(__file__).resolve().parents[1]
6
- MODEL_DIR = BASE_DIR / "model"
7
-
8
  def load_model():
9
- logistic_model = joblib.load(MODEL_DIR / "logistic_model.pkl")
10
 
11
  s2v_model = SentenceTransformer(
12
- "Pachinee/sentence2vec-brd"
13
  )
14
 
15
- return logistic_model, s2v_model
16
 
17
 
18
- def predict_label(texts, logistic_model, s2v_model):
19
- embeddings = s2v_model.encode(
20
- list(texts),
21
- convert_to_numpy=True
22
- )
23
- preds = logistic_model.predict(embeddings)
24
  return ["Clear" if p == 1 else "Unclear" for p in preds]
 
 
1
  import joblib
2
  from sentence_transformers import SentenceTransformer
3
 
 
 
 
4
  def load_model():
5
+ clf = joblib.load("model/logistic_model.pkl")
6
 
7
  s2v_model = SentenceTransformer(
8
+ "Pachinee/sentence2vec-brd" # ← Hugging Face Model
9
  )
10
 
11
+ return clf, s2v_model
12
 
13
 
14
+ def predict_label(texts, clf, s2v_model):
15
+ embeddings = s2v_model.encode(list(texts))
16
+ preds = clf.predict(embeddings)
 
 
 
17
  return ["Clear" if p == 1 else "Unclear" for p in preds]