Update utils/inference.py
Browse files- utils/inference.py +15 -7
utils/inference.py
CHANGED
|
@@ -1,22 +1,30 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import pandas as pd
|
| 3 |
import joblib
|
|
|
|
| 4 |
from catboost import CatBoostClassifier
|
| 5 |
from xgboost import XGBClassifier
|
| 6 |
|
| 7 |
MODEL_DIR = "models"
|
|
|
|
| 8 |
|
| 9 |
-
model_s1 = CatBoostClassifier()
|
| 10 |
-
model_s1.load_model(f"{MODEL_DIR}/stage1_catboost.cbm")
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
THRESHOLD_S1 = 0.463
|
| 18 |
|
| 19 |
def run_inference(X_emb, meta, threshold):
|
|
|
|
|
|
|
| 20 |
df_meta = pd.DataFrame(meta)
|
| 21 |
|
| 22 |
probs_s1 = model_s1.predict_proba(X_emb)[:, 1]
|
|
@@ -31,4 +39,4 @@ def run_inference(X_emb, meta, threshold):
|
|
| 31 |
df_pass["Final_Score"] = final_score
|
| 32 |
df_pass["Final_Prediction"] = (final_score >= threshold).astype(int)
|
| 33 |
|
| 34 |
-
return df_pass
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import pandas as pd
|
| 3 |
import joblib
|
| 4 |
+
import streamlit as st
|
| 5 |
from catboost import CatBoostClassifier
|
| 6 |
from xgboost import XGBClassifier
|
| 7 |
|
| 8 |
MODEL_DIR = "models"
|
| 9 |
+
THRESHOLD_S1 = 0.463
|
| 10 |
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
@st.cache_resource
|
| 13 |
+
def load_models():
|
| 14 |
+
model_s1 = CatBoostClassifier()
|
| 15 |
+
model_s1.load_model(f"{MODEL_DIR}/stage1_catboost.cbm")
|
| 16 |
|
| 17 |
+
model_s2 = XGBClassifier()
|
| 18 |
+
model_s2.load_model(f"{MODEL_DIR}/stage2_xgb.json")
|
| 19 |
+
|
| 20 |
+
encoder = joblib.load(f"{MODEL_DIR}/encoder_s2.pkl")
|
| 21 |
+
|
| 22 |
+
return model_s1, model_s2, encoder
|
| 23 |
|
|
|
|
| 24 |
|
| 25 |
def run_inference(X_emb, meta, threshold):
|
| 26 |
+
model_s1, model_s2, encoder = load_models()
|
| 27 |
+
|
| 28 |
df_meta = pd.DataFrame(meta)
|
| 29 |
|
| 30 |
probs_s1 = model_s1.predict_proba(X_emb)[:, 1]
|
|
|
|
| 39 |
df_pass["Final_Score"] = final_score
|
| 40 |
df_pass["Final_Prediction"] = (final_score >= threshold).astype(int)
|
| 41 |
|
| 42 |
+
return df_pass
|