import streamlit as st import pandas as pd import requests import altair as alt from predictor.batch_handler import get_predictions, get_single_prediction from predictor.utils import reorder_columns, get_csv_download from predictor.chart_plotter import plot_actual_vs_predicted from predictor.error_summary_table import show_error_summary st.set_page_config(page_title="SuperKart Sales Predictor", layout="wide") st.title("🛒 SuperKart Sales Predictor") st.markdown("Use the tabs below to predict sales for a single product or upload a CSV file for batch prediction.") # Create tabs tab1, tab2 = st.tabs(["🔍 Single Prediction", "📄 Batch Prediction"]) # ----------------- Tab 1: Single Prediction ----------------- with tab1: col1, col2 = st.columns(2) with col1: with st.expander("📦 Product Details", expanded=True): product_weight = st.slider("Product Weight (kg)", 4.0, 22.0, 12.65, 0.1) product_allocated_area = st.slider("Allocated Shelf Area", 0.0, 0.3, 0.07, 0.01) product_mrp = st.slider("Product MRP", 31.0, 266.0, 147.0) product_sugar_content = st.radio("Sugar Content", ["Low Sugar", "Regular", "No Sugar", "reg"], horizontal=True) product_type = st.selectbox("Product Type", [ "Fruits and Vegetables", "Snack Foods", "Frozen Foods", "Dairy", "Household", "Baking Goods", "Canned", "Health and Hygiene", "Meat", "Soft Drinks", "Breads", "Hard Drinks", "Others", "Starchy Foods", "Breakfast", "Seafood" ]) with col2: with st.expander("🏬 Store Details", expanded=True): store_id = st.radio("Store ID", ["OUT001", "OUT002", "OUT003", "OUT004"], horizontal=True) store_size = st.selectbox("Store Size", ["Small", "Medium", "High"]) store_location = st.radio("City Tier", ["Tier 1", "Tier 2", "Tier 3"], horizontal=True) store_type = st.selectbox("Store Type", [ "Supermarket Type1", "Supermarket Type2", "Departmental Store", "Grocery Store" ]) est_year = st.slider("Establishment Year", 1987, 2009, 2002) # Submit button col1, col2 = st.columns([1, 3]) # Adjust ratio as needed with col1: predict_clicked = st.button("🎯 Predict Sales 🎯", key="predict_button") if predict_clicked: payload = { "Product_Weight": product_weight, "Product_Allocated_Area": product_allocated_area, "Product_MRP": product_mrp, "Product_Sugar_Content": product_sugar_content, "Product_Type": product_type, "Store_Id": store_id, "Store_Size": store_size, "Store_Location_City_Type": store_location, "Store_Type": store_type, "Store_Establishment_Year": est_year } try: prediction = get_single_prediction(payload) pred_value = prediction[0] if isinstance(prediction, list) else prediction with col2: st.success(f"✅ Predicted Sales: ₹{pred_value:,.2f}") st.json({**payload, "Predicted_Sales": pred_value}) except Exception as e: with col2: st.error(f"⚠️ Error during prediction: {e}") # ----------------- Tab 2: Batch Prediction ----------------- with tab2: st.subheader("📄 Upload CSV for Batch Prediction") uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"]) if uploaded_file: try: df = pd.read_csv(uploaded_file) if df.empty: st.warning("Uploaded file is empty.") else: st.write("📋 Uploaded Data Preview:") st.dataframe(df.head()) df = get_predictions(df) if "Product_Store_Sales_Total" in df.columns: df = reorder_columns(df, ["Product_Store_Sales_Total", "Predicted_Sales"]) else: df = reorder_columns(df, ["Predicted_Sales"]) col1, col2 = st.columns([6, 1]) with col1: st.subheader("📈 Prediction Results:") with col2: st.download_button( label="📥 Download CSV", data=get_csv_download(df), file_name="batch_predictions.csv", mime="text/csv", use_container_width=True ) st.dataframe(df) show_error_summary(df) plot_actual_vs_predicted(df) except Exception as e: st.error(f"⚠️ Error while processing the file: {e}")