Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ FastAPI application for Hugging Face Spaces deployment
|
|
| 6 |
from fastapi import FastAPI, HTTPException
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from pydantic import BaseModel, Field
|
| 9 |
-
from typing import List, Dict
|
| 10 |
import numpy as np
|
| 11 |
import joblib
|
| 12 |
from pathlib import Path
|
|
@@ -45,18 +45,42 @@ BACKEND_ALLOWED_DISEASES = {
|
|
| 45 |
# Pydantic models
|
| 46 |
class PatientData(BaseModel):
|
| 47 |
"""Patient biomarker data for prediction"""
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
class PredictionResponse(BaseModel):
|
| 56 |
"""Response model for disease prediction"""
|
| 57 |
-
|
| 58 |
confidence: float
|
| 59 |
-
|
|
|
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
class HealthResponse(BaseModel):
|
|
@@ -158,17 +182,17 @@ def predict_disease(patient_features: np.ndarray):
|
|
| 158 |
disease = label_encoder.inverse_transform([y_pred])[0]
|
| 159 |
confidence = float(y_proba[y_pred])
|
| 160 |
|
| 161 |
-
# Get top
|
| 162 |
-
|
| 163 |
-
|
| 164 |
{
|
| 165 |
"disease": label_encoder.inverse_transform([idx])[0],
|
| 166 |
-
"
|
| 167 |
}
|
| 168 |
-
for idx in
|
| 169 |
]
|
| 170 |
|
| 171 |
-
return disease, confidence,
|
| 172 |
|
| 173 |
|
| 174 |
@app.get("/")
|
|
@@ -178,15 +202,16 @@ async def root():
|
|
| 178 |
"message": "MediGuard Disease Prediction API",
|
| 179 |
"version": "1.0.0",
|
| 180 |
"endpoints": {
|
| 181 |
-
"health": "/health",
|
| 182 |
-
"predict": "/predict (POST)",
|
| 183 |
-
"features": "/features",
|
|
|
|
| 184 |
"docs": "/docs"
|
| 185 |
}
|
| 186 |
}
|
| 187 |
|
| 188 |
|
| 189 |
-
@app.get("/health", response_model=HealthResponse)
|
| 190 |
async def health_check():
|
| 191 |
"""Health check endpoint"""
|
| 192 |
return HealthResponse(
|
|
@@ -196,7 +221,7 @@ async def health_check():
|
|
| 196 |
)
|
| 197 |
|
| 198 |
|
| 199 |
-
@app.get("/features"
|
| 200 |
async def get_features():
|
| 201 |
"""Get list of required features"""
|
| 202 |
if feature_cols is None:
|
|
@@ -209,13 +234,13 @@ async def get_features():
|
|
| 209 |
}
|
| 210 |
|
| 211 |
|
| 212 |
-
@app.post("/predict", response_model=PredictionResponse)
|
| 213 |
async def predict(patient_data: PatientData):
|
| 214 |
"""
|
| 215 |
Predict disease from patient biomarker data
|
| 216 |
|
| 217 |
Args:
|
| 218 |
-
patient_data: PatientData object with
|
| 219 |
|
| 220 |
Returns:
|
| 221 |
PredictionResponse with predicted disease and confidence
|
|
@@ -228,18 +253,34 @@ async def predict(patient_data: PatientData):
|
|
| 228 |
)
|
| 229 |
|
| 230 |
try:
|
| 231 |
-
# Convert to
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
# Predict
|
| 235 |
-
disease, confidence,
|
| 236 |
|
| 237 |
logger.info(f"Prediction: {disease} ({confidence*100:.2f}%)")
|
| 238 |
|
| 239 |
return PredictionResponse(
|
| 240 |
-
|
| 241 |
confidence=confidence,
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
)
|
| 244 |
|
| 245 |
except ValueError as e:
|
|
|
|
| 6 |
from fastapi import FastAPI, HTTPException
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from pydantic import BaseModel, Field
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
import numpy as np
|
| 11 |
import joblib
|
| 12 |
from pathlib import Path
|
|
|
|
| 45 |
# Pydantic models
|
| 46 |
class PatientData(BaseModel):
|
| 47 |
"""Patient biomarker data for prediction"""
|
| 48 |
+
Glucose: float
|
| 49 |
+
Cholesterol: float
|
| 50 |
+
Hemoglobin: float
|
| 51 |
+
Platelets: float
|
| 52 |
+
White_Blood_Cells: float = Field(..., alias="White Blood Cells")
|
| 53 |
+
Red_Blood_Cells: float = Field(..., alias="Red Blood Cells")
|
| 54 |
+
Hematocrit: float
|
| 55 |
+
Mean_Corpuscular_Volume: float = Field(..., alias="Mean Corpuscular Volume")
|
| 56 |
+
Mean_Corpuscular_Hemoglobin: float = Field(..., alias="Mean Corpuscular Hemoglobin")
|
| 57 |
+
Mean_Corpuscular_Hemoglobin_Concentration: float = Field(..., alias="Mean Corpuscular Hemoglobin Concentration")
|
| 58 |
+
Insulin: float
|
| 59 |
+
BMI: float
|
| 60 |
+
Systolic_Blood_Pressure: float = Field(..., alias="Systolic Blood Pressure")
|
| 61 |
+
Diastolic_Blood_Pressure: float = Field(..., alias="Diastolic Blood Pressure")
|
| 62 |
+
Triglycerides: float
|
| 63 |
+
HbA1c: float
|
| 64 |
+
LDL_Cholesterol: float = Field(..., alias="LDL Cholesterol")
|
| 65 |
+
HDL_Cholesterol: float = Field(..., alias="HDL Cholesterol")
|
| 66 |
+
ALT: float
|
| 67 |
+
AST: float
|
| 68 |
+
Heart_Rate: float = Field(..., alias="Heart Rate")
|
| 69 |
+
Creatinine: float
|
| 70 |
+
Troponin: float
|
| 71 |
+
C_reactive_Protein: float = Field(..., alias="C-reactive Protein")
|
| 72 |
+
|
| 73 |
+
class Config:
|
| 74 |
+
populate_by_name = True
|
| 75 |
|
| 76 |
|
| 77 |
class PredictionResponse(BaseModel):
|
| 78 |
"""Response model for disease prediction"""
|
| 79 |
+
prediction: str
|
| 80 |
confidence: float
|
| 81 |
+
top_5_predictions: List[Dict[str, float]]
|
| 82 |
+
raw_values: Dict[str, float]
|
| 83 |
+
model_info: Dict[str, Any]
|
| 84 |
|
| 85 |
|
| 86 |
class HealthResponse(BaseModel):
|
|
|
|
| 182 |
disease = label_encoder.inverse_transform([y_pred])[0]
|
| 183 |
confidence = float(y_proba[y_pred])
|
| 184 |
|
| 185 |
+
# Get top 5 predictions
|
| 186 |
+
top_5_idx = np.argsort(y_proba)[-5:][::-1]
|
| 187 |
+
top_5 = [
|
| 188 |
{
|
| 189 |
"disease": label_encoder.inverse_transform([idx])[0],
|
| 190 |
+
"probability": float(y_proba[idx])
|
| 191 |
}
|
| 192 |
+
for idx in top_5_idx
|
| 193 |
]
|
| 194 |
|
| 195 |
+
return disease, confidence, top_5
|
| 196 |
|
| 197 |
|
| 198 |
@app.get("/")
|
|
|
|
| 202 |
"message": "MediGuard Disease Prediction API",
|
| 203 |
"version": "1.0.0",
|
| 204 |
"endpoints": {
|
| 205 |
+
"health": "/api/health",
|
| 206 |
+
"predict": "/api/predict (POST)",
|
| 207 |
+
"features": "/api/features",
|
| 208 |
+
"diseases": "/api/diseases",
|
| 209 |
"docs": "/docs"
|
| 210 |
}
|
| 211 |
}
|
| 212 |
|
| 213 |
|
| 214 |
+
@app.get("/api/health", response_model=HealthResponse)
|
| 215 |
async def health_check():
|
| 216 |
"""Health check endpoint"""
|
| 217 |
return HealthResponse(
|
|
|
|
| 221 |
)
|
| 222 |
|
| 223 |
|
| 224 |
+
@app.get("/api/features")
|
| 225 |
async def get_features():
|
| 226 |
"""Get list of required features"""
|
| 227 |
if feature_cols is None:
|
|
|
|
| 234 |
}
|
| 235 |
|
| 236 |
|
| 237 |
+
@app.post("/api/predict", response_model=PredictionResponse)
|
| 238 |
async def predict(patient_data: PatientData):
|
| 239 |
"""
|
| 240 |
Predict disease from patient biomarker data
|
| 241 |
|
| 242 |
Args:
|
| 243 |
+
patient_data: PatientData object with biomarker key-value pairs
|
| 244 |
|
| 245 |
Returns:
|
| 246 |
PredictionResponse with predicted disease and confidence
|
|
|
|
| 253 |
)
|
| 254 |
|
| 255 |
try:
|
| 256 |
+
# Convert Pydantic model to dict and extract values in correct order
|
| 257 |
+
input_dict = patient_data.model_dump(by_alias=True)
|
| 258 |
+
|
| 259 |
+
# Build features array in the order expected by feature_cols
|
| 260 |
+
features = []
|
| 261 |
+
for feature_name in feature_cols:
|
| 262 |
+
if feature_name not in input_dict:
|
| 263 |
+
raise ValueError(f"Missing feature: {feature_name}")
|
| 264 |
+
features.append(float(input_dict[feature_name]))
|
| 265 |
+
|
| 266 |
+
features = np.array(features, dtype=np.float32)
|
| 267 |
|
| 268 |
# Predict
|
| 269 |
+
disease, confidence, top_5 = predict_disease(features)
|
| 270 |
|
| 271 |
logger.info(f"Prediction: {disease} ({confidence*100:.2f}%)")
|
| 272 |
|
| 273 |
return PredictionResponse(
|
| 274 |
+
prediction=disease,
|
| 275 |
confidence=confidence,
|
| 276 |
+
top_5_predictions=top_5,
|
| 277 |
+
raw_values=input_dict,
|
| 278 |
+
model_info={
|
| 279 |
+
"base_models": 2, # rf_model and nn_model
|
| 280 |
+
"features_used": len(feature_cols),
|
| 281 |
+
"meta_input_shape": [1, len(label_encoder.classes_) * 2],
|
| 282 |
+
"n_classes": len(label_encoder.classes_)
|
| 283 |
+
}
|
| 284 |
)
|
| 285 |
|
| 286 |
except ValueError as e:
|