sathishaiuse commited on
Commit
3b5d784
·
verified ·
1 Parent(s): 51547bb

Update predict_utils.py

Browse files
Files changed (1) hide show
  1. predict_utils.py +148 -100
predict_utils.py CHANGED
@@ -1,8 +1,10 @@
1
  # predict_utils.py
2
- # Robust loader with extended monkey-patches for XGBoost and scikit-learn compatibility.
3
  import os
4
  import logging
5
  import joblib
 
 
6
  from huggingface_hub import hf_hub_download
7
 
8
  # Logging
@@ -21,76 +23,68 @@ LOCAL_CANDIDATES = [
21
  ]
22
 
23
  # -------------------------
24
- # Monkey-patch scikit-learn to add missing tags/APIs used during unpickling
25
  # -------------------------
26
- def ensure_sklearn_compat():
 
27
  try:
28
  import sklearn
29
  from sklearn.base import BaseEstimator
30
  except Exception as e:
31
- logger.debug(f"scikit-learn not importable for patching: {e}")
32
  return
33
 
34
- # If older/newer pickles expect 'sklearn_tags' attribute/method on BaseEstimator, provide it.
35
- try:
36
- if not hasattr(BaseEstimator, "sklearn_tags"):
37
- # Provide a method attribute that returns an empty dict by default.
38
- def _sklearn_tags(self):
39
- # Estimators can override by defining sklearn_tags attribute at instance/class level.
40
- return {}
41
  setattr(BaseEstimator, "sklearn_tags", _sklearn_tags)
42
- logger.info("Patched BaseEstimator.sklearn_tags() -> default {}")
43
- except Exception as e:
44
- logger.debug(f"Could not patch BaseEstimator.sklearn_tags: {e}")
45
-
46
- # Ensure _get_tags exists (some older/newer flows call this)
47
- try:
48
- if not hasattr(BaseEstimator, "_get_tags"):
49
- def _get_tags(self):
50
- # If estimator defines _more_tags, call it to merge tags; otherwise use sklearn_tags if present.
51
- tags = {}
52
- # _more_tags (newer style)
53
- more = getattr(self, "_more_tags", None)
54
- if callable(more):
55
- try:
56
- tags.update(more())
57
- except Exception:
58
- pass
59
- # fallback to sklearn_tags method if present
60
- st = getattr(self, "sklearn_tags", None)
61
- if callable(st):
62
- try:
63
- tags.update(st())
64
- except Exception:
65
- pass
66
- return tags
67
  setattr(BaseEstimator, "_get_tags", _get_tags)
68
  logger.info("Patched BaseEstimator._get_tags()")
69
- except Exception as e:
70
- logger.debug(f"Could not patch BaseEstimator._get_tags: {e}")
71
 
72
- # Provide a safe _more_tags no-op if missing on class-level to avoid AttributeError
73
- try:
74
- if not hasattr(BaseEstimator, "_more_tags"):
75
- def _more_tags(self):
76
- return {}
77
  setattr(BaseEstimator, "_more_tags", _more_tags)
78
- logger.info("Patched BaseEstimator._more_tags() -> default {}")
79
- except Exception as e:
80
- logger.debug(f"Could not patch BaseEstimator._more_tags: {e}")
81
 
82
- # -------------------------
83
- # Monkey-patch xgboost sklearn wrappers & base class to add missing attributes.
84
- # Handles 'use_label_encoder', 'gpu_id', 'predictor', etc.
85
- # -------------------------
86
- def ensure_xgb_sklearn_compat():
87
  try:
88
  import xgboost as xgb
89
  except Exception as e:
90
- logger.debug(f"xgboost not importable for patching: {e}")
91
  return
92
 
93
- # Base class: XGBModel (add common attrs)
94
  XGBModel = getattr(xgb, "XGBModel", None)
95
  if XGBModel is not None:
