ashaddams commited on
Commit
a5206fc
·
verified ·
1 Parent(s): 5117c71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -36
app.py CHANGED
@@ -30,12 +30,10 @@ import xgboost as xgb
30
  import lightgbm as lgb
31
  from catboost import CatBoostRegressor
32
  import tensorflow as tf
33
- # ---------- universal sklearn_tags patcher ----------
 
34
  def _safe_sklearn_tags(self):
35
- """
36
- Return sklearn tags without relying on super().sklearn_tags().
37
- Uses get_tags() when available and falls back to {} otherwise.
38
- """
39
  try:
40
  if hasattr(self, "get_tags"):
41
  return self.get_tags()
@@ -43,34 +41,52 @@ def _safe_sklearn_tags(self):
43
  pass
44
  return {}
45
 
46
- def ensure_sklearn_tags_on_mro(est):
47
- """Attach a safe sklearn_tags() to every class in the estimator's MRO
48
- that lacks it (or whose implementation fails)."""
49
- try:
50
- mro = getattr(est.__class__, "mro", lambda: [])()
51
- except Exception:
52
- mro = []
53
- for cls in mro:
54
- if cls is object:
55
  continue
56
- needs_patch = not hasattr(cls, "sklearn_tags") or not callable(getattr(cls, "sklearn_tags"))
57
- if not needs_patch:
58
- try:
59
- # dry-run; if it errors, we’ll patch
60
- getattr(est, "sklearn_tags")()
61
- continue
62
- except Exception:
63
- needs_patch = True
64
- if needs_patch:
65
- try:
66
- setattr(cls, "sklearn_tags", _safe_sklearn_tags)
67
- except Exception:
68
  try:
69
- # fallback: instance-level bind
70
- setattr(est, "sklearn_tags", _safe_sklearn_tags.__get__(est, est.__class__))
71
- except Exception:
72
  pass
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # -----------------------------
76
  # Paths (relative in a Space)
@@ -626,10 +642,12 @@ _ENSEMBLES: dict[str, EnsembleBundle] = {}
626
  def _load_ensemble(target: str) -> EnsembleBundle:
627
  if target in _ENSEMBLES:
628
  return _ENSEMBLES[target]
 
629
  base = MODEL_DIR / target
630
  if not base.exists():
631
  raise FileNotFoundError(f"Model folder not found: {base}")
632
 
 
633
  encoders_b = joblib.load(base / "encoders.joblib")
634
  imputer_b = joblib.load(base / "imputer.joblib")
635
  scaler_b = joblib.load(base / "scaler.joblib") if (base / "scaler.joblib").exists() else None
@@ -643,28 +661,28 @@ def _load_ensemble(target: str) -> EnsembleBundle:
643
  # XGB
644
  xgb_model = xgb.XGBRegressor()
645
  xgb_model.load_model(str(base / "xgb.json"))
646
- ensure_sklearn_tags_on_mro(xgb_model) # <-- add this
647
 
648
  # LGBM
649
  lgb_booster, lgb_model = None, None
650
  if (base / "lgb.txt").exists():
651
  lgb_booster = lgb.Booster(model_file=str(base / "lgb.txt"))
652
- # Booster is not a sklearn estimator; no patch needed.
653
  elif (base / "lgb.joblib").exists():
654
  lgb_model = joblib.load(base / "lgb.joblib")
655
- ensure_sklearn_tags_on_mro(lgb_model) # <-- add this
656
  else:
657
  raise FileNotFoundError("Neither lgb.txt nor lgb.joblib found for LGBM.")
658
 
659
  # CAT
660
  cat_model = CatBoostRegressor()
661
  cat_model.load_model(str(base / "cat.cbm"))
662
- ensure_sklearn_tags_on_mro(cat_model) # <-- add this
663
 
664
- # MLP
665
  mlp_model = tf.keras.models.load_model(base / "mlp.keras")
666
- # (tf models don’t use sklearn tags)
667
 
 
668
  meta = joblib.load(base / "meta.joblib")
669
 
