GL / app.py
DD009's picture
Upload folder using huggingface_hub
7ca18a6 verified
import streamlit as st
import pandas as pd
import requests
from datetime import datetime
# Set the title of the Streamlit app
st.title("Retail Product Sales Prediction")
# Section for single prediction
st.subheader("Single Product-Store Prediction")
# Create two columns for better layout
col1, col2 = st.columns(2)
with col1:
# Product Features
st.markdown("**Product Details**")
product_weight = st.slider("Product Weight (kg)", 4.0, 22.0, 12.66, 0.01)
product_sugar = st.selectbox("Sugar Content", ["Low", "Medium", "High"])
product_area = st.slider("Allocated Area (sqm)", 0.004, 0.3, 0.05, 0.001)
product_type = st.selectbox("Product Type", [
"Frozen Foods", "Dairy", "Canned", "Baking Goods",
"Health and Hygiene", "Snack Foods", "Meat", "Household"
])
product_mrp = st.slider("Product MRP (price)", 31.0, 266.0, 147.0, 0.5)
with col2:
# Store Features
st.markdown("**Store Details**")
establishment_year = st.slider("Store Establishment Year",
1987, datetime.now().year, 2002)
store_size = st.selectbox("Store Size", ["Small", "Medium", "Large"])
city_type = st.selectbox("City Tier", ["Tier 1", "Tier 2", "Tier 3"])
store_type = st.selectbox("Store Type", [
"Supermarket Type1", "Supermarket Type2",
"Departmental Store", "Food Mart"
])
# Prepare input data (only including features used in the backend)
input_data = {
"product_weight": product_weight,
"product_allocated_area": product_area,
"product_mrp": product_mrp,
"store_establishment_year": establishment_year,
"product_sugar_content": product_sugar,
"product_type": product_type,
"store_size": store_size,
"store_location_city_type": city_type,
"store_type": store_type
}
# Make prediction when the "Predict" button is clicked
if st.button("Predict Sales"):
try:
# Updated API endpoint
backend_url = "https://DD009-SuperKartBackend.hf.space" # Corrected URL
endpoint = f"{backend_url}/v1/sales"
st.info(f"Connecting to: {endpoint}") # Show the endpoint being called
response = requests.post(endpoint, json=input_data, timeout=10)
if response.status_code == 200:
try:
result = response.json()
st.success(f"Predicted Sales Total: ${result['predicted_sales']:.2f}")
# Display features used
st.markdown("**Features Used**")
st.write(", ".join(result['features_used']))
except ValueError:
st.error("Could not decode JSON response from server")
st.text(f"Raw response: {response.text}")
else:
st.error(f"Error making prediction (Status {response.status_code})")
try:
error_details = response.json()
st.json(error_details)
except ValueError:
st.text(f"Raw response: {response.text}")
except requests.exceptions.RequestException as e:
st.error(f"Connection error: {str(e)}")
st.info("Please check if the backend server is running and accessible")
# Section for batch prediction
st.subheader("Batch Prediction")
st.write("Upload a CSV file containing multiple product-store combinations")
# Instructions for CSV format
st.markdown("""
**CSV File Requirements:**
- Must contain these exact columns:
- Product_Weight, Product_Allocated_Area, Product_MRP
- Store_Establishment_Year, Product_Sugar_Content, Product_Type
- Store_Size, Store_Location_City_Type, Store_Type
""")
# File uploader for batch predictions
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
if st.button("Predict Batch Sales"):
try:
backend_url = "https://DD009-SuperKartBackend.hf.space" # Corrected URL
endpoint = f"{backend_url}/v1/salesbatch"
st.info(f"Connecting to: {endpoint}") # Show the endpoint being called
files = {'file': (uploaded_file.name, uploaded_file, 'text/csv')}
response = requests.post(endpoint, files=files, timeout=30)
if response.status_code == 200:
try:
results = response.json()
if 'predictions' in results:
results_df = pd.DataFrame(results['predictions'])
st.success("Batch predictions completed!")
st.dataframe(results_df)
# Download button for results
csv = results_df.to_csv(index=False)
st.download_button(
label="Download predictions as CSV",
data=csv,
file_name='sales_predictions.csv',
mime='text/csv'
)
else:
st.error("Unexpected response format from server")
st.json(results)
except ValueError:
st.error("Could not decode JSON response from server")
st.text(f"Raw response: {response.text}")
else:
st.error(f"Error making predictions (Status {response.status_code})")
try:
error_details = response.json()
st.json(error_details)
except ValueError:
st.text(f"Raw response: {response.text}")
except requests.exceptions.RequestException as e:
st.error(f"Connection error: {str(e)}")
st.info("Please check if the backend server is running and accessible")
# Add sample data section
st.sidebar.markdown("### Sample Data")
if st.sidebar.button("Show Sample Input"):
sample_data = {
"product_weight": 12.66,
"product_allocated_area": 0.027,
"product_mrp": 117.08,
"store_establishment_year": 2009,
"product_sugar_content": "Low",
"product_type": "Frozen Foods",
"store_size": "Medium",
"store_location_city_type": "Tier 2",
"store_type": "Supermarket Type2"
}
st.sidebar.json(sample_data)
st.sidebar.download_button(
label="Download Sample CSV",
data=pd.DataFrame([sample_data]).to_csv(index=False),
file_name='sample_input.csv',
mime='text/csv'
)