File size: 17,406 Bytes
1cf65c0
3b5d784
610ffb3
291fc57
427c6bb
3b5d784
 
291fc57
 
ad430b5
291fc57
 
 
 
 
 
 
 
 
 
 
 
 
 
427c6bb
3b5d784
51547bb
3b5d784
 
51547bb
 
 
 
3b5d784
51547bb
 
3b5d784
 
 
 
 
51547bb
3b5d784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51547bb
 
3b5d784
 
51547bb
3b5d784
 
 
 
 
51547bb
3b5d784
 
 
51547bb
3b5d784
 
ad430b5
 
1cf65c0
3b5d784
1cf65c0
 
 
 
175c1d7
 
 
 
 
 
 
 
1cf65c0
175c1d7
 
 
1cf65c0
175c1d7
1cf65c0
175c1d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cf65c0
3b5d784
 
 
ad430b5
 
3b5d784
427c6bb
 
 
 
 
 
 
 
 
51547bb
427c6bb
 
 
3b5d784
427c6bb
 
 
 
 
 
3b5d784
427c6bb
3b5d784
 
 
427c6bb
 
 
 
 
 
 
 
3b5d784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427c6bb
3b5d784
427c6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad430b5
427c6bb
 
 
 
 
 
 
ad430b5
427c6bb
 
 
 
 
 
 
 
3b5d784
427c6bb
291fc57
 
 
 
 
 
3b5d784
291fc57
610ffb3
427c6bb
 
 
 
 
 
 
 
 
3b5d784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427c6bb
610ffb3
427c6bb
3f74648
3b5d784
3f74648
291fc57
427c6bb
291fc57
427c6bb
 
 
 
 
 
291fc57
3b5d784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427c6bb
 
 
 
291fc57
427c6bb
3b5d784
427c6bb
 
291fc57
 
610ffb3
291fc57
427c6bb
 
 
ad430b5
 
3b5d784
427c6bb
3b5d784
 
 
427c6bb
 
 
 
 
 
 
 
 
 
 
3b5d784
427c6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b5d784
427c6bb
3b5d784
 
427c6bb
 
 
 
 
 
 
 
 
 
3b5d784
427c6bb
 
 
 
 
 
 
 
 
 
 
 
ad430b5
427c6bb
 
 
 
 
 
 
 
 
 
 
 
ad430b5
427c6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad430b5
427c6bb
 
 
 
 
 
 
 
 
 
 
 
 
291fc57
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
# predict_utils.py
# Robust loader with upfront patches + manual-unpickle fallback for sklearn/xgboost compatibility.
import os
import logging
import joblib
import io
import pickle
from huggingface_hub import hf_hub_download

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "sathishaiuse/wellness-classifier-model")
HF_MODEL_FILENAME = os.getenv("HF_MODEL_FILENAME", "best_overall_XGBoost.joblib")
HF_TOKEN = os.getenv("HF_TOKEN") or None

LOCAL_CANDIDATES = [
    os.path.join("/app", HF_MODEL_FILENAME),
    os.path.join("/tmp", HF_MODEL_FILENAME),
    os.path.join("/home/user/app", HF_MODEL_FILENAME),
    HF_MODEL_FILENAME
]

# -------------------------
# Upfront compatibility patches (run at import time)
# -------------------------
def patch_sklearn_base():
    """Make sure BaseEstimator exposes sklearn_tags/_get_tags/_more_tags used during unpickling."""
    try:
        import sklearn
        from sklearn.base import BaseEstimator
    except Exception as e:
        logger.debug(f"sklearn not available to patch: {e}")
        return

    # Provide sklearn_tags method if missing
    if not hasattr(BaseEstimator, "sklearn_tags"):
        def _sklearn_tags(self):
            return {}
        try:
            setattr(BaseEstimator, "sklearn_tags", _sklearn_tags)
            logger.info("Patched BaseEstimator.sklearn_tags()")
        except Exception as e:
            logger.debug(f"Could not set BaseEstimator.sklearn_tags: {e}")

    # Provide _get_tags if missing
    if not hasattr(BaseEstimator, "_get_tags"):
        def _get_tags(self):
            tags = {}
            more = getattr(self, "_more_tags", None)
            if callable(more):
                try:
                    tags.update(more())
                except Exception:
                    pass
            st = getattr(self, "sklearn_tags", None)
            if callable(st):
                try:
                    tags.update(st())
                except Exception:
                    pass
            return tags
        try:
            setattr(BaseEstimator, "_get_tags", _get_tags)
            logger.info("Patched BaseEstimator._get_tags()")
        except Exception as e:
            logger.debug(f"Could not set BaseEstimator._get_tags: {e}")

    # Provide a default _more_tags if missing
    if not hasattr(BaseEstimator, "_more_tags"):
        def _more_tags(self):
            return {}
        try:
            setattr(BaseEstimator, "_more_tags", _more_tags)
            logger.info("Patched BaseEstimator._more_tags()")
        except Exception as e:
            logger.debug(f"Could not set BaseEstimator._more_tags: {e}")

