Spaces:
Sleeping
Sleeping
File size: 4,100 Bytes
07eac76 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import strawberry
from strawberry.asgi import GraphQL
import pandas as pd
import joblib
from sklearn.pipeline import Pipeline
from sklearn.preprocessing._label import LabelEncoder
import httpx
from io import BytesIO
from typing import Tuple, List, Optional, Union
from enum import Enum
from config import RANDOM_FOREST_URL, XGBOOST_URL, ENCODER_URL
import logging
# API input features
@strawberry.enum
class ModelChoice(Enum):
RandomForestClassifier = RANDOM_FOREST_URL
XGBoostClassifier = XGBOOST_URL
@strawberry.input
class SepsisFeatures:
prg: List[int]
pl: List[int]
pr: List[int]
sk: List[int]
ts: List[int]
m11: List[float]
bd2: List[float]
age: List[int]
insurance: List[int]
@strawberry.type
class Url:
url: str
pipeline_url: str
encoder_url: str
@strawberry.type
class ResultData:
prediction: List[str]
probability: List[float]
@strawberry.type
class PredictionResponse:
execution_msg: str
execution_code: int
result: ResultData
@strawberry.type
class ErrorResponse:
execution_msg: str
execution_code: int
error: Optional[str]
logging.basicConfig(level=logging.ERROR,
format='%(asctime)s - %(levelname)s - %(message)s')
async def url_to_data(url: Url) -> BytesIO:
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status() # Ensure we catch any HTTP errors
# Convert response content to BytesIO object
data = BytesIO(response.content)
return data
# Load the model pipelines and encoder
async def load_pipeline(pipeline_url: Url, encoder_url: Url) -> Tuple[Pipeline, LabelEncoder]:
pipeline, encoder = None, None
try:
pipeline: Pipeline = joblib.load(await url_to_data(pipeline_url))
encoder: LabelEncoder = joblib.load(await url_to_data(encoder_url))
except Exception as e:
logging.error(
"Omg, an error occurred in loading the pipeline resources: %s", e)
finally:
return pipeline, encoder
async def pipeline_classifier(pipeline: Pipeline, encoder: LabelEncoder, data: SepsisFeatures) -> Union[ErrorResponse, PredictionResponse]:
msg = 'Execution failed'
code = 0
output = ErrorResponse(**{'execution_msg': msg,
'execution_code': code, 'error': None})
try:
# Create dataframe
df = pd.DataFrame.from_dict(data.__dict__)
# Make prediction
preds = pipeline.predict(df)
preds_int = [int(pred) for pred in preds]
predictions = encoder.inverse_transform(preds_int)
probabilities_np = pipeline.predict_proba(df)
probabilities = [round(float(max(prob)*100), 2)
for prob in probabilities_np]
result = ResultData(**{"prediction": predictions,
"probability": probabilities}
)
msg = 'Execution was successful'
code = 1
output = PredictionResponse(
**{'execution_msg': msg,
'execution_code': code, 'result': result}
)
except Exception as e:
error = f"Omg, pipeline classifier and/or encoder failure. {e}"
output = ErrorResponse(**{'execution_msg': msg,
'execution_code': code, 'error': error})
finally:
return output
@strawberry.type
class Query:
@strawberry.field
async def predict_sepsis(self, model: ModelChoice, data: SepsisFeatures) -> Union[ErrorResponse, PredictionResponse]:
pipeline_url: Url = model.value
pipeline, encoder = await load_pipeline(pipeline_url, ENCODER_URL)
output = await pipeline_classifier(pipeline, encoder, data)
return output
# Create the GraphQL Schema
schema = strawberry.Schema(query=Query)
# Create the GraphQL application
graphql_app = GraphQL(schema)
|