670
  bundle = EnsembleBundle(
 
30
  import lightgbm as lgb
31
  from catboost import CatBoostRegressor
32
  import tensorflow as tf
33
+ # ===== Robust sklearn_tags compatibility layer =====
34
+ # Works on sklearn<1.6 + 3rd-party wrappers that call super().sklearn_tags()
35
  def _safe_sklearn_tags(self):
36
+ """Return sklearn tags without relying on super().sklearn_tags()."""
 
 
 
37
  try:
38
  if hasattr(self, "get_tags"):
39
  return self.get_tags()
 
41
  pass
42
  return {}
43
 
44
+ def _patch_class_and_mro(cls):
45
+ """Attach a safe sklearn_tags to cls and all parents in its MRO."""
46
+ if not cls or cls is object:
47
+ return
48
+ for c in getattr(cls, "mro", lambda: [])():
49
+ if c is object:
 
 
 
50
  continue
51
+ try:
52
+ # If missing or likely to fail, replace with safe version
53
+ need = not hasattr(c, "sklearn_tags") or not callable(getattr(c, "sklearn_tags"))
54
+ if not need:
 
 
 
 
 
 
 
 
55
  try:
56
+ # Dry run on a dummy instance if possible
57
+ # (Some classes require init args, so ignore errors.)
 
58
  pass
59
+ except Exception:
60
+ need = True
61
+ if need:
62
+ setattr(c, "sklearn_tags", _safe_sklearn_tags)
63
+ except Exception:
64
+ # As a last resort, patch instance later (handled in loader too)
65
+ pass
66
+
67
+ # Patch common estimator classes up-front
68
+ try:
69
+ _patch_class_and_mro(xgb.XGBRegressor)
70
+ _patch_class_and_mro(xgb.XGBClassifier)
71
+ _patch_class_and_mro(xgb.XGBRFRegressor)
72
+ _patch_class_and_mro(xgb.XGBRFClassifier)
73
+ except Exception:
74
+ pass
75
+
76
+ try:
77
+ _patch_class_and_mro(lgb.LGBMRegressor)
78
+ _patch_class_and_mro(lgb.LGBMClassifier)
79
+ except Exception:
80
+ pass
81
+
82
+ try:
83
+ _patch_class_and_mro(CatBoostRegressor)
84
+ # (Classifier not used here, but harmless to patch if you add later)
85
+ # from catboost import CatBoostClassifier
86
+ # _patch_class_and_mro(CatBoostClassifier)
87
+ except Exception:
88
+ pass
89
+ # ===== end compatibility layer =====
90
 
91
  # -----------------------------
92
  # Paths (relative in a Space)
 
642
  def _load_ensemble(target: str) -> EnsembleBundle:
643
  if target in _ENSEMBLES:
644
  return _ENSEMBLES[target]
645
+
646
  base = MODEL_DIR / target
647
  if not base.exists():
648
  raise FileNotFoundError(f"Model folder not found: {base}")
649
 
650
+ # Preprocess artifacts
651
  encoders_b = joblib.load(base / "encoders.joblib")
652
  imputer_b = joblib.load(base / "imputer.joblib")
653
  scaler_b = joblib.load(base / "scaler.joblib") if (base / "scaler.joblib").exists() else None
 
661
  # XGB
662
  xgb_model = xgb.XGBRegressor()
663
  xgb_model.load_model(str(base / "xgb.json"))
664
+ _patch_class_and_mro(xgb_model.__class__)
665
 
666
  # LGBM
667
  lgb_booster, lgb_model = None, None
668
  if (base / "lgb.txt").exists():
669
  lgb_booster = lgb.Booster(model_file=str(base / "lgb.txt"))
670
+ # Booster is not an sklearn estimator -> no patch needed
671
  elif (base / "lgb.joblib").exists():
672
  lgb_model = joblib.load(base / "lgb.joblib")
673
+ _patch_class_and_mro(lgb_model.__class__)
674
  else:
675
  raise FileNotFoundError("Neither lgb.txt nor lgb.joblib found for LGBM.")
676
 
677
  # CAT
678
  cat_model = CatBoostRegressor()
679
  cat_model.load_model(str(base / "cat.cbm"))
680
+ _patch_class_and_mro(cat_model.__class__)
681
 
682
+ # MLP (Keras)
683
  mlp_model = tf.keras.models.load_model(base / "mlp.keras")
 
684
 
685
+ # Meta
686
  meta = joblib.load(base / "meta.joblib")
687
 
688
  bundle = EnsembleBundle(