def patch_xgboost_wrappers():
    """Add common attributes expected by older pickles to XGBoost classes/base."""
    try:
        import xgboost as xgb
    except Exception as e:
        logger.debug(f"xgboost not available to patch: {e}")
        return

    XGBModel = getattr(xgb, "XGBModel", None)
    if XGBModel is not None:
        for attr, val in {
            "gpu_id": None,
            "nthread": None,
            "n_jobs": None,
            "predictor": None,
            "base_score": None,
            "objective": None,
        }.items():
            try:
                if not hasattr(XGBModel, attr):
                    setattr(XGBModel, attr, val)
                    logger.info(f"Patched XGBModel.{attr} = {val!r}")
            except Exception as e:
                logger.debug(f"Could not patch XGBModel.{attr}: {e}")

    for cls_name in ("XGBClassifier", "XGBRegressor"):
        cls = getattr(xgb, cls_name, None)
        if cls is not None:
            for attr, val in {
                "use_label_encoder": False,
                "objective": None,
                "predictor": None,
                "gpu_id": None,
                "n_jobs": None,
                "nthread": None,
            }.items():
                try:
                    if not hasattr(cls, attr):
                        setattr(cls, attr, val)
                        logger.info(f"Patched {cls_name}.{attr} = {val!r}")
                except Exception as e:
                    logger.debug(f"Could not patch {cls_name}.{attr}: {e}")

# Apply upfront patches
patch_sklearn_base()
patch_xgboost_wrappers()

# -------------------------
# Helpers: inspect file & try loaders
# -------------------------
def inspect_file(path):
    info = {"path": path, "exists": False}
    try:
        info["exists"] = os.path.exists(path)
        if not info["exists"]:
            return info
        info["size"] = os.path.getsize(path)
        with open(path, "rb") as f:
            head = f.read(2048)
        info["head_bytes"] = head
        try:
            info["head_text"] = head.decode("utf-8", errors="replace")
        except Exception:
            info["head_text"] = None
    except Exception as e:
        info["inspect_error"] = str(e)
    return info

def try_joblib_load(path):
    """Try standard joblib load. Return ("joblib", model) or ("joblib", exception)"""
    try:
        # Re-apply patches immediately before load (cover lazy imports)
        patch_sklearn_base()
        patch_xgboost_wrappers()
        logger.info(f"Trying joblib.load on {path}")
        m = joblib.load(path)
        logger.info("joblib.load succeeded")
        return ("joblib", m)
    except Exception as e:
        logger.exception(f"joblib.load failed: {e}")
        return ("joblib", e)

def manual_pickle_unpickle(path):
    """
    Last-resort: attempt to unpickle the raw file bytes with a custom Unpickler
    that maps pickled references of sklearn base classes to the live patched classes.
    This may succeed when joblib.load fails due to base-class method mismatches.
    """
    try:
        data = open(path, "rb").read()
    except Exception as e:
        return ("manual_pickle", e)

    class PatchedUnpickler(pickle.Unpickler):
        def find_class(self, module, name):
            # If pickle references sklearn.base.BaseEstimator, return the live patched class
            if module.startswith("sklearn.") and name in ("BaseEstimator",):
                try:
                    from sklearn.base import BaseEstimator as LiveBase
                    # ensure our patches are present
                    try:
                        if not hasattr(LiveBase, "sklearn_tags"):
                            def _sklearn_tags(self): return {}
                            setattr(LiveBase, "sklearn_tags", _sklearn_tags)
                    except Exception:
                        pass
                    return LiveBase
                except Exception:
                    pass
            # For xgboost wrappers, map to live classes if referenced
            if module.startswith("xgboost.") and name in ("XGBClassifier", "XGBRegressor", "XGBModel"):
                try:
                    import xgboost as xgb
                    cls = getattr(xgb, name, None)
                    if cls is not None:
                        return cls
                except Exception:
                    pass
            return super().find_class(module, name)

    try:
        bio = io.BytesIO(data)
        u = PatchedUnpickler(bio)
        obj = u.load()
        return ("manual_pickle", obj)
    except Exception as e:
        return ("manual_pickle", e)

