File size: 3,843 Bytes
84548c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import json
import pickle
import pandas as pd
import gradio as gr
import pyarrow.feather as feather
from functools import lru_cache
# --- Data & Model Loading Logic ---
def load_model():
"""Load the trained sales forecast model"""
try:
with open("models/sales_forecast_model.pkl", "rb") as file:
model = pickle.load(file)
return model
except FileNotFoundError:
# Using gr.Error for UI notification if called within an interaction
# or standard print for startup logs
print("Error: 'models/sales_forecast_model.pkl' not found.")
return None
def load_feature_stats():
"""Load feature statistics used for normalization"""
try:
with open("models/feature_stats.json", "r") as file:
feature_stats = json.load(file)
return feature_stats
except FileNotFoundError:
print("Error: 'models/feature_stats.json' not found.")
return {}
@lru_cache(maxsize=1)
def load_data():
"""Load preprocessed sales data (lru_cache replaces @st.cache_data)"""
try:
df = pd.read_csv("data/sales_data_preprocessed.csv")
if "date" in df.columns:
df["date"] = pd.to_datetime(df["date"])
return df
except FileNotFoundError:
print("Error: 'data/sales_data_preprocessed.csv' not found.")
return pd.DataFrame(columns=["date", "store", "sales"])
def load_feature_engineered_data():
"""Load feature engineered data with extended features"""
try:
feature_engineered_data = feather.read_feather(
"data/feature_engineered_data_55_features.feather"
)
return feature_engineered_data
except Exception as e:
print(f"Error loading feature engineered data: {str(e)}")
return pd.DataFrame()
# --- Processing Logic ---
def preprocess_data(df, feature_stats=None):
"""Preprocess data for prediction (simplified version)"""
# Create a copy to avoid modifying the original
processed_df = df.copy()
# Extract date features if date column exists
if "date" in processed_df.columns:
processed_df["day_of_week"] = processed_df["date"].dt.dayofweek
processed_df["day_of_month"] = processed_df["date"].dt.day
processed_df["month"] = processed_df["date"].dt.month
processed_df["year"] = processed_df["date"].dt.year
processed_df["is_weekend"] = processed_df["day_of_week"].apply(
lambda x: 1 if x >= 5 else 0
)
# Normalize numerical features if stats are provided
if feature_stats:
for feature, stats in feature_stats.items():
if feature in processed_df.columns and "mean" in stats and "std" in stats:
processed_df[feature] = (processed_df[feature] - stats["mean"]) / stats[
"std"
]
return processed_df
# --- Gradio UI Implementation ---
# Load resources once when the app starts
model = load_model()
stats = load_feature_stats()
def predict_sales_ui(store_id):
"""Example function to link the logic to a Gradio interface"""
if model is None:
raise gr.Error("Model not loaded. Check server logs.")
data = load_data()
# Apply your logic
processed = preprocess_data(data, stats)
# Filter for the specific store
store_data = processed[processed['store'] == store_id]
# Return results (placeholder for actual model.predict logic)
return store_data.head()
# Simple Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Sales Forecast Prediction")
store_input = gr.Number(label="Enter Store ID")
output_table = gr.DataFrame(label="Preprocessed Data Preview")
btn = gr.Button("Predict")
btn.click(fn=predict_sales_ui, inputs=store_input, outputs=output_table)
if __name__ == "__main__":
demo.launch() |