mindpulse / inference.py
AtharvaXX's picture
Upload inference.py with huggingface_hub
d207fbd verified
from __future__ import annotations
import json
from pathlib import Path
import joblib
import pandas as pd
def _load_artifact(model_dir: str | Path):
model_path = Path(model_dir) / "mindpulse_rf.joblib"
raw = joblib.load(model_path)
if isinstance(raw, dict) and "model" in raw:
model = raw["model"]
feature_names = raw.get(
"feature_names",
["heart_rate", "hrv_rmssd", "motion_level", "hr_baseline", "rmssd_baseline"],
)
return model, feature_names
return raw, ["heart_rate", "hrv_rmssd", "motion_level", "hr_baseline", "rmssd_baseline"]
def predict(model_dir: str, inputs: dict):
model, feature_names = _load_artifact(model_dir)
row = {name: float(inputs[name]) for name in feature_names}
x = pd.DataFrame([row], columns=feature_names)
score = float(model.predict(x)[0])
score = max(0.0, min(1.0, score))
return {"engagement_score": score}