ashaddams commited on
Commit
5117c71
·
verified ·
1 Parent(s): 7f7e9b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -16
app.py CHANGED
@@ -30,12 +30,11 @@ import xgboost as xgb
30
  import lightgbm as lgb
31
  from catboost import CatBoostRegressor
32
  import tensorflow as tf
33
- # ---- Robust sklearn tag shim for 3rd-party estimators ----
34
  def _safe_sklearn_tags(self):
35
  """
36
  Return sklearn tags without relying on super().sklearn_tags().
37
- Uses get_tags() when available (works on sklearn < 1.6) and
38
- falls back to an empty dict otherwise.
39
  """
40
  try:
41
  if hasattr(self, "get_tags"):
@@ -44,19 +43,33 @@ def _safe_sklearn_tags(self):
44
  pass
45
  return {}
46
 
47
- # Always override on these classes to avoid super() lookups that may fail
48
- try:
49
- xgb.XGBRegressor.sklearn_tags = _safe_sklearn_tags
50
- except Exception:
51
- pass
52
- try:
53
- lgb.LGBMRegressor.sklearn_tags = _safe_sklearn_tags
54
- except Exception:
55
- pass
56
- try:
57
- CatBoostRegressor.sklearn_tags = _safe_sklearn_tags
58
- except Exception:
59
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  # -----------------------------
@@ -627,21 +640,31 @@ def _load_ensemble(target: str) -> EnsembleBundle:
627
  cyc_cols = cfg["num_cols_cycle_first"]
628
  num_plain = cfg["num_cols_plain"]
629
 
 
630
  xgb_model = xgb.XGBRegressor()
631
  xgb_model.load_model(str(base / "xgb.json"))
 
632
 
 
633
  lgb_booster, lgb_model = None, None
634
  if (base / "lgb.txt").exists():
635
  lgb_booster = lgb.Booster(model_file=str(base / "lgb.txt"))
 
636
  elif (base / "lgb.joblib").exists():
637
  lgb_model = joblib.load(base / "lgb.joblib")
 
638
  else:
639
  raise FileNotFoundError("Neither lgb.txt nor lgb.joblib found for LGBM.")
640
 
 
641
  cat_model = CatBoostRegressor()
642
  cat_model.load_model(str(base / "cat.cbm"))
 
643
 
 
644
  mlp_model = tf.keras.models.load_model(base / "mlp.keras")
 
 
645
  meta = joblib.load(base / "meta.joblib")
646
 
647
  bundle = EnsembleBundle(
 
30
  import lightgbm as lgb
31
  from catboost import CatBoostRegressor
32
  import tensorflow as tf
33
+ # ---------- universal sklearn_tags patcher ----------
34
  def _safe_sklearn_tags(self):
35
  """
36
  Return sklearn tags without relying on super().sklearn_tags().
37
+ Uses get_tags() when available and falls back to {} otherwise.
 
38
  """
39
  try:
40
  if hasattr(self, "get_tags"):
 
43
  pass
44
  return {}
45
 
46
+ def ensure_sklearn_tags_on_mro(est):
47
+ """Attach a safe sklearn_tags() to every class in the estimator's MRO
48
+ that lacks it (or whose implementation fails)."""
49
+ try:
50
+ mro = getattr(est.__class__, "mro", lambda: [])()
51
+ except Exception:
52
+ mro = []
53
+ for cls in mro:
54
+ if cls is object:
55
+ continue
56
+ needs_patch = not hasattr(cls, "sklearn_tags") or not callable(getattr(cls, "sklearn_tags"))
57
+ if not needs_patch:
58
+ try:
59
+ # dry-run; if it errors, we’ll patch
60
+ getattr(est, "sklearn_tags")()
61
+ continue
62
+ except Exception:
63
+ needs_patch = True
64
+ if needs_patch:
65
+ try:
66
+ setattr(cls, "sklearn_tags", _safe_sklearn_tags)
67
+ except Exception:
68
+ try:
69
+ # fallback: instance-level bind
70
+ setattr(est, "sklearn_tags", _safe_sklearn_tags.__get__(est, est.__class__))
71
+ except Exception:
72
+ pass
73
 
74
 
75
  # -----------------------------
 
640
  cyc_cols = cfg["num_cols_cycle_first"]
641
  num_plain = cfg["num_cols_plain"]
642
 
643
+ # XGB
644
  xgb_model = xgb.XGBRegressor()
645
  xgb_model.load_model(str(base / "xgb.json"))
646
+ ensure_sklearn_tags_on_mro(xgb_model) # <-- add this
647
 
648
+ # LGBM
649
  lgb_booster, lgb_model = None, None
650
  if (base / "lgb.txt").exists():
651
  lgb_booster = lgb.Booster(model_file=str(base / "lgb.txt"))
652
+ # Booster is not a sklearn estimator; no patch needed.
653
  elif (base / "lgb.joblib").exists():
654
  lgb_model = joblib.load(base / "lgb.joblib")
655
+ ensure_sklearn_tags_on_mro(lgb_model) # <-- add this
656
  else:
657
  raise FileNotFoundError("Neither lgb.txt nor lgb.joblib found for LGBM.")
658
 
659
+ # CAT
660
  cat_model = CatBoostRegressor()
661
  cat_model.load_model(str(base / "cat.cbm"))
662
+ ensure_sklearn_tags_on_mro(cat_model) # <-- add this
663
 
664
+ # MLP
665
  mlp_model = tf.keras.models.load_model(base / "mlp.keras")
666
+ # (tf models don’t use sklearn tags)
667
+
668
  meta = joblib.load(base / "meta.joblib")
669
 
670
  bundle = EnsembleBundle(