File size: 5,761 Bytes
9d27b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""

FastAPI server for Symptom Checker ML model.

Provides endpoints compatible with Flutter mobile app.

"""

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import numpy as np
from contextlib import asynccontextmanager

# Import from symptom_checker module
from symptom_checker import load_artifacts, build_feature_vector

# Global variables for model artifacts
model = None
label_encoder = None
feature_names = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Load model artifacts on startup."""
    global model, label_encoder, feature_names
    try:
        model, label_encoder, feature_names = load_artifacts("symptom_model")
        print(f"โœ… Model loaded successfully!")
        print(f"   - Features: {len(feature_names)}")
        print(f"   - Classes: {len(label_encoder.classes_)}")
    except FileNotFoundError as e:
        print(f"โŒ Error loading model: {e}")
        raise RuntimeError("Failed to load model artifacts. Ensure symptom_model.* files exist.")
    yield
    # Cleanup (if needed)
    print("๐Ÿ‘‹ Shutting down API server...")


app = FastAPI(
    title="Symptom Checker API",
    description="AI-powered symptom checker using XGBoost",
    version="1.0.0",
    lifespan=lifespan
)

# Enable CORS for Flutter app
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, specify your app's domain
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# ============== Pydantic Models ==============

class SymptomCheckRequest(BaseModel):
    symptoms: List[str]


class SymptomPrediction(BaseModel):
    rank: int
    disease: str
    confidence: float
    confidence_percent: str


class SymptomCheckResponse(BaseModel):
    success: bool
    predictions: List[SymptomPrediction]
    input_symptoms: List[str]
    error: Optional[str] = None


class AvailableSymptomsResponse(BaseModel):
    success: bool
    symptoms: List[str]
    total_symptoms: int
    error: Optional[str] = None


# ============== API Endpoints ==============

@app.get("/")
async def root():
    """Health check endpoint."""
    return {
        "status": "online",
        "message": "Symptom Checker API is running",
        "endpoints": {
            "check_symptoms": "/api/check-symptoms",
            "available_symptoms": "/api/symptoms"
        }
    }


@app.get("/api/symptoms", response_model=AvailableSymptomsResponse)
async def get_available_symptoms():
    """Get list of all available symptoms the model recognizes."""
    try:
        if feature_names is None:
            raise HTTPException(status_code=503, detail="Model not loaded")
        
        return AvailableSymptomsResponse(
            success=True,
            symptoms=feature_names,
            total_symptoms=len(feature_names),
            error=None
        )
    except Exception as e:
        return AvailableSymptomsResponse(
            success=False,
            symptoms=[],
            total_symptoms=0,
            error=str(e)
        )


@app.post("/api/check-symptoms", response_model=SymptomCheckResponse)
async def check_symptoms(request: SymptomCheckRequest):
    """

    Check symptoms and return disease predictions.

    

    Request body:

    {

        "symptoms": ["fever", "cough", "headache"]

    }

    """
    try:
        if model is None or label_encoder is None or feature_names is None:
            raise HTTPException(status_code=503, detail="Model not loaded")
        
        symptoms = request.symptoms
        
        if not symptoms:
            return SymptomCheckResponse(
                success=False,
                predictions=[],
                input_symptoms=[],
                error="No symptoms provided"
            )
        
        # Build feature vector from symptoms
        x = build_feature_vector(feature_names, symptoms)
        
        # Get predictions
        proba = model.predict_proba(x)[0]
        
        # Get top predictions (all classes sorted by probability)
        top_indices = np.argsort(proba)[::-1]
        
        # Build predictions list (top 5 most likely)
        predictions = []
        for rank, idx in enumerate(top_indices[:5], start=1):
            disease_name = label_encoder.inverse_transform([idx])[0]
            confidence = float(proba[idx])
            predictions.append(SymptomPrediction(
                rank=rank,
                disease=disease_name,
                confidence=confidence,
                confidence_percent=f"{confidence * 100:.2f}%"
            ))
        
        return SymptomCheckResponse(
            success=True,
            predictions=predictions,
            input_symptoms=symptoms,
            error=None
        )
        
    except Exception as e:
        return SymptomCheckResponse(
            success=False,
            predictions=[],
            input_symptoms=request.symptoms if request.symptoms else [],
            error=str(e)
        )


# ============== Run Server ==============

if __name__ == "__main__":
    import uvicorn
    import os
    
    # Use PORT env variable for Hugging Face Spaces, default to 8000 for local dev
    port = int(os.environ.get("PORT", 8000))
    host = os.environ.get("HOST", "127.0.0.1")
    
    print("๐Ÿš€ Starting Symptom Checker API server...")
    print(f"๐Ÿ“ Access the API at: http://{host}:{port}")
    print(f"๐Ÿ“– API docs at: http://{host}:{port}/docs")
    uvicorn.run(app, host=host, port=port)