def try_xgboost_booster(path):
    """Try loading file as a native xgboost booster (json/bin)"""
    try:
        import xgboost as xgb
    except Exception as e:
        logger.exception(f"xgboost import failed: {e}")
        return ("xgboost_import", e)

    try:
        logger.info(f"Trying xgboost.Booster().load_model on {path}")
        booster = xgb.Booster()
        booster.load_model(path)
        logger.info("xgboost.Booster.load_model succeeded")
        class BoosterWrapper:
            def __init__(self, booster):
                self.booster = booster
                self._is_xgb_booster = True
            def predict(self, X):
                import numpy as _np, xgboost as _xgb
                arr = _np.array(X, dtype=float)
                dmat = _xgb.DMatrix(arr)
                pred = self.booster.predict(dmat)
                if hasattr(pred, "ndim") and pred.ndim == 1:
                    return (_np.where(pred >= 0.5, 1, 0)).tolist()
                return pred.tolist()
            def predict_proba(self, X):
                import numpy as _np, xgboost as _xgb
                arr = _np.array(X, dtype=float)
                dmat = _xgb.DMatrix(arr)
                pred = self.booster.predict(dmat)
                if hasattr(pred, "ndim") and pred.ndim == 1:
                    return (_np.vstack([1 - pred, pred]).T).tolist()
                return pred.tolist()
        return ("xgboost_booster", BoosterWrapper(booster))
    except Exception as e:
        logger.exception(f"xgboost.Booster.load_model failed: {e}")
        return ("xgboost_booster", e)

# -------------------------
# Main loader: try local -> try HF -> fallbacks
# -------------------------
def load_model():
    logger.info("==== MODEL LOAD START ====")
    logger.info(f"Repo: {HF_MODEL_REPO}")
    logger.info(f"Filename: {HF_MODEL_FILENAME}")
    logger.info(f"HF_TOKEN present? {bool(HF_TOKEN)}")

    # try local candidates
    for path in LOCAL_CANDIDATES:
        try:
            info = inspect_file(path)
            logger.info(f"Inspecting local candidate: {info}")
            if not info.get("exists"):
                continue

            t, res = try_joblib_load(path)
            if t == "joblib" and not isinstance(res, Exception):
                return res

            # if joblib failed with sklearn_tags error, attempt manual unpickle
            if t == "joblib" and isinstance(res, Exception):
                msg = str(res)
                if "sklearn_tags" in msg or "sklearn_tags" in getattr(res, "args", ()):
                    logger.info("joblib.load failed with sklearn_tags; trying manual pickle unpickle fallback")
                    tm, obj = manual_pickle_unpickle(path)
                    if tm == "manual_pickle" and not isinstance(obj, Exception):
                        logger.info("manual unpickle succeeded")
                        return obj
                    else:
                        logger.error("manual unpickle did not succeed; continuing to other fallbacks")

            # try native booster
            t2, res2 = try_xgboost_booster(path)
            if t2 == "xgboost_booster" and not isinstance(res2, Exception):
                return res2

        except Exception as e:
            logger.exception(f"Error while trying local candidate {path}: {e}")

    # try huggingface hub
    try:
        logger.info(f"Trying hf_hub_download from {HF_MODEL_REPO}/{HF_MODEL_FILENAME}")
        model_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=HF_MODEL_FILENAME, token=HF_TOKEN)
        logger.info(f"Downloaded model to: {model_path}")
        info = inspect_file(model_path)
        logger.info(f"Inspecting downloaded file: {info}")

        t, res = try_joblib_load(model_path)
        if t == "joblib" and not isinstance(res, Exception):
            return res

        if t == "joblib" and isinstance(res, Exception):
            msg = str(res)
            if "sklearn_tags" in msg or "sklearn_tags" in getattr(res, "args", ()):
                logger.info("joblib.load failed on downloaded file with sklearn_tags; trying manual unpickle fallback")
                tm, obj = manual_pickle_unpickle(model_path)
                if tm == "manual_pickle" and not isinstance(obj, Exception):
                    logger.info("manual unpickle succeeded on downloaded file")
                    return obj
                else:
                    logger.error("manual unpickle did not succeed on downloaded file")

        t2, res2 = try_xgboost_booster(model_path)
        if t2 == "xgboost_booster" and not isinstance(res2, Exception):
            return res2

        logger.error("Tried joblib/manual-unpickle and xgboost loader on downloaded file but all failed.")
        return None
    except Exception as e:
        logger.exception(f"hf_hub_download failed: {e}")
        return None

