ananttripathiak's picture
Upload folder using huggingface_hub
1aa7fae verified
"""
Inference utilities for the predictive maintenance model.
These functions are used both by scripts and by the Streamlit app
to load a trained model (from local storage or Hugging Face) and
generate predictions from raw sensor values.
"""
from __future__ import annotations
from typing import Dict, Iterable, Optional
import joblib
import numpy as np
import pandas as pd
import config
from hf_model_utils import download_model
def load_local_model() -> object:
"""
Load the best trained model from the local models directory.
"""
if not config.BEST_MODEL_LOCAL_PATH.exists():
raise FileNotFoundError(
f"Local model not found at {config.BEST_MODEL_LOCAL_PATH}. "
"Run train.py to create it, or configure HF model loading."
)
return joblib.load(config.BEST_MODEL_LOCAL_PATH)
def load_hf_model() -> object:
"""
Load the model directly from the Hugging Face model hub.
"""
return download_model()
def build_input_dataframe(
inputs: Dict[str, float],
) -> pd.DataFrame:
"""
Convert a dictionary of feature values into a single-row DataFrame
with columns ordered according to config.FEATURE_COLUMNS.
"""
data = {col: float(inputs.get(col, 0.0)) for col in config.FEATURE_COLUMNS}
return pd.DataFrame([data])
def predict_engine_condition(
inputs: Dict[str, float],
model: Optional[object] = None,
source: str = "local",
) -> Dict[str, float]:
"""
Predict whether the engine requires maintenance.
Parameters
----------
inputs : dict
Keys correspond to feature names in config.FEATURE_COLUMNS.
model : object, optional
Pre-loaded sklearn Pipeline model. If None, it will be loaded
from `source`.
source : {'local', 'hf'}
If model is None, determines where to load it from.
Returns
-------
dict
Contains the predicted class label (0/1) and the probability
of the positive class.
"""
if model is None:
if source == "hf":
model = load_hf_model()
else:
model = load_local_model()
df = build_input_dataframe(inputs)
proba = model.predict_proba(df)[0, 1]
pred = int(proba >= 0.5)
return {
"prediction": pred,
"probability_faulty": float(proba),
}
__all__ = [
"load_local_model",
"load_hf_model",
"build_input_dataframe",
"predict_engine_condition",
]