Spaces:
Build error
Build error
File size: 4,607 Bytes
ea6f215 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import os
import sys
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.pipeline import RossmannPipeline
from src.core import setup_logger
logger = setup_logger(__name__)
def visualize_model_performance():
"""
Evaluates the production model and generates visual reports.
"""
train_csv = os.path.abspath("data/raw/train.csv")
model_path = os.path.abspath("models/rossmann_production_model.pkl")
if not os.path.exists(model_path):
logger.error(f"Production model not found at {model_path}. Run production training first.")
return
logger.info("Initializing Visualization Pipeline...")
pipeline = RossmannPipeline(train_csv)
# Load model
with open(model_path, 'rb') as f:
pipeline.model = pickle.load(f)
# 1. Prepare Validation Data (Final month of data)
df_raw = pipeline.ingestor.ingest(train_csv)
df_raw['Date'] = pd.to_datetime(df_raw['Date'])
# Take latest 30 days for evaluation
max_date = df_raw['Date'].max()
eval_df = df_raw[df_raw['Date'] > (max_date - pd.Timedelta(days=30))]
logger.info(f"Evaluating on {len(eval_df)} records from {eval_df['Date'].min().date()} to {max_date.date()}")
df_feat = pipeline.run_feature_engineering(eval_df)
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
for col in ['StoreType', 'Assortment']:
if col in df_feat.columns:
df_feat[col] = le.fit_transform(df_feat[col].astype(str))
feature_cols = [
'Store', 'DayOfWeek', 'Promo', 'StateHoliday', 'SchoolHoliday',
'Year', 'Month', 'Day', 'IsWeekend', 'DayOfMonth',
'CompetitionDistance', 'CompetitionOpenTime', 'StoreType', 'Assortment'
] + [c for c in df_feat.columns if 'fourier' in c or 'easter' in c]
X_eval = df_feat[feature_cols].fillna(0)
y_eval = df_feat['target']
# 2. Generate Predictions
y_pred_log = pipeline.model.predict(X_eval)
y_pred = np.expm1(y_pred_log)
y_true = np.expm1(y_eval)
metrics = pipeline.evaluate(X_eval, y_eval)
logger.info(f"Evaluation Metrics: {metrics}")
# --- VISUALIZATION ---
os.makedirs('reports/figures', exist_ok=True)
sns.set(style='whitegrid', palette='muted')
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'sans-serif'] # Fix for potential glyph issues
plt.rcParams['axes.unicode_minus'] = False
# Plot 1: Actual vs Predicted (Time Series Sample)
plt.figure(figsize=(15, 6))
# Aggregate by date for a cleaner plot
eval_plot_df = pd.DataFrame({'Date': df_feat['Date'], 'Actual': y_true, 'Predicted': y_pred})
ts_agg = eval_plot_df.groupby('Date').mean()
plt.plot(ts_agg.index, ts_agg['Actual'], label='Actual Sales (Mean)', color='blue', alpha=0.7)
plt.plot(ts_agg.index, ts_agg['Predicted'], label='Predicted Sales (Mean)', color='red', linestyle='--', alpha=0.9)
plt.fill_between(ts_agg.index, ts_agg['Predicted']*0.9, ts_agg['Predicted']*1.1, color='red', alpha=0.1, label='10% error margin')
plt.title('XGBoost Model: Daily Sales Forecast vs Actuals', fontsize=15)
plt.xlabel('Date')
plt.ylabel('Sales')
plt.legend()
plt.savefig('reports/figures/actual_vs_predicted.png')
plt.close()
# Plot 2: Residual Distribution
plt.figure(figsize=(10, 6))
residuals = y_true - y_pred
sns.histplot(residuals, kde=True, color='purple')
plt.title('Residual Distribution (Forecast Error)', fontsize=15)
plt.xlabel('Residual (Actual - Predicted)')
plt.ylabel('Frequency')
plt.savefig('reports/figures/residuals.png')
plt.close()
# Plot 3: Feature Importance (Gains for XGBoost)
plt.figure(figsize=(12, 8))
importance = pd.DataFrame({
'Feature': feature_cols,
'Importance': pipeline.model.feature_importances_
}).sort_values('Importance', ascending=False)
sns.barplot(data=importance, x='Importance', y='Feature', palette='viridis', hue='Feature', legend=False)
plt.title('XGBoost Model Interpretability: Feature Importance (Gain)', fontsize=15)
plt.savefig('reports/figures/feature_importance.png')
plt.close()
logger.info("Visualizations saved to reports/figures/")
print(f"\nFinal SMAPE: {metrics['SMAPE']:.2f}%")
print(f"Final MAE: {metrics['MAE']:.2f}")
if __name__ == "__main__":
visualize_model_performance()
|