import streamlit as st import pandas as pd import numpy as np import requests import logging from datetime import datetime import io # ------------------------------------------------ # CONSTANTS (API endpoints) # ------------------------------------------------ API_SINGLE = "https://Lokiiparihar-backendlokii.hf.space/v1/sales/predict" API_BATCH = "https://Lokiiparihar-backendlokii.hf.space/v1/sales/predict/batch" #Lokiiparihar/backendlokii # ------------------------------------------------ # LOGGING # ------------------------------------------------ logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler()] ) logger = logging.getLogger(__name__) # ------------------------------------------------ # DIRECT PICKLIST VALUES (raw strings — no encoding) # ------------------------------------------------ STORE_IDS = ["OUT001", "OUT002", "OUT003", "OUT004"] STORE_TYPES = [ "Departmental Store", "Food Mart", "Supermarket Type1", "Supermarket Type2" ] STORE_SIZES = ["High", "Medium", "Small"] STORE_CITY_TYPES = ["Tier 1", "Tier 2", "Tier 3"] PRODUCT_TYPES = [ "Baking Goods", "Breads", "Breakfast", "Canned", "Dairy", "Frozen Foods", "Fruits and Vegetables", "Hard Drinks", "Health and Hygiene", "Household", "Meat", "Others", "Seafood", "Snack Foods", "Soft Drinks", "Starchy Foods" ] PRODUCT_SUGAR_CONTENT = ["Low Sugar", "No Sugar", "Regular"] # Allowed prefixes PRODUCT_PREFIXES = ["DR", "FD", "NC"] # ------------------------------------------------ # API HELPERS # ------------------------------------------------ def predict_single(input_data): """Send raw (non-encoded) data directly to the API.""" response = requests.post(API_SINGLE, json=input_data) if response.status_code != 200: raise ValueError(f"API Error: {response.text}") result = response.json() return float(result["Prediction"]) def predict_batch(df): # Convert DataFrame to CSV buffer csv_buffer = io.StringIO() df.to_csv(csv_buffer, index=False) csv_buffer.seek(0) # Send file as multipart/form-data files = {'file': ('data.csv', csv_buffer.getvalue())} response = requests.post(API_BATCH, files=files) if response.status_code != 200: raise ValueError(f"API Error: {response.text}") # Server returns dict {Product_Id: prediction} predictions_dict = response.json() # Map predictions back to DataFrame df_out = df.copy() df_out["Predicted_Sales"] = df_out["Product_Id"].map( lambda x: round(predictions_dict.get(x, 0), 2) ) return df_out # ------------------------------------------------ # STREAMLIT UI # ------------------------------------------------ st.title("Superkart Product Sales Prediction") # --------------------- # BATCH PREDICTION # --------------------- st.header("Batch Prediction") uploaded = st.file_uploader("Upload CSV", type=["csv"]) if uploaded is not None: df = pd.read_csv(uploaded) st.write(df.head()) # Show Predict button only after upload if st.button("Predict Batch"): try: # --- Add derived columns required by the model --- df['Product_Id_Prefix'] = df['Product_Id'].str[:2] current_year = datetime.now().year df['Store_Age'] = current_year - df['Store_Establishment_Year'] df_pred = predict_batch(df) st.success("Batch Prediction Completed!") st.dataframe(df_pred) # ---- DOWNLOAD MERGED FILE ---- csv = df_pred.to_csv(index=False) st.download_button("Download CSV", csv, "predictions.csv") except Exception as e: st.error(str(e)) # --------------------- # SINGLE PREDICTION # --------------------- st.header("Single Prediction") st.write("Store and Product datatils:") with st.form("single_form"): col1, col2 = st.columns(2) with col1: product_id_input = st.text_input("Product ID (e.g., DR123, FD889A)") sugar = st.selectbox("Product Sugar Content", PRODUCT_SUGAR_CONTENT) product_type = st.selectbox("Product Type", PRODUCT_TYPES) store_id = st.selectbox("Store ID", STORE_IDS) store_size = st.selectbox("Store Size", STORE_SIZES) store_type = st.selectbox("Store Type", STORE_TYPES) with col2: weight = st.number_input("Product Weight", min_value=0.0, step=1.0) area = st.number_input("Product Allocated Area", min_value=0.0, max_value=1.0, step=0.010, format="%.3f") mrp = st.number_input("Product MRP", min_value=0.0, step=20.0) est_year = st.number_input( "Store Establishment Year", 1900, datetime.now().year, 2015 ) store_city = st.selectbox("Store Location City Type", STORE_CITY_TYPES) submit = st.form_submit_button("Predict") if submit: try: # ---- Extract and validate prefix ---- product_id = product_id_input.strip().upper() if len(product_id) < 2: st.error("Product ID must contain at least 2 characters (prefix) followed by few digits.") st.stop() prefix = product_id[:2] if prefix not in PRODUCT_PREFIXES: st.error(f"Invalid Product Prefix '{prefix}'. Valid prefixes: {PRODUCT_PREFIXES}") st.stop() # Compute store age current_year = datetime.now().year store_age = current_year - est_year # Build payload payload = { "Store_Id": store_id, "Store_Type": store_type, "Store_Size": store_size, "Store_Location_City_Type": store_city, "Product_ID": product_id, "Product_Id_Prefix": prefix, "Product_Type": product_type, "Product_Sugar_Content": sugar, "Store_Age": store_age, "Product_Weight": weight, "Product_Allocated_Area": area, "Product_MRP": mrp } pred = predict_single(payload) final_pred = round(pred, 2) st.success(f"Predicted Sales: {final_pred} units") except Exception as e: st.error(e)