dpanchali commited on
Commit
1cdd86e
·
verified ·
1 Parent(s): 50a99a9

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +17 -16
  2. requirements.txt +0 -1
app.py CHANGED
@@ -3,33 +3,34 @@ import pandas as pd
3
  from flask import Flask, request, jsonify
4
 
5
  # Initialize Flask app with a name
6
- churn_predictor_api = Flask("Customer Churn Predictor")
7
 
8
  # Load the trained churn prediction model
9
  model = joblib.load("churn_prediction_model_v1_0.joblib")
10
 
11
  # Define a route for the home page
12
- @churn_predictor_api.get('/')
13
  def home():
14
- return "Welcome to the Customer Churn Prediction API!"
15
 
16
  # Define an endpoint to predict churn for a single customer
17
- @churn_predictor_api.post('/v1/customer')
18
  def predict_churn():
19
  # Get JSON data from the request
20
  customer_data = request.get_json()
21
 
22
  # Extract relevant customer features from the input data
23
  sample = {
24
- 'CreditScore': customer_data['CreditScore'],
25
- 'Geography': customer_data['Geography'],
26
- 'Age': customer_data['Age'],
27
- 'Tenure': customer_data['Tenure'],
28
- 'Balance': customer_data['Balance'],
29
- 'NumOfProducts': customer_data['NumOfProducts'],
30
- 'HasCrCard': customer_data['HasCrCard'],
31
- 'IsActiveMember': customer_data['IsActiveMember'],
32
- 'EstimatedSalary': customer_data['EstimatedSalary']
 
33
  }
34
 
35
  # Convert the extracted data into a DataFrame
@@ -45,7 +46,7 @@ def predict_churn():
45
  return jsonify({'Prediction': prediction_label})
46
 
47
  # Define an endpoint to predict churn for a batch of customers
48
- @churn_predictor_api.post('/v1/customerbatch')
49
  def predict_churn_batch():
50
  # Get the uploaded CSV file from the request
51
  file = request.files['file']
@@ -57,10 +58,10 @@ def predict_churn_batch():
57
  predictions = [
58
  'Churn' if x == 1
59
  else "Not Churn"
60
- for x in model.predict(input_data.drop("CustomerId",axis=1)).tolist()
61
  ]
62
 
63
- cust_id_list = input_data.CustomerId.values.tolist()
64
  output_dict = dict(zip(cust_id_list, predictions))
65
 
66
  return output_dict
 
3
  from flask import Flask, request, jsonify
4
 
5
  # Initialize Flask app with a name
6
+ app = Flask("Telecom Customer Churn Predictor")
7
 
8
  # Load the trained churn prediction model
9
  model = joblib.load("churn_prediction_model_v1_0.joblib")
10
 
11
  # Define a route for the home page
12
+ @app.get('/')
13
  def home():
14
+ return "Welcome to the Telecom Customer Churn Prediction API"
15
 
16
  # Define an endpoint to predict churn for a single customer
17
+ @app.post('/v1/customer')
18
  def predict_churn():
19
  # Get JSON data from the request
20
  customer_data = request.get_json()
21
 
22
  # Extract relevant customer features from the input data
23
  sample = {
24
+ 'SeniorCitizen': customer_data['SeniorCitizen'],
25
+ 'Partner': customer_data['Partner'],
26
+ 'Dependents': customer_data['Dependents'],
27
+ 'tenure': customer_data['tenure'],
28
+ 'PhoneService': customer_data['PhoneService'],
29
+ 'InternetService': customer_data['InternetService'],
30
+ 'Contract': customer_data['Contract'],
31
+ 'PaymentMethod': customer_data['PaymentMethod'],
32
+ 'MonthlyCharges': customer_data['MonthlyCharges'],
33
+ 'TotalCharges': customer_data['TotalCharges']
34
  }
35
 
36
  # Convert the extracted data into a DataFrame
 
46
  return jsonify({'Prediction': prediction_label})
47
 
48
  # Define an endpoint to predict churn for a batch of customers
49
+ @app.post('/v1/customerbatch')
50
  def predict_churn_batch():
51
  # Get the uploaded CSV file from the request
52
  file = request.files['file']
 
58
  predictions = [
59
  'Churn' if x == 1
60
  else "Not Churn"
61
+ for x in model.predict(input_data.drop("customerID",axis=1)).tolist()
62
  ]
63
 
64
+ cust_id_list = input_data.customerID.values.tolist()
65
  output_dict = dict(zip(cust_id_list, predictions))
66
 
67
  return output_dict
requirements.txt CHANGED
@@ -8,4 +8,3 @@ flask==2.2.2
8
  gunicorn==20.1.0
9
  requests==2.28.1
10
  uvicorn[standard]
11
- streamlit==1.43.2
 
8
  gunicorn==20.1.0
9
  requests==2.28.1
10
  uvicorn[standard]