File size: 4,100 Bytes
07eac76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import strawberry
from strawberry.asgi import GraphQL

import pandas as pd
import joblib
from sklearn.pipeline import Pipeline
from sklearn.preprocessing._label import LabelEncoder

import httpx
from io import BytesIO

from typing import Tuple, List, Optional, Union
from enum import Enum


from config import RANDOM_FOREST_URL, XGBOOST_URL, ENCODER_URL

import logging


# API input features

@strawberry.enum
class ModelChoice(Enum):
    RandomForestClassifier = RANDOM_FOREST_URL
    XGBoostClassifier = XGBOOST_URL


@strawberry.input
class SepsisFeatures:
    prg: List[int]
    pl: List[int]
    pr: List[int]
    sk: List[int]
    ts: List[int]
    m11: List[float]
    bd2: List[float]
    age: List[int]
    insurance: List[int]


@strawberry.type
class Url:
    url: str
    pipeline_url: str
    encoder_url: str


@strawberry.type
class ResultData:
    prediction: List[str]
    probability: List[float]


@strawberry.type
class PredictionResponse:
    execution_msg: str
    execution_code: int
    result: ResultData


@strawberry.type
class ErrorResponse:
    execution_msg: str
    execution_code: int
    error: Optional[str]


logging.basicConfig(level=logging.ERROR,
                    format='%(asctime)s - %(levelname)s - %(message)s')


async def url_to_data(url: Url) -> BytesIO:
    async with httpx.AsyncClient() as client:
        response = await client.get(url)
        response.raise_for_status()  # Ensure we catch any HTTP errors
        # Convert response content to BytesIO object
        data = BytesIO(response.content)
        return data


# Load the model pipelines and encoder
async def load_pipeline(pipeline_url: Url, encoder_url: Url) -> Tuple[Pipeline, LabelEncoder]:
    pipeline, encoder = None, None
    try:
        pipeline: Pipeline = joblib.load(await url_to_data(pipeline_url))
        encoder: LabelEncoder = joblib.load(await url_to_data(encoder_url))
    except Exception as e:
        logging.error(
            "Omg, an error occurred in loading the pipeline resources: %s", e)
    finally:
        return pipeline, encoder


async def pipeline_classifier(pipeline: Pipeline, encoder: LabelEncoder, data: SepsisFeatures) -> Union[ErrorResponse, PredictionResponse]:
    msg = 'Execution failed'
    code = 0
    output = ErrorResponse(**{'execution_msg': msg,
                              'execution_code': code, 'error': None})
    try:
        # Create dataframe
        df = pd.DataFrame.from_dict(data.__dict__)

        # Make prediction
        preds = pipeline.predict(df)
        preds_int = [int(pred) for pred in preds]

        predictions = encoder.inverse_transform(preds_int)
        probabilities_np = pipeline.predict_proba(df)

        probabilities = [round(float(max(prob)*100), 2)
                         for prob in probabilities_np]

        result = ResultData(**{"prediction": predictions,
                               "probability": probabilities}
                            )

        msg = 'Execution was successful'
        code = 1
        output = PredictionResponse(
            **{'execution_msg': msg,
               'execution_code': code, 'result': result}
        )

    except Exception as e:
        error = f"Omg, pipeline classifier and/or encoder failure. {e}"

        output = ErrorResponse(**{'execution_msg': msg,
                                  'execution_code': code, 'error': error})

    finally:
        return output


@strawberry.type
class Query:
    @strawberry.field
    async def predict_sepsis(self, model: ModelChoice, data: SepsisFeatures) -> Union[ErrorResponse, PredictionResponse]:
        pipeline_url: Url = model.value
        pipeline, encoder = await load_pipeline(pipeline_url, ENCODER_URL)

        output = await pipeline_classifier(pipeline, encoder, data)

        return output


# Create the GraphQL Schema
schema = strawberry.Schema(query=Query)

# Create the GraphQL application
graphql_app = GraphQL(schema)