sathishaiuse commited on
Commit
175c1d7
·
verified ·
1 Parent(s): 1cf65c0

Update predict_utils.py

Browse files
Files changed (1) hide show
  1. predict_utils.py +48 -62
predict_utils.py CHANGED
@@ -1,5 +1,5 @@
1
  # predict_utils.py
2
- # Robust loader + monkey-patches for XGBoost sklearn wrappers
3
  import os
4
  import logging
5
  import joblib
@@ -21,11 +21,9 @@ LOCAL_CANDIDATES = [
21
  ]
22
 
23
  # -------------------------
24
- # Monkey-patch xgboost sklearn wrappers & base class to add missing attributes.
25
- # This handles errors like:
26
- # "'XGBClassifier' object has no attribute 'use_label_encoder'"
27
- # "'XGBModel' object has no attribute 'gpu_id'"
28
- # Call this BEFORE joblib.load so unpickling has the attributes available.
29
  # -------------------------
30
  def ensure_xgb_sklearn_compat():
31
  try:
@@ -34,66 +32,54 @@ def ensure_xgb_sklearn_compat():
34
  logger.debug(f"xgboost not importable for patching: {e}")
35
  return
36
 
37
- # Base class: XGBModel (add gpu_id if missing)
38
  XGBModel = getattr(xgb, "XGBModel", None)
39
  if XGBModel is not None:
40
- if not hasattr(XGBModel, "gpu_id"):
 
 
 
 
 
 
 
41
  try:
42
- setattr(XGBModel, "gpu_id", None) # default None means CPU
43
- logger.info("Patched XGBModel.gpu_id = None")
 
44
  except Exception as e:
45
- logger.debug(f"Could not set XGBModel.gpu_id: {e}")
46
-
47
- # XGBClassifier: add use_label_encoder if missing, maybe objective
48
- XGBClassifier = getattr(xgb, "XGBClassifier", None)
49
- if XGBClassifier is not None:
50
- # Add attributes on the class object so unpickle will find defaults
51
- if not hasattr(XGBClassifier, "use_label_encoder"):
52
- try:
53
- setattr(XGBClassifier, "use_label_encoder", False)
54
- logger.info("Patched XGBClassifier.use_label_encoder = False")
55
- except Exception as e:
56
- logger.debug(f"Could not patch XGBClassifier.use_label_encoder: {e}")
57
- # sometimes older pickles expect 'objective' attribute default
58
- if not hasattr(XGBClassifier, "objective"):
59
- try:
60
- setattr(XGBClassifier, "objective", None)
61
- logger.info("Patched XGBClassifier.objective = None")
62
- except Exception as e:
63
- logger.debug(f"Could not patch XGBClassifier.objective: {e}")
64
-
65
- # XGBRegressor: similar patches
66
- XGBRegressor = getattr(xgb, "XGBRegressor", None)
67
- if XGBRegressor is not None:
68
- if not hasattr(XGBRegressor, "use_label_encoder"):
69
- try:
70
- setattr(XGBRegressor, "use_label_encoder", False)
71
- logger.info("Patched XGBRegressor.use_label_encoder = False")
72
- except Exception as e:
73
- logger.debug(f"Could not patch XGBRegressor.use_label_encoder: {e}")
74
- if not hasattr(XGBRegressor, "objective"):
75
- try:
76
- setattr(XGBRegressor, "objective", None)
77
- logger.info("Patched XGBRegressor.objective = None")
78
- except Exception as e:
79
- logger.debug(f"Could not patch XGBRegressor.objective: {e}")
80
-
81
- # Also handle the case where pickled objects expect 'nthread' or 'n_jobs'
82
- if XGBModel is not None:
83
- if not hasattr(XGBModel, "nthread"):
84
- try:
85
- setattr(XGBModel, "nthread", None)
86
- logger.info("Patched XGBModel.nthread = None")
87
- except Exception as e:
88
- logger.debug(f"Could not patch XGBModel.nthread: {e}")
89
- if not hasattr(XGBModel, "n_jobs"):
90
- try:
91
- setattr(XGBModel, "n_jobs", None)
92
- logger.info("Patched XGBModel.n_jobs = None")
93
- except Exception as e:
94
- logger.debug(f"Could not patch XGBModel.n_jobs: {e}")
95
 
