model-tester / src /sniper_model.py
Arkm20's picture
Create sniper_model.py
77c23e3 verified
"""
sniper_model.py — SniperModel class definition.
The training script (sniper_v7_1.py) defined SniperModel in __main__, so
joblib pickled it with the reference __main__.SniperModel. When loading in
the Space, Python can't resolve that. We fix this with a custom Unpickler
that redirects __main__.SniperModel → src.sniper_model.SniperModel, plus
a monkey-patch of __main__ so any path works.
"""
import io
import pickle
import numpy as np
from scipy.special import expit
class SniperModel:
"""
LightGBM wrapper that supports both focal-loss (raw booster) and
standard LGBMClassifier paths. Must exactly match the class as
defined in sniper_v7_1.py so joblib can round-trip correctly.
"""
def __init__(self, params: dict, use_focal: bool = False):
self.params = {k: v for k, v in params.items()}
self.use_focal = use_focal
self._booster = None
self._clf = None
def fit(self, X, y, sample_weight=None, X_val=None, y_val=None,
early_stopping_rounds=50):
import lightgbm as lgb
p = {k: v for k, v in self.params.items()}
n_est = int(p.pop("n_estimators", 500))
if self.use_focal:
p["num_threads"] = p.pop("n_jobs", -1)
p["seed"] = p.pop("random_state", 42)
for key in ("objective", "metric", "is_unbalance",
"scale_pos_weight", "verbosity"):
p.pop(key, None)
p["verbose"] = -1
y_arr = y.values if hasattr(y, "values") else np.asarray(y)
w_arr = sample_weight if sample_weight is not None else np.ones(len(y_arr))
ds = lgb.Dataset(X, label=y_arr, weight=w_arr)
valid_sets = []
callbacks = [lgb.log_evaluation(period=0)]
if X_val is not None:
y_val_arr = y_val.values if hasattr(y_val, "values") else np.asarray(y_val)
valid_sets = [lgb.Dataset(X_val, label=y_val_arr)]
callbacks.append(lgb.early_stopping(early_stopping_rounds, verbose=False))
from src.focal import focal_loss_objective, focal_loss_eval
p["objective"] = focal_loss_objective
self._booster = lgb.train(
p, ds, num_boost_round=n_est,
valid_sets=valid_sets or None,
feval=focal_loss_eval if valid_sets else None,
callbacks=callbacks,
)
else:
import lightgbm as lgb
p.setdefault("verbosity", -1)
clf_p = {k: v for k, v in p.items()}
self._clf = lgb.LGBMClassifier(n_estimators=n_est, **clf_p)
fit_kw = {"sample_weight": sample_weight}
if X_val is not None:
fit_kw["eval_set"] = [(X_val, y_val)]
fit_kw["callbacks"] = [
lgb.early_stopping(early_stopping_rounds, verbose=False),
lgb.log_evaluation(-1),
]
self._clf.fit(X, y, **fit_kw)
return self
def predict_proba(self, X) -> np.ndarray:
if self._booster is not None:
raw = self._booster.predict(X, raw_score=True)
p = expit(raw)
return np.column_stack([1.0 - p, p])
if self._clf is not None:
return self._clf.predict_proba(X)
raise RuntimeError("SniperModel has not been fitted.")
@property
def feature_importances_(self) -> np.ndarray:
if self._booster is not None:
return self._booster.feature_importance(importance_type="gain")
if self._clf is not None:
return self._clf.feature_importances_
raise RuntimeError("SniperModel has not been fitted.")
# ---------------------------------------------------------------------------
# Custom unpickler — redirects __main__.SniperModel to this module
# ---------------------------------------------------------------------------
class _SniperUnpickler(pickle.Unpickler):
"""
Redirects any class that was pickled as __main__.X to src.sniper_model.X
so loading works regardless of which script trained the model.
"""
_REMAP = {
("__main__", "SniperModel"): SniperModel,
}
def find_class(self, module, name):
key = (module, name)
if key in self._REMAP:
return self._REMAP[key]
return super().find_class(module, name)
def safe_load(path) -> object:
"""
Load a joblib/pickle file, transparently remapping __main__.SniperModel
to the local class definition.
joblib files are multi-stream pickle — we try joblib first (which
internally uses pickle), then fall back to our custom unpickler.
"""
import joblib
# --- Try joblib directly (works if the class is already importable) ---
try:
return joblib.load(path)
except (AttributeError, ModuleNotFoundError):
pass
# --- Inject SniperModel into __main__ so joblib's internal pickle
# loader can resolve __main__.SniperModel on the second attempt ---
import sys
main_mod = sys.modules.get("__main__")
_had_attr = hasattr(main_mod, "SniperModel") if main_mod else False
_old_val = getattr(main_mod, "SniperModel", None) if main_mod else None
try:
if main_mod is not None:
main_mod.SniperModel = SniperModel
return joblib.load(path)
except Exception:
pass
finally:
# Restore __main__ to its original state
if main_mod is not None:
if _had_attr:
main_mod.SniperModel = _old_val
else:
try:
delattr(main_mod, "SniperModel")
except AttributeError:
pass
# --- Last resort: raw pickle with custom Unpickler ---
# joblib files start with a joblib header — try reading as raw pickle
with open(path, "rb") as fh:
data = fh.read()
# joblib >=1.0 stores files as a zip of numpy arrays + pickle header.
# For pure-pickle joblib files (older style), this works directly.
try:
return _SniperUnpickler(io.BytesIO(data)).load()
except Exception as e:
raise RuntimeError(
f"Could not load {path} with any strategy. "
f"Ensure the model was saved with joblib and SniperModel is compatible. "
f"Original error: {e}"
)
def patch_main():
"""
Inject SniperModel into sys.modules['__main__'] once at app startup.
Call this before any joblib.load() calls.
This is the simplest and most reliable fix.
"""
import sys
main = sys.modules.get("__main__")
if main is not None and not hasattr(main, "SniperModel"):
main.SniperModel = SniperModel