# -------------------------
# Prediction helper: accepts dict (col->val), list, or list-of-lists
# -------------------------
def predict(model, features):
    if model is None:
        return {"error": "Model not loaded"}

    try:
        import pandas as _pd
        import numpy as _np

        is_booster = hasattr(model, "_is_xgb_booster")

        # dict -> DataFrame (preserve key order)
        if isinstance(features, dict):
            cols = [str(k) for k in features.keys()]
            row = [features[k] for k in features.keys()]
            df = _pd.DataFrame([row], columns=cols)

            if is_booster:
                arr = df.values.astype(float)
                preds = model.predict(arr)
                prob = None
                if hasattr(model, "predict_proba"):
                    p = model.predict_proba(arr)
                    try:
                        prob = float(p[0][1])
                    except:
                        prob = None
                return {"prediction": int(preds[0]) if isinstance(preds, (list,tuple)) else int(preds), "probability": prob}

            if hasattr(model, "predict"):
                pred = model.predict(df)[0]
                prob = None
                if hasattr(model, "predict_proba"):
                    p = model.predict_proba(df)[0]
                    try:
                        prob = float(max(p))
                    except:
                        prob = None
                try:
                    pred = int(pred)
                except:
                    pass
                return {"prediction": pred, "probability": prob}

            return {"error": "Loaded model object not recognized"}

        # list -> single row numeric
        if isinstance(features, (list,tuple)):
            arr2d = _np.array([features], dtype=float)
            if is_booster:
                preds = model.predict(arr2d)
                prob = None
                if hasattr(model, "predict_proba"):
                    p = model.predict_proba(arr2d)
                    try:
                        prob = float(p[0][1])
                    except:
                        prob = None
                return {"prediction": int(preds[0]), "probability": prob}

            if hasattr(model, "predict"):
                try:
                    pred = model.predict(arr2d)[0]
                    prob = None
                    if hasattr(model, "predict_proba"):
                        p = model.predict_proba(arr2d)[0]
                        try:
                            prob = float(max(p))
                        except:
                            prob = None
                    return {"prediction": pred, "probability": prob}
                except Exception:
                    cols = [str(i) for i in range(arr2d.shape[1])]
                    df = _pd.DataFrame(arr2d, columns=cols)
                    pred = model.predict(df)[0]
                    prob = None
                    if hasattr(model, "predict_proba"):
                        p = model.predict_proba(df)[0]
                        try:
                            prob = float(max(p))
                        except:
                            prob = None
                    return {"prediction": pred, "probability": prob}

        # batch
        if isinstance(features, list) and len(features) > 0 and isinstance(features[0], (list, tuple)):
            arr = _np.array(features, dtype=float)
            if is_booster:
                preds = model.predict(arr)
                prob = None
                if hasattr(model, "predict_proba"):
                    p = model.predict_proba(arr)
                    try:
                        prob = float(p[0][1])
                    except:
                        prob = None
                return {"prediction": preds.tolist(), "probability": prob}
            if hasattr(model, "predict"):
                try:
                    pred = model.predict(arr)
                    prob = None
                    if hasattr(model, "predict_proba"):
                        p = model.predict_proba(arr)
                        try:
                            prob = float(max(p[0]))
                        except:
                            prob = None
                    return {"prediction": pred.tolist(), "probability": prob}
                except Exception:
                    cols = [str(i) for i in range(arr.shape[1])]
                    df = _pd.DataFrame(arr, columns=cols)
                    pred = model.predict(df)
                    prob = None
                    if hasattr(model, "predict_proba"):
                        p = model.predict_proba(df)
                        try:
                            prob = float(max(p[0]))
                        except:
                            prob = None
                    return {"prediction": pred.tolist(), "probability": prob}

        return {"error": "Unsupported features format. Provide dict (col->val) or list of values."}
    except Exception as e:
        logger.exception(f"Prediction error: {e}")
        return {"error": str(e)}