Spaces:
Sleeping
Sleeping
| """ | |
| sniper_model.py — SniperModel class definition. | |
| The training script (sniper_v7_1.py) defined SniperModel in __main__, so | |
| joblib pickled it with the reference __main__.SniperModel. When loading in | |
| the Space, Python can't resolve that. We fix this with a custom Unpickler | |
| that redirects __main__.SniperModel → src.sniper_model.SniperModel, plus | |
| a monkey-patch of __main__ so any path works. | |
| """ | |
| import io | |
| import pickle | |
| import numpy as np | |
| from scipy.special import expit | |
| class SniperModel: | |
| """ | |
| LightGBM wrapper that supports both focal-loss (raw booster) and | |
| standard LGBMClassifier paths. Must exactly match the class as | |
| defined in sniper_v7_1.py so joblib can round-trip correctly. | |
| """ | |
| def __init__(self, params: dict, use_focal: bool = False): | |
| self.params = {k: v for k, v in params.items()} | |
| self.use_focal = use_focal | |
| self._booster = None | |
| self._clf = None | |
| def fit(self, X, y, sample_weight=None, X_val=None, y_val=None, | |
| early_stopping_rounds=50): | |
| import lightgbm as lgb | |
| p = {k: v for k, v in self.params.items()} | |
| n_est = int(p.pop("n_estimators", 500)) | |
| if self.use_focal: | |
| p["num_threads"] = p.pop("n_jobs", -1) | |
| p["seed"] = p.pop("random_state", 42) | |
| for key in ("objective", "metric", "is_unbalance", | |
| "scale_pos_weight", "verbosity"): | |
| p.pop(key, None) | |
| p["verbose"] = -1 | |
| y_arr = y.values if hasattr(y, "values") else np.asarray(y) | |
| w_arr = sample_weight if sample_weight is not None else np.ones(len(y_arr)) | |
| ds = lgb.Dataset(X, label=y_arr, weight=w_arr) | |
| valid_sets = [] | |
| callbacks = [lgb.log_evaluation(period=0)] | |
| if X_val is not None: | |
| y_val_arr = y_val.values if hasattr(y_val, "values") else np.asarray(y_val) | |
| valid_sets = [lgb.Dataset(X_val, label=y_val_arr)] | |
| callbacks.append(lgb.early_stopping(early_stopping_rounds, verbose=False)) | |
| from src.focal import focal_loss_objective, focal_loss_eval | |
| p["objective"] = focal_loss_objective | |
| self._booster = lgb.train( | |
| p, ds, num_boost_round=n_est, | |
| valid_sets=valid_sets or None, | |
| feval=focal_loss_eval if valid_sets else None, | |
| callbacks=callbacks, | |
| ) | |
| else: | |
| import lightgbm as lgb | |
| p.setdefault("verbosity", -1) | |
| clf_p = {k: v for k, v in p.items()} | |
| self._clf = lgb.LGBMClassifier(n_estimators=n_est, **clf_p) | |
| fit_kw = {"sample_weight": sample_weight} | |
| if X_val is not None: | |
| fit_kw["eval_set"] = [(X_val, y_val)] | |
| fit_kw["callbacks"] = [ | |
| lgb.early_stopping(early_stopping_rounds, verbose=False), | |
| lgb.log_evaluation(-1), | |
| ] | |
| self._clf.fit(X, y, **fit_kw) | |
| return self | |
| def predict_proba(self, X) -> np.ndarray: | |
| if self._booster is not None: | |
| raw = self._booster.predict(X, raw_score=True) | |
| p = expit(raw) | |
| return np.column_stack([1.0 - p, p]) | |
| if self._clf is not None: | |
| return self._clf.predict_proba(X) | |
| raise RuntimeError("SniperModel has not been fitted.") | |
| def feature_importances_(self) -> np.ndarray: | |
| if self._booster is not None: | |
| return self._booster.feature_importance(importance_type="gain") | |
| if self._clf is not None: | |
| return self._clf.feature_importances_ | |
| raise RuntimeError("SniperModel has not been fitted.") | |
| # --------------------------------------------------------------------------- | |
| # Custom unpickler — redirects __main__.SniperModel to this module | |
| # --------------------------------------------------------------------------- | |
| class _SniperUnpickler(pickle.Unpickler): | |
| """ | |
| Redirects any class that was pickled as __main__.X to src.sniper_model.X | |
| so loading works regardless of which script trained the model. | |
| """ | |
| _REMAP = { | |
| ("__main__", "SniperModel"): SniperModel, | |
| } | |
| def find_class(self, module, name): | |
| key = (module, name) | |
| if key in self._REMAP: | |
| return self._REMAP[key] | |
| return super().find_class(module, name) | |
| def safe_load(path) -> object: | |
| """ | |
| Load a joblib/pickle file, transparently remapping __main__.SniperModel | |
| to the local class definition. | |
| joblib files are multi-stream pickle — we try joblib first (which | |
| internally uses pickle), then fall back to our custom unpickler. | |
| """ | |
| import joblib | |
| # --- Try joblib directly (works if the class is already importable) --- | |
| try: | |
| return joblib.load(path) | |
| except (AttributeError, ModuleNotFoundError): | |
| pass | |
| # --- Inject SniperModel into __main__ so joblib's internal pickle | |
| # loader can resolve __main__.SniperModel on the second attempt --- | |
| import sys | |
| main_mod = sys.modules.get("__main__") | |
| _had_attr = hasattr(main_mod, "SniperModel") if main_mod else False | |
| _old_val = getattr(main_mod, "SniperModel", None) if main_mod else None | |
| try: | |
| if main_mod is not None: | |
| main_mod.SniperModel = SniperModel | |
| return joblib.load(path) | |
| except Exception: | |
| pass | |
| finally: | |
| # Restore __main__ to its original state | |
| if main_mod is not None: | |
| if _had_attr: | |
| main_mod.SniperModel = _old_val | |
| else: | |
| try: | |
| delattr(main_mod, "SniperModel") | |
| except AttributeError: | |
| pass | |
| # --- Last resort: raw pickle with custom Unpickler --- | |
| # joblib files start with a joblib header — try reading as raw pickle | |
| with open(path, "rb") as fh: | |
| data = fh.read() | |
| # joblib >=1.0 stores files as a zip of numpy arrays + pickle header. | |
| # For pure-pickle joblib files (older style), this works directly. | |
| try: | |
| return _SniperUnpickler(io.BytesIO(data)).load() | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"Could not load {path} with any strategy. " | |
| f"Ensure the model was saved with joblib and SniperModel is compatible. " | |
| f"Original error: {e}" | |
| ) | |
| def patch_main(): | |
| """ | |
| Inject SniperModel into sys.modules['__main__'] once at app startup. | |
| Call this before any joblib.load() calls. | |
| This is the simplest and most reliable fix. | |
| """ | |
| import sys | |
| main = sys.modules.get("__main__") | |
| if main is not None and not hasattr(main, "SniperModel"): | |
| main.SniperModel = SniperModel |