Pachinee commited on
Commit
7844a9c
·
verified ·
1 Parent(s): 4224fa4

Create predict.py

Browse files
Files changed (1) hide show
  1. utils/predict.py +33 -0
utils/predict.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import joblib
3
+ from sentence_transformers import SentenceTransformer
4
+
5
+ # ---------------------------
6
+ # Path setup
7
+ # ---------------------------
8
+ BASE_DIR = Path(__file__).resolve().parents[1]
9
+ MODEL_DIR = BASE_DIR / "model"
10
+
11
+ # ---------------------------
12
+ # Load models
13
+ # ---------------------------
14
+ def load_model():
15
+ logistic_model = joblib.load(MODEL_DIR / "logistic_model.pkl")
16
+
17
+ s2v_model = SentenceTransformer(
18
+ "Pachinee/sentence2vec-brd"
19
+ )
20
+
21
+ return logistic_model, s2v_model
22
+
23
+
24
+ # ---------------------------
25
+ # Predict
26
+ # ---------------------------
27
+ def predict_label(texts, logistic_model, s2v_model):
28
+ embeddings = s2v_model.encode(
29
+ list(texts),
30
+ convert_to_numpy=True
31
+ )
32
+ preds = logistic_model.predict(embeddings)
33
+ return ["Clear" if p == 1 else "Unclear" for p in preds]