File size: 7,863 Bytes
fcf905c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""
GCC Ramadan Retail Demand Forecasting - Inference Script

This script demonstrates how to use the trained demand forecasting model.
When downloaded from HuggingFace, this script works alongside model.joblib and encoders.joblib.

Usage:
    python inference.py

Or import and use programmatically:
    from inference import DemandForecaster
    forecaster = DemandForecaster()
    prediction = forecaster.predict(...)
"""

import joblib
import json
import numpy as np
import os

# Get model directory (same directory as this script)
MODEL_DIR = os.path.dirname(os.path.abspath(__file__))


class DemandForecaster:
    """Class for loading and using the GCC Ramadan demand forecasting model."""

    def __init__(self, model_dir=None):
        """
        Initialize the forecaster by loading the model and encoders.

        Args:
            model_dir: Path to directory containing model files.
                      Defaults to same directory as this script.
        """
        self.model_dir = model_dir or MODEL_DIR
        self._load_model()

    def _load_model(self):
        """Load the trained model, encoders, and configuration."""
        # Load model
        model_path = os.path.join(self.model_dir, "model.joblib")
        self.model = joblib.load(model_path)

        # Load encoders
        encoders_path = os.path.join(self.model_dir, "encoders.joblib")
        encoders = joblib.load(encoders_path)
        self.country_encoder = encoders['country_encoder']
        self.category_encoder = encoders['category_encoder']

        # Load config
        config_path = os.path.join(self.model_dir, "config.json")
        with open(config_path, 'r') as f:
            self.config = json.load(f)

        self.countries = self.config['countries']
        self.categories = self.config['categories']

    def predict(self,
                is_ramadan: int,
                ramadan_week: int,
                days_to_eid: int,
                is_eid_fitr: int,
                is_eid_adha: int,
                is_hajj_season: int,
                country: str,
                category: str,
                temperature: float,
                day_of_week: int,
                month: int,
                hijri_month: int,
                hijri_day: int) -> float:
        """
        Predict demand index for given features.

        Args:
            is_ramadan: 1 if Ramadan, 0 otherwise
            ramadan_week: Week of Ramadan (1-5), 0 if not Ramadan
            days_to_eid: Days until Eid al-Fitr (-1 if not applicable)
            is_eid_fitr: 1 if Eid al-Fitr, 0 otherwise
            is_eid_adha: 1 if Eid al-Adha, 0 otherwise
            is_hajj_season: 1 if Hajj season, 0 otherwise
            country: One of UAE, KSA, Qatar, Kuwait, Bahrain, Oman
            category: Product category (dates_sweets, electronics, fashion_abayas,
                      gifts, groceries, perfumes_oud)
            temperature: Temperature in Celsius
            day_of_week: Day of week (0-6, Monday=0)
            month: Gregorian month (1-12)
            hijri_month: Hijri month (1-12)
            hijri_day: Hijri day (1-30)

        Returns:
            Predicted demand index (typically 30-200 range)
        """
        # Validate inputs
        if country not in self.countries:
            raise ValueError(f"Invalid country: {country}. Must be one of {self.countries}")
        if category not in self.categories:
            raise ValueError(f"Invalid category: {category}. Must be one of {self.categories}")

        # Encode categorical features
        country_encoded = self.country_encoder.transform([country])[0]
        category_encoded = self.category_encoder.transform([category])[0]

        # Create feature vector
        features = np.array([[
            is_ramadan,
            ramadan_week,
            days_to_eid,
            is_eid_fitr,
            is_eid_adha,
            is_hajj_season,
            country_encoded,
            category_encoded,
            temperature,
            day_of_week,
            month,
            hijri_month,
            hijri_day
        ]])

        # Make prediction
        prediction = self.model.predict(features)[0]
        return prediction

    def predict_dict(self, data: dict) -> float:
        """
        Predict demand index from a dictionary of features.

        Args:
            data: Dictionary with keys matching the predict() parameters

        Returns:
            Predicted demand index
        """
        return self.predict(**data)

    def predict_batch(self, data_list: list) -> list:
        """
        Predict demand index for multiple records.

        Args:
            data_list: List of dictionaries with feature values

        Returns:
            List of predicted demand indices
        """
        return [self.predict(**record) for record in data_list]


def demo():
    """Demonstrate the model with example predictions."""
    print("=" * 60)
    print("GCC Ramadan Retail Demand Forecasting - Demo")
    print("=" * 60)

    # Initialize forecaster
    forecaster = DemandForecaster()

    print(f"\nAvailable countries: {forecaster.countries}")
    print(f"Available categories: {forecaster.categories}")
    print(f"\nModel metrics: R2={forecaster.config['metrics']['r2_score']:.3f}, "
          f"RMSE={forecaster.config['metrics']['rmse']:.2f}")

    print("\n" + "-" * 60)
    print("Example Predictions:")
    print("-" * 60)

    examples = [
        {
            "name": "Normal day in UAE (groceries)",
            "params": {
                "is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1,
                "is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 0,
                "country": "UAE", "category": "groceries", "temperature": 25.0,
                "day_of_week": 5, "month": 6, "hijri_month": 11, "hijri_day": 15
            }
        },
        {
            "name": "Ramadan Week 2 in KSA (dates_sweets)",
            "params": {
                "is_ramadan": 1, "ramadan_week": 2, "days_to_eid": 15,
                "is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 0,
                "country": "KSA", "category": "dates_sweets", "temperature": 30.0,
                "day_of_week": 4, "month": 4, "hijri_month": 9, "hijri_day": 15
            }
        },
        {
            "name": "Eid al-Fitr in Qatar (gifts)",
            "params": {
                "is_ramadan": 0, "ramadan_week": 0, "days_to_eid": 0,
                "is_eid_fitr": 1, "is_eid_adha": 0, "is_hajj_season": 0,
                "country": "Qatar", "category": "gifts", "temperature": 35.0,
                "day_of_week": 0, "month": 5, "hijri_month": 10, "hijri_day": 1
            }
        },
        {
            "name": "Hajj season in KSA (perfumes_oud)",
            "params": {
                "is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1,
                "is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 1,
                "country": "KSA", "category": "perfumes_oud", "temperature": 40.0,
                "day_of_week": 3, "month": 7, "hijri_month": 12, "hijri_day": 8
            }
        },
        {
            "name": "Eid al-Adha in Kuwait (fashion_abayas)",
            "params": {
                "is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1,
                "is_eid_fitr": 0, "is_eid_adha": 1, "is_hajj_season": 1,
                "country": "Kuwait", "category": "fashion_abayas", "temperature": 42.0,
                "day_of_week": 5, "month": 7, "hijri_month": 12, "hijri_day": 10
            }
        }
    ]

    for i, example in enumerate(examples, 1):
        pred = forecaster.predict(**example["params"])
        print(f"\n{i}. {example['name']}: {pred:.2f}")

    print("\n" + "=" * 60)
    print("Demo complete!")
    print("=" * 60)


if __name__ == "__main__":
    demo()