Spaces:
Sleeping
Sleeping
| # predict_utils.py | |
| # Robust loader with upfront patches + manual-unpickle fallback for sklearn/xgboost compatibility. | |
| import os | |
| import logging | |
| import joblib | |
| import io | |
| import pickle | |
| from huggingface_hub import hf_hub_download | |
| # Logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "sathishaiuse/wellness-classifier-model") | |
| HF_MODEL_FILENAME = os.getenv("HF_MODEL_FILENAME", "best_overall_XGBoost.joblib") | |
| HF_TOKEN = os.getenv("HF_TOKEN") or None | |
| LOCAL_CANDIDATES = [ | |
| os.path.join("/app", HF_MODEL_FILENAME), | |
| os.path.join("/tmp", HF_MODEL_FILENAME), | |
| os.path.join("/home/user/app", HF_MODEL_FILENAME), | |
| HF_MODEL_FILENAME | |
| ] | |
| # ------------------------- | |
| # Upfront compatibility patches (run at import time) | |
| # ------------------------- | |
| def patch_sklearn_base(): | |
| """Make sure BaseEstimator exposes sklearn_tags/_get_tags/_more_tags used during unpickling.""" | |
| try: | |
| import sklearn | |
| from sklearn.base import BaseEstimator | |
| except Exception as e: | |
| logger.debug(f"sklearn not available to patch: {e}") | |
| return | |
| # Provide sklearn_tags method if missing | |
| if not hasattr(BaseEstimator, "sklearn_tags"): | |
| def _sklearn_tags(self): | |
| return {} | |
| try: | |
| setattr(BaseEstimator, "sklearn_tags", _sklearn_tags) | |
| logger.info("Patched BaseEstimator.sklearn_tags()") | |
| except Exception as e: | |
| logger.debug(f"Could not set BaseEstimator.sklearn_tags: {e}") | |
| # Provide _get_tags if missing | |
| if not hasattr(BaseEstimator, "_get_tags"): | |
| def _get_tags(self): | |
| tags = {} | |
| more = getattr(self, "_more_tags", None) | |
| if callable(more): | |
| try: | |
| tags.update(more()) | |
| except Exception: | |
| pass | |
| st = getattr(self, "sklearn_tags", None) | |
| if callable(st): | |
| try: | |
| tags.update(st()) | |
| except Exception: | |
| pass | |
| return tags | |
| try: | |
| setattr(BaseEstimator, "_get_tags", _get_tags) | |
| logger.info("Patched BaseEstimator._get_tags()") | |
| except Exception as e: | |
| logger.debug(f"Could not set BaseEstimator._get_tags: {e}") | |
| # Provide a default _more_tags if missing | |
| if not hasattr(BaseEstimator, "_more_tags"): | |
| def _more_tags(self): | |
| return {} | |
| try: | |
| setattr(BaseEstimator, "_more_tags", _more_tags) | |
| logger.info("Patched BaseEstimator._more_tags()") | |
| except Exception as e: | |
| logger.debug(f"Could not set BaseEstimator._more_tags: {e}") | |
| def patch_xgboost_wrappers(): | |
| """Add common attributes expected by older pickles to XGBoost classes/base.""" | |
| try: | |
| import xgboost as xgb | |
| except Exception as e: | |
| logger.debug(f"xgboost not available to patch: {e}") | |
| return | |
| XGBModel = getattr(xgb, "XGBModel", None) | |
| if XGBModel is not None: | |
| for attr, val in { | |
| "gpu_id": None, | |
| "nthread": None, | |
| "n_jobs": None, | |
| "predictor": None, | |
| "base_score": None, | |
| "objective": None, | |
| }.items(): | |
| try: | |
| if not hasattr(XGBModel, attr): | |
| setattr(XGBModel, attr, val) | |
| logger.info(f"Patched XGBModel.{attr} = {val!r}") | |
| except Exception as e: | |
| logger.debug(f"Could not patch XGBModel.{attr}: {e}") | |
| for cls_name in ("XGBClassifier", "XGBRegressor"): | |
| cls = getattr(xgb, cls_name, None) | |
| if cls is not None: | |
| for attr, val in { | |
| "use_label_encoder": False, | |
| "objective": None, | |
| "predictor": None, | |
| "gpu_id": None, | |
| "n_jobs": None, | |
| "nthread": None, | |
| }.items(): | |
| try: | |
| if not hasattr(cls, attr): | |
| setattr(cls, attr, val) | |
| logger.info(f"Patched {cls_name}.{attr} = {val!r}") | |
| except Exception as e: | |
| logger.debug(f"Could not patch {cls_name}.{attr}: {e}") | |
| # Apply upfront patches | |
| patch_sklearn_base() | |
| patch_xgboost_wrappers() | |
| # ------------------------- | |
| # Helpers: inspect file & try loaders | |
| # ------------------------- | |
| def inspect_file(path): | |
| info = {"path": path, "exists": False} | |
| try: | |
| info["exists"] = os.path.exists(path) | |
| if not info["exists"]: | |
| return info | |
| info["size"] = os.path.getsize(path) | |
| with open(path, "rb") as f: | |
| head = f.read(2048) | |
| info["head_bytes"] = head | |
| try: | |
| info["head_text"] = head.decode("utf-8", errors="replace") | |
| except Exception: | |
| info["head_text"] = None | |
| except Exception as e: | |
| info["inspect_error"] = str(e) | |
| return info | |
| def try_joblib_load(path): | |
| """Try standard joblib load. Return ("joblib", model) or ("joblib", exception)""" | |
| try: | |
| # Re-apply patches immediately before load (cover lazy imports) | |
| patch_sklearn_base() | |
| patch_xgboost_wrappers() | |
| logger.info(f"Trying joblib.load on {path}") | |
| m = joblib.load(path) | |
| logger.info("joblib.load succeeded") | |
| return ("joblib", m) | |
| except Exception as e: | |
| logger.exception(f"joblib.load failed: {e}") | |
| return ("joblib", e) | |
| def manual_pickle_unpickle(path): | |
| """ | |
| Last-resort: attempt to unpickle the raw file bytes with a custom Unpickler | |
| that maps pickled references of sklearn base classes to the live patched classes. | |
| This may succeed when joblib.load fails due to base-class method mismatches. | |
| """ | |
| try: | |
| data = open(path, "rb").read() | |
| except Exception as e: | |
| return ("manual_pickle", e) | |
| class PatchedUnpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| # If pickle references sklearn.base.BaseEstimator, return the live patched class | |
| if module.startswith("sklearn.") and name in ("BaseEstimator",): | |
| try: | |
| from sklearn.base import BaseEstimator as LiveBase | |
| # ensure our patches are present | |
| try: | |
| if not hasattr(LiveBase, "sklearn_tags"): | |
| def _sklearn_tags(self): return {} | |
| setattr(LiveBase, "sklearn_tags", _sklearn_tags) | |
| except Exception: | |
| pass | |
| return LiveBase | |
| except Exception: | |
| pass | |
| # For xgboost wrappers, map to live classes if referenced | |
| if module.startswith("xgboost.") and name in ("XGBClassifier", "XGBRegressor", "XGBModel"): | |
| try: | |
| import xgboost as xgb | |
| cls = getattr(xgb, name, None) | |
| if cls is not None: | |
| return cls | |
| except Exception: | |
| pass | |
| return super().find_class(module, name) | |
| try: | |
| bio = io.BytesIO(data) | |
| u = PatchedUnpickler(bio) | |
| obj = u.load() | |
| return ("manual_pickle", obj) | |
| except Exception as e: | |
| return ("manual_pickle", e) | |
| def try_xgboost_booster(path): | |
| """Try loading file as a native xgboost booster (json/bin)""" | |
| try: | |
| import xgboost as xgb | |
| except Exception as e: | |
| logger.exception(f"xgboost import failed: {e}") | |
| return ("xgboost_import", e) | |
| try: | |
| logger.info(f"Trying xgboost.Booster().load_model on {path}") | |
| booster = xgb.Booster() | |
| booster.load_model(path) | |
| logger.info("xgboost.Booster.load_model succeeded") | |
| class BoosterWrapper: | |
| def __init__(self, booster): | |
| self.booster = booster | |
| self._is_xgb_booster = True | |
| def predict(self, X): | |
| import numpy as _np, xgboost as _xgb | |
| arr = _np.array(X, dtype=float) | |
| dmat = _xgb.DMatrix(arr) | |
| pred = self.booster.predict(dmat) | |
| if hasattr(pred, "ndim") and pred.ndim == 1: | |
| return (_np.where(pred >= 0.5, 1, 0)).tolist() | |
| return pred.tolist() | |
| def predict_proba(self, X): | |
| import numpy as _np, xgboost as _xgb | |
| arr = _np.array(X, dtype=float) | |
| dmat = _xgb.DMatrix(arr) | |
| pred = self.booster.predict(dmat) | |
| if hasattr(pred, "ndim") and pred.ndim == 1: | |
| return (_np.vstack([1 - pred, pred]).T).tolist() | |
| return pred.tolist() | |
| return ("xgboost_booster", BoosterWrapper(booster)) | |
| except Exception as e: | |
| logger.exception(f"xgboost.Booster.load_model failed: {e}") | |
| return ("xgboost_booster", e) | |
| # ------------------------- | |
| # Main loader: try local -> try HF -> fallbacks | |
| # ------------------------- | |
| def load_model(): | |
| logger.info("==== MODEL LOAD START ====") | |
| logger.info(f"Repo: {HF_MODEL_REPO}") | |
| logger.info(f"Filename: {HF_MODEL_FILENAME}") | |
| logger.info(f"HF_TOKEN present? {bool(HF_TOKEN)}") | |
| # try local candidates | |
| for path in LOCAL_CANDIDATES: | |
| try: | |
| info = inspect_file(path) | |
| logger.info(f"Inspecting local candidate: {info}") | |
| if not info.get("exists"): | |
| continue | |
| t, res = try_joblib_load(path) | |
| if t == "joblib" and not isinstance(res, Exception): | |
| return res | |
| # if joblib failed with sklearn_tags error, attempt manual unpickle | |
| if t == "joblib" and isinstance(res, Exception): | |
| msg = str(res) | |
| if "sklearn_tags" in msg or "sklearn_tags" in getattr(res, "args", ()): | |
| logger.info("joblib.load failed with sklearn_tags; trying manual pickle unpickle fallback") | |
| tm, obj = manual_pickle_unpickle(path) | |
| if tm == "manual_pickle" and not isinstance(obj, Exception): | |
| logger.info("manual unpickle succeeded") | |
| return obj | |
| else: | |
| logger.error("manual unpickle did not succeed; continuing to other fallbacks") | |
| # try native booster | |
| t2, res2 = try_xgboost_booster(path) | |
| if t2 == "xgboost_booster" and not isinstance(res2, Exception): | |
| return res2 | |
| except Exception as e: | |
| logger.exception(f"Error while trying local candidate {path}: {e}") | |
| # try huggingface hub | |
| try: | |
| logger.info(f"Trying hf_hub_download from {HF_MODEL_REPO}/{HF_MODEL_FILENAME}") | |
| model_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=HF_MODEL_FILENAME, token=HF_TOKEN) | |
| logger.info(f"Downloaded model to: {model_path}") | |
| info = inspect_file(model_path) | |
| logger.info(f"Inspecting downloaded file: {info}") | |
| t, res = try_joblib_load(model_path) | |
| if t == "joblib" and not isinstance(res, Exception): | |
| return res | |
| if t == "joblib" and isinstance(res, Exception): | |
| msg = str(res) | |
| if "sklearn_tags" in msg or "sklearn_tags" in getattr(res, "args", ()): | |
| logger.info("joblib.load failed on downloaded file with sklearn_tags; trying manual unpickle fallback") | |
| tm, obj = manual_pickle_unpickle(model_path) | |
| if tm == "manual_pickle" and not isinstance(obj, Exception): | |
| logger.info("manual unpickle succeeded on downloaded file") | |
| return obj | |
| else: | |
| logger.error("manual unpickle did not succeed on downloaded file") | |
| t2, res2 = try_xgboost_booster(model_path) | |
| if t2 == "xgboost_booster" and not isinstance(res2, Exception): | |
| return res2 | |
| logger.error("Tried joblib/manual-unpickle and xgboost loader on downloaded file but all failed.") | |
| return None | |
| except Exception as e: | |
| logger.exception(f"hf_hub_download failed: {e}") | |
| return None | |
| # ------------------------- | |
| # Prediction helper: accepts dict (col->val), list, or list-of-lists | |
| # ------------------------- | |
| def predict(model, features): | |
| if model is None: | |
| return {"error": "Model not loaded"} | |
| try: | |
| import pandas as _pd | |
| import numpy as _np | |
| is_booster = hasattr(model, "_is_xgb_booster") | |
| # dict -> DataFrame (preserve key order) | |
| if isinstance(features, dict): | |
| cols = [str(k) for k in features.keys()] | |
| row = [features[k] for k in features.keys()] | |
| df = _pd.DataFrame([row], columns=cols) | |
| if is_booster: | |
| arr = df.values.astype(float) | |
| preds = model.predict(arr) | |
| prob = None | |
| if hasattr(model, "predict_proba"): | |
| p = model.predict_proba(arr) | |
| try: | |
| prob = float(p[0][1]) | |
| except: | |
| prob = None | |
| return {"prediction": int(preds[0]) if isinstance(preds, (list,tuple)) else int(preds), "probability": prob} | |
| if hasattr(model, "predict"): | |
| pred = model.predict(df)[0] | |
| prob = None | |
| if hasattr(model, "predict_proba"): | |
| p = model.predict_proba(df)[0] | |
| try: | |
| prob = float(max(p)) | |
| except: | |
| prob = None | |
| try: | |
| pred = int(pred) | |
| except: | |
| pass | |
| return {"prediction": pred, "probability": prob} | |
| return {"error": "Loaded model object not recognized"} | |
| # list -> single row numeric | |
| if isinstance(features, (list,tuple)): | |
| arr2d = _np.array([features], dtype=float) | |
| if is_booster: | |
| preds = model.predict(arr2d) | |
| prob = None | |
| if hasattr(model, "predict_proba"): | |
| p = model.predict_proba(arr2d) | |
| try: | |
| prob = float(p[0][1]) | |
| except: | |
| prob = None | |
| return {"prediction": int(preds[0]), "probability": prob} | |
| if hasattr(model, "predict"): | |
| try: | |
| pred = model.predict(arr2d)[0] | |
| prob = None | |
| if hasattr(model, "predict_proba"): | |
| p = model.predict_proba(arr2d)[0] | |
| try: | |
| prob = float(max(p)) | |
| except: | |
| prob = None | |
| return {"prediction": pred, "probability": prob} | |
| except Exception: | |
| cols = [str(i) for i in range(arr2d.shape[1])] | |
| df = _pd.DataFrame(arr2d, columns=cols) | |
| pred = model.predict(df)[0] | |
| prob = None | |
| if hasattr(model, "predict_proba"): | |
| p = model.predict_proba(df)[0] | |
| try: | |
| prob = float(max(p)) | |
| except: | |
| prob = None | |
| return {"prediction": pred, "probability": prob} | |
| # batch | |
| if isinstance(features, list) and len(features) > 0 and isinstance(features[0], (list, tuple)): | |
| arr = _np.array(features, dtype=float) | |
| if is_booster: | |
| preds = model.predict(arr) | |
| prob = None | |
| if hasattr(model, "predict_proba"): | |
| p = model.predict_proba(arr) | |
| try: | |
| prob = float(p[0][1]) | |
| except: | |
| prob = None | |
| return {"prediction": preds.tolist(), "probability": prob} | |
| if hasattr(model, "predict"): | |
| try: | |
| pred = model.predict(arr) | |
| prob = None | |
| if hasattr(model, "predict_proba"): | |
| p = model.predict_proba(arr) | |
| try: | |
| prob = float(max(p[0])) | |
| except: | |
| prob = None | |
| return {"prediction": pred.tolist(), "probability": prob} | |
| except Exception: | |
| cols = [str(i) for i in range(arr.shape[1])] | |
| df = _pd.DataFrame(arr, columns=cols) | |
| pred = model.predict(df) | |
| prob = None | |
| if hasattr(model, "predict_proba"): | |
| p = model.predict_proba(df) | |
| try: | |
| prob = float(max(p[0])) | |
| except: | |
| prob = None | |
| return {"prediction": pred.tolist(), "probability": prob} | |
| return {"error": "Unsupported features format. Provide dict (col->val) or list of values."} | |
| except Exception as e: | |
| logger.exception(f"Prediction error: {e}") | |
| return {"error": str(e)} | |