SalesPredictionSuperKart / streamlit_app.py
vikas0615's picture
Upload folder using huggingface_hub
1275ff8 verified
import streamlit as st
import pandas as pd
import joblib
import requests
import os
os.environ["HOME"] = "/tmp"
config_dir = "./.streamlit"
os.makedirs(config_dir, exist_ok=True)
# Load the trained model
def load_model():
return joblib.load("superkart_sales_prediction_model_v1_0.joblib")
model = load_model()
# Streamlit UI for SuperKart Sales Prediction
st.title("SuperKart Sales Prediction App")
st.write("The Sales Prediction App is an internal tool to predict sales based on past sales, product types, and store.")
st.write("Kindly enter the details to predict sales forecast.")
# Collect user input
Product_Weight = st.number_input("Product_Weight", min_value=0.0, max_value=100.0, step=0.1, value=90.0)
Product_Sugar_Content = st.selectbox("Product_Sugar_Content", ["No Sugar", "Low Sugar", "Regular"])
Product_Allocated_Area = st.number_input("Product_Allocated_Area", min_value=0.0, max_value=3.0, step=0.01, value=1.0)
Product_Type = st.selectbox("Product_Type", [
"Baking Goods", "Breads", "Breakfast", "Canned", "Dairy", "Frozen Food",
"Fruits and Vegetables", "Hard Drinks", "Health and Hygiene", "Household",
"Meat", "Others", "Seafood", "Snack Foods", "Soft Drinks", "Starchy Foods" # πŸ”§ Fixed typo: "Startchy" β†’ "Starchy"
])
Product_MRP = st.number_input("Product_MRP", min_value=1.0, max_value=50.0, step=0.1, value=40.0)
Store_Id = st.selectbox("Store_Id", ["OUT001", "OUT002", "OUT003", "OUT004"])
Store_Size = st.selectbox("Store_Size", ["Small", "Medium", "High"])
Store_Location_City_Type = st.selectbox("Store_Location_City_Type", ["Tier 1", "Tier 2", "Tier 3"])
Store_Type = st.selectbox("Store_Type", ["Supermarket Type2", "Departmental Store", "Supermarket Type1", "Food Mart"])
Product_Store_Sales_Total = st.number_input("Product_Store_Sales_Total", min_value=1.0, max_value=10000.0, step=0.01, value=90.0)
# Convert user input into a DataFrame
input_data = pd.DataFrame([{
'Product_Weight': Product_Weight,
'Product_Sugar_Content': Product_Sugar_Content,
'Product_Allocated_Area': Product_Allocated_Area,
'Product_Type': Product_Type,
'Product_MRP': Product_MRP,
'Store_Id': Store_Id,
'Store_Size': Store_Size,
'Store_Location_City_Type': Store_Location_City_Type,
'Store_Type': Store_Type,
'Product_Store_Sales_Total': Product_Store_Sales_Total
}])
# πŸ“‚ Section for batch prediction
st.subheader("Batch Prediction")
uploaded_file = st.file_uploader("Upload CSV file for batch prediction", type=["csv"])
if uploaded_file is not None:
if st.button("Predict Batch"):
try:
files = {"file": uploaded_file.getvalue()}
response = requests.post(
"https://vikas0615-vikas0615--superkartsalesprediction-updated.hf.space/v1/forecastbatch",
files={"file": uploaded_file}
)
if response.status_code == 200:
predictions = response.json()
st.success("Batch predictions completed!")
st.write(predictions)
else:
st.error(f"Error from batch API: {response.status_code} - {response.text}")
except Exception as e:
st.error(f"Batch request failed: {str(e)}")
classification_threshold = 0.5 # Or whatever value you want
# Predict button
if st.button("Predict"):
prediction_proba = model.predict_proba(input_data)[0, 1]
prediction = (prediction_proba >= classification_threshold).astype(int)
result = "churn" if prediction == 1 else "not churn"
st.write(f"Based on the information provided, the customer is likely to {result}.")