File size: 7,180 Bytes
52cc99a
 
 
 
 
 
 
 
 
0269b4b
798c69b
52cc99a
0269b4b
52cc99a
798c69b
 
 
 
 
52cc99a
 
 
 
 
 
 
 
 
1b2b7c9
52cc99a
0269b4b
52cc99a
 
 
 
 
 
 
 
 
 
 
 
 
 
0269b4b
52cc99a
 
 
 
1b2b7c9
52cc99a
 
 
 
 
0269b4b
 
 
 
 
 
 
 
52cc99a
 
 
 
 
 
 
1b2b7c9
 
 
 
 
 
 
 
 
 
 
 
 
 
52cc99a
 
 
 
 
 
 
 
 
0269b4b
52cc99a
 
 
 
 
 
 
 
 
798c69b
52cc99a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b2b7c9
52cc99a
 
1b2b7c9
52cc99a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798c69b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52cc99a
 
 
 
0269b4b
52cc99a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import pandas as pd
import numpy as np
import xgboost as xgb
import logging
import os
from pathlib import Path
import json
from time import perf_counter

from src.shared.config import DEFAULT_MODEL_METADATA_PATH, DEFAULT_MODEL_PATH, settings
from src.shared.schemas import PredictionRequest, PredictionResponse, ExplanationItem
from src.serving.monitoring import (
    DEFAULT_INFERENCE_LOG_PATH,
    append_jsonl_record,
    build_inference_log_entry,
)
from src.training.data_loader import load_store_data
from src.training.features import (
    apply_feature_pipeline,
    build_feature_matrix
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

model: xgb.Booster | None = None
store_lookup: dict[int, dict[str, object]] = {}
model_version = "unknown"


def load_store_lookup() -> dict[int, dict[str, object]]:
    """Loads store-level metadata used to build prediction rows."""
    store_df = load_store_data(settings.data.store_path).copy()
    store_df["CompetitionDistance"] = store_df["CompetitionDistance"].fillna(100000)
    for col in ["Promo2", "Promo2SinceWeek", "Promo2SinceYear"]:
        store_df[col] = store_df[col].fillna(0).astype(int)
    for col in ["StoreType", "Assortment"]:
        store_df[col] = store_df[col].fillna("0").astype(str)
    return store_df.set_index("Store").to_dict(orient="index")


def load_runtime_assets() -> None:
    global model, model_version, store_lookup

    store_lookup = load_store_lookup()
    model_path = Path(os.environ.get("MODEL_PATH", str(DEFAULT_MODEL_PATH)))
    if model_path.exists():
        model = xgb.Booster()
        model.load_model(str(model_path))
        logger.info(f"Model loaded from {model_path}")
    else:
        logger.warning(f"Model not found at {model_path}. Predict endpoint will fail.")

    metadata_path = Path(os.environ.get("MODEL_METADATA_PATH", str(DEFAULT_MODEL_METADATA_PATH)))
    if metadata_path.exists():
        with metadata_path.open("r", encoding="utf-8") as f:
            model_version = json.load(f).get("model_version", "unknown")
    else:
        logger.warning("Model metadata not found at %s. Using unknown version.", metadata_path)
        model_version = "unknown"


@asynccontextmanager
async def lifespan(_: FastAPI):
    load_runtime_assets()
    yield


def predict_with_model(loaded_model: object, X: pd.DataFrame) -> np.ndarray:
    if isinstance(loaded_model, xgb.Booster):
        return loaded_model.predict(xgb.DMatrix(X))
    return loaded_model.predict(X)


def predict_contributions(loaded_model: object, X: pd.DataFrame) -> np.ndarray:
    if isinstance(loaded_model, xgb.Booster):
        return loaded_model.predict(xgb.DMatrix(X), pred_contribs=True)
    if hasattr(loaded_model, "get_booster"):
        return loaded_model.get_booster().predict(xgb.DMatrix(X), pred_contribs=True)
    raise TypeError("Loaded model does not support feature contribution prediction")


app = FastAPI(
    title=settings.model.name,
    description=settings.model.description,
    version="1.0.0",
    lifespan=lifespan,
)

@app.get("/health")
def health():
    return {"status": "healthy", "model_loaded": model is not None, "model_version": model_version}

@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    if request.Store not in store_lookup:
        raise HTTPException(status_code=404, detail=f"Store {request.Store} not found in metadata")

    try:
        started_at = perf_counter()
        store_meta = store_lookup[request.Store]
        start_date = pd.to_datetime(request.Date)
        dates = [start_date + pd.Timedelta(days=i) for i in range(request.ForecastDays)]

        rows = []
        for d in dates:
            rows.append({
                "Store": request.Store,
                "Date": d,
                "Promo": request.Promo,
                "StateHoliday": request.StateHoliday,
                "SchoolHoliday": request.SchoolHoliday,
                "Assortment": store_meta["Assortment"],
                "StoreType": store_meta["StoreType"],
                "CompetitionDistance": store_meta["CompetitionDistance"],
                "Promo2": store_meta["Promo2"],
                "Promo2SinceWeek": store_meta["Promo2SinceWeek"],
                "Promo2SinceYear": store_meta["Promo2SinceYear"],
                "Open": 1
            })

        df = pd.DataFrame(rows)
        df = apply_feature_pipeline(
            df,
            fourier_period=settings.pipeline.fourier_period,
            fourier_order=settings.pipeline.fourier_order,
        )

        feature_cols = settings.data.features
        X = build_feature_matrix(df, feature_cols)

        y_log = predict_with_model(model, X)
        y_sales = np.expm1(y_log)

        contribs = predict_contributions(model, X)
        avg_contribs = contribs[:, :-1].mean(axis=0)

        explanation_items = []
        impact_map = sorted(zip(feature_cols, avg_contribs), key=lambda x: abs(x[1]), reverse=True)[:5]
        for name, score in impact_map:
            explanation_items.append(ExplanationItem(
                feature=name,
                score=float(score),
                formatted_val=f"{score:+.3f}"
            ))

        forecast_result = []
        for d, s in zip(dates, y_sales):
            date_str = d.strftime("%Y-%m-%d")
            sales_val = float(round(s, 2))
            forecast_result.append({
                "date": date_str,
                "sales": sales_val
            })

        latency_ms = (perf_counter() - started_at) * 1000
        append_jsonl_record(
            DEFAULT_INFERENCE_LOG_PATH,
            build_inference_log_entry(
                store=request.Store,
                start_date=request.Date,
                forecast_days=request.ForecastDays,
                promo=request.Promo,
                state_holiday=request.StateHoliday,
                school_holiday=request.SchoolHoliday,
                model_version=model_version,
                latency_ms=latency_ms,
            ),
        )
        logger.info(
            "prediction_complete store=%s horizon=%s model_version=%s latency_ms=%.3f",
            request.Store,
            request.ForecastDays,
            model_version,
            latency_ms,
        )

        return PredictionResponse(
            Store=request.Store,
            Date=request.Date,
            PredictedSales=float(y_sales[0]),
            ModelVersion=model_version,
            Explanation=explanation_items,
            Forecast=forecast_result,
            Status="success"
        )

    except Exception as e:
        logger.error(f"Prediction error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/", response_class=HTMLResponse)
def index():
    from web.frontend import get_frontend_html
    return get_frontend_html()

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)