File size: 3,843 Bytes
84548c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import json
import pickle
import pandas as pd
import gradio as gr
import pyarrow.feather as feather
from functools import lru_cache

# --- Data & Model Loading Logic ---

def load_model():
    """Load the trained sales forecast model"""
    try:
        with open("models/sales_forecast_model.pkl", "rb") as file:
            model = pickle.load(file)
        return model
    except FileNotFoundError:
        # Using gr.Error for UI notification if called within an interaction
        # or standard print for startup logs
        print("Error: 'models/sales_forecast_model.pkl' not found.")
        return None

def load_feature_stats():
    """Load feature statistics used for normalization"""
    try:
        with open("models/feature_stats.json", "r") as file:
            feature_stats = json.load(file)
        return feature_stats
    except FileNotFoundError:
        print("Error: 'models/feature_stats.json' not found.")
        return {}

@lru_cache(maxsize=1)
def load_data():
    """Load preprocessed sales data (lru_cache replaces @st.cache_data)"""
    try:
        df = pd.read_csv("data/sales_data_preprocessed.csv")
        if "date" in df.columns:
            df["date"] = pd.to_datetime(df["date"])
        return df
    except FileNotFoundError:
        print("Error: 'data/sales_data_preprocessed.csv' not found.")
        return pd.DataFrame(columns=["date", "store", "sales"])

def load_feature_engineered_data():
    """Load feature engineered data with extended features"""
    try:
        feature_engineered_data = feather.read_feather(
            "data/feature_engineered_data_55_features.feather"
        )
        return feature_engineered_data
    except Exception as e:
        print(f"Error loading feature engineered data: {str(e)}")
        return pd.DataFrame()

# --- Processing Logic ---

def preprocess_data(df, feature_stats=None):
    """Preprocess data for prediction (simplified version)"""
    # Create a copy to avoid modifying the original
    processed_df = df.copy()

    # Extract date features if date column exists
    if "date" in processed_df.columns:
        processed_df["day_of_week"] = processed_df["date"].dt.dayofweek
        processed_df["day_of_month"] = processed_df["date"].dt.day
        processed_df["month"] = processed_df["date"].dt.month
        processed_df["year"] = processed_df["date"].dt.year
        processed_df["is_weekend"] = processed_df["day_of_week"].apply(
            lambda x: 1 if x >= 5 else 0
        )

    # Normalize numerical features if stats are provided
    if feature_stats:
        for feature, stats in feature_stats.items():
            if feature in processed_df.columns and "mean" in stats and "std" in stats:
                processed_df[feature] = (processed_df[feature] - stats["mean"]) / stats[
                    "std"
                ]

    return processed_df

# --- Gradio UI Implementation ---

# Load resources once when the app starts
model = load_model()
stats = load_feature_stats()

def predict_sales_ui(store_id):
    """Example function to link the logic to a Gradio interface"""
    if model is None:
        raise gr.Error("Model not loaded. Check server logs.")

    data = load_data()
    # Apply your logic
    processed = preprocess_data(data, stats)

    # Filter for the specific store
    store_data = processed[processed['store'] == store_id]

    # Return results (placeholder for actual model.predict logic)
    return store_data.head()

# Simple Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Sales Forecast Prediction")
    store_input = gr.Number(label="Enter Store ID")
    output_table = gr.DataFrame(label="Preprocessed Data Preview")
    btn = gr.Button("Predict")

    btn.click(fn=predict_sales_ui, inputs=store_input, outputs=output_table)

if __name__ == "__main__":
    demo.launch()