Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI server for Symptom Checker ML model. | |
| Provides endpoints compatible with Flutter mobile app. | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| import numpy as np | |
| from contextlib import asynccontextmanager | |
| # Import from symptom_checker module | |
| from symptom_checker import load_artifacts, build_feature_vector | |
| # Global variables for model artifacts | |
| model = None | |
| label_encoder = None | |
| feature_names = None | |
| async def lifespan(app: FastAPI): | |
| """Load model artifacts on startup.""" | |
| global model, label_encoder, feature_names | |
| try: | |
| model, label_encoder, feature_names = load_artifacts("symptom_model") | |
| print(f"โ Model loaded successfully!") | |
| print(f" - Features: {len(feature_names)}") | |
| print(f" - Classes: {len(label_encoder.classes_)}") | |
| except FileNotFoundError as e: | |
| print(f"โ Error loading model: {e}") | |
| raise RuntimeError("Failed to load model artifacts. Ensure symptom_model.* files exist.") | |
| yield | |
| # Cleanup (if needed) | |
| print("๐ Shutting down API server...") | |
| app = FastAPI( | |
| title="Symptom Checker API", | |
| description="AI-powered symptom checker using XGBoost", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Enable CORS for Flutter app | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify your app's domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============== Pydantic Models ============== | |
| class SymptomCheckRequest(BaseModel): | |
| symptoms: List[str] | |
| class SymptomPrediction(BaseModel): | |
| rank: int | |
| disease: str | |
| confidence: float | |
| confidence_percent: str | |
| class SymptomCheckResponse(BaseModel): | |
| success: bool | |
| predictions: List[SymptomPrediction] | |
| input_symptoms: List[str] | |
| error: Optional[str] = None | |
| class AvailableSymptomsResponse(BaseModel): | |
| success: bool | |
| symptoms: List[str] | |
| total_symptoms: int | |
| error: Optional[str] = None | |
| # ============== API Endpoints ============== | |
| async def root(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "online", | |
| "message": "Symptom Checker API is running", | |
| "endpoints": { | |
| "check_symptoms": "/api/check-symptoms", | |
| "available_symptoms": "/api/symptoms" | |
| } | |
| } | |
| async def get_available_symptoms(): | |
| """Get list of all available symptoms the model recognizes.""" | |
| try: | |
| if feature_names is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| return AvailableSymptomsResponse( | |
| success=True, | |
| symptoms=feature_names, | |
| total_symptoms=len(feature_names), | |
| error=None | |
| ) | |
| except Exception as e: | |
| return AvailableSymptomsResponse( | |
| success=False, | |
| symptoms=[], | |
| total_symptoms=0, | |
| error=str(e) | |
| ) | |
| async def check_symptoms(request: SymptomCheckRequest): | |
| """ | |
| Check symptoms and return disease predictions. | |
| Request body: | |
| { | |
| "symptoms": ["fever", "cough", "headache"] | |
| } | |
| """ | |
| try: | |
| if model is None or label_encoder is None or feature_names is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| symptoms = request.symptoms | |
| if not symptoms: | |
| return SymptomCheckResponse( | |
| success=False, | |
| predictions=[], | |
| input_symptoms=[], | |
| error="No symptoms provided" | |
| ) | |
| # Build feature vector from symptoms | |
| x = build_feature_vector(feature_names, symptoms) | |
| # Get predictions | |
| proba = model.predict_proba(x)[0] | |
| # Get top predictions (all classes sorted by probability) | |
| top_indices = np.argsort(proba)[::-1] | |
| # Build predictions list (top 5 most likely) | |
| predictions = [] | |
| for rank, idx in enumerate(top_indices[:5], start=1): | |
| disease_name = label_encoder.inverse_transform([idx])[0] | |
| confidence = float(proba[idx]) | |
| predictions.append(SymptomPrediction( | |
| rank=rank, | |
| disease=disease_name, | |
| confidence=confidence, | |
| confidence_percent=f"{confidence * 100:.2f}%" | |
| )) | |
| return SymptomCheckResponse( | |
| success=True, | |
| predictions=predictions, | |
| input_symptoms=symptoms, | |
| error=None | |
| ) | |
| except Exception as e: | |
| return SymptomCheckResponse( | |
| success=False, | |
| predictions=[], | |
| input_symptoms=request.symptoms if request.symptoms else [], | |
| error=str(e) | |
| ) | |
| # ============== Run Server ============== | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import os | |
| # Use PORT env variable for Hugging Face Spaces, default to 8000 for local dev | |
| port = int(os.environ.get("PORT", 8000)) | |
| host = os.environ.get("HOST", "127.0.0.1") | |
| print("๐ Starting Symptom Checker API server...") | |
| print(f"๐ Access the API at: http://{host}:{port}") | |
| print(f"๐ API docs at: http://{host}:{port}/docs") | |
| uvicorn.run(app, host=host, port=port) | |