|
|
"""
|
|
|
Demand Prediction System - Prediction Script
|
|
|
|
|
|
This script loads a trained model and makes demand predictions for products
|
|
|
on future dates. Supports both ML models and time-series models (ARIMA, Prophet).
|
|
|
|
|
|
Usage (ML Models):
|
|
|
python predict.py --product_id 1 --date 2024-01-15 --price 100 --discount 10 --category Electronics
|
|
|
|
|
|
Usage (Time-Series Models - overall demand):
|
|
|
python predict.py --date 2024-01-15 --model_type timeseries
|
|
|
"""
|
|
|
|
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
import joblib
|
|
|
import json
|
|
|
import argparse
|
|
|
from datetime import datetime
|
|
|
import os
|
|
|
import warnings
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
|
|
|
|
MODEL_DIR = 'models'
|
|
|
MODEL_PATH = f'{MODEL_DIR}/best_model.joblib'
|
|
|
TS_MODEL_PATH = f'{MODEL_DIR}/best_timeseries_model.joblib'
|
|
|
PREPROCESSING_PATH = f'{MODEL_DIR}/preprocessing.joblib'
|
|
|
METADATA_PATH = f'{MODEL_DIR}/model_metadata.json'
|
|
|
ALL_MODELS_METADATA_PATH = f'{MODEL_DIR}/all_models_metadata.json'
|
|
|
|
|
|
|
|
|
def load_model_and_preprocessing(model_type='auto'):
|
|
|
"""
|
|
|
Load the trained model and preprocessing objects.
|
|
|
|
|
|
Args:
|
|
|
model_type: 'ml', 'timeseries', or 'auto' (auto-detect best model)
|
|
|
|
|
|
Returns:
|
|
|
tuple: (model, preprocessing_data, model_name, is_timeseries)
|
|
|
"""
|
|
|
|
|
|
if os.path.exists(ALL_MODELS_METADATA_PATH):
|
|
|
with open(ALL_MODELS_METADATA_PATH, 'r') as f:
|
|
|
all_metadata = json.load(f)
|
|
|
best_model_name = all_metadata.get('best_model', 'Unknown')
|
|
|
else:
|
|
|
best_model_name = None
|
|
|
|
|
|
|
|
|
if model_type == 'auto':
|
|
|
if best_model_name in ['ARIMA', 'Prophet']:
|
|
|
model_type = 'timeseries'
|
|
|
else:
|
|
|
model_type = 'ml'
|
|
|
|
|
|
is_timeseries = (model_type == 'timeseries')
|
|
|
|
|
|
if is_timeseries:
|
|
|
|
|
|
if not os.path.exists(TS_MODEL_PATH):
|
|
|
raise FileNotFoundError(
|
|
|
f"Time-series model not found at {TS_MODEL_PATH}. Please run train_model.py first."
|
|
|
)
|
|
|
|
|
|
print("Loading time-series model...")
|
|
|
model = joblib.load(TS_MODEL_PATH)
|
|
|
preprocessing_data = None
|
|
|
|
|
|
if best_model_name:
|
|
|
print(f"Model: {best_model_name}")
|
|
|
if best_model_name in all_metadata.get('all_models', {}):
|
|
|
metrics = all_metadata['all_models'][best_model_name]
|
|
|
print(f"R2 Score: {metrics.get('r2', 'N/A'):.4f}")
|
|
|
|
|
|
return model, preprocessing_data, best_model_name or 'Time-Series', True
|
|
|
else:
|
|
|
|
|
|
if not os.path.exists(MODEL_PATH):
|
|
|
raise FileNotFoundError(
|
|
|
f"ML model not found at {MODEL_PATH}. Please run train_model.py first."
|
|
|
)
|
|
|
|
|
|
if not os.path.exists(PREPROCESSING_PATH):
|
|
|
raise FileNotFoundError(
|
|
|
f"Preprocessing objects not found at {PREPROCESSING_PATH}. Please run train_model.py first."
|
|
|
)
|
|
|
|
|
|
print("Loading ML model and preprocessing objects...")
|
|
|
model = joblib.load(MODEL_PATH)
|
|
|
preprocessing_data = joblib.load(PREPROCESSING_PATH)
|
|
|
|
|
|
|
|
|
if os.path.exists(METADATA_PATH):
|
|
|
with open(METADATA_PATH, 'r') as f:
|
|
|
metadata = json.load(f)
|
|
|
model_name = metadata.get('model_name', 'ML Model')
|
|
|
print(f"Model: {model_name}")
|
|
|
print(f"R2 Score: {metadata.get('metrics', {}).get('r2', 'N/A'):.4f}")
|
|
|
else:
|
|
|
model_name = best_model_name or 'ML Model'
|
|
|
|
|
|
return model, preprocessing_data, model_name, False
|
|
|
|
|
|
|
|
|
def prepare_features(product_id, date, price, discount, category, preprocessing_data):
|
|
|
"""
|
|
|
Prepare features for prediction using the same preprocessing pipeline.
|
|
|
|
|
|
Args:
|
|
|
product_id: Product ID
|
|
|
date: Date string (YYYY-MM-DD) or datetime object
|
|
|
price: Product price
|
|
|
discount: Discount percentage (0-100)
|
|
|
category: Product category
|
|
|
preprocessing_data: Dictionary containing encoders and scaler
|
|
|
|
|
|
Returns:
|
|
|
numpy array: Prepared features for prediction
|
|
|
"""
|
|
|
|
|
|
if isinstance(date, str):
|
|
|
date = pd.to_datetime(date)
|
|
|
|
|
|
|
|
|
day = date.day
|
|
|
month = date.month
|
|
|
day_of_week = date.weekday()
|
|
|
weekend = 1 if day_of_week >= 5 else 0
|
|
|
year = date.year
|
|
|
quarter = date.quarter
|
|
|
|
|
|
|
|
|
category_encoder = preprocessing_data['encoders']['category']
|
|
|
product_encoder = preprocessing_data['encoders']['product_id']
|
|
|
|
|
|
|
|
|
try:
|
|
|
category_encoded = category_encoder.transform([category])[0]
|
|
|
except ValueError:
|
|
|
|
|
|
print(f"Warning: Category '{category}' not seen during training. Using default encoding.")
|
|
|
category_encoded = 0
|
|
|
|
|
|
try:
|
|
|
product_id_encoded = product_encoder.transform([product_id])[0]
|
|
|
except ValueError:
|
|
|
|
|
|
print(f"Warning: Product ID '{product_id}' not seen during training. Using default encoding.")
|
|
|
product_id_encoded = product_encoder.transform([product_encoder.classes_[0]])[0]
|
|
|
|
|
|
|
|
|
feature_dict = {
|
|
|
'price': price,
|
|
|
'discount': discount,
|
|
|
'day': day,
|
|
|
'month': month,
|
|
|
'day_of_week': day_of_week,
|
|
|
'weekend': weekend,
|
|
|
'year': year,
|
|
|
'quarter': quarter,
|
|
|
'category_encoded': category_encoded,
|
|
|
'product_id_encoded': product_id_encoded
|
|
|
}
|
|
|
|
|
|
|
|
|
feature_names = preprocessing_data['feature_names']
|
|
|
features = np.array([[feature_dict[name] for name in feature_names]])
|
|
|
|
|
|
|
|
|
scaler = preprocessing_data['scaler']
|
|
|
features_scaled = scaler.transform(features)
|
|
|
|
|
|
return features_scaled
|
|
|
|
|
|
|
|
|
def predict_demand_ml(product_id, date, price, discount, category, model, preprocessing_data):
|
|
|
"""
|
|
|
Predict demand for a product on a given date using ML model.
|
|
|
|
|
|
Args:
|
|
|
product_id: Product ID
|
|
|
date: Date string (YYYY-MM-DD) or datetime object
|
|
|
price: Product price
|
|
|
discount: Discount percentage (0-100)
|
|
|
category: Product category
|
|
|
model: Trained ML model
|
|
|
preprocessing_data: Dictionary containing encoders and scaler
|
|
|
|
|
|
Returns:
|
|
|
float: Predicted sales quantity
|
|
|
"""
|
|
|
|
|
|
features = prepare_features(product_id, date, price, discount, category, preprocessing_data)
|
|
|
|
|
|
|
|
|
prediction = model.predict(features)[0]
|
|
|
|
|
|
|
|
|
prediction = max(0, prediction)
|
|
|
|
|
|
return prediction
|
|
|
|
|
|
|
|
|
def predict_demand_timeseries(date, model, model_name):
|
|
|
"""
|
|
|
Predict overall daily demand using time-series model.
|
|
|
|
|
|
Args:
|
|
|
date: Date string (YYYY-MM-DD) or datetime object
|
|
|
model: Trained time-series model (ARIMA or Prophet)
|
|
|
model_name: Name of the model ('ARIMA' or 'Prophet')
|
|
|
|
|
|
Returns:
|
|
|
float: Predicted total daily sales quantity
|
|
|
"""
|
|
|
|
|
|
if isinstance(date, str):
|
|
|
date = pd.to_datetime(date)
|
|
|
|
|
|
if model_name == 'ARIMA':
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
forecast = model.forecast(steps=1)
|
|
|
prediction = forecast[0] if hasattr(forecast, '__iter__') else forecast
|
|
|
prediction = max(0, prediction)
|
|
|
return prediction
|
|
|
except Exception as e:
|
|
|
print(f"Error in ARIMA prediction: {e}")
|
|
|
return None
|
|
|
|
|
|
elif model_name == 'Prophet':
|
|
|
|
|
|
try:
|
|
|
future = pd.DataFrame({'ds': [date]})
|
|
|
forecast = model.predict(future)
|
|
|
prediction = forecast['yhat'].iloc[0]
|
|
|
prediction = max(0, prediction)
|
|
|
return prediction
|
|
|
except Exception as e:
|
|
|
print(f"Error in Prophet prediction: {e}")
|
|
|
return None
|
|
|
|
|
|
else:
|
|
|
print(f"Unknown time-series model: {model_name}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
def predict_batch(predictions_data, model, preprocessing_data):
|
|
|
"""
|
|
|
Predict demand for multiple products/dates at once.
|
|
|
|
|
|
Args:
|
|
|
predictions_data: List of dictionaries, each containing:
|
|
|
- product_id
|
|
|
- date
|
|
|
- price
|
|
|
- discount
|
|
|
- category
|
|
|
model: Trained model
|
|
|
preprocessing_data: Dictionary containing encoders and scaler
|
|
|
|
|
|
Returns:
|
|
|
list: List of predicted sales quantities
|
|
|
"""
|
|
|
predictions = []
|
|
|
|
|
|
for data in predictions_data:
|
|
|
pred = predict_demand(
|
|
|
data['product_id'],
|
|
|
data['date'],
|
|
|
data['price'],
|
|
|
data['discount'],
|
|
|
data['category'],
|
|
|
model,
|
|
|
preprocessing_data
|
|
|
)
|
|
|
predictions.append(pred)
|
|
|
|
|
|
return predictions
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""
|
|
|
Main function for command-line interface.
|
|
|
"""
|
|
|
parser = argparse.ArgumentParser(
|
|
|
description='Predict product demand for a given date and product details',
|
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
|
epilog="""
|
|
|
Examples (ML Models):
|
|
|
python predict.py --product_id 1 --date 2024-01-15 --price 100 --discount 10 --category Electronics
|
|
|
python predict.py --product_id 5 --date 2024-06-20 --price 50 --discount 0 --category Clothing
|
|
|
|
|
|
Examples (Time-Series Models - overall daily demand):
|
|
|
python predict.py --date 2024-01-15 --model_type timeseries
|
|
|
"""
|
|
|
)
|
|
|
|
|
|
parser.add_argument('--product_id', type=int, default=None,
|
|
|
help='Product ID (required for ML models)')
|
|
|
parser.add_argument('--date', type=str, required=True,
|
|
|
help='Date in YYYY-MM-DD format')
|
|
|
parser.add_argument('--price', type=float, default=None,
|
|
|
help='Product price (required for ML models)')
|
|
|
parser.add_argument('--discount', type=float, default=0,
|
|
|
help='Discount percentage (0-100), default: 0 (for ML models)')
|
|
|
parser.add_argument('--category', type=str, default=None,
|
|
|
help='Product category (required for ML models)')
|
|
|
parser.add_argument('--model_type', type=str, default='auto',
|
|
|
choices=['auto', 'ml', 'timeseries'],
|
|
|
help='Model type to use: auto (best model), ml, or timeseries')
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
try:
|
|
|
date_obj = pd.to_datetime(args.date)
|
|
|
except ValueError:
|
|
|
print(f"Error: Invalid date format '{args.date}'. Please use YYYY-MM-DD format.")
|
|
|
return
|
|
|
|
|
|
|
|
|
try:
|
|
|
model, preprocessing_data, model_name, is_timeseries = load_model_and_preprocessing(args.model_type)
|
|
|
except FileNotFoundError as e:
|
|
|
print(f"Error: {e}")
|
|
|
return
|
|
|
|
|
|
|
|
|
if not is_timeseries:
|
|
|
|
|
|
if args.product_id is None or args.price is None or args.category is None:
|
|
|
print("Error: ML models require --product_id, --price, and --category arguments.")
|
|
|
return
|
|
|
|
|
|
|
|
|
if args.discount < 0 or args.discount > 100:
|
|
|
print(f"Warning: Discount {args.discount}% is outside 0-100 range. Clamping to valid range.")
|
|
|
args.discount = max(0, min(100, args.discount))
|
|
|
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
print("MAKING PREDICTION")
|
|
|
print("="*60)
|
|
|
print(f"Model: {model_name}")
|
|
|
print(f"Model Type: {'Time-Series' if is_timeseries else 'Machine Learning'}")
|
|
|
print(f"Date: {args.date}")
|
|
|
|
|
|
if not is_timeseries:
|
|
|
print(f"Product ID: {args.product_id}")
|
|
|
print(f"Price: ${args.price:.2f}")
|
|
|
print(f"Discount: {args.discount}%")
|
|
|
print(f"Category: {args.category}")
|
|
|
|
|
|
print("-"*60)
|
|
|
|
|
|
if is_timeseries:
|
|
|
predicted_demand = predict_demand_timeseries(
|
|
|
args.date,
|
|
|
model,
|
|
|
model_name
|
|
|
)
|
|
|
|
|
|
if predicted_demand is None:
|
|
|
print("Error: Failed to make prediction.")
|
|
|
return
|
|
|
|
|
|
print(f"\nPredicted Total Daily Sales Quantity: {predicted_demand:.0f} units")
|
|
|
print("(This is the predicted total demand across all products for this date)")
|
|
|
else:
|
|
|
predicted_demand = predict_demand_ml(
|
|
|
args.product_id,
|
|
|
args.date,
|
|
|
args.price,
|
|
|
args.discount,
|
|
|
args.category,
|
|
|
model,
|
|
|
preprocessing_data
|
|
|
)
|
|
|
|
|
|
print(f"\nPredicted Sales Quantity: {predicted_demand:.0f} units")
|
|
|
print("(This is the predicted demand for this specific product)")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
date_obj = pd.to_datetime(args.date)
|
|
|
day_name = date_obj.strftime('%A')
|
|
|
is_weekend = "Yes" if date_obj.weekday() >= 5 else "No"
|
|
|
|
|
|
print(f"\nDate Information:")
|
|
|
print(f" Day of week: {day_name}")
|
|
|
print(f" Weekend: {is_weekend}")
|
|
|
print(f" Month: {date_obj.strftime('%B')}")
|
|
|
print(f" Quarter: Q{date_obj.quarter}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|