xeroISB's picture
commit
9091885
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import pandas as pd
import joblib
app = FastAPI()
class AttritionInput(BaseModel):
Age: int
DistanceFromHome: int
Education: int
NumCompaniesWorked: int
PercentSalaryHike: int
TotalWorkingYears: int
TrainingTimesLastYear: int
WorkLifeBalance: int
YearsInCurrentRole: int
YearsSinceLastPromotion: int
YearsWithCurrManager: int
BusinessTravel_Travel_Rarely: int
BusinessTravel_Travel_Frequently: int
Department_Research: int
Department_Sales: int
EducationField_Life_Sciences: int
EducationField_Medical: int
EducationField_Marketing: int
EducationField_Other: int
EducationField_Technical_Degree: int
Gender_Male: int
JobRole_Research_Scientist: int
JobRole_Sales_Executive: int
JobRole_Laboratory_Technician: int
JobRole_Manufacturing_Director: int
JobRole_Healthcare_Representative: int
JobRole_Manager: int
JobRole_Sales_Representative: int
JobRole_Research_Director: int
MaritalStatus_Married: int
MaritalStatus_Single: int
OverTime_Yes: int
class HRAttritionModel:
def __init__(self, model_path):
try:
self.model = joblib.load(model_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
def predict(self, input_data):
try:
all_columns = [
'Age', 'DistanceFromHome', 'Education', 'NumCompaniesWorked', 'PercentSalaryHike',
'TotalWorkingYears', 'TrainingTimesLastYear', 'WorkLifeBalance', 'YearsInCurrentRole',
'YearsSinceLastPromotion', 'YearsWithCurrManager', 'BusinessTravel_Travel_Rarely',
'BusinessTravel_Travel_Frequently', 'Department_Research & Development', 'Department_Sales',
'EducationField_Life Sciences', 'EducationField_Medical', 'EducationField_Marketing',
'EducationField_Other', 'EducationField_Technical Degree', 'Gender_Male',
'JobRole_Research Scientist', 'JobRole_Sales Executive', 'JobRole_Laboratory Technician',
'JobRole_Manufacturing Director', 'JobRole_Healthcare Representative', 'JobRole_Manager',
'JobRole_Sales Representative', 'JobRole_Research Director', 'MaritalStatus_Married',
'MaritalStatus_Single', 'JobRole_Human Resources','OverTime_Yes'
]
input_df = pd.DataFrame([input_data], columns=all_columns).fillna(0)
print(input_df)
survival_function = self.model.predict_survival_function(input_df)
print(survival_function)
survival_values = survival_function.iloc[:, 0].tolist() # Get survival values for the first instance
print(survival_values)
return survival_values
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
model = HRAttritionModel('cox_model.pkl')
@app.post("/predict")
def predict(input_data: AttritionInput):
input_dict = input_data.dict()
prediction = model.predict(input_dict)
return {"prediction": prediction}