frontendloki / app.py
Lokiiparihar's picture
Update app.py
cef008f verified
import streamlit as st
import pandas as pd
import numpy as np
import requests
import logging
from datetime import datetime
import io
# ------------------------------------------------
# CONSTANTS (API endpoints)
# ------------------------------------------------
API_SINGLE = "https://Lokiiparihar-backendlokii.hf.space/v1/sales/predict"
API_BATCH = "https://Lokiiparihar-backendlokii.hf.space/v1/sales/predict/batch"
#Lokiiparihar/backendlokii
# ------------------------------------------------
# LOGGING
# ------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# ------------------------------------------------
# DIRECT PICKLIST VALUES (raw strings — no encoding)
# ------------------------------------------------
STORE_IDS = ["OUT001", "OUT002", "OUT003", "OUT004"]
STORE_TYPES = [
"Departmental Store",
"Food Mart",
"Supermarket Type1",
"Supermarket Type2"
]
STORE_SIZES = ["High", "Medium", "Small"]
STORE_CITY_TYPES = ["Tier 1", "Tier 2", "Tier 3"]
PRODUCT_TYPES = [
"Baking Goods", "Breads", "Breakfast", "Canned", "Dairy",
"Frozen Foods", "Fruits and Vegetables", "Hard Drinks",
"Health and Hygiene", "Household", "Meat", "Others",
"Seafood", "Snack Foods", "Soft Drinks", "Starchy Foods"
]
PRODUCT_SUGAR_CONTENT = ["Low Sugar", "No Sugar", "Regular"]
# Allowed prefixes
PRODUCT_PREFIXES = ["DR", "FD", "NC"]
# ------------------------------------------------
# API HELPERS
# ------------------------------------------------
def predict_single(input_data):
"""Send raw (non-encoded) data directly to the API."""
response = requests.post(API_SINGLE, json=input_data)
if response.status_code != 200:
raise ValueError(f"API Error: {response.text}")
result = response.json()
return float(result["Prediction"])
def predict_batch(df):
# Convert DataFrame to CSV buffer
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False)
csv_buffer.seek(0)
# Send file as multipart/form-data
files = {'file': ('data.csv', csv_buffer.getvalue())}
response = requests.post(API_BATCH, files=files)
if response.status_code != 200:
raise ValueError(f"API Error: {response.text}")
# Server returns dict {Product_Id: prediction}
predictions_dict = response.json()
# Map predictions back to DataFrame
df_out = df.copy()
df_out["Predicted_Sales"] = df_out["Product_Id"].map(
lambda x: round(predictions_dict.get(x, 0), 2)
)
return df_out
# ------------------------------------------------
# STREAMLIT UI
# ------------------------------------------------
st.title("Superkart Product Sales Prediction")
# ---------------------
# BATCH PREDICTION
# ---------------------
st.header("Batch Prediction")
uploaded = st.file_uploader("Upload CSV", type=["csv"])
if uploaded is not None:
df = pd.read_csv(uploaded)
st.write(df.head())
# Show Predict button only after upload
if st.button("Predict Batch"):
try:
# --- Add derived columns required by the model ---
df['Product_Id_Prefix'] = df['Product_Id'].str[:2]
current_year = datetime.now().year
df['Store_Age'] = current_year - df['Store_Establishment_Year']
df_pred = predict_batch(df)
st.success("Batch Prediction Completed!")
st.dataframe(df_pred)
# ---- DOWNLOAD MERGED FILE ----
csv = df_pred.to_csv(index=False)
st.download_button("Download CSV", csv, "predictions.csv")
except Exception as e:
st.error(str(e))
# ---------------------
# SINGLE PREDICTION
# ---------------------
st.header("Single Prediction")
st.write("Store and Product datatils:")
with st.form("single_form"):
col1, col2 = st.columns(2)
with col1:
product_id_input = st.text_input("Product ID (e.g., DR123, FD889A)")
sugar = st.selectbox("Product Sugar Content", PRODUCT_SUGAR_CONTENT)
product_type = st.selectbox("Product Type", PRODUCT_TYPES)
store_id = st.selectbox("Store ID", STORE_IDS)
store_size = st.selectbox("Store Size", STORE_SIZES)
store_type = st.selectbox("Store Type", STORE_TYPES)
with col2:
weight = st.number_input("Product Weight", min_value=0.0, step=1.0)
area = st.number_input("Product Allocated Area", min_value=0.0, max_value=1.0, step=0.010, format="%.3f")
mrp = st.number_input("Product MRP", min_value=0.0, step=20.0)
est_year = st.number_input(
"Store Establishment Year",
1900, datetime.now().year,
2015
)
store_city = st.selectbox("Store Location City Type", STORE_CITY_TYPES)
submit = st.form_submit_button("Predict")
if submit:
try:
# ---- Extract and validate prefix ----
product_id = product_id_input.strip().upper()
if len(product_id) < 2:
st.error("Product ID must contain at least 2 characters (prefix) followed by few digits.")
st.stop()
prefix = product_id[:2]
if prefix not in PRODUCT_PREFIXES:
st.error(f"Invalid Product Prefix '{prefix}'. Valid prefixes: {PRODUCT_PREFIXES}")
st.stop()
# Compute store age
current_year = datetime.now().year
store_age = current_year - est_year
# Build payload
payload = {
"Store_Id": store_id,
"Store_Type": store_type,
"Store_Size": store_size,
"Store_Location_City_Type": store_city,
"Product_ID": product_id,
"Product_Id_Prefix": prefix,
"Product_Type": product_type,
"Product_Sugar_Content": sugar,
"Store_Age": store_age,
"Product_Weight": weight,
"Product_Allocated_Area": area,
"Product_MRP": mrp
}
pred = predict_single(payload)
final_pred = round(pred, 2)
st.success(f"Predicted Sales: {final_pred} units")
except Exception as e:
st.error(e)