# predict_utils.py # Robust loader with upfront patches + manual-unpickle fallback for sklearn/xgboost compatibility. import os import logging import joblib import io import pickle from huggingface_hub import hf_hub_download # Logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "sathishaiuse/wellness-classifier-model") HF_MODEL_FILENAME = os.getenv("HF_MODEL_FILENAME", "best_overall_XGBoost.joblib") HF_TOKEN = os.getenv("HF_TOKEN") or None LOCAL_CANDIDATES = [ os.path.join("/app", HF_MODEL_FILENAME), os.path.join("/tmp", HF_MODEL_FILENAME), os.path.join("/home/user/app", HF_MODEL_FILENAME), HF_MODEL_FILENAME ] # ------------------------- # Upfront compatibility patches (run at import time) # ------------------------- def patch_sklearn_base(): """Make sure BaseEstimator exposes sklearn_tags/_get_tags/_more_tags used during unpickling.""" try: import sklearn from sklearn.base import BaseEstimator except Exception as e: logger.debug(f"sklearn not available to patch: {e}") return # Provide sklearn_tags method if missing if not hasattr(BaseEstimator, "sklearn_tags"): def _sklearn_tags(self): return {} try: setattr(BaseEstimator, "sklearn_tags", _sklearn_tags) logger.info("Patched BaseEstimator.sklearn_tags()") except Exception as e: logger.debug(f"Could not set BaseEstimator.sklearn_tags: {e}") # Provide _get_tags if missing if not hasattr(BaseEstimator, "_get_tags"): def _get_tags(self): tags = {} more = getattr(self, "_more_tags", None) if callable(more): try: tags.update(more()) except Exception: pass st = getattr(self, "sklearn_tags", None) if callable(st): try: tags.update(st()) except Exception: pass return tags try: setattr(BaseEstimator, "_get_tags", _get_tags) logger.info("Patched BaseEstimator._get_tags()") except Exception as e: logger.debug(f"Could not set BaseEstimator._get_tags: {e}") # Provide a default _more_tags if missing if not hasattr(BaseEstimator, "_more_tags"): def _more_tags(self): return {} try: setattr(BaseEstimator, "_more_tags", _more_tags) logger.info("Patched BaseEstimator._more_tags()") except Exception as e: logger.debug(f"Could not set BaseEstimator._more_tags: {e}") def patch_xgboost_wrappers(): """Add common attributes expected by older pickles to XGBoost classes/base.""" try: import xgboost as xgb except Exception as e: logger.debug(f"xgboost not available to patch: {e}") return XGBModel = getattr(xgb, "XGBModel", None) if XGBModel is not None: for attr, val in { "gpu_id": None, "nthread": None, "n_jobs": None, "predictor": None, "base_score": None, "objective": None, }.items(): try: if not hasattr(XGBModel, attr): setattr(XGBModel, attr, val) logger.info(f"Patched XGBModel.{attr} = {val!r}") except Exception as e: logger.debug(f"Could not patch XGBModel.{attr}: {e}") for cls_name in ("XGBClassifier", "XGBRegressor"): cls = getattr(xgb, cls_name, None) if cls is not None: for attr, val in { "use_label_encoder": False, "objective": None, "predictor": None, "gpu_id": None, "n_jobs": None, "nthread": None, }.items(): try: if not hasattr(cls, attr): setattr(cls, attr, val) logger.info(f"Patched {cls_name}.{attr} = {val!r}") except Exception as e: logger.debug(f"Could not patch {cls_name}.{attr}: {e}") # Apply upfront patches patch_sklearn_base() patch_xgboost_wrappers() # ------------------------- # Helpers: inspect file & try loaders # ------------------------- def inspect_file(path): info = {"path": path, "exists": False} try: info["exists"] = os.path.exists(path) if not info["exists"]: return info info["size"] = os.path.getsize(path) with open(path, "rb") as f: head = f.read(2048) info["head_bytes"] = head try: info["head_text"] = head.decode("utf-8", errors="replace") except Exception: info["head_text"] = None except Exception as e: info["inspect_error"] = str(e) return info def try_joblib_load(path): """Try standard joblib load. Return ("joblib", model) or ("joblib", exception)""" try: # Re-apply patches immediately before load (cover lazy imports) patch_sklearn_base() patch_xgboost_wrappers() logger.info(f"Trying joblib.load on {path}") m = joblib.load(path) logger.info("joblib.load succeeded") return ("joblib", m) except Exception as e: logger.exception(f"joblib.load failed: {e}") return ("joblib", e) def manual_pickle_unpickle(path): """ Last-resort: attempt to unpickle the raw file bytes with a custom Unpickler that maps pickled references of sklearn base classes to the live patched classes. This may succeed when joblib.load fails due to base-class method mismatches. """ try: data = open(path, "rb").read() except Exception as e: return ("manual_pickle", e) class PatchedUnpickler(pickle.Unpickler): def find_class(self, module, name): # If pickle references sklearn.base.BaseEstimator, return the live patched class if module.startswith("sklearn.") and name in ("BaseEstimator",): try: from sklearn.base import BaseEstimator as LiveBase # ensure our patches are present try: if not hasattr(LiveBase, "sklearn_tags"): def _sklearn_tags(self): return {} setattr(LiveBase, "sklearn_tags", _sklearn_tags) except Exception: pass return LiveBase except Exception: pass # For xgboost wrappers, map to live classes if referenced if module.startswith("xgboost.") and name in ("XGBClassifier", "XGBRegressor", "XGBModel"): try: import xgboost as xgb cls = getattr(xgb, name, None) if cls is not None: return cls except Exception: pass return super().find_class(module, name) try: bio = io.BytesIO(data) u = PatchedUnpickler(bio) obj = u.load() return ("manual_pickle", obj) except Exception as e: return ("manual_pickle", e) def try_xgboost_booster(path): """Try loading file as a native xgboost booster (json/bin)""" try: import xgboost as xgb except Exception as e: logger.exception(f"xgboost import failed: {e}") return ("xgboost_import", e) try: logger.info(f"Trying xgboost.Booster().load_model on {path}") booster = xgb.Booster() booster.load_model(path) logger.info("xgboost.Booster.load_model succeeded") class BoosterWrapper: def __init__(self, booster): self.booster = booster self._is_xgb_booster = True def predict(self, X): import numpy as _np, xgboost as _xgb arr = _np.array(X, dtype=float) dmat = _xgb.DMatrix(arr) pred = self.booster.predict(dmat) if hasattr(pred, "ndim") and pred.ndim == 1: return (_np.where(pred >= 0.5, 1, 0)).tolist() return pred.tolist() def predict_proba(self, X): import numpy as _np, xgboost as _xgb arr = _np.array(X, dtype=float) dmat = _xgb.DMatrix(arr) pred = self.booster.predict(dmat) if hasattr(pred, "ndim") and pred.ndim == 1: return (_np.vstack([1 - pred, pred]).T).tolist() return pred.tolist() return ("xgboost_booster", BoosterWrapper(booster)) except Exception as e: logger.exception(f"xgboost.Booster.load_model failed: {e}") return ("xgboost_booster", e) # ------------------------- # Main loader: try local -> try HF -> fallbacks # ------------------------- def load_model(): logger.info("==== MODEL LOAD START ====") logger.info(f"Repo: {HF_MODEL_REPO}") logger.info(f"Filename: {HF_MODEL_FILENAME}") logger.info(f"HF_TOKEN present? {bool(HF_TOKEN)}") # try local candidates for path in LOCAL_CANDIDATES: try: info = inspect_file(path) logger.info(f"Inspecting local candidate: {info}") if not info.get("exists"): continue t, res = try_joblib_load(path) if t == "joblib" and not isinstance(res, Exception): return res # if joblib failed with sklearn_tags error, attempt manual unpickle if t == "joblib" and isinstance(res, Exception): msg = str(res) if "sklearn_tags" in msg or "sklearn_tags" in getattr(res, "args", ()): logger.info("joblib.load failed with sklearn_tags; trying manual pickle unpickle fallback") tm, obj = manual_pickle_unpickle(path) if tm == "manual_pickle" and not isinstance(obj, Exception): logger.info("manual unpickle succeeded") return obj else: logger.error("manual unpickle did not succeed; continuing to other fallbacks") # try native booster t2, res2 = try_xgboost_booster(path) if t2 == "xgboost_booster" and not isinstance(res2, Exception): return res2 except Exception as e: logger.exception(f"Error while trying local candidate {path}: {e}") # try huggingface hub try: logger.info(f"Trying hf_hub_download from {HF_MODEL_REPO}/{HF_MODEL_FILENAME}") model_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=HF_MODEL_FILENAME, token=HF_TOKEN) logger.info(f"Downloaded model to: {model_path}") info = inspect_file(model_path) logger.info(f"Inspecting downloaded file: {info}") t, res = try_joblib_load(model_path) if t == "joblib" and not isinstance(res, Exception): return res if t == "joblib" and isinstance(res, Exception): msg = str(res) if "sklearn_tags" in msg or "sklearn_tags" in getattr(res, "args", ()): logger.info("joblib.load failed on downloaded file with sklearn_tags; trying manual unpickle fallback") tm, obj = manual_pickle_unpickle(model_path) if tm == "manual_pickle" and not isinstance(obj, Exception): logger.info("manual unpickle succeeded on downloaded file") return obj else: logger.error("manual unpickle did not succeed on downloaded file") t2, res2 = try_xgboost_booster(model_path) if t2 == "xgboost_booster" and not isinstance(res2, Exception): return res2 logger.error("Tried joblib/manual-unpickle and xgboost loader on downloaded file but all failed.") return None except Exception as e: logger.exception(f"hf_hub_download failed: {e}") return None # ------------------------- # Prediction helper: accepts dict (col->val), list, or list-of-lists # ------------------------- def predict(model, features): if model is None: return {"error": "Model not loaded"} try: import pandas as _pd import numpy as _np is_booster = hasattr(model, "_is_xgb_booster") # dict -> DataFrame (preserve key order) if isinstance(features, dict): cols = [str(k) for k in features.keys()] row = [features[k] for k in features.keys()] df = _pd.DataFrame([row], columns=cols) if is_booster: arr = df.values.astype(float) preds = model.predict(arr) prob = None if hasattr(model, "predict_proba"): p = model.predict_proba(arr) try: prob = float(p[0][1]) except: prob = None return {"prediction": int(preds[0]) if isinstance(preds, (list,tuple)) else int(preds), "probability": prob} if hasattr(model, "predict"): pred = model.predict(df)[0] prob = None if hasattr(model, "predict_proba"): p = model.predict_proba(df)[0] try: prob = float(max(p)) except: prob = None try: pred = int(pred) except: pass return {"prediction": pred, "probability": prob} return {"error": "Loaded model object not recognized"} # list -> single row numeric if isinstance(features, (list,tuple)): arr2d = _np.array([features], dtype=float) if is_booster: preds = model.predict(arr2d) prob = None if hasattr(model, "predict_proba"): p = model.predict_proba(arr2d) try: prob = float(p[0][1]) except: prob = None return {"prediction": int(preds[0]), "probability": prob} if hasattr(model, "predict"): try: pred = model.predict(arr2d)[0] prob = None if hasattr(model, "predict_proba"): p = model.predict_proba(arr2d)[0] try: prob = float(max(p)) except: prob = None return {"prediction": pred, "probability": prob} except Exception: cols = [str(i) for i in range(arr2d.shape[1])] df = _pd.DataFrame(arr2d, columns=cols) pred = model.predict(df)[0] prob = None if hasattr(model, "predict_proba"): p = model.predict_proba(df)[0] try: prob = float(max(p)) except: prob = None return {"prediction": pred, "probability": prob} # batch if isinstance(features, list) and len(features) > 0 and isinstance(features[0], (list, tuple)): arr = _np.array(features, dtype=float) if is_booster: preds = model.predict(arr) prob = None if hasattr(model, "predict_proba"): p = model.predict_proba(arr) try: prob = float(p[0][1]) except: prob = None return {"prediction": preds.tolist(), "probability": prob} if hasattr(model, "predict"): try: pred = model.predict(arr) prob = None if hasattr(model, "predict_proba"): p = model.predict_proba(arr) try: prob = float(max(p[0])) except: prob = None return {"prediction": pred.tolist(), "probability": prob} except Exception: cols = [str(i) for i in range(arr.shape[1])] df = _pd.DataFrame(arr, columns=cols) pred = model.predict(df) prob = None if hasattr(model, "predict_proba"): p = model.predict_proba(df) try: prob = float(max(p[0])) except: prob = None return {"prediction": pred.tolist(), "probability": prob} return {"error": "Unsupported features format. Provide dict (col->val) or list of values."} except Exception as e: logger.exception(f"Prediction error: {e}") return {"error": str(e)}