ymlin105's picture
fix: load serving model with xgboost booster
1b2b7c9
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import pandas as pd
import numpy as np
import xgboost as xgb
import logging
import os
from pathlib import Path
import json
from time import perf_counter
from src.shared.config import DEFAULT_MODEL_METADATA_PATH, DEFAULT_MODEL_PATH, settings
from src.shared.schemas import PredictionRequest, PredictionResponse, ExplanationItem
from src.serving.monitoring import (
DEFAULT_INFERENCE_LOG_PATH,
append_jsonl_record,
build_inference_log_entry,
)
from src.training.data_loader import load_store_data
from src.training.features import (
apply_feature_pipeline,
build_feature_matrix
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
model: xgb.Booster | None = None
store_lookup: dict[int, dict[str, object]] = {}
model_version = "unknown"
def load_store_lookup() -> dict[int, dict[str, object]]:
"""Loads store-level metadata used to build prediction rows."""
store_df = load_store_data(settings.data.store_path).copy()
store_df["CompetitionDistance"] = store_df["CompetitionDistance"].fillna(100000)
for col in ["Promo2", "Promo2SinceWeek", "Promo2SinceYear"]:
store_df[col] = store_df[col].fillna(0).astype(int)
for col in ["StoreType", "Assortment"]:
store_df[col] = store_df[col].fillna("0").astype(str)
return store_df.set_index("Store").to_dict(orient="index")
def load_runtime_assets() -> None:
global model, model_version, store_lookup
store_lookup = load_store_lookup()
model_path = Path(os.environ.get("MODEL_PATH", str(DEFAULT_MODEL_PATH)))
if model_path.exists():
model = xgb.Booster()
model.load_model(str(model_path))
logger.info(f"Model loaded from {model_path}")
else:
logger.warning(f"Model not found at {model_path}. Predict endpoint will fail.")
metadata_path = Path(os.environ.get("MODEL_METADATA_PATH", str(DEFAULT_MODEL_METADATA_PATH)))
if metadata_path.exists():
with metadata_path.open("r", encoding="utf-8") as f:
model_version = json.load(f).get("model_version", "unknown")
else:
logger.warning("Model metadata not found at %s. Using unknown version.", metadata_path)
model_version = "unknown"
@asynccontextmanager
async def lifespan(_: FastAPI):
load_runtime_assets()
yield
def predict_with_model(loaded_model: object, X: pd.DataFrame) -> np.ndarray:
if isinstance(loaded_model, xgb.Booster):
return loaded_model.predict(xgb.DMatrix(X))
return loaded_model.predict(X)
def predict_contributions(loaded_model: object, X: pd.DataFrame) -> np.ndarray:
if isinstance(loaded_model, xgb.Booster):
return loaded_model.predict(xgb.DMatrix(X), pred_contribs=True)
if hasattr(loaded_model, "get_booster"):
return loaded_model.get_booster().predict(xgb.DMatrix(X), pred_contribs=True)
raise TypeError("Loaded model does not support feature contribution prediction")
app = FastAPI(
title=settings.model.name,
description=settings.model.description,
version="1.0.0",
lifespan=lifespan,
)
@app.get("/health")
def health():
return {"status": "healthy", "model_loaded": model is not None, "model_version": model_version}
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if request.Store not in store_lookup:
raise HTTPException(status_code=404, detail=f"Store {request.Store} not found in metadata")
try:
started_at = perf_counter()
store_meta = store_lookup[request.Store]
start_date = pd.to_datetime(request.Date)
dates = [start_date + pd.Timedelta(days=i) for i in range(request.ForecastDays)]
rows = []
for d in dates:
rows.append({
"Store": request.Store,
"Date": d,
"Promo": request.Promo,
"StateHoliday": request.StateHoliday,
"SchoolHoliday": request.SchoolHoliday,
"Assortment": store_meta["Assortment"],
"StoreType": store_meta["StoreType"],
"CompetitionDistance": store_meta["CompetitionDistance"],
"Promo2": store_meta["Promo2"],
"Promo2SinceWeek": store_meta["Promo2SinceWeek"],
"Promo2SinceYear": store_meta["Promo2SinceYear"],
"Open": 1
})
df = pd.DataFrame(rows)
df = apply_feature_pipeline(
df,
fourier_period=settings.pipeline.fourier_period,
fourier_order=settings.pipeline.fourier_order,
)
feature_cols = settings.data.features
X = build_feature_matrix(df, feature_cols)
y_log = predict_with_model(model, X)
y_sales = np.expm1(y_log)
contribs = predict_contributions(model, X)
avg_contribs = contribs[:, :-1].mean(axis=0)
explanation_items = []
impact_map = sorted(zip(feature_cols, avg_contribs), key=lambda x: abs(x[1]), reverse=True)[:5]
for name, score in impact_map:
explanation_items.append(ExplanationItem(
feature=name,
score=float(score),
formatted_val=f"{score:+.3f}"
))
forecast_result = []
for d, s in zip(dates, y_sales):
date_str = d.strftime("%Y-%m-%d")
sales_val = float(round(s, 2))
forecast_result.append({
"date": date_str,
"sales": sales_val
})
latency_ms = (perf_counter() - started_at) * 1000
append_jsonl_record(
DEFAULT_INFERENCE_LOG_PATH,
build_inference_log_entry(
store=request.Store,
start_date=request.Date,
forecast_days=request.ForecastDays,
promo=request.Promo,
state_holiday=request.StateHoliday,
school_holiday=request.SchoolHoliday,
model_version=model_version,
latency_ms=latency_ms,
),
)
logger.info(
"prediction_complete store=%s horizon=%s model_version=%s latency_ms=%.3f",
request.Store,
request.ForecastDays,
model_version,
latency_ms,
)
return PredictionResponse(
Store=request.Store,
Date=request.Date,
PredictedSales=float(y_sales[0]),
ModelVersion=model_version,
Explanation=explanation_items,
Forecast=forecast_result,
Status="success"
)
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/", response_class=HTMLResponse)
def index():
from web.frontend import get_frontend_html
return get_frontend_html()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)