Spaces:
Running
Running
File size: 9,823 Bytes
3635bbe c3d660b 3635bbe 9e63502 3635bbe c3d660b 3635bbe c3d660b 3635bbe c3d660b 3635bbe c3d660b 3635bbe c3d660b 3635bbe c3d660b 3635bbe c3d660b 3635bbe c3d660b 92ff78b 205fba1 92ff78b c3d660b 3635bbe c3d660b 3635bbe c3d660b 3635bbe e48acb8 3635bbe c3d660b 764e253 c3d660b 3635bbe c3d660b 764e253 c3d660b 3635bbe 4f9bbec 3635bbe c3d660b 3635bbe c3d660b 3635bbe c3d660b 3635bbe c3d660b 3635bbe e48acb8 c3d660b e48acb8 c3d660b e48acb8 c3d660b 764e253 3635bbe c3d660b 764e253 c3d660b 3635bbe 764e253 3635bbe c3d660b 3635bbe c3d660b 78005e1 c3d660b fc9ccdf c3d660b 764e253 3635bbe e48acb8 c3d660b 51b1929 c3d660b 764e253 c3d660b 764e253 c3d660b 764e253 c3d660b 764e253 c3d660b c37cd10 764e253 c3d660b 764e253 c37cd10 c3d660b 764e253 c3d660b 764e253 c37cd10 e48acb8 c37cd10 c3d660b e48acb8 c37cd10 332859c 3635bbe 764e253 e48acb8 4f9bbec c3d660b 3635bbe 764e253 c3d660b 764e253 3635bbe c3d660b 3635bbe c3d660b 3635bbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
import pandas as pd
import numpy as np
import os
import sys
import pickle
# Add project root to path for imports if running from src
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.config import global_config
from src.pipeline import RossmannPipeline
from src.frontend import FRONTEND_HTML
from src.core import setup_logger
logger = setup_logger(__name__)
app = FastAPI(
title=global_config.model.name,
description=global_config.model.description,
version="2.0.0",
)
# Global variables
pipeline = None
store_metadata = None
@app.on_event("startup")
def startup_event():
global pipeline, store_metadata
logger.info("Starting up application...")
# 1. Load Model
model_path = os.path.abspath("models/rossmann_production_model.pkl")
logger.info(f"Looking for model at: {model_path}")
logger.info(f"Model file exists: {os.path.exists(model_path)}")
if not os.path.exists(model_path):
logger.error(f"Model not found at {model_path}. Application will not work!")
# 2. Initialize Pipeline
pipeline = RossmannPipeline(global_config.data.archive_path)
if os.path.exists(model_path):
try:
with open(model_path, "rb") as f:
pipeline.model = pickle.load(f)
# Fix for version compatibility - remove gpu_id
try:
if hasattr(pipeline.model, "gpu_id"):
pipeline.model = pipeline.model.set_params(gpu_id=None)
if hasattr(pipeline.model, "device"):
pipeline.model = pipeline.model.set_params(device="cpu")
# Force using hist tree method
pipeline.model = pipeline.model.set_params(tree_method="hist")
# Log model params
logger.info(
f"Model params: n_estimators={getattr(pipeline.model, 'n_estimators', 'N/A')}"
)
logger.info(
f"Model params: max_depth={getattr(pipeline.model, 'max_depth', 'N/A')}"
)
logger.info(
f"Model params: learning_rate={getattr(pipeline.model, 'learning_rate', 'N/A')}"
)
logger.info("Model parameters fixed for compatibility")
except Exception as e:
logger.warning(f"Could not fix model params: {e}")
logger.info(f"Model loaded successfully. Type: {type(pipeline.model)}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
pipeline.model = None
else:
pipeline.model = None
# 3. Load Store Metadata
store_path = global_config.data.store_path
if store_path and os.path.exists(store_path):
store_metadata = pd.read_csv(store_path)
logger.info(f"Store metadata loaded from {store_path}")
class PredictionRequest(BaseModel):
Store: int
Date: str
Promo: int
StateHoliday: str
SchoolHoliday: int
Assortment: str
StoreType: str
CompetitionDistance: int
ForecastDays: int = 1 # Horizon
class ExplanationItem(BaseModel):
feature: str
impact: float
formatted_val: str
class PredictionResponse(BaseModel):
Store: int
Date: str
PredictedSales: float
ConfidenceInterval: list # [lower, upper]
Explanation: list[ExplanationItem] = []
Forecast: list = [] # List of {date: str, sales: float}
Status: str
DebugInfo: dict = {}
@app.get("/", response_class=HTMLResponse)
def read_root():
return FRONTEND_HTML
@app.get("/health")
def health_check():
return {
"status": "healthy",
"model_loaded": pipeline is not None and pipeline.model is not None,
"config_name": global_config.model.name,
}
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
if not pipeline or not pipeline.model:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# 1. Generate Date Range (Batch Prediction)
start_date = pd.to_datetime(request.Date)
dates = [start_date + pd.Timedelta(days=i) for i in range(request.ForecastDays)]
# 2. Prepare Input Batch
rows = []
for d in dates:
rows.append(
{
"Store": request.Store,
"Date": d,
"Promo": request.Promo,
"StateHoliday": request.StateHoliday,
"SchoolHoliday": request.SchoolHoliday,
"Assortment": request.Assortment,
"StoreType": request.StoreType,
"CompetitionDistance": request.CompetitionDistance,
"Open": 1,
}
)
input_data = pd.DataFrame(rows)
# 3. Features
processed_df = pipeline.run_feature_engineering(input_data)
# 4. Encoding
if "StoreType" in processed_df.columns:
processed_df["StoreType"] = (
processed_df["StoreType"]
.astype(str)
.map({"a": 1, "b": 2, "c": 3, "d": 4})
.fillna(0)
)
if "Assortment" in processed_df.columns:
processed_df["Assortment"] = (
processed_df["Assortment"]
.astype(str)
.map({"a": 1, "b": 2, "c": 3})
.fillna(0)
)
# 5. Selection
feature_cols = [
"Store",
"DayOfWeek",
"Promo",
"StateHoliday",
"SchoolHoliday",
"Year",
"Month",
"Day",
"IsWeekend",
"DayOfMonth",
"CompetitionDistance",
"StoreType",
"Assortment",
]
for i in range(1, 6):
feature_cols.extend([f"fourier_sin_{i}", f"fourier_cos_{i}"])
feature_cols.append("days_to_easter")
feature_cols.append("easter_effect")
X = pd.DataFrame()
for c in feature_cols:
if c in processed_df.columns:
val = processed_df[c]
# Robustness: Cap Year to training range (2013-2015)
if c == "Year":
val = val.clip(upper=2015)
X[c] = val
else:
X[c] = 0
# Ensure numeric types
X = X.apply(pd.to_numeric, errors="coerce").fillna(0)
# 6. Predict & Explain
# Standard Prediction
y_log = pipeline.model.predict(X)
y_sales = np.expm1(y_log)
# DEBUG LOGGING
print(f">>> DEBUG: X shape={X.shape}")
print(f">>> DEBUG: X row 0={X.iloc[0].to_dict()}")
print(f">>> DEBUG: Raw Log Pred row 0={y_log[0]:.4f}")
logger.info(f"Target Pred: {y_sales[0]:.2f}")
# Explanations (for first day)
import xgboost as xgb
X_first = X.iloc[[0]]
dmat = xgb.DMatrix(X_first, feature_names=feature_cols)
booster = pipeline.model.get_booster()
contribs = booster.predict(dmat, pred_contribs=True)[0]
feature_impacts = contribs[:-1]
explanation_items = []
indicators = list(zip(feature_cols, feature_impacts))
indicators.sort(key=lambda x: abs(x[1]), reverse=True)
for name, log_impact in indicators[:6]:
uplift_pct = (np.exp(log_impact) - 1) * 100
mapping = {
"Promo": "Promotion Lift",
"CompetitionDistance": "Local Competition",
"IsWeekend": "Weekend Traffic",
"Month": "Seasonal Factor",
"StateHoliday": "Holiday Impact",
"SchoolHoliday": "School Schedule",
"Year": "Annual Growth",
"fourier_sin_1": "Core Seaonality",
}
display_name = mapping.get(name, name)
if "fourier" in name:
display_name = "Seasonality"
explanation_items.append(
ExplanationItem(
feature=display_name,
impact=uplift_pct,
formatted_val=f"{uplift_pct:+.1f}%",
)
)
# 7. Formatting Forecast with Confidence Bands
# RMSPE is ~12%, so 95% CI is approx +/- 23.5%
forecast_result = []
for d, s in zip(dates, y_sales):
# Precision: use round(s, 2) but the UI might round again
forecast_result.append(
{
"date": d.strftime("%Y-%m-%d"),
"sales": float(round(s, 2)),
"lb": float(round(s * 0.85, 2)),
"ub": float(round(s * 1.15, 2)),
}
)
# Global KPI bounds
lower_bound = y_sales[0] * 0.85
upper_bound = y_sales[0] * 1.15
return PredictionResponse(
Store=request.Store,
Date=request.Date,
PredictedSales=float(y_sales[0]),
ConfidenceInterval=[float(lower_bound), float(upper_bound)],
Explanation=explanation_items,
Forecast=forecast_result,
Status="success",
DebugInfo={"y_log": float(y_log[0]), "X_row0": X.iloc[0].to_dict()},
)
except Exception as e:
logger.error(f"Prediction error: {e}")
# Return fallback error, but we want to see traceback in logs
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|