Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import requests | |
| import logging | |
| from datetime import datetime | |
| import io | |
| # ------------------------------------------------ | |
| # CONSTANTS (API endpoints) | |
| # ------------------------------------------------ | |
| API_SINGLE = "https://Lokiiparihar-backendlokii.hf.space/v1/sales/predict" | |
| API_BATCH = "https://Lokiiparihar-backendlokii.hf.space/v1/sales/predict/batch" | |
| #Lokiiparihar/backendlokii | |
| # ------------------------------------------------ | |
| # LOGGING | |
| # ------------------------------------------------ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| handlers=[logging.StreamHandler()] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ------------------------------------------------ | |
| # DIRECT PICKLIST VALUES (raw strings — no encoding) | |
| # ------------------------------------------------ | |
| STORE_IDS = ["OUT001", "OUT002", "OUT003", "OUT004"] | |
| STORE_TYPES = [ | |
| "Departmental Store", | |
| "Food Mart", | |
| "Supermarket Type1", | |
| "Supermarket Type2" | |
| ] | |
| STORE_SIZES = ["High", "Medium", "Small"] | |
| STORE_CITY_TYPES = ["Tier 1", "Tier 2", "Tier 3"] | |
| PRODUCT_TYPES = [ | |
| "Baking Goods", "Breads", "Breakfast", "Canned", "Dairy", | |
| "Frozen Foods", "Fruits and Vegetables", "Hard Drinks", | |
| "Health and Hygiene", "Household", "Meat", "Others", | |
| "Seafood", "Snack Foods", "Soft Drinks", "Starchy Foods" | |
| ] | |
| PRODUCT_SUGAR_CONTENT = ["Low Sugar", "No Sugar", "Regular"] | |
| # Allowed prefixes | |
| PRODUCT_PREFIXES = ["DR", "FD", "NC"] | |
| # ------------------------------------------------ | |
| # API HELPERS | |
| # ------------------------------------------------ | |
| def predict_single(input_data): | |
| """Send raw (non-encoded) data directly to the API.""" | |
| response = requests.post(API_SINGLE, json=input_data) | |
| if response.status_code != 200: | |
| raise ValueError(f"API Error: {response.text}") | |
| result = response.json() | |
| return float(result["Prediction"]) | |
| def predict_batch(df): | |
| # Convert DataFrame to CSV buffer | |
| csv_buffer = io.StringIO() | |
| df.to_csv(csv_buffer, index=False) | |
| csv_buffer.seek(0) | |
| # Send file as multipart/form-data | |
| files = {'file': ('data.csv', csv_buffer.getvalue())} | |
| response = requests.post(API_BATCH, files=files) | |
| if response.status_code != 200: | |
| raise ValueError(f"API Error: {response.text}") | |
| # Server returns dict {Product_Id: prediction} | |
| predictions_dict = response.json() | |
| # Map predictions back to DataFrame | |
| df_out = df.copy() | |
| df_out["Predicted_Sales"] = df_out["Product_Id"].map( | |
| lambda x: round(predictions_dict.get(x, 0), 2) | |
| ) | |
| return df_out | |
| # ------------------------------------------------ | |
| # STREAMLIT UI | |
| # ------------------------------------------------ | |
| st.title("Superkart Product Sales Prediction") | |
| # --------------------- | |
| # BATCH PREDICTION | |
| # --------------------- | |
| st.header("Batch Prediction") | |
| uploaded = st.file_uploader("Upload CSV", type=["csv"]) | |
| if uploaded is not None: | |
| df = pd.read_csv(uploaded) | |
| st.write(df.head()) | |
| # Show Predict button only after upload | |
| if st.button("Predict Batch"): | |
| try: | |
| # --- Add derived columns required by the model --- | |
| df['Product_Id_Prefix'] = df['Product_Id'].str[:2] | |
| current_year = datetime.now().year | |
| df['Store_Age'] = current_year - df['Store_Establishment_Year'] | |
| df_pred = predict_batch(df) | |
| st.success("Batch Prediction Completed!") | |
| st.dataframe(df_pred) | |
| # ---- DOWNLOAD MERGED FILE ---- | |
| csv = df_pred.to_csv(index=False) | |
| st.download_button("Download CSV", csv, "predictions.csv") | |
| except Exception as e: | |
| st.error(str(e)) | |
| # --------------------- | |
| # SINGLE PREDICTION | |
| # --------------------- | |
| st.header("Single Prediction") | |
| st.write("Store and Product datatils:") | |
| with st.form("single_form"): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| product_id_input = st.text_input("Product ID (e.g., DR123, FD889A)") | |
| sugar = st.selectbox("Product Sugar Content", PRODUCT_SUGAR_CONTENT) | |
| product_type = st.selectbox("Product Type", PRODUCT_TYPES) | |
| store_id = st.selectbox("Store ID", STORE_IDS) | |
| store_size = st.selectbox("Store Size", STORE_SIZES) | |
| store_type = st.selectbox("Store Type", STORE_TYPES) | |
| with col2: | |
| weight = st.number_input("Product Weight", min_value=0.0, step=1.0) | |
| area = st.number_input("Product Allocated Area", min_value=0.0, max_value=1.0, step=0.010, format="%.3f") | |
| mrp = st.number_input("Product MRP", min_value=0.0, step=20.0) | |
| est_year = st.number_input( | |
| "Store Establishment Year", | |
| 1900, datetime.now().year, | |
| 2015 | |
| ) | |
| store_city = st.selectbox("Store Location City Type", STORE_CITY_TYPES) | |
| submit = st.form_submit_button("Predict") | |
| if submit: | |
| try: | |
| # ---- Extract and validate prefix ---- | |
| product_id = product_id_input.strip().upper() | |
| if len(product_id) < 2: | |
| st.error("Product ID must contain at least 2 characters (prefix) followed by few digits.") | |
| st.stop() | |
| prefix = product_id[:2] | |
| if prefix not in PRODUCT_PREFIXES: | |
| st.error(f"Invalid Product Prefix '{prefix}'. Valid prefixes: {PRODUCT_PREFIXES}") | |
| st.stop() | |
| # Compute store age | |
| current_year = datetime.now().year | |
| store_age = current_year - est_year | |
| # Build payload | |
| payload = { | |
| "Store_Id": store_id, | |
| "Store_Type": store_type, | |
| "Store_Size": store_size, | |
| "Store_Location_City_Type": store_city, | |
| "Product_ID": product_id, | |
| "Product_Id_Prefix": prefix, | |
| "Product_Type": product_type, | |
| "Product_Sugar_Content": sugar, | |
| "Store_Age": store_age, | |
| "Product_Weight": weight, | |
| "Product_Allocated_Area": area, | |
| "Product_MRP": mrp | |
| } | |
| pred = predict_single(payload) | |
| final_pred = round(pred, 2) | |
| st.success(f"Predicted Sales: {final_pred} units") | |
| except Exception as e: | |
| st.error(e) |