File size: 6,314 Bytes
cef1463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import streamlit as st
import pandas as pd
import requests
import io

# Set the title of the Streamlit app
st.title("Product Store Sales Prediction")

# --- Online Prediction Section ---
st.subheader("Online Prediction")
st.markdown("Enter the details for a single product to get a sales prediction.")

# Create a two-column layout for better organization
col1, col2 = st.columns(2)

with col1:
    # Product-related inputs
    st.text_input("Product ID", "FD_123")
    product_weight = st.number_input("Product Weight (grams)", min_value=0.0, value=150.5, step=0.1)
    product_sugar_content = st.selectbox("Product Sugar Content", ["Low Sugar", "Regular", "High Sugar"])
    product_allocated_area = st.number_input("Product Allocated Area (sq. cm)", min_value=0.0, value=500.0, step=10.0)
    product_type = st.selectbox("Product Type", [
        "Dairy", "Soft Drinks", "Meat", "Fruits and Vegetables", "Household",
        "Baking Goods", "Snack Foods", "Frozen Foods", "Breakfast",
        "Health and Hygiene", "Hard Drinks", "Canned", "Breads", "Starchy Foods",
        "Others", "Seafood"
    ])
    product_mrp = st.number_input("Product MRP (in $)", min_value=0.0, value=150.0)

with col2:
    # Store-related inputs
    st.text_input("Store ID", "STR_001")
    store_establishment_year = st.number_input("Store Establishment Year", min_value=1900, max_value=2024, step=1, value=2010)
    store_size = st.selectbox("Store Size", ["Small", "Medium", "High"])
    store_location_city_type = st.selectbox("Store Location City Type", ["City Tier 1", "City Tier 2", "City Tier 3"])
    store_type = st.selectbox("Store Type", ["Supermarket Type1", "Supermarket Type2", "Grocery Store", "Supermarket Type3"])

# Button to trigger the prediction
if st.button("Predict Sales"):
    # Collect user input into a dictionary
    input_data = {
        'Product_Weight': product_weight,
        'Product_Sugar_Content': product_sugar_content,
        'Product_Allocated_Area': product_allocated_area,
        'Product_Type': product_type,
        'Product_MRP': product_mrp,
        'Store_Establishment_Year': store_establishment_year,
        'Store_Size': store_size,
        'Store_Location_City_Type': store_location_city_type,
        'Store_Type': store_type
    }

    # Display the collected data for confirmation
    st.write("---")
    st.write("**Input Data for Prediction:**")
    st.json(input_data)

    # --- API Call ---
    # NOTE: Replace the URL with your actual API endpoint.
    API_URL_SINGLE = "https://vaishaliagarwal-ProductPricePredictionBackend.hf.space/v1/predict_sales"
    try:
        # Send data to the prediction API
        response = requests.post(API_URL_SINGLE, json=input_data)
        response.raise_for_status()  # Raise an exception for bad status codes (4xx or 5xx)

        if response.status_code == 200:
            prediction = response.json()
            # Assuming the API returns a dictionary like: {'Predicted_Sales': 5500.75}
            predicted_sales = prediction.get('Predicted_Sales', 'N/A')
            st.success(f"**Predicted Total Sales:** ${predicted_sales:,.2f}")
        else:
            st.error(f"Error: Received status code {response.status_code}")
            st.json(response.json())

    except requests.exceptions.RequestException as e:
        st.error(f"API request failed: {e}")
        st.warning("Please ensure the backend API is running and the URL is correct.")


# --- Batch Prediction Section ---
st.subheader("Batch Prediction")
st.markdown("Upload a CSV file with multiple product details to get batch sales predictions.")

# Allow users to upload a CSV file
uploaded_file = st.file_uploader("Choose a CSV file", type=["csv"])

if uploaded_file is not None:
    # Display a preview of the uploaded data
    try:
        df_upload = pd.read_csv(uploaded_file)
        st.write("**Uploaded Data Preview:**")
        st.dataframe(df_upload.head())

        # Button to trigger batch prediction
        if st.button("Predict for Batch"):
            # --- API Call for Batch Prediction ---
            # NOTE: Replace the URL with your actual batch prediction API endpoint.
            API_URL_BATCH = "https://vaishaliagarwal-ProductPricePredictionBackend.hf.space/v1/predict_sales_batch"

            # Reset file pointer to the beginning before sending
            uploaded_file.seek(0)
            
            # The file needs to be sent in a 'files' dictionary
            files = {"file": (uploaded_file.name, uploaded_file, "text/csv")}
            
            try:
                with st.spinner('Processing batch prediction...'):
                    response = requests.post(API_URL_BATCH, files=files)
                    response.raise_for_status()

                    if response.status_code == 200:
                        # Assuming the API returns a JSON array of predictions
                        predictions = response.json()
                        st.success("Batch predictions completed successfully!")

                        # Convert predictions to a DataFrame for better display
                        df_predictions = pd.DataFrame(predictions)
                        st.write("**Prediction Results:**")
                        st.dataframe(df_predictions)

                        # Provide an option to download the results
                        @st.cache_data
                        def convert_df_to_csv(df):
                            return df.to_csv(index=False).encode('utf-8')

                        csv = convert_df_to_csv(df_predictions)
                        st.download_button(
                           label="Download Predictions as CSV",
                           data=csv,
                           file_name='sales_predictions.csv',
                           mime='text/csv',
                        )
                    else:
                        st.error(f"Error: Received status code {response.status_code}")
                        st.json(response.json())

            except requests.exceptions.RequestException as e:
                st.error(f"API request failed: {e}")
                st.warning("Please ensure the backend API is running and the URL is correct.")

    except Exception as e:
        st.error(f"An error occurred while processing the file: {e}")