xeroISB commited on
Commit
0adc24e
·
1 Parent(s): 42f1c88

Commit Updates

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. app.py +35 -27
  3. request.py +76 -0
Dockerfile CHANGED
@@ -17,4 +17,4 @@ EXPOSE 8000
17
  ENV MODEL_PATH="/app/cox_model.pkl"
18
 
19
  # Run app.py when the container launches
20
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
 
17
  ENV MODEL_PATH="/app/cox_model.pkl"
18
 
19
  # Run app.py when the container launches
20
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8960"]
app.py CHANGED
@@ -1,28 +1,9 @@
1
- # main.py
2
-
3
- from fastapi import FastAPI
4
  from pydantic import BaseModel
5
- import joblib
6
  import pandas as pd
7
- from typing import Dict, Any
8
 
9
- class HRAttritionModel:
10
- def __init__(self, model_path: str):
11
- self.model = joblib.load(model_path)
12
- self.features = ['Age', 'DistanceFromHome', 'Education', 'NumCompaniesWorked', 'PercentSalaryHike',
13
- 'TotalWorkingYears', 'TrainingTimesLastYear', 'WorkLifeBalance', 'YearsInCurrentRole',
14
- 'YearsSinceLastPromotion', 'YearsWithCurrManager', 'BusinessTravel_Travel_Rarely',
15
- 'BusinessTravel_Travel_Frequently', 'Department_Research & Development', 'Department_Sales',
16
- 'EducationField_Life Sciences', 'EducationField_Medical', 'EducationField_Marketing',
17
- 'EducationField_Other', 'EducationField_Technical Degree', 'Gender_Male', 'JobRole_Research Scientist',
18
- 'JobRole_Sales Executive', 'JobRole_Laboratory Technician', 'JobRole_Manufacturing Director',
19
- 'JobRole_Healthcare Representative', 'JobRole_Manager', 'JobRole_Sales Representative',
20
- 'JobRole_Research Director', 'MaritalStatus_Married', 'MaritalStatus_Single', 'OverTime_Yes']
21
-
22
- def predict_survival(self, input_data: Dict[str, Any]) -> Any:
23
- df = pd.DataFrame([input_data], columns=self.features)
24
- survival_function = self.model.predict_survival_function(df)
25
- return survival_function.T
26
 
27
  class AttritionInput(BaseModel):
28
  Age: int
@@ -58,12 +39,39 @@ class AttritionInput(BaseModel):
58
  MaritalStatus_Single: int
59
  OverTime_Yes: int
60
 
61
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  model = HRAttritionModel('cox_model.pkl')
64
 
65
  @app.post("/predict")
66
- def predict(attrition_input: AttritionInput):
67
- input_data = attrition_input.dict()
68
- prediction = model.predict_survival(input_data)
69
- return {"prediction": prediction.tolist()}
 
1
+ from fastapi import FastAPI, HTTPException
 
 
2
  from pydantic import BaseModel
 
3
  import pandas as pd
4
+ import joblib
5
 
6
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class AttritionInput(BaseModel):
9
  Age: int
 
39
  MaritalStatus_Single: int
40
  OverTime_Yes: int
41
 
42
+ class HRAttritionModel:
43
+ def __init__(self, model_path):
44
+ try:
45
+ self.model = joblib.load(model_path)
46
+ except Exception as e:
47
+ raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
48
+
49
+ def predict(self, input_data):
50
+ try:
51
+ all_columns = [
52
+ 'Age', 'DistanceFromHome', 'Education', 'NumCompaniesWorked', 'PercentSalaryHike',
53
+ 'TotalWorkingYears', 'TrainingTimesLastYear', 'WorkLifeBalance', 'YearsInCurrentRole',
54
+ 'YearsSinceLastPromotion', 'YearsWithCurrManager', 'BusinessTravel_Travel_Rarely',
55
+ 'BusinessTravel_Travel_Frequently', 'Department_Research & Development', 'Department_Sales',
56
+ 'EducationField_Life Sciences', 'EducationField_Medical', 'EducationField_Marketing',
57
+ 'EducationField_Other', 'EducationField_Technical Degree', 'Gender_Male',
58
+ 'JobRole_Research Scientist', 'JobRole_Sales Executive', 'JobRole_Laboratory Technician',
59
+ 'JobRole_Manufacturing Director', 'JobRole_Healthcare Representative', 'JobRole_Manager',
60
+ 'JobRole_Sales Representative', 'JobRole_Research Director', 'MaritalStatus_Married',
61
+ 'MaritalStatus_Single', 'JobRole_Human Resources','OverTime_Yes'
62
+ ]
63
+
64
+ input_df = pd.DataFrame([input_data], columns=all_columns).fillna(0)
65
+ survival_function = self.model.predict_survival_function(input_df)
66
+ survival_values = survival_function.iloc[:, 0].tolist() # Get survival values for the first instance
67
+ return survival_values
68
+ except Exception as e:
69
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
70
 
