Pushpak21 commited on
Commit
c4fa469
ยท
verified ยท
1 Parent(s): f11414d

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +125 -64
app.py CHANGED
@@ -1,97 +1,158 @@
1
 
2
  import streamlit as st
3
  import pandas as pd
4
-
5
- # Custom module imports
6
- from predictor.batch_handler import get_predictions, get_single_prediction
7
- from predictor.chart_plotter import plot_actual_vs_predicted
8
- from predictor.utils import reorder_columns, get_csv_download
9
 
10
  st.set_page_config(page_title="SuperKart Sales Predictor", layout="wide")
11
 
12
- # App header
13
- st.markdown("## ๐Ÿ›’ SuperKart Sales Predictor")
14
- st.write("Use the tabs below to predict sales for a single product or upload a CSV file for batch prediction.")
15
 
16
- # Define tabs
17
  tab1, tab2 = st.tabs(["๐Ÿ” Single Prediction", "๐Ÿ“„ Batch Prediction"])
18
 
19
- # ๐Ÿ” TAB 1: Single Prediction
20
  with tab1:
21
- st.subheader("๐Ÿ” Enter Product Details for Prediction")
22
-
23
- with st.form("single_form"):
24
- col1, col2 = st.columns(2)
25
-
26
- with col1:
27
- product_weight = st.number_input("Product Weight (kg)", min_value=0.0)
28
- sugar_content = st.selectbox("Sugar Content", ["Low Sugar", "Regular", "No Sugar"])
29
- allocated_area = st.number_input("Allocated Area (mยฒ)", min_value=0.0)
30
- product_type = st.selectbox("Product Type", ["Dairy", "Canned", "Frozen Foods", "Health and Hygiene", "Baking Goods"])
31
-
32
- with col2:
33
- mrp = st.number_input("Product MRP (โ‚น)", min_value=0.0)
34
- store_id = st.selectbox("Store ID", ["OUT001", "OUT002", "OUT003", "OUT004"])
35
- store_year = st.number_input("Establishment Year", min_value=1980, max_value=2025)
 
 
36
  store_size = st.selectbox("Store Size", ["Small", "Medium", "High"])
37
- city_type = st.selectbox("Store City Tier", ["Tier 1", "Tier 2", "Tier 3"])
38
- store_type = st.selectbox("Store Type", ["Supermarket Type1", "Supermarket Type2", "Departmental Store", "Food Mart"])
39
-
40
- submitted = st.form_submit_button("๐Ÿ”ฎ Predict Sales")
41
-
42
- if submitted:
43
- input_record = {
 
 
 
44
  "Product_Weight": product_weight,
45
- "Product_Sugar_Content": sugar_content,
46
- "Product_Allocated_Area": allocated_area,
 
47
  "Product_Type": product_type,
48
- "Product_MRP": mrp,
49
  "Store_Id": store_id,
50
- "Store_Establishment_Year": store_year,
51
  "Store_Size": store_size,
52
- "Store_Location_City_Type": city_type,
53
- "Store_Type": store_type
 
54
  }
55
 
56
- try:
57
- prediction = get_single_prediction(input_record)
58
- st.success(f"โœ… Predicted Sales: โ‚น{prediction:,.2f}")
59
- st.json({**input_record, "Predicted_Sales": prediction})
60
- except Exception as e:
61
- st.error(f"โš ๏ธ Error during prediction: {e}")
 
 
 
 
 
62
 
63
- # ๐Ÿ“„ TAB 2: Batch Prediction
64
  with tab2:
65
  st.subheader("๐Ÿ“„ Upload CSV for Batch Prediction")
66
  uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
67
 
68
- if uploaded_file:
69
  try:
70
  df = pd.read_csv(uploaded_file)
71
 
72
  if df.empty:
73
- st.warning("Uploaded file is empty.")
74
  else:
75
  st.write("๐Ÿ“‹ Uploaded Data Preview:")
76
  st.dataframe(df.head())
77
 
