epitope / utils /inference.py
yunuk0's picture
Update utils/inference.py
4b88fce verified
import numpy as np
import pandas as pd
import joblib
import streamlit as st
from catboost import CatBoostClassifier
from xgboost import XGBClassifier
MODEL_DIR = "models"
THRESHOLD_S1 = 0.463
@st.cache_resource
def load_models():
model_s1 = CatBoostClassifier()
model_s1.load_model(f"{MODEL_DIR}/stage1_catboost.cbm")
model_s2 = XGBClassifier()
model_s2.load_model(f"{MODEL_DIR}/stage2_xgb.json")
encoder = joblib.load(f"{MODEL_DIR}/encoder_s2.pkl")
return model_s1, model_s2, encoder
def run_inference(X_emb, meta, threshold):
model_s1, model_s2, encoder = load_models()
df_meta = pd.DataFrame(meta)
probs_s1 = model_s1.predict_proba(X_emb)[:, 1]
mask = probs_s1 >= THRESHOLD_S1
X_pass = X_emb[mask]
df_pass = df_meta.loc[mask].copy()
probs_s2 = model_s2.predict_proba(X_pass)[:, 1]
final_score = 0.4 * probs_s1[mask] + 0.6 * probs_s2
df_pass["Final_Score"] = final_score
df_pass["Final_Prediction"] = (final_score >= threshold).astype(int)
return df_pass