File size: 1,864 Bytes
881d19a fcb06ad 881d19a fcb06ad 881d19a fcb06ad 881d19a fcb06ad 881d19a fcb06ad 881d19a fcb06ad 881d19a fcb06ad 881d19a fcb06ad 881d19a 11a55e4 881d19a 11a55e4 881d19a fcb06ad 881d19a fcb06ad 1c35d6c fcb06ad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | import json
import joblib
import pandas as pd
import yaml
from huggingface_hub import hf_hub_download
from src.preprocess import preprocess_input
def load_config(config_path: str = "config/config.yaml") -> dict:
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def load_model_and_info():
config = load_config()
repo_id = config["model"]["repo_id"]
model_filename = config["model"]["filename"]
info_filename = config["model"]["info_filename"]
model_path = hf_hub_download(
repo_id=repo_id,
filename=model_filename,
repo_type="model",
)
info_path = hf_hub_download(
repo_id=repo_id,
filename=info_filename,
repo_type="model",
)
model = joblib.load(model_path)
with open(info_path, "r", encoding="utf-8") as f:
model_info = json.load(f)
return model, model_info
def align_features_for_inference(input_df: pd.DataFrame, feature_columns: list[str]) -> pd.DataFrame:
df = input_df.copy()
df.columns = [col.strip().lower().replace(" ", "_") for col in df.columns]
df = pd.get_dummies(df, drop_first=False)
df = df.reindex(columns=feature_columns, fill_value=0)
return df
def predict_input(input_df: pd.DataFrame) -> dict:
model, model_info = load_model_and_info()
processed_df = preprocess_input(input_df)
feature_columns = model_info["feature_columns"]
aligned_df = align_features_for_inference(processed_df, feature_columns)
prediction = model.predict(aligned_df)
result = {
"prediction": prediction[0],
"processed_input": aligned_df.to_dict(orient="records")[0],
}
if hasattr(model, "predict_proba"):
probabilities = model.predict_proba(aligned_df)
result["probabilities"] = probabilities[0].tolist()
return result
|