78
- df = get_predictions(df)
79
- df = reorder_columns(df, ["Product_Store_Sales_Total", "Predicted_Sales"])
80
-
81
- col1, col2 = st.columns([6, 1])
82
- with col1:
83
- st.subheader("๐Ÿ“ˆ Prediction Results:")
84
- with col2:
85
- st.download_button(
86
- label="๐Ÿ“ฅ Download CSV",
87
- data=get_csv_download(df),
88
- file_name="batch_predictions.csv",
89
- mime="text/csv",
90
- use_container_width=True
91
- )
92
-
93
- st.dataframe(df)
94
- plot_actual_vs_predicted(df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  except Exception as e:
97
  st.error(f"โš ๏ธ Error while processing the file: {e}")
 
1
 
2
  import streamlit as st
3
  import pandas as pd
4
+ import requests
5
+ import altair as alt
 
 
 
6
 
7
  st.set_page_config(page_title="SuperKart Sales Predictor", layout="wide")
8
 
9
+ st.title("๐Ÿ›’ SuperKart Sales Predictor")
10
+ st.markdown("Use the tabs below to predict sales for a single product or upload a CSV file for batch prediction.")
 
11
 
12
+ # Create two tabs
13
  tab1, tab2 = st.tabs(["๐Ÿ” Single Prediction", "๐Ÿ“„ Batch Prediction"])
14
 
15
+ # ----------------- Tab 1: Single Prediction -----------------
16
  with tab1:
17
+ col1, col2 = st.columns(2)
18
+
19
+ with col1:
20
+ with st.expander("๐Ÿ“ฆ Product Details", expanded=True):
21
+ product_weight = st.slider("Product Weight (kg)", 4.0, 22.0, 12.65, 0.1)
22
+ product_allocated_area = st.slider("Allocated Shelf Area", 0.0, 0.3, 0.07, 0.01)
23
+ product_mrp = st.slider("Product MRP", 31.0, 266.0, 147.0)
24
+ product_sugar_content = st.radio("Sugar Content", ["Low Sugar", "Regular", "No Sugar", "reg"], horizontal=True)
25
+ product_type = st.selectbox("Product Type", [
26
+ "Fruits and Vegetables", "Snack Foods", "Frozen Foods", "Dairy", "Household", "Baking Goods",
27
+ "Canned", "Health and Hygiene", "Meat", "Soft Drinks", "Breads", "Hard Drinks", "Others",
28
+ "Starchy Foods", "Breakfast", "Seafood"
29
+ ])
30
+
31
+ with col2:
32
+ with st.expander("๐Ÿฌ Store Details", expanded=True):
33
+ store_id = st.radio("Store ID", ["OUT001", "OUT002", "OUT003", "OUT004"], horizontal=True)
34
  store_size = st.selectbox("Store Size", ["Small", "Medium", "High"])
35
+ store_location = st.radio("City Tier", ["Tier 1", "Tier 2", "Tier 3"], horizontal=True)
36
+ store_type = st.selectbox("Store Type", [
37
+ "Supermarket Type1", "Supermarket Type2", "Departmental Store", "Grocery Store"
38
+ ])
39
+ est_year = st.slider("Establishment Year", 1987, 2009, 2002)
40
+
41
+ # Submit action
42
+ if st.button("๐ŸŽฏ Predict Sales ๐ŸŽฏ"):
43
+ try:
44
+ payload = {
45
  "Product_Weight": product_weight,
46
+ "Product_Allocated_Area": product_allocated_area,
47
+ "Product_MRP": product_mrp,
48
+ "Product_Sugar_Content": product_sugar_content,
49
  "Product_Type": product_type,
 
50
  "Store_Id": store_id,
 
51
  "Store_Size": store_size,
52
+ "Store_Location_City_Type": store_location,
53
+ "Store_Type": store_type,
54
+ "Store_Establishment_Year": est_year
55
  }
56
 
57
+ url = "https://Pushpak21-SuperKart-Sales-Forecast-Backend.hf.space/predict"
58
+ response = requests.post(url, json=payload)
59
+ result = response.json()
60
+
61
+ if response.status_code == 200:
62
+ st.success(f"๐Ÿ“ˆ Predicted Sales: โ‚น{result['prediction'][0]:,.2f}")
63
+ else:
64
+ st.error(f"โŒ Error: {result.get('error', 'Unknown error')}")
65
+
66
+ except Exception as e:
67
+ st.error(f"โš ๏ธ Request failed: {e}")
68
 
69
+ # ----------------- Tab 2: Batch Prediction -----------------
70
  with tab2:
71
  st.subheader("๐Ÿ“„ Upload CSV for Batch Prediction")
72
  uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
73
 
74
+ if uploaded_file is not None:
75
  try:
76
  df = pd.read_csv(uploaded_file)
77
 
78
  if df.empty:
79
+ st.warning("Uploaded file is empty!")
80
  else:
81
  st.write("๐Ÿ“‹ Uploaded Data Preview:")
82
  st.dataframe(df.head())
83
 
84
+ # Convert DataFrame to list of records
85
+ records = df.to_dict(orient="records")
86
+
87
+ # Make prediction API call
88
+ response = requests.post(
89
+ "https://Pushpak21-SuperKart-Sales-Forecast-Backend.hf.space/predict_batch",
90
+ json=records
91
+ )
92
+
93
+ if response.status_code == 200:
94
+ predictions = response.json().get("predictions", [])
95
+ df["Predicted_Sales"] = predictions
96
+
97
+ # Reorder columns: move actual & predicted sales to front
98
+ priority_cols = ["Product_Store_Sales_Total", "Predicted_Sales"]
99
+ other_cols = [col for col in df.columns if col not in priority_cols]
100
+ df = df[priority_cols + other_cols]
101
+
102
+ st.success("โœ… Batch prediction complete!")
103
+
104
+ # Header + Download button aligned
105
+ col1, col2 = st.columns([6, 1])
106
+ with col1:
107
+ st.subheader("๐Ÿ“ˆ Prediction Results:")
108
+ with col2:
109
+ csv = df.to_csv(index=False).encode('utf-8')
110
+ st.download_button(
111
+ label="๐Ÿ“ฅ Download CSV",
112
+ data=csv,
113
+ file_name="batch_predictions.csv",
114
+ mime="text/csv",
115
+ use_container_width=True
116
+ )
117
+
118
+ # Display result table
119
+ st.dataframe(df)
120
+
121
+ # ๐Ÿ“‰ Altair Line Chart: Actual vs Predicted
122
+ if "Product_Store_Sales_Total" in df.columns:
123
+ plot_df = df[["Product_Store_Sales_Total", "Predicted_Sales"]].dropna()
124
+
125
+ if not plot_df.empty:
126
+ st.subheader("๐Ÿ“‰ Actual vs Predicted Sales")
127
+
128
+ # Prepare long-format dataframe for Altair
129
+ plot_df = plot_df.reset_index().rename(columns={
130
+ "Product_Store_Sales_Total": "Actual Sales",
131
+ "Predicted_Sales": "Predicted Sales",
132
+ "index": "Index"
133
+ })
134
+
135
+ plot_df_melted = plot_df.melt(
136
+ id_vars="Index",
137
+ var_name="Type",
138
+ value_name="Sales"
139
+ )
140
+
141
+ line_chart = alt.Chart(plot_df_melted).mark_line(point=True).encode(
142
+ x=alt.X("Index:O", title="Record Index"),
143
+ y=alt.Y("Sales:Q", title="Sales Value"),
144
+ color=alt.Color("Type:N", title="Sales Type"),
145
+ tooltip=["Index", "Type", "Sales"]
146
+ ).properties(
147
+ width=700,
148
+ height=400
149
+ )
150
+
151
+ st.altair_chart(line_chart, use_container_width=True)
152
+ else:
153
+ st.info("โ„น๏ธ Not enough valid rows for plotting.")
154
+ else:
155
+ st.error(f"โŒ API Error {response.status_code}: {response.text}")
156
 
157
  except Exception as e:
158
  st.error(f"โš ๏ธ Error while processing the file: {e}")