96
- # Call patch early so joblib.load can use these defaults
 
 
 
 
 
 
 
 
 
 
97
  ensure_xgb_sklearn_compat()
98
 
99
  # -------------------------
@@ -119,7 +105,7 @@ def inspect_file(path):
119
 
120
  def try_joblib_load(path):
121
  try:
122
- # Ensure patch right before load (in case xgboost gets imported lazily)
123
  ensure_xgb_sklearn_compat()
124
  logger.info(f"Trying joblib.load on {path}")
125
  m = joblib.load(path)
 
1
  # predict_utils.py
2
+ # Robust loader + extended monkey-patches for XGBoost sklearn wrappers
3
  import os
4
  import logging
5
  import joblib
 
21
  ]
22
 
23
  # -------------------------
24
+ # Extended monkey-patch
25
+ # Add commonly-expected attributes so unpickling older models succeeds.
26
+ # Call this BEFORE joblib.load so unpickle finds these attributes.
 
 
27
  # -------------------------
28
  def ensure_xgb_sklearn_compat():
29
  try:
 
32
  logger.debug(f"xgboost not importable for patching: {e}")
33
  return
34
 
35
+ # Attributes to add on XGBModel base class (safe defaults)
36
  XGBModel = getattr(xgb, "XGBModel", None)
37
  if XGBModel is not None:
38
+ for attr, val in {
39
+ "gpu_id": None,
40
+ "nthread": None,
41
+ "n_jobs": None,
42
+ "predictor": None,
43
+ "base_score": None,
44
+ "objective": None,
45
+ }.items():
46
  try:
47
+ if not hasattr(XGBModel, attr):
48
+ setattr(XGBModel, attr, val)
49
+ logger.info(f"Patched XGBModel.{attr} = {val!r}")
50
  except Exception as e:
51
+ logger.debug(f"Could not patch XGBModel.{attr}: {e}")
52
+
53
+ # Patch classifier/regressor class-level defaults used in older pickles
54
+ for cls_name in ("XGBClassifier", "XGBRegressor"):
55
+ cls = getattr(xgb, cls_name, None)
56
+ if cls is not None:
57
+ for attr, val in {
58
+ "use_label_encoder": False,
59
+ "objective": None,
60
+ "predictor": None,
61
+ "gpu_id": None,
62
+ "n_jobs": None,
63
+ "nthread": None,
64
+ }.items():
65
+ try:
66
+ if not hasattr(cls, attr):
67
+ setattr(cls, attr, val)
68
+ logger.info(f"Patched {cls_name}.{attr} = {val!r}")
69
+ except Exception as e:
70
+ logger.debug(f"Could not patch {cls_name}.{attr}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ # Some pickles expect certain module-level names (rare) leave safe no-op fallbacks
73
+ try:
74
+ # e.g., older pickles might refer to xgb.core.Booster attributes; skip if not present
75
+ core = getattr(xgb, "core", None)
76
+ if core is not None:
77
+ if not hasattr(core, "DataBatch"):
78
+ setattr(core, "DataBatch", object)
79
+ except Exception:
80
+ pass
81
+
82
+ # Run the patch early
83
  ensure_xgb_sklearn_compat()
84
 
85
  # -------------------------
 
105
 
106
  def try_joblib_load(path):
107
  try:
108
+ # ensure patch immediately before load (handles lazy imports)
109
  ensure_xgb_sklearn_compat()
110
  logger.info(f"Trying joblib.load on {path}")
111
  m = joblib.load(path)