Spaces:
Sleeping
Sleeping
File size: 7,180 Bytes
52cc99a 0269b4b 798c69b 52cc99a 0269b4b 52cc99a 798c69b 52cc99a 1b2b7c9 52cc99a 0269b4b 52cc99a 0269b4b 52cc99a 1b2b7c9 52cc99a 0269b4b 52cc99a 1b2b7c9 52cc99a 0269b4b 52cc99a 798c69b 52cc99a 1b2b7c9 52cc99a 1b2b7c9 52cc99a 798c69b 52cc99a 0269b4b 52cc99a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import pandas as pd
import numpy as np
import xgboost as xgb
import logging
import os
from pathlib import Path
import json
from time import perf_counter
from src.shared.config import DEFAULT_MODEL_METADATA_PATH, DEFAULT_MODEL_PATH, settings
from src.shared.schemas import PredictionRequest, PredictionResponse, ExplanationItem
from src.serving.monitoring import (
DEFAULT_INFERENCE_LOG_PATH,
append_jsonl_record,
build_inference_log_entry,
)
from src.training.data_loader import load_store_data
from src.training.features import (
apply_feature_pipeline,
build_feature_matrix
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
model: xgb.Booster | None = None
store_lookup: dict[int, dict[str, object]] = {}
model_version = "unknown"
def load_store_lookup() -> dict[int, dict[str, object]]:
"""Loads store-level metadata used to build prediction rows."""
store_df = load_store_data(settings.data.store_path).copy()
store_df["CompetitionDistance"] = store_df["CompetitionDistance"].fillna(100000)
for col in ["Promo2", "Promo2SinceWeek", "Promo2SinceYear"]:
store_df[col] = store_df[col].fillna(0).astype(int)
for col in ["StoreType", "Assortment"]:
store_df[col] = store_df[col].fillna("0").astype(str)
return store_df.set_index("Store").to_dict(orient="index")
def load_runtime_assets() -> None:
global model, model_version, store_lookup
store_lookup = load_store_lookup()
model_path = Path(os.environ.get("MODEL_PATH", str(DEFAULT_MODEL_PATH)))
if model_path.exists():
model = xgb.Booster()
model.load_model(str(model_path))
logger.info(f"Model loaded from {model_path}")
else:
logger.warning(f"Model not found at {model_path}. Predict endpoint will fail.")
metadata_path = Path(os.environ.get("MODEL_METADATA_PATH", str(DEFAULT_MODEL_METADATA_PATH)))
if metadata_path.exists():
with metadata_path.open("r", encoding="utf-8") as f:
model_version = json.load(f).get("model_version", "unknown")
else:
logger.warning("Model metadata not found at %s. Using unknown version.", metadata_path)
model_version = "unknown"
@asynccontextmanager
async def lifespan(_: FastAPI):
load_runtime_assets()
yield
def predict_with_model(loaded_model: object, X: pd.DataFrame) -> np.ndarray:
if isinstance(loaded_model, xgb.Booster):
return loaded_model.predict(xgb.DMatrix(X))
return loaded_model.predict(X)
def predict_contributions(loaded_model: object, X: pd.DataFrame) -> np.ndarray:
if isinstance(loaded_model, xgb.Booster):
return loaded_model.predict(xgb.DMatrix(X), pred_contribs=True)
if hasattr(loaded_model, "get_booster"):
return loaded_model.get_booster().predict(xgb.DMatrix(X), pred_contribs=True)
raise TypeError("Loaded model does not support feature contribution prediction")
app = FastAPI(
title=settings.model.name,
description=settings.model.description,
version="1.0.0",
lifespan=lifespan,
)
@app.get("/health")
def health():
return {"status": "healthy", "model_loaded": model is not None, "model_version": model_version}
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if request.Store not in store_lookup:
raise HTTPException(status_code=404, detail=f"Store {request.Store} not found in metadata")
try:
started_at = perf_counter()
store_meta = store_lookup[request.Store]
start_date = pd.to_datetime(request.Date)
dates = [start_date + pd.Timedelta(days=i) for i in range(request.ForecastDays)]
rows = []
for d in dates:
rows.append({
"Store": request.Store,
"Date": d,
"Promo": request.Promo,
"StateHoliday": request.StateHoliday,
"SchoolHoliday": request.SchoolHoliday,
"Assortment": store_meta["Assortment"],
"StoreType": store_meta["StoreType"],
"CompetitionDistance": store_meta["CompetitionDistance"],
"Promo2": store_meta["Promo2"],
"Promo2SinceWeek": store_meta["Promo2SinceWeek"],
"Promo2SinceYear": store_meta["Promo2SinceYear"],
"Open": 1
})
df = pd.DataFrame(rows)
df = apply_feature_pipeline(
df,
fourier_period=settings.pipeline.fourier_period,
fourier_order=settings.pipeline.fourier_order,
)
feature_cols = settings.data.features
X = build_feature_matrix(df, feature_cols)
y_log = predict_with_model(model, X)
y_sales = np.expm1(y_log)
contribs = predict_contributions(model, X)
avg_contribs = contribs[:, :-1].mean(axis=0)
explanation_items = []
impact_map = sorted(zip(feature_cols, avg_contribs), key=lambda x: abs(x[1]), reverse=True)[:5]
for name, score in impact_map:
explanation_items.append(ExplanationItem(
feature=name,
score=float(score),
formatted_val=f"{score:+.3f}"
))
forecast_result = []
for d, s in zip(dates, y_sales):
date_str = d.strftime("%Y-%m-%d")
sales_val = float(round(s, 2))
forecast_result.append({
"date": date_str,
"sales": sales_val
})
latency_ms = (perf_counter() - started_at) * 1000
append_jsonl_record(
DEFAULT_INFERENCE_LOG_PATH,
build_inference_log_entry(
store=request.Store,
start_date=request.Date,
forecast_days=request.ForecastDays,
promo=request.Promo,
state_holiday=request.StateHoliday,
school_holiday=request.SchoolHoliday,
model_version=model_version,
latency_ms=latency_ms,
),
)
logger.info(
"prediction_complete store=%s horizon=%s model_version=%s latency_ms=%.3f",
request.Store,
request.ForecastDays,
model_version,
latency_ms,
)
return PredictionResponse(
Store=request.Store,
Date=request.Date,
PredictedSales=float(y_sales[0]),
ModelVersion=model_version,
Explanation=explanation_items,
Forecast=forecast_result,
Status="success"
)
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/", response_class=HTMLResponse)
def index():
from web.frontend import get_frontend_html
return get_frontend_html()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|