ymlin105 commited on
Commit
1b2b7c9
·
1 Parent(s): fc6fdcf

fix: load serving model with xgboost booster

Browse files
Files changed (1) hide show
  1. src/serving/api.py +18 -4
src/serving/api.py CHANGED
@@ -26,7 +26,7 @@ from src.training.features import (
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
28
 
29
- model = None
30
  store_lookup: dict[int, dict[str, object]] = {}
31
  model_version = "unknown"
32
 
@@ -48,7 +48,7 @@ def load_runtime_assets() -> None:
48
  store_lookup = load_store_lookup()
49
  model_path = Path(os.environ.get("MODEL_PATH", str(DEFAULT_MODEL_PATH)))
50
  if model_path.exists():
51
- model = xgb.XGBRegressor()
52
  model.load_model(str(model_path))
53
  logger.info(f"Model loaded from {model_path}")
54
  else:
@@ -69,6 +69,20 @@ async def lifespan(_: FastAPI):
69
  yield
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  app = FastAPI(
73
  title=settings.model.name,
74
  description=settings.model.description,
@@ -120,10 +134,10 @@ def predict(request: PredictionRequest):
120
  feature_cols = settings.data.features
121
  X = build_feature_matrix(df, feature_cols)
122
 
123
- y_log = model.predict(X)
124
  y_sales = np.expm1(y_log)
125
 
126
- contribs = model.get_booster().predict(xgb.DMatrix(X), pred_contribs=True)
127
  avg_contribs = contribs[:, :-1].mean(axis=0)
128
 
129
  explanation_items = []
 
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
28
 
29
+ model: xgb.Booster | None = None
30
  store_lookup: dict[int, dict[str, object]] = {}
31
  model_version = "unknown"
32
 
 
48
  store_lookup = load_store_lookup()
49
  model_path = Path(os.environ.get("MODEL_PATH", str(DEFAULT_MODEL_PATH)))
50
  if model_path.exists():
51
+ model = xgb.Booster()
52
  model.load_model(str(model_path))
53
  logger.info(f"Model loaded from {model_path}")
54
  else:
 
69
  yield
70
 
71
 
72
+ def predict_with_model(loaded_model: object, X: pd.DataFrame) -> np.ndarray:
73
+ if isinstance(loaded_model, xgb.Booster):
74
+ return loaded_model.predict(xgb.DMatrix(X))
75
+ return loaded_model.predict(X)
76
+
77
+
78
+ def predict_contributions(loaded_model: object, X: pd.DataFrame) -> np.ndarray:
79
+ if isinstance(loaded_model, xgb.Booster):
80
+ return loaded_model.predict(xgb.DMatrix(X), pred_contribs=True)
81
+ if hasattr(loaded_model, "get_booster"):
82
+ return loaded_model.get_booster().predict(xgb.DMatrix(X), pred_contribs=True)
83
+ raise TypeError("Loaded model does not support feature contribution prediction")
84
+
85
+
86
  app = FastAPI(
87
  title=settings.model.name,
88
  description=settings.model.description,
 
134
  feature_cols = settings.data.features
135
  X = build_feature_matrix(df, feature_cols)
136
 
137
+ y_log = predict_with_model(model, X)
138
  y_sales = np.expm1(y_log)
139
 
140
+ contribs = predict_contributions(model, X)
141
  avg_contribs = contribs[:, :-1].mean(axis=0)
142
 
143
  explanation_items = []