96
  for attr, val in {
@@ -108,7 +102,6 @@ def ensure_xgb_sklearn_compat():
108
  except Exception as e:
109
  logger.debug(f"Could not patch XGBModel.{attr}: {e}")
110
 
111
- # XGBClassifier and XGBRegressor class-level defaults
112
  for cls_name in ("XGBClassifier", "XGBRegressor"):
113
  cls = getattr(xgb, cls_name, None)
114
  if cls is not None:
@@ -127,14 +120,12 @@ def ensure_xgb_sklearn_compat():
127
  except Exception as e:
128
  logger.debug(f"Could not patch {cls_name}.{attr}: {e}")
129
 
130
- # -------------------------
131
- # Call compatibility patches early so joblib.load has them available
132
- # -------------------------
133
- ensure_sklearn_compat()
134
- ensure_xgb_sklearn_compat()
135
 
136
  # -------------------------
137
- # Helpers: file inspection and loader attempts
138
  # -------------------------
139
  def inspect_file(path):
140
  info = {"path": path, "exists": False}
@@ -148,17 +139,18 @@ def inspect_file(path):
148
  info["head_bytes"] = head
149
  try:
150
  info["head_text"] = head.decode("utf-8", errors="replace")
151
- except:
152
  info["head_text"] = None
153
  except Exception as e:
154
  info["inspect_error"] = str(e)
155
  return info
156
 
157
  def try_joblib_load(path):
 
158
  try:
159
- # ensure patches just before load (in case of lazy imports)
160
- ensure_sklearn_compat()
161
- ensure_xgb_sklearn_compat()
162
  logger.info(f"Trying joblib.load on {path}")
163
  m = joblib.load(path)
164
  logger.info("joblib.load succeeded")
@@ -167,7 +159,54 @@ def try_joblib_load(path):
167
  logger.exception(f"joblib.load failed: {e}")
168
  return ("joblib", e)
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def try_xgboost_booster(path):
 
171
  try:
172
  import xgboost as xgb
173
  except Exception as e:
@@ -179,12 +218,10 @@ def try_xgboost_booster(path):
179
  booster = xgb.Booster()
180
  booster.load_model(path)
181
  logger.info("xgboost.Booster.load_model succeeded")
182
-
183
  class BoosterWrapper:
184
  def __init__(self, booster):
185
  self.booster = booster
186
  self._is_xgb_booster = True
187
-
188
  def predict(self, X):
189
  import numpy as _np, xgboost as _xgb
190
  arr = _np.array(X, dtype=float)
@@ -193,7 +230,6 @@ def try_xgboost_booster(path):
193
  if hasattr(pred, "ndim") and pred.ndim == 1:
194
  return (_np.where(pred >= 0.5, 1, 0)).tolist()
195
  return pred.tolist()
196
-
197
  def predict_proba(self, X):
198
  import numpy as _np, xgboost as _xgb
199
  arr = _np.array(X, dtype=float)
@@ -202,14 +238,13 @@ def try_xgboost_booster(path):
202
  if hasattr(pred, "ndim") and pred.ndim == 1:
203
  return (_np.vstack([1 - pred, pred]).T).tolist()
204
  return pred.tolist()
205
-
206
  return ("xgboost_booster", BoosterWrapper(booster))
207
  except Exception as e:
208
  logger.exception(f"xgboost.Booster.load_model failed: {e}")
209
  return ("xgboost_booster", e)
210
 
211
  # -------------------------
212
- # Core loader
213
  # -------------------------
214
  def load_model():
215
  logger.info("==== MODEL LOAD START ====")
@@ -217,7 +252,7 @@ def load_model():
217
  logger.info(f"Filename: {HF_MODEL_FILENAME}")
218
  logger.info(f"HF_TOKEN present? {bool(HF_TOKEN)}")
219
 
220
- # Try local candidate paths
221
  for path in LOCAL_CANDIDATES:
222
  try:
223
  info = inspect_file(path)
@@ -229,14 +264,27 @@ def load_model():
229
  if t == "joblib" and not isinstance(res, Exception):
230
  return res
231
 
232
- t, res = try_xgboost_booster(path)
233
- if t == "xgboost_booster" and not isinstance(res, Exception):
234
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  except Exception as e:
237
  logger.exception(f"Error while trying local candidate {path}: {e}")
238
 
239
- # Try huggingface hub download
240
  try:
241
  logger.info(f"Trying hf_hub_download from {HF_MODEL_REPO}/{HF_MODEL_FILENAME}")
242
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=HF_MODEL_FILENAME, token=HF_TOKEN)
@@ -248,27 +296,31 @@ def load_model():
248
  if t == "joblib" and not isinstance(res, Exception):
