""" FastAPI server for Symptom Checker ML model. Provides endpoints compatible with Flutter mobile app. """ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List, Optional import numpy as np from contextlib import asynccontextmanager # Import from symptom_checker module from symptom_checker import load_artifacts, build_feature_vector # Global variables for model artifacts model = None label_encoder = None feature_names = None @asynccontextmanager async def lifespan(app: FastAPI): """Load model artifacts on startup.""" global model, label_encoder, feature_names try: model, label_encoder, feature_names = load_artifacts("symptom_model") print(f"✅ Model loaded successfully!") print(f" - Features: {len(feature_names)}") print(f" - Classes: {len(label_encoder.classes_)}") except FileNotFoundError as e: print(f"❌ Error loading model: {e}") raise RuntimeError("Failed to load model artifacts. Ensure symptom_model.* files exist.") yield # Cleanup (if needed) print("👋 Shutting down API server...") app = FastAPI( title="Symptom Checker API", description="AI-powered symptom checker using XGBoost", version="1.0.0", lifespan=lifespan ) # Enable CORS for Flutter app app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify your app's domain allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============== Pydantic Models ============== class SymptomCheckRequest(BaseModel): symptoms: List[str] class SymptomPrediction(BaseModel): rank: int disease: str confidence: float confidence_percent: str class SymptomCheckResponse(BaseModel): success: bool predictions: List[SymptomPrediction] input_symptoms: List[str] error: Optional[str] = None class AvailableSymptomsResponse(BaseModel): success: bool symptoms: List[str] total_symptoms: int error: Optional[str] = None # ============== API Endpoints ============== @app.get("/") async def root(): """Health check endpoint.""" return { "status": "online", "message": "Symptom Checker API is running", "endpoints": { "check_symptoms": "/api/check-symptoms", "available_symptoms": "/api/symptoms" } } @app.get("/api/symptoms", response_model=AvailableSymptomsResponse) async def get_available_symptoms(): """Get list of all available symptoms the model recognizes.""" try: if feature_names is None: raise HTTPException(status_code=503, detail="Model not loaded") return AvailableSymptomsResponse( success=True, symptoms=feature_names, total_symptoms=len(feature_names), error=None ) except Exception as e: return AvailableSymptomsResponse( success=False, symptoms=[], total_symptoms=0, error=str(e) ) @app.post("/api/check-symptoms", response_model=SymptomCheckResponse) async def check_symptoms(request: SymptomCheckRequest): """ Check symptoms and return disease predictions. Request body: { "symptoms": ["fever", "cough", "headache"] } """ try: if model is None or label_encoder is None or feature_names is None: raise HTTPException(status_code=503, detail="Model not loaded") symptoms = request.symptoms if not symptoms: return SymptomCheckResponse( success=False, predictions=[], input_symptoms=[], error="No symptoms provided" ) # Build feature vector from symptoms x = build_feature_vector(feature_names, symptoms) # Get predictions proba = model.predict_proba(x)[0] # Get top predictions (all classes sorted by probability) top_indices = np.argsort(proba)[::-1] # Build predictions list (top 5 most likely) predictions = [] for rank, idx in enumerate(top_indices[:5], start=1): disease_name = label_encoder.inverse_transform([idx])[0] confidence = float(proba[idx]) predictions.append(SymptomPrediction( rank=rank, disease=disease_name, confidence=confidence, confidence_percent=f"{confidence * 100:.2f}%" )) return SymptomCheckResponse( success=True, predictions=predictions, input_symptoms=symptoms, error=None ) except Exception as e: return SymptomCheckResponse( success=False, predictions=[], input_symptoms=request.symptoms if request.symptoms else [], error=str(e) ) # ============== Run Server ============== if __name__ == "__main__": import uvicorn import os # Use PORT env variable for Hugging Face Spaces, default to 8000 for local dev port = int(os.environ.get("PORT", 8000)) host = os.environ.get("HOST", "127.0.0.1") print("🚀 Starting Symptom Checker API server...") print(f"📍 Access the API at: http://{host}:{port}") print(f"📖 API docs at: http://{host}:{port}/docs") uvicorn.run(app, host=host, port=port)