| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | import pickle |
| | import pandas as pd |
| |
|
| | |
| | app = FastAPI( |
| | title="Sepsis Prediction API", |
| | description="This FastAPI application provides sepsis predictions using a machine learning model.", |
| | version="1.0" |
| | ) |
| |
|
| | |
| | with open('model_and_key_components.pkl', 'rb') as file: |
| | loaded_components = pickle.load(file) |
| |
|
| | loaded_model = loaded_components['model'] |
| | loaded_encoder = loaded_components['encoder'] |
| | loaded_scaler = loaded_components['scaler'] |
| |
|
| | |
| | class InputData(BaseModel): |
| | PRG: int |
| | PL: float |
| | PR: float |
| | SK: float |
| | TS: int |
| | M11: float |
| | BD2: float |
| | Age: int |
| |
|
| | |
| | class OutputData(BaseModel): |
| | Sepsis: str |
| |
|
| | |
| | def preprocess_input_data(input_data: InputData): |
| | |
| | |
| |
|
| | |
| | numerical_cols = ['PRG', 'PL', 'PR', 'SK', 'TS', 'M11', 'BD2', 'Age'] |
| | input_data_scaled = loaded_scaler.transform([list(input_data.dict().values())]) |
| |
|
| | return pd.DataFrame(input_data_scaled, columns=numerical_cols) |
| |
|
| | |
| | def make_predictions(input_data_scaled_df: pd.DataFrame): |
| | y_pred = loaded_model.predict(input_data_scaled_df) |
| | sepsis_mapping = {0: 'Negative', 1: 'Positive'} |
| | return sepsis_mapping[y_pred[0]] |
| |
|
| | @app.get("/") |
| | async def root(): |
| | |
| | message = "Welcome to the Sepsis Classification API! This API Provides predictions for Sepsis based on several medical inputs. To use this API, please access the API documentation here: https://rasmodev-sepsis-prediction.hf.space/docs/" |
| | return message |
| |
|
| | @app.post("/predict/", response_model=OutputData) |
| | async def predict_sepsis(input_data: InputData): |
| | try: |
| | input_data_scaled_df = preprocess_input_data(input_data) |
| | sepsis_status = make_predictions(input_data_scaled_df) |
| | return {"Sepsis": sepsis_status} |
| | except Exception as e: |
| |
|
| | |
| | raise HTTPException(status_code=500, detail=str(e)) |
| | |
| | if __name__ == "__main__": |
| | import uvicorn |
| | |
| | CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--reload"] |