249
  return res
250
 
251
- t, res = try_xgboost_booster(model_path)
252
- if t == "xgboost_booster" and not isinstance(res, Exception):
253
- return res
254
-
255
- logger.error("Tried joblib and xgboost loader on downloaded file but both failed.")
 
 
 
 
 
 
 
 
 
 
 
256
  return None
257
  except Exception as e:
258
  logger.exception(f"hf_hub_download failed: {e}")
259
  return None
260
 
261
  # -------------------------
262
- # Robust predict
263
  # -------------------------
264
  def predict(model, features):
265
- """
266
- Accepts:
267
- - dict (col_name -> value) -> builds a single-row pandas.DataFrame preserving key order
268
- - list/tuple -> single row (numeric)
269
- - list-of-lists -> batch
270
- Returns: {"prediction": ..., "probability": ...} or {"error": "..."}
271
- """
272
  if model is None:
273
  return {"error": "Model not loaded"}
274
 
@@ -278,12 +330,11 @@ def predict(model, features):
278
 
279
  is_booster = hasattr(model, "_is_xgb_booster")
280
 
281
- # dict -> DataFrame
282
  if isinstance(features, dict):
283
- col_names = [str(k) for k in features.keys()]
284
- row_values = [features[k] for k in features.keys()]
285
- df = _pd.DataFrame([row_values], columns=col_names)
286
- logger.info(f"Prepared DataFrame for prediction with columns: {col_names}")
287
 
288
  if is_booster:
289
  arr = df.values.astype(float)
@@ -295,8 +346,7 @@ def predict(model, features):
295
  prob = float(p[0][1])
296
  except:
297
  prob = None
298
- pred_val = int(preds[0]) if isinstance(preds, (list, tuple)) else int(preds)
299
- return {"prediction": pred_val, "probability": prob}
300
 
301
  if hasattr(model, "predict"):
302
  pred = model.predict(df)[0]
@@ -313,10 +363,10 @@ def predict(model, features):
313
  pass
314
  return {"prediction": pred, "probability": prob}
315
 
316
- return {"error": "Loaded model object not recognized (no predict method)"}
317
 
318
- # list/tuple single row
319
- if isinstance(features, (list, tuple)):
320
  arr2d = _np.array([features], dtype=float)
321
  if is_booster:
322
  preds = model.predict(arr2d)
@@ -327,8 +377,7 @@ def predict(model, features):
327
  prob = float(p[0][1])
328
  except:
329
  prob = None
330
- pred_val = int(preds[0]) if isinstance(preds, (list, tuple)) else int(preds)
331
- return {"prediction": pred_val, "probability": prob}
332
 
333
  if hasattr(model, "predict"):
334
  try:
@@ -392,7 +441,6 @@ def predict(model, features):
392
  return {"prediction": pred.tolist(), "probability": prob}
393
 
394
  return {"error": "Unsupported features format. Provide dict (col->val) or list of values."}
395
-
396
  except Exception as e:
397
  logger.exception(f"Prediction error: {e}")
398
  return {"error": str(e)}
 
1
  # predict_utils.py
2
+ # Robust loader with upfront patches + manual-unpickle fallback for sklearn/xgboost compatibility.
3
  import os
4
  import logging
5
  import joblib
6
+ import io
7
+ import pickle
8
  from huggingface_hub import hf_hub_download
9
 
10
  # Logging
 
23
  ]
24
 
25
  # -------------------------
26
+ # Upfront compatibility patches (run at import time)
27
  # -------------------------
28
+ def patch_sklearn_base():
29
+ """Make sure BaseEstimator exposes sklearn_tags/_get_tags/_more_tags used during unpickling."""
30
  try:
