GencoDiv's picture
Upload folder using huggingface_hub
fcf905c verified
"""
GCC Ramadan Retail Demand Forecasting - Inference Script
This script demonstrates how to use the trained demand forecasting model.
When downloaded from HuggingFace, this script works alongside model.joblib and encoders.joblib.
Usage:
python inference.py
Or import and use programmatically:
from inference import DemandForecaster
forecaster = DemandForecaster()
prediction = forecaster.predict(...)
"""
import joblib
import json
import numpy as np
import os
# Get model directory (same directory as this script)
MODEL_DIR = os.path.dirname(os.path.abspath(__file__))
class DemandForecaster:
"""Class for loading and using the GCC Ramadan demand forecasting model."""
def __init__(self, model_dir=None):
"""
Initialize the forecaster by loading the model and encoders.
Args:
model_dir: Path to directory containing model files.
Defaults to same directory as this script.
"""
self.model_dir = model_dir or MODEL_DIR
self._load_model()
def _load_model(self):
"""Load the trained model, encoders, and configuration."""
# Load model
model_path = os.path.join(self.model_dir, "model.joblib")
self.model = joblib.load(model_path)
# Load encoders
encoders_path = os.path.join(self.model_dir, "encoders.joblib")
encoders = joblib.load(encoders_path)
self.country_encoder = encoders['country_encoder']
self.category_encoder = encoders['category_encoder']
# Load config
config_path = os.path.join(self.model_dir, "config.json")
with open(config_path, 'r') as f:
self.config = json.load(f)
self.countries = self.config['countries']
self.categories = self.config['categories']
def predict(self,
is_ramadan: int,
ramadan_week: int,
days_to_eid: int,
is_eid_fitr: int,
is_eid_adha: int,
is_hajj_season: int,
country: str,
category: str,
temperature: float,
day_of_week: int,
month: int,
hijri_month: int,
hijri_day: int) -> float:
"""
Predict demand index for given features.
Args:
is_ramadan: 1 if Ramadan, 0 otherwise
ramadan_week: Week of Ramadan (1-5), 0 if not Ramadan
days_to_eid: Days until Eid al-Fitr (-1 if not applicable)
is_eid_fitr: 1 if Eid al-Fitr, 0 otherwise
is_eid_adha: 1 if Eid al-Adha, 0 otherwise
is_hajj_season: 1 if Hajj season, 0 otherwise
country: One of UAE, KSA, Qatar, Kuwait, Bahrain, Oman
category: Product category (dates_sweets, electronics, fashion_abayas,
gifts, groceries, perfumes_oud)
temperature: Temperature in Celsius
day_of_week: Day of week (0-6, Monday=0)
month: Gregorian month (1-12)
hijri_month: Hijri month (1-12)
hijri_day: Hijri day (1-30)
Returns:
Predicted demand index (typically 30-200 range)
"""
# Validate inputs
if country not in self.countries:
raise ValueError(f"Invalid country: {country}. Must be one of {self.countries}")
if category not in self.categories:
raise ValueError(f"Invalid category: {category}. Must be one of {self.categories}")
# Encode categorical features
country_encoded = self.country_encoder.transform([country])[0]
category_encoded = self.category_encoder.transform([category])[0]
# Create feature vector
features = np.array([[
is_ramadan,
ramadan_week,
days_to_eid,
is_eid_fitr,
is_eid_adha,
is_hajj_season,
country_encoded,
category_encoded,
temperature,
day_of_week,
month,
hijri_month,
hijri_day
]])
# Make prediction
prediction = self.model.predict(features)[0]
return prediction
def predict_dict(self, data: dict) -> float:
"""
Predict demand index from a dictionary of features.
Args:
data: Dictionary with keys matching the predict() parameters
Returns:
Predicted demand index
"""
return self.predict(**data)
def predict_batch(self, data_list: list) -> list:
"""
Predict demand index for multiple records.
Args:
data_list: List of dictionaries with feature values
Returns:
List of predicted demand indices
"""
return [self.predict(**record) for record in data_list]
def demo():
"""Demonstrate the model with example predictions."""
print("=" * 60)
print("GCC Ramadan Retail Demand Forecasting - Demo")
print("=" * 60)
# Initialize forecaster
forecaster = DemandForecaster()
print(f"\nAvailable countries: {forecaster.countries}")
print(f"Available categories: {forecaster.categories}")
print(f"\nModel metrics: R2={forecaster.config['metrics']['r2_score']:.3f}, "
f"RMSE={forecaster.config['metrics']['rmse']:.2f}")
print("\n" + "-" * 60)
print("Example Predictions:")
print("-" * 60)
examples = [
{
"name": "Normal day in UAE (groceries)",
"params": {
"is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1,
"is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 0,
"country": "UAE", "category": "groceries", "temperature": 25.0,
"day_of_week": 5, "month": 6, "hijri_month": 11, "hijri_day": 15
}
},
{
"name": "Ramadan Week 2 in KSA (dates_sweets)",
"params": {
"is_ramadan": 1, "ramadan_week": 2, "days_to_eid": 15,
"is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 0,
"country": "KSA", "category": "dates_sweets", "temperature": 30.0,
"day_of_week": 4, "month": 4, "hijri_month": 9, "hijri_day": 15
}
},
{
"name": "Eid al-Fitr in Qatar (gifts)",
"params": {
"is_ramadan": 0, "ramadan_week": 0, "days_to_eid": 0,
"is_eid_fitr": 1, "is_eid_adha": 0, "is_hajj_season": 0,
"country": "Qatar", "category": "gifts", "temperature": 35.0,
"day_of_week": 0, "month": 5, "hijri_month": 10, "hijri_day": 1
}
},
{
"name": "Hajj season in KSA (perfumes_oud)",
"params": {
"is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1,
"is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 1,
"country": "KSA", "category": "perfumes_oud", "temperature": 40.0,
"day_of_week": 3, "month": 7, "hijri_month": 12, "hijri_day": 8
}
},
{
"name": "Eid al-Adha in Kuwait (fashion_abayas)",
"params": {
"is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1,
"is_eid_fitr": 0, "is_eid_adha": 1, "is_hajj_season": 1,
"country": "Kuwait", "category": "fashion_abayas", "temperature": 42.0,
"day_of_week": 5, "month": 7, "hijri_month": 12, "hijri_day": 10
}
}
]
for i, example in enumerate(examples, 1):
pred = forecaster.predict(**example["params"])
print(f"\n{i}. {example['name']}: {pred:.2f}")
print("\n" + "=" * 60)
print("Demo complete!")
print("=" * 60)
if __name__ == "__main__":
demo()