Spaces:
Sleeping
Sleeping
File size: 5,761 Bytes
9d27b5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | """
FastAPI server for Symptom Checker ML model.
Provides endpoints compatible with Flutter mobile app.
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import numpy as np
from contextlib import asynccontextmanager
# Import from symptom_checker module
from symptom_checker import load_artifacts, build_feature_vector
# Global variables for model artifacts
model = None
label_encoder = None
feature_names = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model artifacts on startup."""
global model, label_encoder, feature_names
try:
model, label_encoder, feature_names = load_artifacts("symptom_model")
print(f"โ
Model loaded successfully!")
print(f" - Features: {len(feature_names)}")
print(f" - Classes: {len(label_encoder.classes_)}")
except FileNotFoundError as e:
print(f"โ Error loading model: {e}")
raise RuntimeError("Failed to load model artifacts. Ensure symptom_model.* files exist.")
yield
# Cleanup (if needed)
print("๐ Shutting down API server...")
app = FastAPI(
title="Symptom Checker API",
description="AI-powered symptom checker using XGBoost",
version="1.0.0",
lifespan=lifespan
)
# Enable CORS for Flutter app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify your app's domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============== Pydantic Models ==============
class SymptomCheckRequest(BaseModel):
symptoms: List[str]
class SymptomPrediction(BaseModel):
rank: int
disease: str
confidence: float
confidence_percent: str
class SymptomCheckResponse(BaseModel):
success: bool
predictions: List[SymptomPrediction]
input_symptoms: List[str]
error: Optional[str] = None
class AvailableSymptomsResponse(BaseModel):
success: bool
symptoms: List[str]
total_symptoms: int
error: Optional[str] = None
# ============== API Endpoints ==============
@app.get("/")
async def root():
"""Health check endpoint."""
return {
"status": "online",
"message": "Symptom Checker API is running",
"endpoints": {
"check_symptoms": "/api/check-symptoms",
"available_symptoms": "/api/symptoms"
}
}
@app.get("/api/symptoms", response_model=AvailableSymptomsResponse)
async def get_available_symptoms():
"""Get list of all available symptoms the model recognizes."""
try:
if feature_names is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return AvailableSymptomsResponse(
success=True,
symptoms=feature_names,
total_symptoms=len(feature_names),
error=None
)
except Exception as e:
return AvailableSymptomsResponse(
success=False,
symptoms=[],
total_symptoms=0,
error=str(e)
)
@app.post("/api/check-symptoms", response_model=SymptomCheckResponse)
async def check_symptoms(request: SymptomCheckRequest):
"""
Check symptoms and return disease predictions.
Request body:
{
"symptoms": ["fever", "cough", "headache"]
}
"""
try:
if model is None or label_encoder is None or feature_names is None:
raise HTTPException(status_code=503, detail="Model not loaded")
symptoms = request.symptoms
if not symptoms:
return SymptomCheckResponse(
success=False,
predictions=[],
input_symptoms=[],
error="No symptoms provided"
)
# Build feature vector from symptoms
x = build_feature_vector(feature_names, symptoms)
# Get predictions
proba = model.predict_proba(x)[0]
# Get top predictions (all classes sorted by probability)
top_indices = np.argsort(proba)[::-1]
# Build predictions list (top 5 most likely)
predictions = []
for rank, idx in enumerate(top_indices[:5], start=1):
disease_name = label_encoder.inverse_transform([idx])[0]
confidence = float(proba[idx])
predictions.append(SymptomPrediction(
rank=rank,
disease=disease_name,
confidence=confidence,
confidence_percent=f"{confidence * 100:.2f}%"
))
return SymptomCheckResponse(
success=True,
predictions=predictions,
input_symptoms=symptoms,
error=None
)
except Exception as e:
return SymptomCheckResponse(
success=False,
predictions=[],
input_symptoms=request.symptoms if request.symptoms else [],
error=str(e)
)
# ============== Run Server ==============
if __name__ == "__main__":
import uvicorn
import os
# Use PORT env variable for Hugging Face Spaces, default to 8000 for local dev
port = int(os.environ.get("PORT", 8000))
host = os.environ.get("HOST", "127.0.0.1")
print("๐ Starting Symptom Checker API server...")
print(f"๐ Access the API at: http://{host}:{port}")
print(f"๐ API docs at: http://{host}:{port}/docs")
uvicorn.run(app, host=host, port=port)
|