31
  import sklearn
32
  from sklearn.base import BaseEstimator
33
  except Exception as e:
34
+ logger.debug(f"sklearn not available to patch: {e}")
35
  return
36
 
37
+ # Provide sklearn_tags method if missing
38
+ if not hasattr(BaseEstimator, "sklearn_tags"):
39
+ def _sklearn_tags(self):
40
+ return {}
41
+ try:
 
 
42
  setattr(BaseEstimator, "sklearn_tags", _sklearn_tags)
43
+ logger.info("Patched BaseEstimator.sklearn_tags()")
44
+ except Exception as e:
45
+ logger.debug(f"Could not set BaseEstimator.sklearn_tags: {e}")
46
+
47
+ # Provide _get_tags if missing
48
+ if not hasattr(BaseEstimator, "_get_tags"):
49
+ def _get_tags(self):
50
+ tags = {}
51
+ more = getattr(self, "_more_tags", None)
52
+ if callable(more):
53
+ try:
54
+ tags.update(more())
55
+ except Exception:
56
+ pass
57
+ st = getattr(self, "sklearn_tags", None)
58
+ if callable(st):
59
+ try:
60
+ tags.update(st())
61
+ except Exception:
62
+ pass
63
+ return tags
64
+ try:
 
 
 
65
  setattr(BaseEstimator, "_get_tags", _get_tags)
66
  logger.info("Patched BaseEstimator._get_tags()")
67
+ except Exception as e:
68
+ logger.debug(f"Could not set BaseEstimator._get_tags: {e}")
69
 
70
+ # Provide a default _more_tags if missing
71
+ if not hasattr(BaseEstimator, "_more_tags"):
72
+ def _more_tags(self):
73
+ return {}
74
+ try:
75
  setattr(BaseEstimator, "_more_tags", _more_tags)
76
+ logger.info("Patched BaseEstimator._more_tags()")
77
+ except Exception as e:
78
+ logger.debug(f"Could not set BaseEstimator._more_tags: {e}")
79
 
80
+ def patch_xgboost_wrappers():
81
+ """Add common attributes expected by older pickles to XGBoost classes/base."""
 
 
 
82
  try:
83
  import xgboost as xgb
84
  except Exception as e:
85
+ logger.debug(f"xgboost not available to patch: {e}")
86
  return
87
 
 
88
  XGBModel = getattr(xgb, "XGBModel", None)
89
  if XGBModel is not None:
