|
|
""" |
|
|
GCC Ramadan Retail Demand Forecasting - Inference Script |
|
|
|
|
|
This script demonstrates how to use the trained demand forecasting model. |
|
|
When downloaded from HuggingFace, this script works alongside model.joblib and encoders.joblib. |
|
|
|
|
|
Usage: |
|
|
python inference.py |
|
|
|
|
|
Or import and use programmatically: |
|
|
from inference import DemandForecaster |
|
|
forecaster = DemandForecaster() |
|
|
prediction = forecaster.predict(...) |
|
|
""" |
|
|
|
|
|
import joblib |
|
|
import json |
|
|
import numpy as np |
|
|
import os |
|
|
|
|
|
|
|
|
MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
|
|
|
class DemandForecaster: |
|
|
"""Class for loading and using the GCC Ramadan demand forecasting model.""" |
|
|
|
|
|
def __init__(self, model_dir=None): |
|
|
""" |
|
|
Initialize the forecaster by loading the model and encoders. |
|
|
|
|
|
Args: |
|
|
model_dir: Path to directory containing model files. |
|
|
Defaults to same directory as this script. |
|
|
""" |
|
|
self.model_dir = model_dir or MODEL_DIR |
|
|
self._load_model() |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load the trained model, encoders, and configuration.""" |
|
|
|
|
|
model_path = os.path.join(self.model_dir, "model.joblib") |
|
|
self.model = joblib.load(model_path) |
|
|
|
|
|
|
|
|
encoders_path = os.path.join(self.model_dir, "encoders.joblib") |
|
|
encoders = joblib.load(encoders_path) |
|
|
self.country_encoder = encoders['country_encoder'] |
|
|
self.category_encoder = encoders['category_encoder'] |
|
|
|
|
|
|
|
|
config_path = os.path.join(self.model_dir, "config.json") |
|
|
with open(config_path, 'r') as f: |
|
|
self.config = json.load(f) |
|
|
|
|
|
self.countries = self.config['countries'] |
|
|
self.categories = self.config['categories'] |
|
|
|
|
|
def predict(self, |
|
|
is_ramadan: int, |
|
|
ramadan_week: int, |
|
|
days_to_eid: int, |
|
|
is_eid_fitr: int, |
|
|
is_eid_adha: int, |
|
|
is_hajj_season: int, |
|
|
country: str, |
|
|
category: str, |
|
|
temperature: float, |
|
|
day_of_week: int, |
|
|
month: int, |
|
|
hijri_month: int, |
|
|
hijri_day: int) -> float: |
|
|
""" |
|
|
Predict demand index for given features. |
|
|
|
|
|
Args: |
|
|
is_ramadan: 1 if Ramadan, 0 otherwise |
|
|
ramadan_week: Week of Ramadan (1-5), 0 if not Ramadan |
|
|
days_to_eid: Days until Eid al-Fitr (-1 if not applicable) |
|
|
is_eid_fitr: 1 if Eid al-Fitr, 0 otherwise |
|
|
is_eid_adha: 1 if Eid al-Adha, 0 otherwise |
|
|
is_hajj_season: 1 if Hajj season, 0 otherwise |
|
|
country: One of UAE, KSA, Qatar, Kuwait, Bahrain, Oman |
|
|
category: Product category (dates_sweets, electronics, fashion_abayas, |
|
|
gifts, groceries, perfumes_oud) |
|
|
temperature: Temperature in Celsius |
|
|
day_of_week: Day of week (0-6, Monday=0) |
|
|
month: Gregorian month (1-12) |
|
|
hijri_month: Hijri month (1-12) |
|
|
hijri_day: Hijri day (1-30) |
|
|
|
|
|
Returns: |
|
|
Predicted demand index (typically 30-200 range) |
|
|
""" |
|
|
|
|
|
if country not in self.countries: |
|
|
raise ValueError(f"Invalid country: {country}. Must be one of {self.countries}") |
|
|
if category not in self.categories: |
|
|
raise ValueError(f"Invalid category: {category}. Must be one of {self.categories}") |
|
|
|
|
|
|
|
|
country_encoded = self.country_encoder.transform([country])[0] |
|
|
category_encoded = self.category_encoder.transform([category])[0] |
|
|
|
|
|
|
|
|
features = np.array([[ |
|
|
is_ramadan, |
|
|
ramadan_week, |
|
|
days_to_eid, |
|
|
is_eid_fitr, |
|
|
is_eid_adha, |
|
|
is_hajj_season, |
|
|
country_encoded, |
|
|
category_encoded, |
|
|
temperature, |
|
|
day_of_week, |
|
|
month, |
|
|
hijri_month, |
|
|
hijri_day |
|
|
]]) |
|
|
|
|
|
|
|
|
prediction = self.model.predict(features)[0] |
|
|
return prediction |
|
|
|
|
|
def predict_dict(self, data: dict) -> float: |
|
|
""" |
|
|
Predict demand index from a dictionary of features. |
|
|
|
|
|
Args: |
|
|
data: Dictionary with keys matching the predict() parameters |
|
|
|
|
|
Returns: |
|
|
Predicted demand index |
|
|
""" |
|
|
return self.predict(**data) |
|
|
|
|
|
def predict_batch(self, data_list: list) -> list: |
|
|
""" |
|
|
Predict demand index for multiple records. |
|
|
|
|
|
Args: |
|
|
data_list: List of dictionaries with feature values |
|
|
|
|
|
Returns: |
|
|
List of predicted demand indices |
|
|
""" |
|
|
return [self.predict(**record) for record in data_list] |
|
|
|
|
|
|
|
|
def demo(): |
|
|
"""Demonstrate the model with example predictions.""" |
|
|
print("=" * 60) |
|
|
print("GCC Ramadan Retail Demand Forecasting - Demo") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
forecaster = DemandForecaster() |
|
|
|
|
|
print(f"\nAvailable countries: {forecaster.countries}") |
|
|
print(f"Available categories: {forecaster.categories}") |
|
|
print(f"\nModel metrics: R2={forecaster.config['metrics']['r2_score']:.3f}, " |
|
|
f"RMSE={forecaster.config['metrics']['rmse']:.2f}") |
|
|
|
|
|
print("\n" + "-" * 60) |
|
|
print("Example Predictions:") |
|
|
print("-" * 60) |
|
|
|
|
|
examples = [ |
|
|
{ |
|
|
"name": "Normal day in UAE (groceries)", |
|
|
"params": { |
|
|
"is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1, |
|
|
"is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 0, |
|
|
"country": "UAE", "category": "groceries", "temperature": 25.0, |
|
|
"day_of_week": 5, "month": 6, "hijri_month": 11, "hijri_day": 15 |
|
|
} |
|
|
}, |
|
|
{ |
|
|
"name": "Ramadan Week 2 in KSA (dates_sweets)", |
|
|
"params": { |
|
|
"is_ramadan": 1, "ramadan_week": 2, "days_to_eid": 15, |
|
|
"is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 0, |
|
|
"country": "KSA", "category": "dates_sweets", "temperature": 30.0, |
|
|
"day_of_week": 4, "month": 4, "hijri_month": 9, "hijri_day": 15 |
|
|
} |
|
|
}, |
|
|
{ |
|
|
"name": "Eid al-Fitr in Qatar (gifts)", |
|
|
"params": { |
|
|
"is_ramadan": 0, "ramadan_week": 0, "days_to_eid": 0, |
|
|
"is_eid_fitr": 1, "is_eid_adha": 0, "is_hajj_season": 0, |
|
|
"country": "Qatar", "category": "gifts", "temperature": 35.0, |
|
|
"day_of_week": 0, "month": 5, "hijri_month": 10, "hijri_day": 1 |
|
|
} |
|
|
}, |
|
|
{ |
|
|
"name": "Hajj season in KSA (perfumes_oud)", |
|
|
"params": { |
|
|
"is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1, |
|
|
"is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 1, |
|
|
"country": "KSA", "category": "perfumes_oud", "temperature": 40.0, |
|
|
"day_of_week": 3, "month": 7, "hijri_month": 12, "hijri_day": 8 |
|
|
} |
|
|
}, |
|
|
{ |
|
|
"name": "Eid al-Adha in Kuwait (fashion_abayas)", |
|
|
"params": { |
|
|
"is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1, |
|
|
"is_eid_fitr": 0, "is_eid_adha": 1, "is_hajj_season": 1, |
|
|
"country": "Kuwait", "category": "fashion_abayas", "temperature": 42.0, |
|
|
"day_of_week": 5, "month": 7, "hijri_month": 12, "hijri_day": 10 |
|
|
} |
|
|
} |
|
|
] |
|
|
|
|
|
for i, example in enumerate(examples, 1): |
|
|
pred = forecaster.predict(**example["params"]) |
|
|
print(f"\n{i}. {example['name']}: {pred:.2f}") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Demo complete!") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo() |
|
|
|