Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI Backend for Symptom Checker | |
| Provides REST API endpoints for the Flutter mobile application. | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| import numpy as np | |
| import xgboost as xgb | |
| from sklearn.preprocessing import LabelEncoder | |
| import os | |
| # ============================================================================ | |
| # Pydantic Models (matching Flutter frontend expectations) | |
| # ============================================================================ | |
| class SymptomsRequest(BaseModel): | |
| symptoms: List[str] | |
| class SymptomPrediction(BaseModel): | |
| rank: int | |
| disease: str | |
| confidence: float | |
| confidence_percent: str | |
| class SymptomCheckResponse(BaseModel): | |
| success: bool | |
| predictions: List[SymptomPrediction] | |
| input_symptoms: List[str] | |
| error: Optional[str] | |
| class AvailableSymptomsResponse(BaseModel): | |
| success: bool | |
| symptoms: List[str] | |
| total_symptoms: int | |
| error: Optional[str] | |
| # ============================================================================ | |
| # Model Loading (same as symptom_checker.py) | |
| # ============================================================================ | |
| class LoadedModel: | |
| """Wrapper for loaded XGBoost model that provides predict_proba functionality.""" | |
| def __init__(self, booster: xgb.Booster, n_classes: int, feature_names: List[str] = None): | |
| self.booster = booster | |
| self.n_classes = n_classes | |
| self.feature_names = feature_names | |
| def predict_proba(self, X: np.ndarray) -> np.ndarray: | |
| """Return probability predictions using the booster.""" | |
| dmatrix = xgb.DMatrix(X, feature_names=self.feature_names) | |
| preds = self.booster.predict(dmatrix) | |
| if len(preds.shape) == 1: | |
| return np.column_stack([1 - preds, preds]) | |
| return preds | |
| def load_artifacts(prefix: str): | |
| """Load model artifacts from files.""" | |
| model_path = f"{prefix}.json" | |
| labels_path = f"{prefix}.labels.npy" | |
| features_path = f"{prefix}.features.txt" | |
| if not (os.path.exists(model_path) and os.path.exists(labels_path) and os.path.exists(features_path)): | |
| raise FileNotFoundError( | |
| f"Missing artifacts. Expected: '{model_path}', '{labels_path}', '{features_path}'." | |
| ) | |
| # Load label encoder classes | |
| label_encoder = LabelEncoder() | |
| classes = np.load(labels_path, allow_pickle=True) | |
| label_encoder.classes_ = classes | |
| n_classes = len(classes) | |
| # Load feature names | |
| with open(features_path, "r", encoding="utf-8") as f: | |
| feature_names = [line.strip() for line in f if line.strip()] | |
| # Load model using Booster | |
| booster = xgb.Booster() | |
| booster.load_model(model_path) | |
| model = LoadedModel(booster, n_classes, feature_names) | |
| return model, label_encoder, feature_names | |
| def build_feature_vector(symptom_names: List[str], selected: List[str]) -> np.ndarray: | |
| """Build a binary feature vector from selected symptoms.""" | |
| features = np.zeros(len(symptom_names), dtype=float) | |
| name_to_index = {name.lower().strip(): idx for idx, name in enumerate(symptom_names)} | |
| for s in selected: | |
| key = s.lower().strip() | |
| if key in name_to_index: | |
| features[name_to_index[key]] = 1.0 | |
| return features.reshape(1, -1) | |
| # ============================================================================ | |
| # FastAPI App Setup | |
| # ============================================================================ | |
| app = FastAPI( | |
| title="Symptom Checker API", | |
| description="AI-powered symptom checker using XGBoost", | |
| version="1.0.0" | |
| ) | |
| # Enable CORS for Flutter app | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify your app's domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for model (loaded on startup) | |
| model = None | |
| label_encoder = None | |
| feature_names = None | |
| async def startup_event(): | |
| """Load model artifacts on startup.""" | |
| global model, label_encoder, feature_names | |
| # Get the directory where this script is located | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| artifacts_prefix = os.path.join(script_dir, "symptom_model") | |
| try: | |
| model, label_encoder, feature_names = load_artifacts(artifacts_prefix) | |
| print(f"✅ Model loaded successfully!") | |
| print(f" - Features: {len(feature_names)}") | |
| print(f" - Classes: {len(label_encoder.classes_)}") | |
| except Exception as e: | |
| print(f"❌ Failed to load model: {e}") | |
| raise | |
| # ============================================================================ | |
| # API Endpoints | |
| # ============================================================================ | |
| async def root(): | |
| """Health check endpoint.""" | |
| return {"status": "healthy", "message": "Symptom Checker API is running"} | |
| async def get_available_symptoms(): | |
| """ | |
| Get list of all available symptoms the model recognizes. | |
| """ | |
| try: | |
| if feature_names is None: | |
| return AvailableSymptomsResponse( | |
| success=False, | |
| symptoms=[], | |
| total_symptoms=0, | |
| error="Model not loaded" | |
| ) | |
| # Return symptoms with proper capitalization | |
| formatted_symptoms = [s.replace("_", " ").title() for s in feature_names] | |
| return AvailableSymptomsResponse( | |
| success=True, | |
| symptoms=formatted_symptoms, | |
| total_symptoms=len(formatted_symptoms), | |
| error=None | |
| ) | |
| except Exception as e: | |
| return AvailableSymptomsResponse( | |
| success=False, | |
| symptoms=[], | |
| total_symptoms=0, | |
| error=str(e) | |
| ) | |
| async def check_symptoms(request: SymptomsRequest): | |
| """ | |
| Check symptoms and return disease predictions. | |
| """ | |
| try: | |
| if model is None or label_encoder is None or feature_names is None: | |
| return SymptomCheckResponse( | |
| success=False, | |
| predictions=[], | |
| input_symptoms=request.symptoms, | |
| error="Model not loaded" | |
| ) | |
| if not request.symptoms: | |
| return SymptomCheckResponse( | |
| success=False, | |
| predictions=[], | |
| input_symptoms=[], | |
| error="No symptoms provided" | |
| ) | |
| # Build feature vector | |
| x = build_feature_vector(feature_names, request.symptoms) | |
| # Get predictions | |
| proba = model.predict_proba(x)[0] | |
| # Get top predictions (all classes sorted by probability) | |
| top_indices = np.argsort(proba)[::-1] | |
| # Build predictions list (top 5 by default) | |
| predictions = [] | |
| for rank, idx in enumerate(top_indices[:5], start=1): | |
| disease_name = label_encoder.inverse_transform([idx])[0] | |
| confidence = float(proba[idx]) | |
| predictions.append(SymptomPrediction( | |
| rank=rank, | |
| disease=str(disease_name), | |
| confidence=round(confidence, 4), | |
| confidence_percent=f"{confidence * 100:.2f}%" | |
| )) | |
| return SymptomCheckResponse( | |
| success=True, | |
| predictions=predictions, | |
| input_symptoms=request.symptoms, | |
| error=None | |
| ) | |
| except Exception as e: | |
| return SymptomCheckResponse( | |
| success=False, | |
| predictions=[], | |
| input_symptoms=request.symptoms, | |
| error=str(e) | |
| ) | |
| # ============================================================================ | |
| # Run with: uvicorn main:app --reload --host 0.0.0.0 --port 8000 | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |