Dattaluri commited on
Commit
1af38a8
·
verified ·
1 Parent(s): baa4dde

Upload backend app files

Browse files
Files changed (3) hide show
  1. Dockerfile +11 -7
  2. app.py +49 -46
  3. requirements.txt +4 -2
Dockerfile CHANGED
@@ -4,13 +4,17 @@ FROM python:3.9-slim
4
  # Set the working directory inside the container to /app
5
  WORKDIR /app
6
 
7
- # Copy all files from the current directory on the host to the container's /app directory
8
- COPY . .
 
 
 
9
 
10
- # Install Python dependencies listed in requirements.txt
11
- RUN pip3 install -r requirements.txt
12
 
13
- # Define the command to run the Streamlit app on port 8501 and make it accessible externally
14
- CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
15
 
16
- # NOTE: Disable XSRF protection for easier external access in order to make batch predictions
 
 
4
  # Set the working directory inside the container to /app
5
  WORKDIR /app
6
 
7
+ # Copy the requirements.txt file into the working directory
8
+ COPY requirements.txt .
9
+
10
+ # Install Python dependencies
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
 
13
+ # Copy the rest of the application files into the working directory
14
+ COPY . .
15
 
16
+ # Expose port 5000
17
+ EXPOSE 5000
18
 
19
+ # Define the command to run the Flask application
20
+ CMD ["flask", "run", "--host=0.0.0.0"]
app.py CHANGED
@@ -1,51 +1,54 @@
1
- import streamlit as st
2
- import requests
3
- import json
4
  import pandas as pd
5
 
6
- # Define the backend API endpoint URL
7
- # Replace with the actual URL of your deployed backend API
8
- BACKEND_API_URL = "https://huggingface.co/spaces/Dattaluri/Great_Learning"
9
-
10
- st.title("SuperKart Sales Forecast")
11
-
12
- st.write("Enter the product and store details to get a sales forecast.")
13
-
14
- # Create input fields for features
15
- product_weight = st.number_input("Product Weight", min_value=0.0)
16
- product_sugar_content = st.selectbox("Product Sugar Content", ["Low Sugar", "Regular", "No Sugar", "reg"])
17
- product_allocated_area = st.number_input("Product Allocated Area", min_value=0.0)
18
- product_type = st.selectbox("Product Type", ["Frozen Foods", "Dairy", "Canned", "Baking Goods", "Health and Hygiene", "Meat", "Snack Foods", "Hard Drinks", "Breakfast", "Household", "Breads", "Starchy Foods", "Fruits and Vegetables", "Seafood", "Others", "Soft Drinks"])
19
- product_mrp = st.number_input("Product MRP", min_value=0.0)
20
- store_size = st.selectbox("Store Size", ["Medium", "High", "Small"])
21
- store_location_city_type = st.selectbox("Store Location City Type", ["Tier 2", "Tier 1", "Tier 3"])
22
- store_type = st.selectbox("Store Type", ["Supermarket Type2", "Departmental Store", "Supermarket Type1", "Food Mart"])
23
- store_establishment_year = st.number_input("Store Establishment Year", min_value=1985, max_value=2025, step=1)
24
-
25
- # Create a button to trigger the prediction
26
- if st.button("Get Sales Forecast"):
27
- # Prepare the input data as a dictionary
28
- input_data = {
29
- "Product_Weight": product_weight,
30
- "Product_Sugar_Content": product_sugar_content,
31
- "Product_Allocated_Area": product_allocated_area,
32
- "Product_Type": product_type,
33
- "Product_MRP": product_mrp,
34
- "Store_Size": store_size,
35
- "Store_Location_City_Type": store_location_city_type,
36
- "Store_Type": store_type,
37
- "Store_Establishment_Year": store_establishment_year # Include original year for 'Store_Age' calculation in backend
38
- }
39
-
40
- # Send the input data to the backend API for prediction
41
  try:
42
- response = requests.post(f"{https://huggingface.co/spaces/Dattaluri/Great_Learning/blob/main/app.py}/predict", json=input_data)
43
- response.raise_for_status() # Raise an exception for bad status codes
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- prediction = response.json()["prediction"][0]
46
- st.success(f"The estimated sales total is: ${prediction:.2f}")
47
 
48
- except requests.exceptions.RequestException as e:
49
- st.error(f"Error communicating with the backend API: {e}")
50
- except KeyError:
51
- st.error("Invalid response from the backend API.")
 
1
+ from flask import Flask, request, jsonify
2
+ import joblib
 
3
  import pandas as pd
4
 
5
+ app = Flask(__name__)
6
+
7
+ # Load the serialized full pipeline
8
+ try:
9
+ full_pipeline = joblib.load('deployment_files/SuperKart_model_v1_0.joblib')
10
+ # Get the list of columns from the training data used by the pipeline
11
+ # Adjusting to correctly get column names after one-hot encoding and scaling
12
+ # This part might need refinement based on the exact structure of your pipeline's preprocessor
13
+ # A safer approach is to save the column names of X_train during preprocessing
14
+ # For now, let's assume the order is numerical followed by one-hot encoded categorical
15
+ # We need to get the feature names from the OneHotEncoder and combine with numerical names
16
+ categorical_feature_names = full_pipeline.named_steps['preprocessor'].transformers_[0][1].get_feature_names_out(
17
+ full_pipeline.named_steps['preprocessor'].transformers_[0][2]
18
+ )
19
+ numerical_feature_names = full_pipeline.named_steps['scaler'].feature_names_in_
20
+
21
+ # Combine numerical and categorical feature names in the correct order
22
+ pipeline_columns = list(numerical_feature_names) + list(categorical_feature_names)
23
+
24
+ except Exception as e:
25
+ full_pipeline = None
26
+ print(f"Error loading pipeline: {e}")
27
+
28
+ @app.route('/predict', methods=['POST'])
29
+ def predict():
30
+ if full_pipeline is None:
31
+ return jsonify({'error': 'Model not loaded'}), 500
32
+
 
 
 
 
 
 
 
33
  try:
34
+ data = request.get_json(force=True)
35
+
36
+ # Convert input data to DataFrame, ensuring column order matches training data
37
+ input_df = pd.DataFrame([data])
38
+
39
+ # Reorder columns to match the order expected by the pipeline
40
+ # This assumes all expected columns are present in the input data
41
+ input_df = input_df[pipeline_columns]
42
+
43
+
44
+ # Make prediction
45
+ prediction = full_pipeline.predict(input_df)
46
+
47
+ # Return prediction as JSON
48
+ return jsonify({'prediction': prediction.tolist()})
49
 
50
+ except Exception as e:
51
+ return jsonify({'error': str(e)}), 400
52
 
53
+ if __name__ == '__main__':
54
+ app.run(debug=True, host='0.0.0.0', port=5000)
 
 
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
- streamlit==1.36.0
2
- requests==2.32.3
3
  pandas==2.2.2
 
 
 
1
+ Flask==3.0.3
2
+ joblib==1.4.2
3
  pandas==2.2.2
4
+ scikit-learn==1.6.1
5
+ numpy==2.0.2