Tourism-Package / predict_utils.py
sathishaiuse's picture
Update predict_utils.py
3b5d784 verified
# predict_utils.py
# Robust loader with upfront patches + manual-unpickle fallback for sklearn/xgboost compatibility.
import os
import logging
import joblib
import io
import pickle
from huggingface_hub import hf_hub_download
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "sathishaiuse/wellness-classifier-model")
HF_MODEL_FILENAME = os.getenv("HF_MODEL_FILENAME", "best_overall_XGBoost.joblib")
HF_TOKEN = os.getenv("HF_TOKEN") or None
LOCAL_CANDIDATES = [
os.path.join("/app", HF_MODEL_FILENAME),
os.path.join("/tmp", HF_MODEL_FILENAME),
os.path.join("/home/user/app", HF_MODEL_FILENAME),
HF_MODEL_FILENAME
]
# -------------------------
# Upfront compatibility patches (run at import time)
# -------------------------
def patch_sklearn_base():
"""Make sure BaseEstimator exposes sklearn_tags/_get_tags/_more_tags used during unpickling."""
try:
import sklearn
from sklearn.base import BaseEstimator
except Exception as e:
logger.debug(f"sklearn not available to patch: {e}")
return
# Provide sklearn_tags method if missing
if not hasattr(BaseEstimator, "sklearn_tags"):
def _sklearn_tags(self):
return {}
try:
setattr(BaseEstimator, "sklearn_tags", _sklearn_tags)
logger.info("Patched BaseEstimator.sklearn_tags()")
except Exception as e:
logger.debug(f"Could not set BaseEstimator.sklearn_tags: {e}")
# Provide _get_tags if missing
if not hasattr(BaseEstimator, "_get_tags"):
def _get_tags(self):
tags = {}
more = getattr(self, "_more_tags", None)
if callable(more):
try:
tags.update(more())
except Exception:
pass
st = getattr(self, "sklearn_tags", None)
if callable(st):
try:
tags.update(st())
except Exception:
pass
return tags
try:
setattr(BaseEstimator, "_get_tags", _get_tags)
logger.info("Patched BaseEstimator._get_tags()")
except Exception as e:
logger.debug(f"Could not set BaseEstimator._get_tags: {e}")
# Provide a default _more_tags if missing
if not hasattr(BaseEstimator, "_more_tags"):
def _more_tags(self):
return {}
try:
setattr(BaseEstimator, "_more_tags", _more_tags)
logger.info("Patched BaseEstimator._more_tags()")
except Exception as e:
logger.debug(f"Could not set BaseEstimator._more_tags: {e}")
def patch_xgboost_wrappers():
"""Add common attributes expected by older pickles to XGBoost classes/base."""
try:
import xgboost as xgb
except Exception as e:
logger.debug(f"xgboost not available to patch: {e}")
return
XGBModel = getattr(xgb, "XGBModel", None)
if XGBModel is not None:
for attr, val in {
"gpu_id": None,
"nthread": None,
"n_jobs": None,
"predictor": None,
"base_score": None,
"objective": None,
}.items():
try:
if not hasattr(XGBModel, attr):
setattr(XGBModel, attr, val)
logger.info(f"Patched XGBModel.{attr} = {val!r}")
except Exception as e:
logger.debug(f"Could not patch XGBModel.{attr}: {e}")
for cls_name in ("XGBClassifier", "XGBRegressor"):
cls = getattr(xgb, cls_name, None)
if cls is not None:
for attr, val in {
"use_label_encoder": False,
"objective": None,
"predictor": None,
"gpu_id": None,
"n_jobs": None,
"nthread": None,
}.items():
try:
if not hasattr(cls, attr):
setattr(cls, attr, val)
logger.info(f"Patched {cls_name}.{attr} = {val!r}")
except Exception as e:
logger.debug(f"Could not patch {cls_name}.{attr}: {e}")
# Apply upfront patches
patch_sklearn_base()
patch_xgboost_wrappers()
# -------------------------
# Helpers: inspect file & try loaders
# -------------------------
def inspect_file(path):
info = {"path": path, "exists": False}
try:
info["exists"] = os.path.exists(path)
if not info["exists"]:
return info
info["size"] = os.path.getsize(path)
with open(path, "rb") as f:
head = f.read(2048)
info["head_bytes"] = head
try:
info["head_text"] = head.decode("utf-8", errors="replace")
except Exception:
info["head_text"] = None
except Exception as e:
info["inspect_error"] = str(e)
return info
def try_joblib_load(path):
"""Try standard joblib load. Return ("joblib", model) or ("joblib", exception)"""
try:
# Re-apply patches immediately before load (cover lazy imports)
patch_sklearn_base()
patch_xgboost_wrappers()
logger.info(f"Trying joblib.load on {path}")
m = joblib.load(path)
logger.info("joblib.load succeeded")
return ("joblib", m)
except Exception as e:
logger.exception(f"joblib.load failed: {e}")
return ("joblib", e)
def manual_pickle_unpickle(path):
"""
Last-resort: attempt to unpickle the raw file bytes with a custom Unpickler
that maps pickled references of sklearn base classes to the live patched classes.
This may succeed when joblib.load fails due to base-class method mismatches.
"""
try:
data = open(path, "rb").read()
except Exception as e:
return ("manual_pickle", e)
class PatchedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# If pickle references sklearn.base.BaseEstimator, return the live patched class
if module.startswith("sklearn.") and name in ("BaseEstimator",):
try:
from sklearn.base import BaseEstimator as LiveBase
# ensure our patches are present
try:
if not hasattr(LiveBase, "sklearn_tags"):
def _sklearn_tags(self): return {}
setattr(LiveBase, "sklearn_tags", _sklearn_tags)
except Exception:
pass
return LiveBase
except Exception:
pass
# For xgboost wrappers, map to live classes if referenced
if module.startswith("xgboost.") and name in ("XGBClassifier", "XGBRegressor", "XGBModel"):
try:
import xgboost as xgb
cls = getattr(xgb, name, None)
if cls is not None:
return cls
except Exception:
pass
return super().find_class(module, name)
try:
bio = io.BytesIO(data)
u = PatchedUnpickler(bio)
obj = u.load()
return ("manual_pickle", obj)
except Exception as e:
return ("manual_pickle", e)
def try_xgboost_booster(path):
"""Try loading file as a native xgboost booster (json/bin)"""
try:
import xgboost as xgb
except Exception as e:
logger.exception(f"xgboost import failed: {e}")
return ("xgboost_import", e)
try:
logger.info(f"Trying xgboost.Booster().load_model on {path}")
booster = xgb.Booster()
booster.load_model(path)
logger.info("xgboost.Booster.load_model succeeded")
class BoosterWrapper:
def __init__(self, booster):
self.booster = booster
self._is_xgb_booster = True
def predict(self, X):
import numpy as _np, xgboost as _xgb
arr = _np.array(X, dtype=float)
dmat = _xgb.DMatrix(arr)
pred = self.booster.predict(dmat)
if hasattr(pred, "ndim") and pred.ndim == 1:
return (_np.where(pred >= 0.5, 1, 0)).tolist()
return pred.tolist()
def predict_proba(self, X):
import numpy as _np, xgboost as _xgb
arr = _np.array(X, dtype=float)
dmat = _xgb.DMatrix(arr)
pred = self.booster.predict(dmat)
if hasattr(pred, "ndim") and pred.ndim == 1:
return (_np.vstack([1 - pred, pred]).T).tolist()
return pred.tolist()
return ("xgboost_booster", BoosterWrapper(booster))
except Exception as e:
logger.exception(f"xgboost.Booster.load_model failed: {e}")
return ("xgboost_booster", e)
# -------------------------
# Main loader: try local -> try HF -> fallbacks
# -------------------------
def load_model():
logger.info("==== MODEL LOAD START ====")
logger.info(f"Repo: {HF_MODEL_REPO}")
logger.info(f"Filename: {HF_MODEL_FILENAME}")
logger.info(f"HF_TOKEN present? {bool(HF_TOKEN)}")
# try local candidates
for path in LOCAL_CANDIDATES:
try:
info = inspect_file(path)
logger.info(f"Inspecting local candidate: {info}")
if not info.get("exists"):
continue
t, res = try_joblib_load(path)
if t == "joblib" and not isinstance(res, Exception):
return res
# if joblib failed with sklearn_tags error, attempt manual unpickle
if t == "joblib" and isinstance(res, Exception):
msg = str(res)
if "sklearn_tags" in msg or "sklearn_tags" in getattr(res, "args", ()):
logger.info("joblib.load failed with sklearn_tags; trying manual pickle unpickle fallback")
tm, obj = manual_pickle_unpickle(path)
if tm == "manual_pickle" and not isinstance(obj, Exception):
logger.info("manual unpickle succeeded")
return obj
else:
logger.error("manual unpickle did not succeed; continuing to other fallbacks")
# try native booster
t2, res2 = try_xgboost_booster(path)
if t2 == "xgboost_booster" and not isinstance(res2, Exception):
return res2
except Exception as e:
logger.exception(f"Error while trying local candidate {path}: {e}")
# try huggingface hub
try:
logger.info(f"Trying hf_hub_download from {HF_MODEL_REPO}/{HF_MODEL_FILENAME}")
model_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=HF_MODEL_FILENAME, token=HF_TOKEN)
logger.info(f"Downloaded model to: {model_path}")
info = inspect_file(model_path)
logger.info(f"Inspecting downloaded file: {info}")
t, res = try_joblib_load(model_path)
if t == "joblib" and not isinstance(res, Exception):
return res
if t == "joblib" and isinstance(res, Exception):
msg = str(res)
if "sklearn_tags" in msg or "sklearn_tags" in getattr(res, "args", ()):
logger.info("joblib.load failed on downloaded file with sklearn_tags; trying manual unpickle fallback")
tm, obj = manual_pickle_unpickle(model_path)
if tm == "manual_pickle" and not isinstance(obj, Exception):
logger.info("manual unpickle succeeded on downloaded file")
return obj
else:
logger.error("manual unpickle did not succeed on downloaded file")
t2, res2 = try_xgboost_booster(model_path)
if t2 == "xgboost_booster" and not isinstance(res2, Exception):
return res2
logger.error("Tried joblib/manual-unpickle and xgboost loader on downloaded file but all failed.")
return None
except Exception as e:
logger.exception(f"hf_hub_download failed: {e}")
return None
# -------------------------
# Prediction helper: accepts dict (col->val), list, or list-of-lists
# -------------------------
def predict(model, features):
if model is None:
return {"error": "Model not loaded"}
try:
import pandas as _pd
import numpy as _np
is_booster = hasattr(model, "_is_xgb_booster")
# dict -> DataFrame (preserve key order)
if isinstance(features, dict):
cols = [str(k) for k in features.keys()]
row = [features[k] for k in features.keys()]
df = _pd.DataFrame([row], columns=cols)
if is_booster:
arr = df.values.astype(float)
preds = model.predict(arr)
prob = None
if hasattr(model, "predict_proba"):
p = model.predict_proba(arr)
try:
prob = float(p[0][1])
except:
prob = None
return {"prediction": int(preds[0]) if isinstance(preds, (list,tuple)) else int(preds), "probability": prob}
if hasattr(model, "predict"):
pred = model.predict(df)[0]
prob = None
if hasattr(model, "predict_proba"):
p = model.predict_proba(df)[0]
try:
prob = float(max(p))
except:
prob = None
try:
pred = int(pred)
except:
pass
return {"prediction": pred, "probability": prob}
return {"error": "Loaded model object not recognized"}
# list -> single row numeric
if isinstance(features, (list,tuple)):
arr2d = _np.array([features], dtype=float)
if is_booster:
preds = model.predict(arr2d)
prob = None
if hasattr(model, "predict_proba"):
p = model.predict_proba(arr2d)
try:
prob = float(p[0][1])
except:
prob = None
return {"prediction": int(preds[0]), "probability": prob}
if hasattr(model, "predict"):
try:
pred = model.predict(arr2d)[0]
prob = None
if hasattr(model, "predict_proba"):
p = model.predict_proba(arr2d)[0]
try:
prob = float(max(p))
except:
prob = None
return {"prediction": pred, "probability": prob}
except Exception:
cols = [str(i) for i in range(arr2d.shape[1])]
df = _pd.DataFrame(arr2d, columns=cols)
pred = model.predict(df)[0]
prob = None
if hasattr(model, "predict_proba"):
p = model.predict_proba(df)[0]
try:
prob = float(max(p))
except:
prob = None
return {"prediction": pred, "probability": prob}
# batch
if isinstance(features, list) and len(features) > 0 and isinstance(features[0], (list, tuple)):
arr = _np.array(features, dtype=float)
if is_booster:
preds = model.predict(arr)
prob = None
if hasattr(model, "predict_proba"):
p = model.predict_proba(arr)
try:
prob = float(p[0][1])
except:
prob = None
return {"prediction": preds.tolist(), "probability": prob}
if hasattr(model, "predict"):
try:
pred = model.predict(arr)
prob = None
if hasattr(model, "predict_proba"):
p = model.predict_proba(arr)
try:
prob = float(max(p[0]))
except:
prob = None
return {"prediction": pred.tolist(), "probability": prob}
except Exception:
cols = [str(i) for i in range(arr.shape[1])]
df = _pd.DataFrame(arr, columns=cols)
pred = model.predict(df)
prob = None
if hasattr(model, "predict_proba"):
p = model.predict_proba(df)
try:
prob = float(max(p[0]))
except:
prob = None
return {"prediction": pred.tolist(), "probability": prob}
return {"error": "Unsupported features format. Provide dict (col->val) or list of values."}
except Exception as e:
logger.exception(f"Prediction error: {e}")
return {"error": str(e)}