Spaces:
Sleeping
Sleeping
Update predict_utils.py
Browse files- predict_utils.py +48 -62
predict_utils.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# predict_utils.py
|
| 2 |
-
# Robust loader + monkey-patches for XGBoost sklearn wrappers
|
| 3 |
import os
|
| 4 |
import logging
|
| 5 |
import joblib
|
|
@@ -21,11 +21,9 @@ LOCAL_CANDIDATES = [
|
|
| 21 |
]
|
| 22 |
|
| 23 |
# -------------------------
|
| 24 |
-
#
|
| 25 |
-
#
|
| 26 |
-
#
|
| 27 |
-
# "'XGBModel' object has no attribute 'gpu_id'"
|
| 28 |
-
# Call this BEFORE joblib.load so unpickling has the attributes available.
|
| 29 |
# -------------------------
|
| 30 |
def ensure_xgb_sklearn_compat():
|
| 31 |
try:
|
|
@@ -34,66 +32,54 @@ def ensure_xgb_sklearn_compat():
|
|
| 34 |
logger.debug(f"xgboost not importable for patching: {e}")
|
| 35 |
return
|
| 36 |
|
| 37 |
-
#
|
| 38 |
XGBModel = getattr(xgb, "XGBModel", None)
|
| 39 |
if XGBModel is not None:
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
try:
|
| 42 |
-
|
| 43 |
-
|
|
|
|
| 44 |
except Exception as e:
|
| 45 |
-
logger.debug(f"Could not
|
| 46 |
-
|
| 47 |
-
#
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# XGBRegressor: similar patches
|
| 66 |
-
XGBRegressor = getattr(xgb, "XGBRegressor", None)
|
| 67 |
-
if XGBRegressor is not None:
|
| 68 |
-
if not hasattr(XGBRegressor, "use_label_encoder"):
|
| 69 |
-
try:
|
| 70 |
-
setattr(XGBRegressor, "use_label_encoder", False)
|
| 71 |
-
logger.info("Patched XGBRegressor.use_label_encoder = False")
|
| 72 |
-
except Exception as e:
|
| 73 |
-
logger.debug(f"Could not patch XGBRegressor.use_label_encoder: {e}")
|
| 74 |
-
if not hasattr(XGBRegressor, "objective"):
|
| 75 |
-
try:
|
| 76 |
-
setattr(XGBRegressor, "objective", None)
|
| 77 |
-
logger.info("Patched XGBRegressor.objective = None")
|
| 78 |
-
except Exception as e:
|
| 79 |
-
logger.debug(f"Could not patch XGBRegressor.objective: {e}")
|
| 80 |
-
|
| 81 |
-
# Also handle the case where pickled objects expect 'nthread' or 'n_jobs'
|
| 82 |
-
if XGBModel is not None:
|
| 83 |
-
if not hasattr(XGBModel, "nthread"):
|
| 84 |
-
try:
|
| 85 |
-
setattr(XGBModel, "nthread", None)
|
| 86 |
-
logger.info("Patched XGBModel.nthread = None")
|
| 87 |
-
except Exception as e:
|
| 88 |
-
logger.debug(f"Could not patch XGBModel.nthread: {e}")
|
| 89 |
-
if not hasattr(XGBModel, "n_jobs"):
|
| 90 |
-
try:
|
| 91 |
-
setattr(XGBModel, "n_jobs", None)
|
| 92 |
-
logger.info("Patched XGBModel.n_jobs = None")
|
| 93 |
-
except Exception as e:
|
| 94 |
-
logger.debug(f"Could not patch XGBModel.n_jobs: {e}")
|
| 95 |
|
| 96 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
ensure_xgb_sklearn_compat()
|
| 98 |
|
| 99 |
# -------------------------
|
|
@@ -119,7 +105,7 @@ def inspect_file(path):
|
|
| 119 |
|
| 120 |
def try_joblib_load(path):
|
| 121 |
try:
|
| 122 |
-
#
|
| 123 |
ensure_xgb_sklearn_compat()
|
| 124 |
logger.info(f"Trying joblib.load on {path}")
|
| 125 |
m = joblib.load(path)
|
|
|
|
| 1 |
# predict_utils.py
|
| 2 |
+
# Robust loader + extended monkey-patches for XGBoost sklearn wrappers
|
| 3 |
import os
|
| 4 |
import logging
|
| 5 |
import joblib
|
|
|
|
| 21 |
]
|
| 22 |
|
| 23 |
# -------------------------
|
| 24 |
+
# Extended monkey-patch
|
| 25 |
+
# Add commonly-expected attributes so unpickling older models succeeds.
|
| 26 |
+
# Call this BEFORE joblib.load so unpickle finds these attributes.
|
|
|
|
|
|
|
| 27 |
# -------------------------
|
| 28 |
def ensure_xgb_sklearn_compat():
|
| 29 |
try:
|
|
|
|
| 32 |
logger.debug(f"xgboost not importable for patching: {e}")
|
| 33 |
return
|
| 34 |
|
| 35 |
+
# Attributes to add on XGBModel base class (safe defaults)
|
| 36 |
XGBModel = getattr(xgb, "XGBModel", None)
|
| 37 |
if XGBModel is not None:
|
| 38 |
+
for attr, val in {
|
| 39 |
+
"gpu_id": None,
|
| 40 |
+
"nthread": None,
|
| 41 |
+
"n_jobs": None,
|
| 42 |
+
"predictor": None,
|
| 43 |
+
"base_score": None,
|
| 44 |
+
"objective": None,
|
| 45 |
+
}.items():
|
| 46 |
try:
|
| 47 |
+
if not hasattr(XGBModel, attr):
|
| 48 |
+
setattr(XGBModel, attr, val)
|
| 49 |
+
logger.info(f"Patched XGBModel.{attr} = {val!r}")
|
| 50 |
except Exception as e:
|
| 51 |
+
logger.debug(f"Could not patch XGBModel.{attr}: {e}")
|
| 52 |
+
|
| 53 |
+
# Patch classifier/regressor class-level defaults used in older pickles
|
| 54 |
+
for cls_name in ("XGBClassifier", "XGBRegressor"):
|
| 55 |
+
cls = getattr(xgb, cls_name, None)
|
| 56 |
+
if cls is not None:
|
| 57 |
+
for attr, val in {
|
| 58 |
+
"use_label_encoder": False,
|
| 59 |
+
"objective": None,
|
| 60 |
+
"predictor": None,
|
| 61 |
+
"gpu_id": None,
|
| 62 |
+
"n_jobs": None,
|
| 63 |
+
"nthread": None,
|
| 64 |
+
}.items():
|
| 65 |
+
try:
|
| 66 |
+
if not hasattr(cls, attr):
|
| 67 |
+
setattr(cls, attr, val)
|
| 68 |
+
logger.info(f"Patched {cls_name}.{attr} = {val!r}")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.debug(f"Could not patch {cls_name}.{attr}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
# Some pickles expect certain module-level names (rare) — leave safe no-op fallbacks
|
| 73 |
+
try:
|
| 74 |
+
# e.g., older pickles might refer to xgb.core.Booster attributes; skip if not present
|
| 75 |
+
core = getattr(xgb, "core", None)
|
| 76 |
+
if core is not None:
|
| 77 |
+
if not hasattr(core, "DataBatch"):
|
| 78 |
+
setattr(core, "DataBatch", object)
|
| 79 |
+
except Exception:
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
# Run the patch early
|
| 83 |
ensure_xgb_sklearn_compat()
|
| 84 |
|
| 85 |
# -------------------------
|
|
|
|
| 105 |
|
| 106 |
def try_joblib_load(path):
|
| 107 |
try:
|
| 108 |
+
# ensure patch immediately before load (handles lazy imports)
|
| 109 |
ensure_xgb_sklearn_compat()
|
| 110 |
logger.info(f"Trying joblib.load on {path}")
|
| 111 |
m = joblib.load(path)
|