bappiahk commited on
Commit
5873aef
·
verified ·
1 Parent(s): 9f93e68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -19
app.py CHANGED
@@ -1,45 +1,40 @@
1
  import gradio as gr
2
  import pandas as pd
3
  from joblib import load
4
- import os
5
 
6
- # Load the model once when the app starts
7
  rf_model = load('churn_prediction_model.joblib')
8
 
9
  def predict_churn(file):
10
- """
11
- Process the uploaded file and return churn predictions
12
- """
13
  try:
14
- # Read the uploaded file
15
  if file.name.endswith('.csv'):
16
  test_data = pd.read_csv(file.name)
17
  else:
18
  test_data = pd.read_excel(file.name)
19
 
20
- # Ensure 'major_issue' is one-hot encoded
21
  if 'major_issue_Technical Issue' not in test_data.columns:
22
  test_data = pd.get_dummies(test_data, columns=['major_issue'], drop_first=True)
23
 
24
- # Ensure all required columns are present
25
  required_columns = [
26
  'late_payments_last_year', 'missed_payments_last_year', 'plan_tenure',
27
  'num_employees', 'avg_monthly_contribution', 'annual_revenue',
28
  'support_calls_last_year', 'support_engagement_per_year',
29
  'major_issue_Technical Issue'
30
  ]
31
-
32
  for col in required_columns:
33
  if col not in test_data.columns:
34
  test_data[col] = 0
35
 
36
  # Extract features
37
  test_data_features = test_data[required_columns]
38
-
39
  # Make predictions
40
  test_data['predicted_churn'] = rf_model.predict(test_data_features)
41
 
42
- # Prepare output
43
  output_columns = [
44
  'customer_id', 'late_payments_last_year', 'missed_payments_last_year',
45
  'plan_tenure', 'num_employees', 'avg_monthly_contribution',
@@ -49,7 +44,7 @@ def predict_churn(file):
49
  output_data = test_data[output_columns]
50
 
51
  # Save to temporary file
52
- temp_output = "temp_output.csv"
53
  output_data.to_csv(temp_output, index=False)
54
 
55
  return temp_output
@@ -67,13 +62,7 @@ iface = gr.Interface(
67
  ),
68
  outputs=gr.File(label="Download Predictions"),
69
  title="Customer Churn Prediction",
70
- description="""Upload a file with customer data to predict churn probability.
71
- Required columns: customer_id, late_payments_last_year, missed_payments_last_year,
72
- plan_tenure, num_employees, avg_monthly_contribution, annual_revenue,
73
- support_calls_last_year, support_engagement_per_year, major_issue""",
74
- examples=[
75
- ["sample_input.csv"] # Add a sample file if you have one
76
- ],
77
  cache_examples=False
78
  )
79
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  from joblib import load
 
4
 
5
+ # Load the model at startup
6
  rf_model = load('churn_prediction_model.joblib')
7
 
8
  def predict_churn(file):
 
 
 
9
  try:
10
+ # Load the uploaded file
11
  if file.name.endswith('.csv'):
12
  test_data = pd.read_csv(file.name)
13
  else:
14
  test_data = pd.read_excel(file.name)
15
 
16
+ # Ensure 'major_issue' is one-hot encoded if needed
17
  if 'major_issue_Technical Issue' not in test_data.columns:
18
  test_data = pd.get_dummies(test_data, columns=['major_issue'], drop_first=True)
19
 
20
+ # Ensure all training-time columns are present
21
  required_columns = [
22
  'late_payments_last_year', 'missed_payments_last_year', 'plan_tenure',
23
  'num_employees', 'avg_monthly_contribution', 'annual_revenue',
24
  'support_calls_last_year', 'support_engagement_per_year',
25
  'major_issue_Technical Issue'
26
  ]
 
27
  for col in required_columns:
28
  if col not in test_data.columns:
29
  test_data[col] = 0
30
 
31
  # Extract features
32
  test_data_features = test_data[required_columns]
33
+
34
  # Make predictions
35
  test_data['predicted_churn'] = rf_model.predict(test_data_features)
36
 
37
+ # Select output columns
38
  output_columns = [
39
  'customer_id', 'late_payments_last_year', 'missed_payments_last_year',
40
  'plan_tenure', 'num_employees', 'avg_monthly_contribution',
 
44
  output_data = test_data[output_columns]
45
 
46
  # Save to temporary file
47
+ temp_output = "output.csv"
48
  output_data.to_csv(temp_output, index=False)
49
 
50
  return temp_output
 
62
  ),
63
  outputs=gr.File(label="Download Predictions"),
64
  title="Customer Churn Prediction",
65
+ description="Upload your customer data file to predict churn probability",
 
 
 
 
 
 
66
  cache_examples=False
67
  )
68