Pushpak21's picture
Upload folder using huggingface_hub
90e759f verified
import streamlit as st
import pandas as pd
import requests
import altair as alt
from predictor.batch_handler import get_predictions, get_single_prediction
from predictor.utils import reorder_columns, get_csv_download
from predictor.chart_plotter import plot_actual_vs_predicted
from predictor.error_summary_table import show_error_summary
st.set_page_config(page_title="SuperKart Sales Predictor", layout="wide")
st.title("πŸ›’ SuperKart Sales Predictor")
st.markdown("Use the tabs below to predict sales for a single product or upload a CSV file for batch prediction.")
# Create tabs
tab1, tab2 = st.tabs(["πŸ” Single Prediction", "πŸ“„ Batch Prediction"])
# ----------------- Tab 1: Single Prediction -----------------
with tab1:
col1, col2 = st.columns(2)
with col1:
with st.expander("πŸ“¦ Product Details", expanded=True):
product_weight = st.slider("Product Weight (kg)", 4.0, 22.0, 12.65, 0.1)
product_allocated_area = st.slider("Allocated Shelf Area", 0.0, 0.3, 0.07, 0.01)
product_mrp = st.slider("Product MRP", 31.0, 266.0, 147.0)
product_sugar_content = st.radio("Sugar Content", ["Low Sugar", "Regular", "No Sugar", "reg"], horizontal=True)
product_type = st.selectbox("Product Type", [
"Fruits and Vegetables", "Snack Foods", "Frozen Foods", "Dairy", "Household", "Baking Goods",
"Canned", "Health and Hygiene", "Meat", "Soft Drinks", "Breads", "Hard Drinks", "Others",
"Starchy Foods", "Breakfast", "Seafood"
])
with col2:
with st.expander("🏬 Store Details", expanded=True):
store_id = st.radio("Store ID", ["OUT001", "OUT002", "OUT003", "OUT004"], horizontal=True)
store_size = st.selectbox("Store Size", ["Small", "Medium", "High"])
store_location = st.radio("City Tier", ["Tier 1", "Tier 2", "Tier 3"], horizontal=True)
store_type = st.selectbox("Store Type", [
"Supermarket Type1", "Supermarket Type2", "Departmental Store", "Grocery Store"
])
est_year = st.slider("Establishment Year", 1987, 2009, 2002)
# Submit button
col1, col2 = st.columns([1, 3]) # Adjust ratio as needed
with col1:
predict_clicked = st.button("🎯 Predict Sales 🎯", key="predict_button")
if predict_clicked:
payload = {
"Product_Weight": product_weight,
"Product_Allocated_Area": product_allocated_area,
"Product_MRP": product_mrp,
"Product_Sugar_Content": product_sugar_content,
"Product_Type": product_type,
"Store_Id": store_id,
"Store_Size": store_size,
"Store_Location_City_Type": store_location,
"Store_Type": store_type,
"Store_Establishment_Year": est_year
}
try:
prediction = get_single_prediction(payload)
pred_value = prediction[0] if isinstance(prediction, list) else prediction
with col2:
st.success(f"βœ… Predicted Sales: β‚Ή{pred_value:,.2f}")
st.json({**payload, "Predicted_Sales": pred_value})
except Exception as e:
with col2:
st.error(f"⚠️ Error during prediction: {e}")
# ----------------- Tab 2: Batch Prediction -----------------
with tab2:
st.subheader("πŸ“„ Upload CSV for Batch Prediction")
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
if uploaded_file:
try:
df = pd.read_csv(uploaded_file)
if df.empty:
st.warning("Uploaded file is empty.")
else:
st.write("πŸ“‹ Uploaded Data Preview:")
st.dataframe(df.head())
df = get_predictions(df)
if "Product_Store_Sales_Total" in df.columns:
df = reorder_columns(df, ["Product_Store_Sales_Total", "Predicted_Sales"])
else:
df = reorder_columns(df, ["Predicted_Sales"])
col1, col2 = st.columns([6, 1])
with col1:
st.subheader("πŸ“ˆ Prediction Results:")
with col2:
st.download_button(
label="πŸ“₯ Download CSV",
data=get_csv_download(df),
file_name="batch_predictions.csv",
mime="text/csv",
use_container_width=True
)
st.dataframe(df)
show_error_summary(df)
plot_actual_vs_predicted(df)
except Exception as e:
st.error(f"⚠️ Error while processing the file: {e}")