Pachinee commited on
Commit
109d1d4
·
verified ·
1 Parent(s): 4826eda

Create predict.py

Browse files
Files changed (1) hide show
  1. utils/predict.py +24 -0
utils/predict.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import joblib
3
+ from sentence_transformers import SentenceTransformer
4
+
5
+ BASE_DIR = Path(__file__).resolve().parents[1]
6
+ MODEL_DIR = BASE_DIR / "model"
7
+
8
+ def load_model():
9
+ logistic_model = joblib.load(MODEL_DIR / "logistic_model.pkl")
10
+
11
+ s2v_model = SentenceTransformer(
12
+ "Pachinee/sentence2vec-brd"
13
+ )
14
+
15
+ return logistic_model, s2v_model
16
+
17
+
18
+ def predict_label(texts, logistic_model, s2v_model):
19
+ embeddings = s2v_model.encode(
20
+ list(texts),
21
+ convert_to_numpy=True
22
+ )
23
+ preds = logistic_model.predict(embeddings)
24
+ return ["Clear" if p == 1 else "Unclear" for p in preds]