90
  for attr, val in {
 
102
  except Exception as e:
103
  logger.debug(f"Could not patch XGBModel.{attr}: {e}")
104
 
 
105
  for cls_name in ("XGBClassifier", "XGBRegressor"):
106
  cls = getattr(xgb, cls_name, None)
107
  if cls is not None:
 
120
  except Exception as e:
121
  logger.debug(f"Could not patch {cls_name}.{attr}: {e}")
122
 
123
+ # Apply upfront patches
124
+ patch_sklearn_base()
125
+ patch_xgboost_wrappers()
 
 
126
 
127
  # -------------------------
128
+ # Helpers: inspect file & try loaders
129
  # -------------------------
130
  def inspect_file(path):
131
  info = {"path": path, "exists": False}
 
139
  info["head_bytes"] = head
140
  try:
141
  info["head_text"] = head.decode("utf-8", errors="replace")
142
+ except Exception:
143
  info["head_text"] = None
144
  except Exception as e:
145
  info["inspect_error"] = str(e)
146
  return info
147
 
148
  def try_joblib_load(path):
149
+ """Try standard joblib load. Return ("joblib", model) or ("joblib", exception)"""
150
  try:
151
+ # Re-apply patches immediately before load (cover lazy imports)
152
+ patch_sklearn_base()
153
+ patch_xgboost_wrappers()
154
  logger.info(f"Trying joblib.load on {path}")
155
  m = joblib.load(path)
156
  logger.info("joblib.load succeeded")
 
159
  logger.exception(f"joblib.load failed: {e}")
160
  return ("joblib", e)
161
 
162
+ def manual_pickle_unpickle(path):
163
+ """
164
+ Last-resort: attempt to unpickle the raw file bytes with a custom Unpickler
165
+ that maps pickled references of sklearn base classes to the live patched classes.
166
+ This may succeed when joblib.load fails due to base-class method mismatches.
167
+ """
168
+ try:
169
+ data = open(path, "rb").read()
170
+ except Exception as e:
171
+ return ("manual_pickle", e)
172
+
173
+ class PatchedUnpickler(pickle.Unpickler):
174
+ def find_class(self, module, name):
175
+ # If pickle references sklearn.base.BaseEstimator, return the live patched class
176
+ if module.startswith("sklearn.") and name in ("BaseEstimator",):
177
+ try:
178
+ from sklearn.base import BaseEstimator as LiveBase
179
+ # ensure our patches are present
180
+ try:
181
+ if not hasattr(LiveBase, "sklearn_tags"):
182
+ def _sklearn_tags(self): return {}
183
+ setattr(LiveBase, "sklearn_tags", _sklearn_tags)
184
+ except Exception:
185
+ pass
186
+ return LiveBase
187
+ except Exception:
188
+ pass
189
+ # For xgboost wrappers, map to live classes if referenced
190
+ if module.startswith("xgboost.") and name in ("XGBClassifier", "XGBRegressor", "XGBModel"):
191
+ try:
192
+ import xgboost as xgb
193
+ cls = getattr(xgb, name, None)
194
+ if cls is not None:
195
+ return cls
196
+ except Exception:
197
+ pass
198
+ return super().find_class(module, name)
199
+
200
+ try:
201
+ bio = io.BytesIO(data)
202
+ u = PatchedUnpickler(bio)
203
+ obj = u.load()
204
+ return ("manual_pickle", obj)
205
+ except Exception as e:
206
+ return ("manual_pickle", e)
207
+
208
  def try_xgboost_booster(path):
209
+ """Try loading file as a native xgboost booster (json/bin)"""
210
  try:
211
  import xgboost as xgb
212
  except Exception as e:
 
218
  booster = xgb.Booster()
219
  booster.load_model(path)
220
  logger.info("xgboost.Booster.load_model succeeded")
 
221
  class BoosterWrapper:
222
  def __init__(self, booster):
223
  self.booster = booster
224
  self._is_xgb_booster = True
 
225
  def predict(self, X):
226
  import numpy as _np, xgboost as _xgb
227
  arr = _np.array(X, dtype=float)
 
230
  if hasattr(pred, "ndim") and pred.ndim == 1:
231
  return (_np.where(pred >= 0.5, 1, 0)).tolist()
232
  return pred.tolist()
 
233
  def predict_proba(self, X):
234
  import numpy as _np, xgboost as _xgb
235
  arr = _np.array(X, dtype=float)
 
238
  if hasattr(pred, "ndim") and pred.ndim == 1:
239
  return (_np.vstack([1 - pred, pred]).T).tolist()
240
  return pred.tolist()
 
241
  return ("xgboost_booster", BoosterWrapper(booster))
242
  except Exception as e:
243
  logger.exception(f"xgboost.Booster.load_model failed: {e}")
244
  return ("xgboost_booster", e)
245
 
246
  # -------------------------
247
+ # Main loader: try local -> try HF -> fallbacks
248
  # -------------------------
249
  def load_model():
250
  logger.info("==== MODEL LOAD START ====")
 
252
  logger.info(f"Filename: {HF_MODEL_FILENAME}")
253
  logger.info(f"HF_TOKEN present? {bool(HF_TOKEN)}")
254
 
255
+ # try local candidates
256
  for path in LOCAL_CANDIDATES:
257
  try:
258
  info = inspect_file(path)
 
264
  if t == "joblib" and not isinstance(res, Exception):
265
  return res
266
 
267
+ # if joblib failed with sklearn_tags error, attempt manual unpickle
268
+ if t == "joblib" and isinstance(res, Exception):
269
+ msg = str(res)
270
+ if "sklearn_tags" in msg or "sklearn_tags" in getattr(res, "args", ()):
271
+ logger.info("joblib.load failed with sklearn_tags; trying manual pickle unpickle fallback")
272
+ tm, obj = manual_pickle_unpickle(path)
273
+ if tm == "manual_pickle" and not isinstance(obj, Exception):
274
+ logger.info("manual unpickle succeeded")
275
+ return obj
276
+ else:
277
+ logger.error("manual unpickle did not succeed; continuing to other fallbacks")
278
+
279
+ # try native booster
280
+ t2, res2 = try_xgboost_booster(path)
281
+ if t2 == "xgboost_booster" and not isinstance(res2, Exception):
282
+ return res2
283
 
284
  except Exception as e:
285
  logger.exception(f"Error while trying local candidate {path}: {e}")
286
 
287
+ # try huggingface hub
288
  try:
289
  logger.info(f"Trying hf_hub_download from {HF_MODEL_REPO}/{HF_MODEL_FILENAME}")
290
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=HF_MODEL_FILENAME, token=HF_TOKEN)
 