71
  model = HRAttritionModel('cox_model.pkl')
72
 
73
  @app.post("/predict")
74
+ def predict(input_data: AttritionInput):
75
+ input_dict = input_data.dict()
76
+ prediction = model.predict(input_dict)
77
+ return {"prediction": prediction}
request.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import matplotlib.pyplot as plt
3
+
4
+ # Define the URL for the prediction service
5
+ url = "http://localhost:8960/predict"
6
+
7
+ # Prepare the data for the prediction
8
+ data = {
9
+ "Age": 35,
10
+ "DistanceFromHome": 10,
11
+ "Education": 2,
12
+ "NumCompaniesWorked": 3,
13
+ "PercentSalaryHike": 15,
14
+ "TotalWorkingYears": 10,
15
+ "TrainingTimesLastYear": 3,
16
+ "WorkLifeBalance": 2,
17
+ "YearsInCurrentRole": 4,
18
+ "YearsSinceLastPromotion": 1,
19
+ "YearsWithCurrManager": 2,
20
+ "BusinessTravel_Travel_Rarely": 1,
21
+ "BusinessTravel_Travel_Frequently": 0,
22
+ "Department_Research": 1,
23
+ "Department_Sales": 0,
24
+ "EducationField_Life_Sciences": 1,
25
+ "EducationField_Medical": 0,
26
+ "EducationField_Marketing": 0,
27
+ "EducationField_Other": 0,
28
+ "EducationField_Technical_Degree": 0,
29
+ "Gender_Male": 1,
30
+ "JobRole_Research_Scientist": 1,
31
+ "JobRole_Sales_Executive": 0,
32
+ "JobRole_Laboratory_Technician": 0,
33
+ "JobRole_Manufacturing_Director": 0,
34
+ "JobRole_Healthcare_Representative": 0,
35
+ "JobRole_Manager": 0,
36
+ "JobRole_Sales_Representative": 0,
37
+ "JobRole_Research_Director": 0,
38
+ "MaritalStatus_Married": 1,
39
+ "MaritalStatus_Single": 0,
40
+ "OverTime_Yes": 0
41
+ }
42
+
43
+ # Make the POST request
44
+ try:
45
+ response = requests.post(url, json=data)
46
+ response.raise_for_status() # Raise an error for bad responses
47
+ prediction = response.json() # Parse the JSON response
48
+
49
+ # Check if the prediction contains the expected key and is a list
50
+ if isinstance(prediction, dict) and 'prediction' in prediction and isinstance(prediction['prediction'], list):
51
+ survival_probabilities = prediction['prediction']
52
+ else:
53
+ raise ValueError("Unexpected response format: {}".format(prediction))
54
+
55
+ # Create a list of years based on the number of predictions
56
+ years = list(range(1, len(survival_probabilities) + 1))
57
+
58
+ # Plot the data
59
+ plt.figure(figsize=(10, 6))
60
+ plt.plot(years, survival_probabilities, marker='o', linestyle='-', color='b')
61
+ plt.xlabel('Years')
62
+ plt.ylabel('Survival Probability')
63
+ plt.title('Employee Survival Probability Over Time')
64
+ plt.grid(True)
65
+ plt.xticks(years)
66
+ plt.ylim(0, 1)
67
+
68
+ # Show the plot
69
+ plt.show()
70
+
71
+ except requests.exceptions.RequestException as e:
72
+ print("An error occurred while making the request:", e)
73
+ except ValueError as ve:
74
+ print("Value error:", ve)
75
+ except Exception as ex:
76
+ print("An unexpected error occurred:", ex)