from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List import json import os import logging # Import the existing symptom checker logic from api_symptom_checker import load_artifacts, predict_symptoms_json import numpy as np def safe_predict_symptoms_json(symptoms, model, label_encoder, feature_names): """Safe prediction that only uses diseases the label encoder knows about""" if not symptoms: return {"error": "No symptoms provided"} # Build feature vector (convert display names back to feature names) feature_dict = {name.replace("_", " ").title(): name for name in feature_names} x = np.zeros(len(feature_names)) matched_symptoms = [] for symptom in symptoms: if symptom in feature_dict: feature_name = feature_dict[symptom] if feature_name in feature_names: idx = feature_names.index(feature_name) x[idx] = 1.0 matched_symptoms.append(symptom) if len(matched_symptoms) == 0: return {"error": "No valid symptoms found"} x = x.reshape(1, -1) # Get predictions - but only use classes the label encoder knows about proba = model.predict_proba(x)[0] # SAFETY: Only use the first len(label_encoder.classes_) predictions max_valid_class = len(label_encoder.classes_) valid_proba = proba[:max_valid_class] # Only use valid classes # Get top 3 from valid classes only top3_idx = np.argsort(valid_proba)[-3:][::-1] predictions = [] for rank, idx in enumerate(top3_idx, 1): disease_name = label_encoder.inverse_transform([idx])[0] confidence = float(valid_proba[idx]) predictions.append({ "rank": rank, "disease": disease_name, "confidence": confidence, "confidence_percent": round(confidence * 100, 2) }) return { "input_symptoms": matched_symptoms, "primary_diagnosis": predictions[0], "top_predictions": predictions, "model_confidence": "high" if predictions[0]["confidence"] > 0.7 else "medium" if predictions[0]["confidence"] > 0.4 else "low" } # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Symptom Checker API", description="AI-powered symptom analysis service", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Configure this properly for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables for model artifacts model = None label_encoder = None feature_names = None # Pydantic models for request/response class SymptomRequest(BaseModel): symptoms: List[str] class PredictionItem(BaseModel): rank: int disease: str confidence: float confidence_percent: float class SymptomResponse(BaseModel): input_symptoms: List[str] primary_diagnosis: PredictionItem top_predictions: List[PredictionItem] model_confidence: str class AvailableSymptomsResponse(BaseModel): success: bool = True symptoms: List[str] total_symptoms: int @app.on_event("startup") async def startup_event(): """Load model artifacts on startup""" global model, label_encoder, feature_names try: logger.info("Loading symptom checker model artifacts...") model, label_encoder, feature_names = load_artifacts("symptom_model") logger.info(f"Model loaded successfully with {len(feature_names)} features") except Exception as e: logger.error(f"Failed to load model artifacts: {e}") raise e @app.get("/") async def root(): """Root endpoint""" return { "message": "Symptom Checker API", "version": "1.0.0", "endpoints": ["/health", "/api/symptoms", "/api/check-symptoms"] } @app.get("/health") async def health_check(): """Health check endpoint""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") return { "status": "healthy", "service": "symptom-checker", "model_loaded": model is not None, "features_count": len(feature_names) if feature_names else 0 } @app.get("/api/symptoms", response_model=AvailableSymptomsResponse) async def get_available_symptoms(): """Get list of all available symptoms that the model can recognize""" if feature_names is None: raise HTTPException(status_code=503, detail="Model not loaded") # Clean up symptom names for display clean_symptoms = [] for symptom in feature_names: # Convert from feature format to readable format clean_symptom = symptom.replace('_', ' ').title() clean_symptoms.append(clean_symptom) return AvailableSymptomsResponse( success=True, symptoms=sorted(clean_symptoms), total_symptoms=len(clean_symptoms) ) @app.post("/api/check-symptoms") async def check_symptoms(request: SymptomRequest): """Analyze symptoms and return disease predictions""" global model, label_encoder, feature_names if model is None or label_encoder is None or feature_names is None: raise HTTPException(status_code=503, detail="Model not loaded") if not request.symptoms: raise HTTPException(status_code=400, detail="No symptoms provided") try: # Convert display names back to feature names (Title Case With Spaces -> underscore_format) feature_symptoms = [] for symptom in request.symptoms: # Convert "Anxiety And Nervousness" -> "anxiety_and_nervousness" feature_format = symptom.lower().replace(' ', '_') feature_symptoms.append(feature_format) # Use the SAFE prediction logic that handles class mismatch result = safe_predict_symptoms_json(request.symptoms, model, label_encoder, feature_names) if "error" in result: raise HTTPException(status_code=400, detail=result["error"]) # Convert to response format predictions = [] for pred in result["top_predictions"]: predictions.append(PredictionItem( rank=pred["rank"], disease=pred["disease"], confidence=pred["confidence"], confidence_percent=pred["confidence_percent"] )) # Return format that matches Flutter's SymptomCheckResponse expectations return { "success": True, "predictions": [ { "rank": pred["rank"], "disease": pred["disease"], "confidence": pred["confidence"], "confidence_percent": f"{pred['confidence_percent']:.2f}%" } for pred in result["top_predictions"] ], "input_symptoms": request.symptoms, "primary_diagnosis": result["primary_diagnosis"]["disease"], "model_confidence": result["model_confidence"] } except Exception as e: logger.error(f"Error during symptom prediction: {e}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") if __name__ == "__main__": import uvicorn import os # Use port 7860 for Hugging Face Spaces, fallback to 8002 for local development port = int(os.getenv("PORT", 7860)) uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)