github-actions[bot]
Automated deployment to Hugging Face
89ca667
import os
import logging
import pandas as pd
import numpy as np
import lightgbm as lgb
import joblib
from pathlib import Path
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def train_lightgbm(parquet_path: str, model_dir: str):
"""
Trains a LightGBM model on the engineered features.
Simulates training a scalable ML model for multiple SKU/Stores.
"""
logging.info(f"Loading features from {parquet_path}...")
df = pd.read_parquet(parquet_path)
# Sort temporally for time-based split
df = df.sort_values('date')
# Convert identifiers to categorical so LightGBM learns store/SKU hierarchies natively
categorical_cols = ['store_nbr', 'family', 'city', 'state', 'store_type', 'cluster', 'is_holiday']
for col in categorical_cols:
if col in df.columns:
df[col] = df[col].astype('category')
# Define features and target
features = [
'store_nbr', 'family', 'city', 'state', 'store_type', 'cluster',
'onpromotion', 'month', 'day_of_week', 'day_of_year', 'is_weekend', 'is_holiday',
'dcoilwtico',
'sales_lag_1', 'sales_lag_7', 'sales_lag_28',
'transactions_lag_1', 'transactions_lag_7',
'rolling_mean_7', 'rolling_std_7', 'rolling_mean_28'
]
target = 'sales'
# Time-based Train/Test Split (e.g., last 30 days for validation)
split_date = df['date'].max() - pd.Timedelta(days=30)
train = df[df['date'] <= split_date]
valid = df[df['date'] > split_date]
X_train, y_train = train[features], train[target]
X_valid, y_valid = valid[features], valid[target]
logging.info(f"Training LightGBM with {len(X_train)} rows and {len(X_valid)} validation rows.")
# LightGBM setup
train_data = lgb.Dataset(X_train, label=y_train)
valid_data = lgb.Dataset(X_valid, label=y_valid, reference=train_data)
# 1. Train Model 1: Mean Point Forecast (Tweedie Loss)
# Tweedie is the SOTA objective for retail/FMCG because it natively handles
# both high-volume (Beverages) and low-volume/zero-inflated data (Automotive).
params_point = {
'objective': 'tweedie',
'tweedie_variance_power': 1.5, # 1.5 = Compound Poisson-Gamma (SOTA for extreme zero-inflation)
'metric': 'rmse',
'boosting_type': 'gbdt',
'learning_rate': 0.05,
'num_leaves': 127, # Increased so the global model can learn more complex per-SKU rules
'min_data_in_leaf': 10, # Lowered from 20 to allow the AI to learn from rare, spiky events
'max_bin': 511, # Increased from 255 to allow finer splits on numeric features (like rolling stats)
'feature_fraction': 0.8,
'seed': 42,
'verbose': -1
}
logging.info("Training LightGBM Model 1 (Tweedie Mean Forecast)...")
model_median = lgb.train(
params_point,
train_data,
num_boost_round=500,
valid_sets=[train_data, valid_data],
callbacks=[lgb.early_stopping(stopping_rounds=50)]
)
# 2. Train Model 2: 95th Quantile Forecast for AI-driven Safety Stock
params_q95 = params_point.copy()
params_q95['objective'] = 'quantile'
params_q95['alpha'] = 0.95
params_q95['metric'] = 'quantile'
logging.info("Training LightGBM Model 2 (95th Quantile Forecast)...")
model_q95 = lgb.train(
params_q95,
train_data,
num_boost_round=500,
valid_sets=[train_data, valid_data],
callbacks=[lgb.early_stopping(stopping_rounds=50)]
)
# Evaluation
preds = model_median.predict(X_valid)
preds = np.maximum(0, preds) # Ensure no negative predictions
mae = mean_absolute_error(y_valid, preds)
rmse = np.sqrt(mean_squared_error(y_valid, preds))
r2 = r2_score(y_valid, preds)
# Mathematically robust SMAPE to handle Zero-Sales days
denominator = (np.abs(y_valid) + np.abs(preds)) / 2.0
smape = np.mean(np.where(denominator == 0, 0.0, np.abs(y_valid - preds) / denominator)) * 100
logging.info(f"Validation MAE: {mae:.2f}")
logging.info(f"Validation RMSE: {rmse:.2f}")
logging.info(f"Validation R² (Accuracy): {r2:.3f}")
logging.info(f"Validation (S)MAPE: {smape:.2f}%")
# Save Model
model_path = os.path.join(model_dir, 'lgb_model.pkl')
joblib.dump(model_median, model_path) # Main model used by the UI Explainability
logging.info(f"Model saved to {model_path}")
model_q95_path = os.path.join(model_dir, 'lgb_model_q95.pkl')
joblib.dump(model_q95, model_q95_path)
logging.info(f"Quantile Model saved to {model_q95_path}")
# In a real environment, you might also train statsmodels SARIMA here per SKU
# But for millions of rows, LightGBM on lags is vastly more efficient globally.
if __name__ == "__main__":
project_dir = Path(__file__).resolve().parents[2]
parquet_path = os.path.join(project_dir, "data", "processed", "features.parquet")
model_dir = os.path.join(project_dir, "src", "models")
if not os.path.exists(parquet_path):
logging.error("Features Parquet not found. Please run build_features.py first.")
else:
train_lightgbm(parquet_path, model_dir)