ymlin105's picture
fix: pin xgboost to 3.1.0 to resolve prediction discrepancy and cleanup code
f1f65bd
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
import pandas as pd
import numpy as np
import os
import sys
import pickle
# Add project root to path for imports if running from src
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.config import global_config
from src.pipeline import RossmannPipeline
from src.frontend import FRONTEND_HTML
from src.core import setup_logger, settings
logger = setup_logger(__name__)
app = FastAPI(
title=global_config.model.name,
description=global_config.model.description,
version="2.0.0"
)
# Global variables
pipeline = None
store_metadata = None
@app.on_event("startup")
def startup_event():
global pipeline, store_metadata
logger.info("Starting up application...")
# 1. Load Model
# Assuming the model is saved in models/rossmann_production_model.pkl as per old main.py
# or we can train one if missing (but sticking to serving existing model for refactor)
model_path = os.path.abspath("models/rossmann_production_model.pkl")
if not os.path.exists(model_path):
logger.warning(f"Model not found at {model_path}. Application may not work until trained.")
# 2. Initialize Pipeline
# We use the configured archive path (train.csv or schema) to init the pipeline components
pipeline = RossmannPipeline(global_config.data.archive_path)
if os.path.exists(model_path):
with open(model_path, 'rb') as f:
pipeline.model = pickle.load(f)
logger.info("Model loaded successfully.")
# 3. Load Store Metadata (for Open/Promo2 checks if needed, or simple merging)
store_path = global_config.data.store_path
if store_path and os.path.exists(store_path):
store_metadata = pd.read_csv(store_path)
logger.info(f"Store metadata loaded from {store_path}")
class PredictionRequest(BaseModel):
Store: int
Date: str
Promo: int
StateHoliday: str
SchoolHoliday: int
Assortment: str
StoreType: str
CompetitionDistance: int
ForecastDays: int = 1 # Horizon
class ExplanationItem(BaseModel):
feature: str
impact: float
formatted_val: str
class PredictionResponse(BaseModel):
Store: int
Date: str
PredictedSales: float
ConfidenceInterval: list # [lower, upper]
Explanation: list[ExplanationItem] = []
Forecast: list = [] # List of {date: str, sales: float}
Status: str
DebugInfo: dict = {}
@app.get("/", response_class=HTMLResponse)
def read_root():
return FRONTEND_HTML
@app.get("/health")
def health_check():
return {
"status": "healthy",
"model_loaded": pipeline is not None and pipeline.model is not None,
"config_name": global_config.model.name
}
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
if not pipeline or not pipeline.model:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# 1. Generate Date Range (Batch Prediction)
start_date = pd.to_datetime(request.Date)
dates = [start_date + pd.Timedelta(days=i) for i in range(request.ForecastDays)]
# 2. Prepare Input Batch
rows = []
for d in dates:
rows.append({
'Store': request.Store,
'Date': d,
'Promo': request.Promo,
'StateHoliday': request.StateHoliday,
'SchoolHoliday': request.SchoolHoliday,
'Assortment': request.Assortment,
'StoreType': request.StoreType,
'CompetitionDistance': request.CompetitionDistance,
'Open': 1
})
input_data = pd.DataFrame(rows)
# 3. Features
processed_df = pipeline.run_feature_engineering(input_data)
# 4. Encoding
if 'StoreType' in processed_df.columns:
processed_df['StoreType'] = processed_df['StoreType'].astype(str).map({'a':1, 'b':2, 'c':3, 'd':4}).fillna(0)
if 'Assortment' in processed_df.columns:
processed_df['Assortment'] = processed_df['Assortment'].astype(str).map({'a':1, 'b':2, 'c':3}).fillna(0)
# 5. Selection
feature_cols = [
'Store', 'DayOfWeek', 'Promo', 'StateHoliday', 'SchoolHoliday',
'Year', 'Month', 'Day', 'IsWeekend', 'DayOfMonth',
'CompetitionDistance', 'StoreType', 'Assortment'
]
for i in range(1, 6):
feature_cols.extend([f'fourier_sin_{i}', f'fourier_cos_{i}'])
feature_cols.append('days_to_easter')
feature_cols.append('easter_effect')
X = pd.DataFrame()
for c in feature_cols:
if c in processed_df.columns:
val = processed_df[c]
# Robustness: Cap Year to training range (2013-2015)
if c == 'Year':
val = val.clip(upper=2015)
X[c] = val
else:
X[c] = 0
# Ensure numeric types
X = X.apply(pd.to_numeric, errors='coerce').fillna(0)
# 6. Predict & Explain
# Standard Prediction
y_log = pipeline.model.predict(X)
y_sales = np.expm1(y_log)
# DEBUG LOGGING
print(f">>> DEBUG: X shape={X.shape}")
print(f">>> DEBUG: X row 0={X.iloc[0].to_dict()}")
print(f">>> DEBUG: Raw Log Pred row 0={y_log[0]:.4f}")
logger.info(f"Target Pred: {y_sales[0]:.2f}")
# Explanations (for first day)
import xgboost as xgb
X_first = X.iloc[[0]]
dmat = xgb.DMatrix(X_first, feature_names=feature_cols)
booster = pipeline.model.get_booster()
contribs = booster.predict(dmat, pred_contribs=True)[0]
feature_impacts = contribs[:-1]
explanation_items = []
indicators = list(zip(feature_cols, feature_impacts))
indicators.sort(key=lambda x: abs(x[1]), reverse=True)
for name, log_impact in indicators[:6]:
uplift_pct = (np.exp(log_impact) - 1) * 100
mapping = {
"Promo": "Promotion Lift",
"CompetitionDistance": "Local Competition",
"IsWeekend": "Weekend Traffic",
"Month": "Seasonal Factor",
"StateHoliday": "Holiday Impact",
"SchoolHoliday": "School Schedule",
"Year": "Annual Growth",
"fourier_sin_1": "Core Seaonality"
}
display_name = mapping.get(name, name)
if "fourier" in name: display_name = "Seasonality"
explanation_items.append(ExplanationItem(
feature=display_name,
impact=uplift_pct,
formatted_val=f"{uplift_pct:+.1f}%"
))
# 7. Formatting Forecast with Confidence Bands
# RMSPE is ~12%, so 95% CI is approx +/- 23.5%
forecast_result = []
for d, s in zip(dates, y_sales):
# Precision: use round(s, 2) but the UI might round again
forecast_result.append({
"date": d.strftime('%Y-%m-%d'),
"sales": float(round(s, 2)),
"lb": float(round(s * 0.85, 2)),
"ub": float(round(s * 1.15, 2))
})
# Global KPI bounds
lower_bound = y_sales[0] * 0.85
upper_bound = y_sales[0] * 1.15
return PredictionResponse(
Store=request.Store,
Date=request.Date,
PredictedSales=float(y_sales[0]),
ConfidenceInterval=[float(lower_bound), float(upper_bound)],
Explanation=explanation_items,
Forecast=forecast_result,
Status="success",
DebugInfo={
"y_log": float(y_log[0]),
"X_row0": X.iloc[0].to_dict()
}
)
except Exception as e:
logger.error(f"Prediction error: {e}")
# Return fallback error, but we want to see traceback in logs
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)