|
|
import json |
|
|
import pickle |
|
|
import pandas as pd |
|
|
import gradio as gr |
|
|
import pyarrow.feather as feather |
|
|
from functools import lru_cache |
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""Load the trained sales forecast model""" |
|
|
try: |
|
|
with open("models/sales_forecast_model.pkl", "rb") as file: |
|
|
model = pickle.load(file) |
|
|
return model |
|
|
except FileNotFoundError: |
|
|
|
|
|
|
|
|
print("Error: 'models/sales_forecast_model.pkl' not found.") |
|
|
return None |
|
|
|
|
|
def load_feature_stats(): |
|
|
"""Load feature statistics used for normalization""" |
|
|
try: |
|
|
with open("models/feature_stats.json", "r") as file: |
|
|
feature_stats = json.load(file) |
|
|
return feature_stats |
|
|
except FileNotFoundError: |
|
|
print("Error: 'models/feature_stats.json' not found.") |
|
|
return {} |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def load_data(): |
|
|
"""Load preprocessed sales data (lru_cache replaces @st.cache_data)""" |
|
|
try: |
|
|
df = pd.read_csv("data/sales_data_preprocessed.csv") |
|
|
if "date" in df.columns: |
|
|
df["date"] = pd.to_datetime(df["date"]) |
|
|
return df |
|
|
except FileNotFoundError: |
|
|
print("Error: 'data/sales_data_preprocessed.csv' not found.") |
|
|
return pd.DataFrame(columns=["date", "store", "sales"]) |
|
|
|
|
|
def load_feature_engineered_data(): |
|
|
"""Load feature engineered data with extended features""" |
|
|
try: |
|
|
feature_engineered_data = feather.read_feather( |
|
|
"data/feature_engineered_data_55_features.feather" |
|
|
) |
|
|
return feature_engineered_data |
|
|
except Exception as e: |
|
|
print(f"Error loading feature engineered data: {str(e)}") |
|
|
return pd.DataFrame() |
|
|
|
|
|
|
|
|
|
|
|
def preprocess_data(df, feature_stats=None): |
|
|
"""Preprocess data for prediction (simplified version)""" |
|
|
|
|
|
processed_df = df.copy() |
|
|
|
|
|
|
|
|
if "date" in processed_df.columns: |
|
|
processed_df["day_of_week"] = processed_df["date"].dt.dayofweek |
|
|
processed_df["day_of_month"] = processed_df["date"].dt.day |
|
|
processed_df["month"] = processed_df["date"].dt.month |
|
|
processed_df["year"] = processed_df["date"].dt.year |
|
|
processed_df["is_weekend"] = processed_df["day_of_week"].apply( |
|
|
lambda x: 1 if x >= 5 else 0 |
|
|
) |
|
|
|
|
|
|
|
|
if feature_stats: |
|
|
for feature, stats in feature_stats.items(): |
|
|
if feature in processed_df.columns and "mean" in stats and "std" in stats: |
|
|
processed_df[feature] = (processed_df[feature] - stats["mean"]) / stats[ |
|
|
"std" |
|
|
] |
|
|
|
|
|
return processed_df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = load_model() |
|
|
stats = load_feature_stats() |
|
|
|
|
|
def predict_sales_ui(store_id): |
|
|
"""Example function to link the logic to a Gradio interface""" |
|
|
if model is None: |
|
|
raise gr.Error("Model not loaded. Check server logs.") |
|
|
|
|
|
data = load_data() |
|
|
|
|
|
processed = preprocess_data(data, stats) |
|
|
|
|
|
|
|
|
store_data = processed[processed['store'] == store_id] |
|
|
|
|
|
|
|
|
return store_data.head() |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Sales Forecast Prediction") |
|
|
store_input = gr.Number(label="Enter Store ID") |
|
|
output_table = gr.DataFrame(label="Preprocessed Data Preview") |
|
|
btn = gr.Button("Predict") |
|
|
|
|
|
btn.click(fn=predict_sales_ui, inputs=store_input, outputs=output_table) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |