Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import time | |
| import joblib | |
| from loguru import logger | |
| import numpy as np | |
| import pandas as pd | |
| from predicting_outcomes_in_heart_failure.app.schema import HeartSample | |
| from predicting_outcomes_in_heart_failure.config import ( | |
| FIGURES_DIR, | |
| INPUT_COLUMNS, | |
| MODEL_PATH, | |
| MULTI_CAT, | |
| NUM_COLS_DEFAULT, | |
| SCALER_PATH, | |
| ) | |
| from predicting_outcomes_in_heart_failure.modeling.explainability import ( | |
| explain_prediction, | |
| save_shap_waterfall_plot, | |
| ) | |
| def preprocessing(sample_df: pd.DataFrame) -> pd.DataFrame: | |
| """ | |
| Apply the exact same preprocessing used during training: | |
| """ | |
| logger.info("Applying preprocessing pipeline for inference...") | |
| if not (SCALER_PATH.exists() and MODEL_PATH.exists()): | |
| raise FileNotFoundError("Preprocessing artifacts missing.") | |
| scaler = joblib.load(SCALER_PATH) | |
| input_columns = INPUT_COLUMNS | |
| multi_cat = MULTI_CAT | |
| num_cols = NUM_COLS_DEFAULT | |
| logger.debug(f"Loaded scaler from {SCALER_PATH}") | |
| logger.debug(f"Using {len(input_columns)} input columns") | |
| if "Sex" in sample_df.columns and "Sex" not in input_columns: | |
| logger.debug("Dropping column 'Sex' since it's not used by the current model variant.") | |
| sample_df = sample_df.drop(columns=["Sex"]) | |
| if "Sex" in sample_df.columns and "Sex" in input_columns: | |
| sample_df["Sex"] = sample_df["Sex"].map({"M": 1, "F": 0}).astype(int) | |
| logger.debug("Mapped 'Sex' to binary values (M=1, F=0).") | |
| if "ExerciseAngina" in sample_df.columns and "ExerciseAngina" in input_columns: | |
| sample_df["ExerciseAngina"] = sample_df["ExerciseAngina"].map({"Y": 1, "N": 0}).astype(int) | |
| logger.debug("Mapped 'ExerciseAngina' to binary values (Y=1, N=0).") | |
| present_multi = [c for c in multi_cat if c in sample_df.columns] | |
| if present_multi: | |
| logger.debug(f"Performing one-hot encoding on: {present_multi}") | |
| sample_df = pd.get_dummies(sample_df, columns=present_multi, drop_first=False) | |
| for col in input_columns: | |
| if col not in sample_df.columns: | |
| sample_df[col] = 0 | |
| sample_df = sample_df.reindex(columns=input_columns, fill_value=0) | |
| logger.debug("Aligned input columns with training feature order.") | |
| cols_to_scale = [c for c in num_cols if c in sample_df.columns] | |
| sample_df[cols_to_scale] = scaler.transform(sample_df[cols_to_scale]) | |
| logger.debug(f"Scaled numerical columns: {cols_to_scale}") | |
| logger.success("Preprocessing completed successfully.") | |
| return sample_df | |
| def main(): | |
| logger.info("Starting static inference...") | |
| sample = HeartSample( | |
| Age=54, | |
| ChestPainType="ASY", | |
| RestingBP=140, | |
| Cholesterol=239, | |
| FastingBS=0, | |
| RestingECG="Normal", | |
| MaxHR=160, | |
| ExerciseAngina="N", | |
| Oldpeak=0.0, | |
| ST_Slope="Up", | |
| ) | |
| logger.info("Sample created successfully.") | |
| X_raw = sample.to_dataframe() | |
| logger.debug(f"Raw input features:\n{X_raw}") | |
| X = preprocessing(X_raw) | |
| if not MODEL_PATH.exists(): | |
| raise FileNotFoundError(f"Model not found: {MODEL_PATH}") | |
| model = joblib.load(MODEL_PATH) | |
| logger.success(f"Loaded model from {MODEL_PATH}") | |
| # Perform prediction | |
| t0 = time.perf_counter() | |
| y_pred = model.predict(X)[0] | |
| inference_time = time.perf_counter() - t0 | |
| y_pred = int(y_pred) if np.issubdtype(type(y_pred), np.integer) else y_pred | |
| result = { | |
| "prediction": y_pred, | |
| "inference_time_seconds": inference_time, | |
| } | |
| # Explainability | |
| model = joblib.load(MODEL_PATH) | |
| model_type = MODEL_PATH.stem | |
| try: | |
| logger.info("Computing explanation for the prediction...") | |
| explanations = explain_prediction(model, X, model_type=model_type, top_k=5) | |
| result["explanations"] = explanations | |
| logger.success("Explanation computed successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to compute explanation: {e}") | |
| try: | |
| shap_path = FIGURES_DIR / f"shap_waterfall_{model_type}.png" | |
| saved = save_shap_waterfall_plot(model, X, model_type=model_type, output_path=shap_path) | |
| if saved is not None: | |
| result["explanation_plot"] = str(saved) | |
| except Exception as e: | |
| logger.error(f"Failed to generate SHAP waterfall plot: {e}") | |
| logger.info("Inference completed.") | |
| logger.success(f"Prediction result: {result}") | |
| return result | |
| if __name__ == "__main__": | |
| main() | |