fintech_recon_classifier / inference_helper.py
IITRohit's picture
Upload trained reconciliation model artifacts
4679a4a verified
raw
history blame contribute delete
511 Bytes
import json
import joblib
import pandas as pd
def load_model(model_dir):
model = joblib.load(f"{model_dir}/best_model.joblib")
with open(f"{model_dir}/feature_schema.json") as f:
schema = json.load(f)
return model, schema
def predict(df: pd.DataFrame, model, schema):
X = df[schema['feature_cols']].copy()
proba = model.predict_proba(X)[:, 1]
pred = (proba >= 0.5).astype(int)
out = df.copy()
out['pred_label'] = pred
out['pred_probability'] = proba
return out