296
  if t == "joblib" and not isinstance(res, Exception):
297
  return res
298
 
299
+ if t == "joblib" and isinstance(res, Exception):
300
+ msg = str(res)
301
+ if "sklearn_tags" in msg or "sklearn_tags" in getattr(res, "args", ()):
302
+ logger.info("joblib.load failed on downloaded file with sklearn_tags; trying manual unpickle fallback")
303
+ tm, obj = manual_pickle_unpickle(model_path)
304
+ if tm == "manual_pickle" and not isinstance(obj, Exception):
305
+ logger.info("manual unpickle succeeded on downloaded file")
306
+ return obj
307
+ else:
308
+ logger.error("manual unpickle did not succeed on downloaded file")
309
+
310
+ t2, res2 = try_xgboost_booster(model_path)
311
+ if t2 == "xgboost_booster" and not isinstance(res2, Exception):
312
+ return res2
313
+
314
+ logger.error("Tried joblib/manual-unpickle and xgboost loader on downloaded file but all failed.")
315
  return None
316
  except Exception as e:
317
  logger.exception(f"hf_hub_download failed: {e}")
318
  return None
319
 
320
  # -------------------------
321
+ # Prediction helper: accepts dict (col->val), list, or list-of-lists
322
  # -------------------------
323
  def predict(model, features):
 
 
 
 
 
 
 
324
  if model is None:
325
  return {"error": "Model not loaded"}
326
 
 
330
 
331
  is_booster = hasattr(model, "_is_xgb_booster")
332
 
333
+ # dict -> DataFrame (preserve key order)
334
  if isinstance(features, dict):
335
+ cols = [str(k) for k in features.keys()]
336
+ row = [features[k] for k in features.keys()]
337
+ df = _pd.DataFrame([row], columns=cols)
 
338
 
339
  if is_booster:
340
  arr = df.values.astype(float)
 
346
  prob = float(p[0][1])
347
  except:
348
  prob = None
349
+ return {"prediction": int(preds[0]) if isinstance(preds, (list,tuple)) else int(preds), "probability": prob}
 
350
 
351
  if hasattr(model, "predict"):
352
  pred = model.predict(df)[0]
 
363
  pass
364
  return {"prediction": pred, "probability": prob}
365
 
366
+ return {"error": "Loaded model object not recognized"}
367
 
368
+ # list -> single row numeric
369
+ if isinstance(features, (list,tuple)):
370
  arr2d = _np.array([features], dtype=float)
371
  if is_booster:
372
  preds = model.predict(arr2d)
 
377
  prob = float(p[0][1])
378
  except:
379
  prob = None
380
+ return {"prediction": int(preds[0]), "probability": prob}
 
381
 
382
  if hasattr(model, "predict"):
383
  try:
 
441
  return {"prediction": pred.tolist(), "probability": prob}
442
 
443
  return {"error": "Unsupported features format. Provide dict (col->val) or list of values."}
 
444
  except Exception as e:
445
  logger.exception(f"Prediction error: {e}")
446
  return {"error": str(e)}