File size: 7,863 Bytes
fcf905c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
"""
GCC Ramadan Retail Demand Forecasting - Inference Script
This script demonstrates how to use the trained demand forecasting model.
When downloaded from HuggingFace, this script works alongside model.joblib and encoders.joblib.
Usage:
python inference.py
Or import and use programmatically:
from inference import DemandForecaster
forecaster = DemandForecaster()
prediction = forecaster.predict(...)
"""
import joblib
import json
import numpy as np
import os
# Get model directory (same directory as this script)
MODEL_DIR = os.path.dirname(os.path.abspath(__file__))
class DemandForecaster:
"""Class for loading and using the GCC Ramadan demand forecasting model."""
def __init__(self, model_dir=None):
"""
Initialize the forecaster by loading the model and encoders.
Args:
model_dir: Path to directory containing model files.
Defaults to same directory as this script.
"""
self.model_dir = model_dir or MODEL_DIR
self._load_model()
def _load_model(self):
"""Load the trained model, encoders, and configuration."""
# Load model
model_path = os.path.join(self.model_dir, "model.joblib")
self.model = joblib.load(model_path)
# Load encoders
encoders_path = os.path.join(self.model_dir, "encoders.joblib")
encoders = joblib.load(encoders_path)
self.country_encoder = encoders['country_encoder']
self.category_encoder = encoders['category_encoder']
# Load config
config_path = os.path.join(self.model_dir, "config.json")
with open(config_path, 'r') as f:
self.config = json.load(f)
self.countries = self.config['countries']
self.categories = self.config['categories']
def predict(self,
is_ramadan: int,
ramadan_week: int,
days_to_eid: int,
is_eid_fitr: int,
is_eid_adha: int,
is_hajj_season: int,
country: str,
category: str,
temperature: float,
day_of_week: int,
month: int,
hijri_month: int,
hijri_day: int) -> float:
"""
Predict demand index for given features.
Args:
is_ramadan: 1 if Ramadan, 0 otherwise
ramadan_week: Week of Ramadan (1-5), 0 if not Ramadan
days_to_eid: Days until Eid al-Fitr (-1 if not applicable)
is_eid_fitr: 1 if Eid al-Fitr, 0 otherwise
is_eid_adha: 1 if Eid al-Adha, 0 otherwise
is_hajj_season: 1 if Hajj season, 0 otherwise
country: One of UAE, KSA, Qatar, Kuwait, Bahrain, Oman
category: Product category (dates_sweets, electronics, fashion_abayas,
gifts, groceries, perfumes_oud)
temperature: Temperature in Celsius
day_of_week: Day of week (0-6, Monday=0)
month: Gregorian month (1-12)
hijri_month: Hijri month (1-12)
hijri_day: Hijri day (1-30)
Returns:
Predicted demand index (typically 30-200 range)
"""
# Validate inputs
if country not in self.countries:
raise ValueError(f"Invalid country: {country}. Must be one of {self.countries}")
if category not in self.categories:
raise ValueError(f"Invalid category: {category}. Must be one of {self.categories}")
# Encode categorical features
country_encoded = self.country_encoder.transform([country])[0]
category_encoded = self.category_encoder.transform([category])[0]
# Create feature vector
features = np.array([[
is_ramadan,
ramadan_week,
days_to_eid,
is_eid_fitr,
is_eid_adha,
is_hajj_season,
country_encoded,
category_encoded,
temperature,
day_of_week,
month,
hijri_month,
hijri_day
]])
# Make prediction
prediction = self.model.predict(features)[0]
return prediction
def predict_dict(self, data: dict) -> float:
"""
Predict demand index from a dictionary of features.
Args:
data: Dictionary with keys matching the predict() parameters
Returns:
Predicted demand index
"""
return self.predict(**data)
def predict_batch(self, data_list: list) -> list:
"""
Predict demand index for multiple records.
Args:
data_list: List of dictionaries with feature values
Returns:
List of predicted demand indices
"""
return [self.predict(**record) for record in data_list]
def demo():
"""Demonstrate the model with example predictions."""
print("=" * 60)
print("GCC Ramadan Retail Demand Forecasting - Demo")
print("=" * 60)
# Initialize forecaster
forecaster = DemandForecaster()
print(f"\nAvailable countries: {forecaster.countries}")
print(f"Available categories: {forecaster.categories}")
print(f"\nModel metrics: R2={forecaster.config['metrics']['r2_score']:.3f}, "
f"RMSE={forecaster.config['metrics']['rmse']:.2f}")
print("\n" + "-" * 60)
print("Example Predictions:")
print("-" * 60)
examples = [
{
"name": "Normal day in UAE (groceries)",
"params": {
"is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1,
"is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 0,
"country": "UAE", "category": "groceries", "temperature": 25.0,
"day_of_week": 5, "month": 6, "hijri_month": 11, "hijri_day": 15
}
},
{
"name": "Ramadan Week 2 in KSA (dates_sweets)",
"params": {
"is_ramadan": 1, "ramadan_week": 2, "days_to_eid": 15,
"is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 0,
"country": "KSA", "category": "dates_sweets", "temperature": 30.0,
"day_of_week": 4, "month": 4, "hijri_month": 9, "hijri_day": 15
}
},
{
"name": "Eid al-Fitr in Qatar (gifts)",
"params": {
"is_ramadan": 0, "ramadan_week": 0, "days_to_eid": 0,
"is_eid_fitr": 1, "is_eid_adha": 0, "is_hajj_season": 0,
"country": "Qatar", "category": "gifts", "temperature": 35.0,
"day_of_week": 0, "month": 5, "hijri_month": 10, "hijri_day": 1
}
},
{
"name": "Hajj season in KSA (perfumes_oud)",
"params": {
"is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1,
"is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 1,
"country": "KSA", "category": "perfumes_oud", "temperature": 40.0,
"day_of_week": 3, "month": 7, "hijri_month": 12, "hijri_day": 8
}
},
{
"name": "Eid al-Adha in Kuwait (fashion_abayas)",
"params": {
"is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1,
"is_eid_fitr": 0, "is_eid_adha": 1, "is_hajj_season": 1,
"country": "Kuwait", "category": "fashion_abayas", "temperature": 42.0,
"day_of_week": 5, "month": 7, "hijri_month": 12, "hijri_day": 10
}
}
]
for i, example in enumerate(examples, 1):
pred = forecaster.predict(**example["params"])
print(f"\n{i}. {example['name']}: {pred:.2f}")
print("\n" + "=" * 60)
print("Demo complete!")
print("=" * 60)
if __name__ == "__main__":
demo()
|