Martinacap02's picture
Init deploy branch for HF Space
f7d11f7
from __future__ import annotations
import time
import joblib
from loguru import logger
import numpy as np
import pandas as pd
from predicting_outcomes_in_heart_failure.app.schema import HeartSample
from predicting_outcomes_in_heart_failure.config import (
FIGURES_DIR,
INPUT_COLUMNS,
MODEL_PATH,
MULTI_CAT,
NUM_COLS_DEFAULT,
SCALER_PATH,
)
from predicting_outcomes_in_heart_failure.modeling.explainability import (
explain_prediction,
save_shap_waterfall_plot,
)
def preprocessing(sample_df: pd.DataFrame) -> pd.DataFrame:
"""
Apply the exact same preprocessing used during training:
"""
logger.info("Applying preprocessing pipeline for inference...")
if not (SCALER_PATH.exists() and MODEL_PATH.exists()):
raise FileNotFoundError("Preprocessing artifacts missing.")
scaler = joblib.load(SCALER_PATH)
input_columns = INPUT_COLUMNS
multi_cat = MULTI_CAT
num_cols = NUM_COLS_DEFAULT
logger.debug(f"Loaded scaler from {SCALER_PATH}")
logger.debug(f"Using {len(input_columns)} input columns")
if "Sex" in sample_df.columns and "Sex" not in input_columns:
logger.debug("Dropping column 'Sex' since it's not used by the current model variant.")
sample_df = sample_df.drop(columns=["Sex"])
if "Sex" in sample_df.columns and "Sex" in input_columns:
sample_df["Sex"] = sample_df["Sex"].map({"M": 1, "F": 0}).astype(int)
logger.debug("Mapped 'Sex' to binary values (M=1, F=0).")
if "ExerciseAngina" in sample_df.columns and "ExerciseAngina" in input_columns:
sample_df["ExerciseAngina"] = sample_df["ExerciseAngina"].map({"Y": 1, "N": 0}).astype(int)
logger.debug("Mapped 'ExerciseAngina' to binary values (Y=1, N=0).")
present_multi = [c for c in multi_cat if c in sample_df.columns]
if present_multi:
logger.debug(f"Performing one-hot encoding on: {present_multi}")
sample_df = pd.get_dummies(sample_df, columns=present_multi, drop_first=False)
for col in input_columns:
if col not in sample_df.columns:
sample_df[col] = 0
sample_df = sample_df.reindex(columns=input_columns, fill_value=0)
logger.debug("Aligned input columns with training feature order.")
cols_to_scale = [c for c in num_cols if c in sample_df.columns]
sample_df[cols_to_scale] = scaler.transform(sample_df[cols_to_scale])
logger.debug(f"Scaled numerical columns: {cols_to_scale}")
logger.success("Preprocessing completed successfully.")
return sample_df
def main():
logger.info("Starting static inference...")
sample = HeartSample(
Age=54,
ChestPainType="ASY",
RestingBP=140,
Cholesterol=239,
FastingBS=0,
RestingECG="Normal",
MaxHR=160,
ExerciseAngina="N",
Oldpeak=0.0,
ST_Slope="Up",
)
logger.info("Sample created successfully.")
X_raw = sample.to_dataframe()
logger.debug(f"Raw input features:\n{X_raw}")
X = preprocessing(X_raw)
if not MODEL_PATH.exists():
raise FileNotFoundError(f"Model not found: {MODEL_PATH}")
model = joblib.load(MODEL_PATH)
logger.success(f"Loaded model from {MODEL_PATH}")
# Perform prediction
t0 = time.perf_counter()
y_pred = model.predict(X)[0]
inference_time = time.perf_counter() - t0
y_pred = int(y_pred) if np.issubdtype(type(y_pred), np.integer) else y_pred
result = {
"prediction": y_pred,
"inference_time_seconds": inference_time,
}
# Explainability
model = joblib.load(MODEL_PATH)
model_type = MODEL_PATH.stem
try:
logger.info("Computing explanation for the prediction...")
explanations = explain_prediction(model, X, model_type=model_type, top_k=5)
result["explanations"] = explanations
logger.success("Explanation computed successfully.")
except Exception as e:
logger.error(f"Failed to compute explanation: {e}")
try:
shap_path = FIGURES_DIR / f"shap_waterfall_{model_type}.png"
saved = save_shap_waterfall_plot(model, X, model_type=model_type, output_path=shap_path)
if saved is not None:
result["explanation_plot"] = str(saved)
except Exception as e:
logger.error(f"Failed to generate SHAP waterfall plot: {e}")
logger.info("Inference completed.")
logger.success(f"Prediction result: {result}")
return result
if __name__ == "__main__":
main()