Final-text-classify / api_server.py
Tantawi's picture
Upload 14 files
9d27b5e verified
"""
FastAPI server for Symptom Checker ML model.
Provides endpoints compatible with Flutter mobile app.
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import numpy as np
from contextlib import asynccontextmanager
# Import from symptom_checker module
from symptom_checker import load_artifacts, build_feature_vector
# Global variables for model artifacts
model = None
label_encoder = None
feature_names = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model artifacts on startup."""
global model, label_encoder, feature_names
try:
model, label_encoder, feature_names = load_artifacts("symptom_model")
print(f"โœ… Model loaded successfully!")
print(f" - Features: {len(feature_names)}")
print(f" - Classes: {len(label_encoder.classes_)}")
except FileNotFoundError as e:
print(f"โŒ Error loading model: {e}")
raise RuntimeError("Failed to load model artifacts. Ensure symptom_model.* files exist.")
yield
# Cleanup (if needed)
print("๐Ÿ‘‹ Shutting down API server...")
app = FastAPI(
title="Symptom Checker API",
description="AI-powered symptom checker using XGBoost",
version="1.0.0",
lifespan=lifespan
)
# Enable CORS for Flutter app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify your app's domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============== Pydantic Models ==============
class SymptomCheckRequest(BaseModel):
symptoms: List[str]
class SymptomPrediction(BaseModel):
rank: int
disease: str
confidence: float
confidence_percent: str
class SymptomCheckResponse(BaseModel):
success: bool
predictions: List[SymptomPrediction]
input_symptoms: List[str]
error: Optional[str] = None
class AvailableSymptomsResponse(BaseModel):
success: bool
symptoms: List[str]
total_symptoms: int
error: Optional[str] = None
# ============== API Endpoints ==============
@app.get("/")
async def root():
"""Health check endpoint."""
return {
"status": "online",
"message": "Symptom Checker API is running",
"endpoints": {
"check_symptoms": "/api/check-symptoms",
"available_symptoms": "/api/symptoms"
}
}
@app.get("/api/symptoms", response_model=AvailableSymptomsResponse)
async def get_available_symptoms():
"""Get list of all available symptoms the model recognizes."""
try:
if feature_names is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return AvailableSymptomsResponse(
success=True,
symptoms=feature_names,
total_symptoms=len(feature_names),
error=None
)
except Exception as e:
return AvailableSymptomsResponse(
success=False,
symptoms=[],
total_symptoms=0,
error=str(e)
)
@app.post("/api/check-symptoms", response_model=SymptomCheckResponse)
async def check_symptoms(request: SymptomCheckRequest):
"""
Check symptoms and return disease predictions.
Request body:
{
"symptoms": ["fever", "cough", "headache"]
}
"""
try:
if model is None or label_encoder is None or feature_names is None:
raise HTTPException(status_code=503, detail="Model not loaded")
symptoms = request.symptoms
if not symptoms:
return SymptomCheckResponse(
success=False,
predictions=[],
input_symptoms=[],
error="No symptoms provided"
)
# Build feature vector from symptoms
x = build_feature_vector(feature_names, symptoms)
# Get predictions
proba = model.predict_proba(x)[0]
# Get top predictions (all classes sorted by probability)
top_indices = np.argsort(proba)[::-1]
# Build predictions list (top 5 most likely)
predictions = []
for rank, idx in enumerate(top_indices[:5], start=1):
disease_name = label_encoder.inverse_transform([idx])[0]
confidence = float(proba[idx])
predictions.append(SymptomPrediction(
rank=rank,
disease=disease_name,
confidence=confidence,
confidence_percent=f"{confidence * 100:.2f}%"
))
return SymptomCheckResponse(
success=True,
predictions=predictions,
input_symptoms=symptoms,
error=None
)
except Exception as e:
return SymptomCheckResponse(
success=False,
predictions=[],
input_symptoms=request.symptoms if request.symptoms else [],
error=str(e)
)
# ============== Run Server ==============
if __name__ == "__main__":
import uvicorn
import os
# Use PORT env variable for Hugging Face Spaces, default to 8000 for local dev
port = int(os.environ.get("PORT", 8000))
host = os.environ.get("HOST", "127.0.0.1")
print("๐Ÿš€ Starting Symptom Checker API server...")
print(f"๐Ÿ“ Access the API at: http://{host}:{port}")
print(f"๐Ÿ“– API docs at: http://{host}:{port}/docs")
uvicorn.run(app, host=host, port=port)