Tantawi commited on
Commit
9d27b5e
·
verified ·
1 Parent(s): 7fcdac5

Upload 14 files

Browse files
.dockerignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ venv/
7
+ env/
8
+ .env
9
+ .venv
10
+ *.egg-info/
11
+ dist/
12
+ build/
13
+ .git/
14
+ .gitignore
15
+ *.md
16
+ !README.md
17
+ .DS_Store
18
+ Thumbs.db
19
+ *.log
20
+ cleaned_dataset.csv
21
+ *.csv
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.10 slim image
2
+ FROM python:3.10-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Set environment variables
8
+ ENV PYTHONDONTWRITEBYTECODE=1
9
+ ENV PYTHONUNBUFFERED=1
10
+
11
+ # Install system dependencies
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ build-essential \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Copy requirements first for better caching
17
+ COPY requirements.txt .
18
+
19
+ # Install Python dependencies
20
+ RUN pip install --no-cache-dir --upgrade pip && \
21
+ pip install --no-cache-dir -r requirements.txt
22
+
23
+ # Copy application code
24
+ COPY symptom_checker.py .
25
+ COPY api_server.py .
26
+
27
+ # Copy model artifacts
28
+ COPY symptom_model.json .
29
+ COPY symptom_model.labels.npy .
30
+ COPY symptom_model.features.txt .
31
+
32
+ # Expose port (Hugging Face Spaces uses port 7860)
33
+ EXPOSE 7860
34
+
35
+ # Create a non-root user for security
36
+ RUN useradd -m -u 1000 appuser
37
+ USER appuser
38
+
39
+ # Run the FastAPI server
40
+ CMD ["uvicorn", "api_server:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,60 @@
1
- ---
2
- title: Final Text Classify
3
- emoji: 📊
4
- colorFrom: gray
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Symptom Checker (XGBoost)
2
+
3
+ This folder contains the standalone symptom checker based on XGBoost.
4
+
5
+ ## Contents
6
+ - `symptom_checker.py` — Train, save artifacts, evaluate, and interactive prediction
7
+ - `preprocess_data.py` — Clean the raw dataset into `cleaned_dataset.csv`
8
+ - `evaluate_symptom_checker.py` — Train/test split evaluation
9
+ - `requirements.txt` — Minimal dependencies
10
+ - (Optional) `cleaned_dataset.csv` Cleaned dataset for evaluation
11
+ - Artifacts after saving:
12
+ - `symptom_model.json` — Trained model
13
+ - `symptom_model.labels.npy` — Label encoder classes
14
+ - `symptom_model.features.txt` — Feature order used by the model
15
+
16
+ ## Setup
17
+ ```bash
18
+ pip install -r requirements.txt
19
+ ```
20
+
21
+ ## Preprocess (optional)
22
+ ```bash
23
+ python preprocess_data.py --input "Disease and symptoms dataset.csv" --output cleaned_dataset.csv
24
+ ```
25
+
26
+ ## Train and Save Artifacts (one-time)
27
+ ```bash
28
+ python symptom_checker.py --csv cleaned_dataset.csv --save-prefix symptom_model
29
+ ```
30
+ Creates:
31
+ - `symptom_model.json`
32
+ - `symptom_model.labels.npy`
33
+ - `symptom_model.features.txt`
34
+
35
+ ## Evaluate Saved Model (no retraining)
36
+ ```bash
37
+ python symptom_checker.py --eval-only --csv cleaned_dataset.csv --artifacts-prefix symptom_model
38
+ ```
39
+
40
+ ## Interactive Predictions (No Training Needed)
41
+ ```bash
42
+ python symptom_checker.py --interactive-only --artifacts-prefix symptom_model
43
+ ```
44
+ Enter symptoms separated by commas (e.g., `fever, cough, headache`). Type `list` to see features, or `quit` to exit.
45
+
46
+ ## API-Style Predictions (Multiple Output Formats)
47
+ ```bash
48
+ # JSON format (best for web apps/APIs)
49
+ python api_symptom_checker.py --symptoms fever cough headache --format json
50
+
51
+ # CSV format (best for data analysis)
52
+ python api_symptom_checker.py --symptoms fever cough headache --format csv
53
+
54
+ # Simple format (best for CLI)
55
+ python api_symptom_checker.py --symptoms fever cough headache --format simple
56
+ ```
57
+
58
+ ## Notes
59
+ - GPU is used automatically if supported by your XGBoost build; otherwise CPU.
60
+ - Keep the three artifact files together for evaluation and interactive use.
README_HF.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Symptom Checker API
3
+ emoji: 🩺
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ app_port: 7860
10
+ ---
11
+
12
+ # Symptom Checker API
13
+
14
+ AI-powered symptom checker using XGBoost machine learning model.
15
+
16
+ ## API Endpoints
17
+
18
+ ### Health Check
19
+ - **GET** `/` - Check if API is running
20
+
21
+ ### Get Available Symptoms
22
+ - **GET** `/api/symptoms` - Returns list of all symptoms the model recognizes
23
+
24
+ ### Check Symptoms
25
+ - **POST** `/api/check-symptoms` - Analyze symptoms and get disease predictions
26
+
27
+ #### Request Body:
28
+ ```json
29
+ {
30
+ "symptoms": ["fever", "cough", "headache"]
31
+ }
32
+ ```
33
+
34
+ #### Response:
35
+ ```json
36
+ {
37
+ "success": true,
38
+ "predictions": [
39
+ {
40
+ "rank": 1,
41
+ "disease": "Disease Name",
42
+ "confidence": 0.85,
43
+ "confidence_percent": "85.00%"
44
+ }
45
+ ],
46
+ "input_symptoms": ["fever", "cough", "headache"],
47
+ "error": null
48
+ }
49
+ ```
50
+
51
+ ## Documentation
52
+
53
+ Visit `/docs` for interactive Swagger documentation.
api.py ADDED
File without changes
api_server.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI server for Symptom Checker ML model.
3
+ Provides endpoints compatible with Flutter mobile app.
4
+ """
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
+ from typing import List, Optional
10
+ import numpy as np
11
+ from contextlib import asynccontextmanager
12
+
13
+ # Import from symptom_checker module
14
+ from symptom_checker import load_artifacts, build_feature_vector
15
+
16
+ # Global variables for model artifacts
17
+ model = None
18
+ label_encoder = None
19
+ feature_names = None
20
+
21
+
22
+ @asynccontextmanager
23
+ async def lifespan(app: FastAPI):
24
+ """Load model artifacts on startup."""
25
+ global model, label_encoder, feature_names
26
+ try:
27
+ model, label_encoder, feature_names = load_artifacts("symptom_model")
28
+ print(f"✅ Model loaded successfully!")
29
+ print(f" - Features: {len(feature_names)}")
30
+ print(f" - Classes: {len(label_encoder.classes_)}")
31
+ except FileNotFoundError as e:
32
+ print(f"❌ Error loading model: {e}")
33
+ raise RuntimeError("Failed to load model artifacts. Ensure symptom_model.* files exist.")
34
+ yield
35
+ # Cleanup (if needed)
36
+ print("👋 Shutting down API server...")
37
+
38
+
39
+ app = FastAPI(
40
+ title="Symptom Checker API",
41
+ description="AI-powered symptom checker using XGBoost",
42
+ version="1.0.0",
43
+ lifespan=lifespan
44
+ )
45
+
46
+ # Enable CORS for Flutter app
47
+ app.add_middleware(
48
+ CORSMiddleware,
49
+ allow_origins=["*"], # In production, specify your app's domain
50
+ allow_credentials=True,
51
+ allow_methods=["*"],
52
+ allow_headers=["*"],
53
+ )
54
+
55
+
56
+ # ============== Pydantic Models ==============
57
+
58
+ class SymptomCheckRequest(BaseModel):
59
+ symptoms: List[str]
60
+
61
+
62
+ class SymptomPrediction(BaseModel):
63
+ rank: int
64
+ disease: str
65
+ confidence: float
66
+ confidence_percent: str
67
+
68
+
69
+ class SymptomCheckResponse(BaseModel):
70
+ success: bool
71
+ predictions: List[SymptomPrediction]
72
+ input_symptoms: List[str]
73
+ error: Optional[str] = None
74
+
75
+
76
+ class AvailableSymptomsResponse(BaseModel):
77
+ success: bool
78
+ symptoms: List[str]
79
+ total_symptoms: int
80
+ error: Optional[str] = None
81
+
82
+
83
+ # ============== API Endpoints ==============
84
+
85
+ @app.get("/")
86
+ async def root():
87
+ """Health check endpoint."""
88
+ return {
89
+ "status": "online",
90
+ "message": "Symptom Checker API is running",
91
+ "endpoints": {
92
+ "check_symptoms": "/api/check-symptoms",
93
+ "available_symptoms": "/api/symptoms"
94
+ }
95
+ }
96
+
97
+
98
+ @app.get("/api/symptoms", response_model=AvailableSymptomsResponse)
99
+ async def get_available_symptoms():
100
+ """Get list of all available symptoms the model recognizes."""
101
+ try:
102
+ if feature_names is None:
103
+ raise HTTPException(status_code=503, detail="Model not loaded")
104
+
105
+ return AvailableSymptomsResponse(
106
+ success=True,
107
+ symptoms=feature_names,
108
+ total_symptoms=len(feature_names),
109
+ error=None
110
+ )
111
+ except Exception as e:
112
+ return AvailableSymptomsResponse(
113
+ success=False,
114
+ symptoms=[],
115
+ total_symptoms=0,
116
+ error=str(e)
117
+ )
118
+
119
+
120
+ @app.post("/api/check-symptoms", response_model=SymptomCheckResponse)
121
+ async def check_symptoms(request: SymptomCheckRequest):
122
+ """
123
+ Check symptoms and return disease predictions.
124
+
125
+ Request body:
126
+ {
127
+ "symptoms": ["fever", "cough", "headache"]
128
+ }
129
+ """
130
+ try:
131
+ if model is None or label_encoder is None or feature_names is None:
132
+ raise HTTPException(status_code=503, detail="Model not loaded")
133
+
134
+ symptoms = request.symptoms
135
+
136
+ if not symptoms:
137
+ return SymptomCheckResponse(
138
+ success=False,
139
+ predictions=[],
140
+ input_symptoms=[],
141
+ error="No symptoms provided"
142
+ )
143
+
144
+ # Build feature vector from symptoms
145
+ x = build_feature_vector(feature_names, symptoms)
146
+
147
+ # Get predictions
148
+ proba = model.predict_proba(x)[0]
149
+
150
+ # Get top predictions (all classes sorted by probability)
151
+ top_indices = np.argsort(proba)[::-1]
152
+
153
+ # Build predictions list (top 5 most likely)
154
+ predictions = []
155
+ for rank, idx in enumerate(top_indices[:5], start=1):
156
+ disease_name = label_encoder.inverse_transform([idx])[0]
157
+ confidence = float(proba[idx])
158
+ predictions.append(SymptomPrediction(
159
+ rank=rank,
160
+ disease=disease_name,
161
+ confidence=confidence,
162
+ confidence_percent=f"{confidence * 100:.2f}%"
163
+ ))
164
+
165
+ return SymptomCheckResponse(
166
+ success=True,
167
+ predictions=predictions,
168
+ input_symptoms=symptoms,
169
+ error=None
170
+ )
171
+
172
+ except Exception as e:
173
+ return SymptomCheckResponse(
174
+ success=False,
175
+ predictions=[],
176
+ input_symptoms=request.symptoms if request.symptoms else [],
177
+ error=str(e)
178
+ )
179
+
180
+
181
+ # ============== Run Server ==============
182
+
183
+ if __name__ == "__main__":
184
+ import uvicorn
185
+ import os
186
+
187
+ # Use PORT env variable for Hugging Face Spaces, default to 8000 for local dev
188
+ port = int(os.environ.get("PORT", 8000))
189
+ host = os.environ.get("HOST", "127.0.0.1")
190
+
191
+ print("🚀 Starting Symptom Checker API server...")
192
+ print(f"📍 Access the API at: http://{host}:{port}")
193
+ print(f"📖 API docs at: http://{host}:{port}/docs")
194
+ uvicorn.run(app, host=host, port=port)
api_symptom_checker.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from typing import List, Dict, Any
5
+
6
+ import numpy as np
7
+ import xgboost as xgb
8
+ from sklearn.preprocessing import LabelEncoder
9
+
10
+
11
+ def load_artifacts(prefix: str):
12
+ """Load the trained model artifacts."""
13
+ model_path = f"{prefix}.json"
14
+ labels_path = f"{prefix}.labels.npy"
15
+ features_path = f"{prefix}.features.txt"
16
+
17
+ if not (os.path.exists(model_path) and os.path.exists(labels_path) and os.path.exists(features_path)):
18
+ raise FileNotFoundError(f"Missing artifacts. Expected: {model_path}, {labels_path}, {features_path}")
19
+
20
+ model = xgb.XGBClassifier()
21
+ model.load_model(model_path)
22
+
23
+ label_encoder = LabelEncoder()
24
+ classes = np.load(labels_path, allow_pickle=True)
25
+ label_encoder.classes_ = classes
26
+
27
+ with open(features_path, "r", encoding="utf-8") as f:
28
+ feature_names = [line.strip() for line in f if line.strip()]
29
+
30
+ return model, label_encoder, feature_names
31
+
32
+
33
+ def build_feature_vector(symptom_names: List[str], selected_symptoms: List[str]) -> np.ndarray:
34
+ """Convert symptom list to feature vector."""
35
+ features = np.zeros(len(symptom_names), dtype=float)
36
+ name_to_index = {name.lower().strip(): idx for idx, name in enumerate(symptom_names)}
37
+
38
+ for symptom in selected_symptoms:
39
+ key = symptom.lower().strip()
40
+ if key in name_to_index:
41
+ features[name_to_index[key]] = 1.0
42
+
43
+ return features.reshape(1, -1)
44
+
45
+
46
+ def predict_symptoms_json(symptoms: List[str], model, label_encoder, feature_names: List[str]) -> Dict[str, Any]:
47
+ """Return predictions in JSON format for API integration."""
48
+ if not symptoms:
49
+ return {"error": "No symptoms provided"}
50
+
51
+ # Build feature vector
52
+ x = build_feature_vector(feature_names, symptoms)
53
+
54
+ # Get predictions
55
+ proba = model.predict_proba(x)[0]
56
+ top3_idx = np.argsort(proba)[-3:][::-1]
57
+
58
+ # Format results
59
+ predictions = []
60
+ for rank, idx in enumerate(top3_idx, 1):
61
+ disease_name = label_encoder.inverse_transform([idx])[0]
62
+ confidence = float(proba[idx])
63
+ predictions.append({
64
+ "rank": rank,
65
+ "disease": disease_name,
66
+ "confidence": confidence,
67
+ "confidence_percent": round(confidence * 100, 2)
68
+ })
69
+
70
+ return {
71
+ "input_symptoms": symptoms,
72
+ "primary_diagnosis": predictions[0],
73
+ "top_predictions": predictions,
74
+ "model_confidence": "high" if predictions[0]["confidence"] > 0.7 else "medium" if predictions[0]["confidence"] > 0.4 else "low"
75
+ }
76
+
77
+
78
+ def predict_symptoms_csv(symptoms: List[str], model, label_encoder, feature_names: List[str]) -> str:
79
+ """Return predictions in CSV format."""
80
+ if not symptoms:
81
+ return "error,No symptoms provided"
82
+
83
+ x = build_feature_vector(feature_names, symptoms)
84
+ proba = model.predict_proba(x)[0]
85
+ top3_idx = np.argsort(proba)[-3:][::-1]
86
+
87
+ csv_lines = ["rank,disease,confidence,confidence_percent"]
88
+ for rank, idx in enumerate(top3_idx, 1):
89
+ disease_name = label_encoder.inverse_transform([idx])[0]
90
+ confidence = proba[idx]
91
+ csv_lines.append(f"{rank},{disease_name},{confidence:.4f},{confidence*100:.2f}")
92
+
93
+ return "\n".join(csv_lines)
94
+
95
+
96
+ def predict_symptoms_simple(symptoms: List[str], model, label_encoder, feature_names: List[str]) -> str:
97
+ """Return simple text format."""
98
+ if not symptoms:
99
+ return "Error: No symptoms provided"
100
+
101
+ x = build_feature_vector(feature_names, symptoms)
102
+ proba = model.predict_proba(x)[0]
103
+ top1_idx = np.argmax(proba)
104
+
105
+ disease_name = label_encoder.inverse_transform([top1_idx])[0]
106
+ confidence = proba[top1_idx]
107
+
108
+ return f"Diagnosis: {disease_name} (Confidence: {confidence*100:.1f}%)"
109
+
110
+
111
+ def main():
112
+ parser = argparse.ArgumentParser(description="API-style symptom checker using saved model")
113
+ parser.add_argument("--symptoms", nargs="+", required=True, help="List of symptoms")
114
+ parser.add_argument("--format", choices=["json", "csv", "simple"], default="json", help="Output format")
115
+ parser.add_argument("--artifacts-prefix", default="symptom_checker/symptom_model", help="Path to model artifacts")
116
+ args = parser.parse_args()
117
+
118
+ try:
119
+ # Load the trained model
120
+ model, label_encoder, feature_names = load_artifacts(args.artifacts_prefix)
121
+
122
+ # Get predictions in requested format
123
+ if args.format == "json":
124
+ result = predict_symptoms_json(args.symptoms, model, label_encoder, feature_names)
125
+ print(json.dumps(result, indent=2))
126
+ elif args.format == "csv":
127
+ result = predict_symptoms_csv(args.symptoms, model, label_encoder, feature_names)
128
+ print(result)
129
+ elif args.format == "simple":
130
+ result = predict_symptoms_simple(args.symptoms, model, label_encoder, feature_names)
131
+ print(result)
132
+
133
+ except Exception as e:
134
+ error_result = {"error": str(e), "input_symptoms": args.symptoms}
135
+ if args.format == "json":
136
+ print(json.dumps(error_result, indent=2))
137
+ else:
138
+ print(f"Error: {e}")
139
+
140
+
141
+ if __name__ == "__main__":
142
+ main()
evaluate_symptom_checker.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import xgboost as xgb
8
+ from sklearn.model_selection import train_test_split
9
+ from sklearn.preprocessing import LabelEncoder
10
+ from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
11
+
12
+
13
+ def load_data(csv_path: str) -> pd.DataFrame:
14
+ if not os.path.exists(csv_path):
15
+ raise FileNotFoundError(f"CSV not found: {csv_path}")
16
+ df = pd.read_csv(csv_path)
17
+ if df.shape[1] < 2:
18
+ raise ValueError("CSV must have at least 2 columns (target + features)")
19
+ return df
20
+
21
+
22
+ def split_encode(df: pd.DataFrame, test_size: float, seed: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, LabelEncoder, list]:
23
+ target = df.columns[0]
24
+ X = df.iloc[:, 1:]
25
+ y = df[target]
26
+
27
+ X_train, X_test, y_train, y_test = train_test_split(
28
+ X, y, test_size=test_size, random_state=seed, stratify=y
29
+ )
30
+
31
+ label_encoder = LabelEncoder()
32
+ y_train_enc = label_encoder.fit_transform(y_train)
33
+ y_test_enc = label_encoder.transform(y_test)
34
+
35
+ return X_train.values, X_test.values, y_train_enc, y_test_enc, label_encoder, X.columns.tolist()
36
+
37
+
38
+ def build_model(num_classes: int):
39
+ common_kwargs = dict(
40
+ objective="multi:softprob",
41
+ num_class=num_classes,
42
+ eval_metric="mlogloss",
43
+ tree_method="hist",
44
+ n_estimators=300,
45
+ max_depth=6,
46
+ learning_rate=0.05,
47
+ subsample=0.8,
48
+ colsample_bytree=0.8,
49
+ random_state=42,
50
+ )
51
+ try:
52
+ model = xgb.XGBClassifier(device="cuda", **common_kwargs)
53
+ except TypeError:
54
+ try:
55
+ model = xgb.XGBClassifier(tree_method="gpu_hist", **{k: v for k, v in common_kwargs.items() if k != "tree_method"})
56
+ except Exception:
57
+ model = xgb.XGBClassifier(**common_kwargs)
58
+ return model
59
+
60
+
61
+ def main():
62
+ parser = argparse.ArgumentParser(description="Evaluate XGBoost Symptom Checker accuracy")
63
+ parser.add_argument("--csv", required=True, help="Path to cleaned CSV (target + binary features)")
64
+ parser.add_argument("--test-size", type=float, default=0.2, help="Test set fraction (default 0.2)")
65
+ parser.add_argument("--seed", type=int, default=42, help="Random seed (default 42)")
66
+ args = parser.parse_args()
67
+
68
+ print("Loading data...")
69
+ df = load_data(args.csv)
70
+ print(f"Shape: {df.shape}")
71
+
72
+ print("Splitting and encoding labels...")
73
+ X_train, X_test, y_train, y_test, label_enc, feature_names = split_encode(df, args.test_size, args.seed)
74
+ num_classes = len(np.unique(y_train))
75
+ print(f"Classes: {num_classes}; Features: {len(feature_names)}")
76
+
77
+ print("Training model...")
78
+ model = build_model(num_classes)
79
+ try:
80
+ model.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=50, early_stopping_rounds=30)
81
+ except TypeError:
82
+ model.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=50)
83
+
84
+ print("Evaluating...")
85
+ y_proba = model.predict_proba(X_test)
86
+ y_pred = np.argmax(y_proba, axis=1)
87
+
88
+ acc = accuracy_score(y_test, y_pred)
89
+ print(f"\nAccuracy: {acc:.4f} ({acc*100:.2f}%)")
90
+
91
+ print("\nClassification report:")
92
+ target_names = label_enc.inverse_transform(np.arange(num_classes))
93
+ print(classification_report(y_test, y_pred, target_names=target_names, zero_division=0))
94
+
95
+ print("Confusion matrix (rows=true, cols=pred):")
96
+ cm = confusion_matrix(y_test, y_pred)
97
+ print(cm)
98
+
99
+
100
+ if __name__ == "__main__":
101
+ main()
102
+
103
+
main.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI Backend for Symptom Checker
3
+ Provides REST API endpoints for the Flutter mobile application.
4
+ """
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
+ from typing import List, Optional
10
+ import numpy as np
11
+ import xgboost as xgb
12
+ from sklearn.preprocessing import LabelEncoder
13
+ import os
14
+
15
+ # ============================================================================
16
+ # Pydantic Models (matching Flutter frontend expectations)
17
+ # ============================================================================
18
+
19
+ class SymptomsRequest(BaseModel):
20
+ symptoms: List[str]
21
+
22
+
23
+ class SymptomPrediction(BaseModel):
24
+ rank: int
25
+ disease: str
26
+ confidence: float
27
+ confidence_percent: str
28
+
29
+
30
+ class SymptomCheckResponse(BaseModel):
31
+ success: bool
32
+ predictions: List[SymptomPrediction]
33
+ input_symptoms: List[str]
34
+ error: Optional[str]
35
+
36
+
37
+ class AvailableSymptomsResponse(BaseModel):
38
+ success: bool
39
+ symptoms: List[str]
40
+ total_symptoms: int
41
+ error: Optional[str]
42
+
43
+
44
+ # ============================================================================
45
+ # Model Loading (same as symptom_checker.py)
46
+ # ============================================================================
47
+
48
+ class LoadedModel:
49
+ """Wrapper for loaded XGBoost model that provides predict_proba functionality."""
50
+ def __init__(self, booster: xgb.Booster, n_classes: int, feature_names: List[str] = None):
51
+ self.booster = booster
52
+ self.n_classes = n_classes
53
+ self.feature_names = feature_names
54
+
55
+ def predict_proba(self, X: np.ndarray) -> np.ndarray:
56
+ """Return probability predictions using the booster."""
57
+ dmatrix = xgb.DMatrix(X, feature_names=self.feature_names)
58
+ preds = self.booster.predict(dmatrix)
59
+ if len(preds.shape) == 1:
60
+ return np.column_stack([1 - preds, preds])
61
+ return preds
62
+
63
+
64
+ def load_artifacts(prefix: str):
65
+ """Load model artifacts from files."""
66
+ model_path = f"{prefix}.json"
67
+ labels_path = f"{prefix}.labels.npy"
68
+ features_path = f"{prefix}.features.txt"
69
+
70
+ if not (os.path.exists(model_path) and os.path.exists(labels_path) and os.path.exists(features_path)):
71
+ raise FileNotFoundError(
72
+ f"Missing artifacts. Expected: '{model_path}', '{labels_path}', '{features_path}'."
73
+ )
74
+
75
+ # Load label encoder classes
76
+ label_encoder = LabelEncoder()
77
+ classes = np.load(labels_path, allow_pickle=True)
78
+ label_encoder.classes_ = classes
79
+ n_classes = len(classes)
80
+
81
+ # Load feature names
82
+ with open(features_path, "r", encoding="utf-8") as f:
83
+ feature_names = [line.strip() for line in f if line.strip()]
84
+
85
+ # Load model using Booster
86
+ booster = xgb.Booster()
87
+ booster.load_model(model_path)
88
+
89
+ model = LoadedModel(booster, n_classes, feature_names)
90
+
91
+ return model, label_encoder, feature_names
92
+
93
+
94
+ def build_feature_vector(symptom_names: List[str], selected: List[str]) -> np.ndarray:
95
+ """Build a binary feature vector from selected symptoms."""
96
+ features = np.zeros(len(symptom_names), dtype=float)
97
+ name_to_index = {name.lower().strip(): idx for idx, name in enumerate(symptom_names)}
98
+ for s in selected:
99
+ key = s.lower().strip()
100
+ if key in name_to_index:
101
+ features[name_to_index[key]] = 1.0
102
+ return features.reshape(1, -1)
103
+
104
+
105
+ # ============================================================================
106
+ # FastAPI App Setup
107
+ # ============================================================================
108
+
109
+ app = FastAPI(
110
+ title="Symptom Checker API",
111
+ description="AI-powered symptom checker using XGBoost",
112
+ version="1.0.0"
113
+ )
114
+
115
+ # Enable CORS for Flutter app
116
+ app.add_middleware(
117
+ CORSMiddleware,
118
+ allow_origins=["*"], # In production, specify your app's domain
119
+ allow_credentials=True,
120
+ allow_methods=["*"],
121
+ allow_headers=["*"],
122
+ )
123
+
124
+ # Global variables for model (loaded on startup)
125
+ model = None
126
+ label_encoder = None
127
+ feature_names = None
128
+
129
+
130
+ @app.on_event("startup")
131
+ async def startup_event():
132
+ """Load model artifacts on startup."""
133
+ global model, label_encoder, feature_names
134
+
135
+ # Get the directory where this script is located
136
+ script_dir = os.path.dirname(os.path.abspath(__file__))
137
+ artifacts_prefix = os.path.join(script_dir, "symptom_model")
138
+
139
+ try:
140
+ model, label_encoder, feature_names = load_artifacts(artifacts_prefix)
141
+ print(f"✅ Model loaded successfully!")
142
+ print(f" - Features: {len(feature_names)}")
143
+ print(f" - Classes: {len(label_encoder.classes_)}")
144
+ except Exception as e:
145
+ print(f"❌ Failed to load model: {e}")
146
+ raise
147
+
148
+
149
+ # ============================================================================
150
+ # API Endpoints
151
+ # ============================================================================
152
+
153
+ @app.get("/")
154
+ async def root():
155
+ """Health check endpoint."""
156
+ return {"status": "healthy", "message": "Symptom Checker API is running"}
157
+
158
+
159
+ @app.get("/api/symptoms", response_model=AvailableSymptomsResponse)
160
+ async def get_available_symptoms():
161
+ """
162
+ Get list of all available symptoms the model recognizes.
163
+ """
164
+ try:
165
+ if feature_names is None:
166
+ return AvailableSymptomsResponse(
167
+ success=False,
168
+ symptoms=[],
169
+ total_symptoms=0,
170
+ error="Model not loaded"
171
+ )
172
+
173
+ # Return symptoms with proper capitalization
174
+ formatted_symptoms = [s.replace("_", " ").title() for s in feature_names]
175
+
176
+ return AvailableSymptomsResponse(
177
+ success=True,
178
+ symptoms=formatted_symptoms,
179
+ total_symptoms=len(formatted_symptoms),
180
+ error=None
181
+ )
182
+ except Exception as e:
183
+ return AvailableSymptomsResponse(
184
+ success=False,
185
+ symptoms=[],
186
+ total_symptoms=0,
187
+ error=str(e)
188
+ )
189
+
190
+
191
+ @app.post("/api/check-symptoms", response_model=SymptomCheckResponse)
192
+ async def check_symptoms(request: SymptomsRequest):
193
+ """
194
+ Check symptoms and return disease predictions.
195
+ """
196
+ try:
197
+ if model is None or label_encoder is None or feature_names is None:
198
+ return SymptomCheckResponse(
199
+ success=False,
200
+ predictions=[],
201
+ input_symptoms=request.symptoms,
202
+ error="Model not loaded"
203
+ )
204
+
205
+ if not request.symptoms:
206
+ return SymptomCheckResponse(
207
+ success=False,
208
+ predictions=[],
209
+ input_symptoms=[],
210
+ error="No symptoms provided"
211
+ )
212
+
213
+ # Build feature vector
214
+ x = build_feature_vector(feature_names, request.symptoms)
215
+
216
+ # Get predictions
217
+ proba = model.predict_proba(x)[0]
218
+
219
+ # Get top predictions (all classes sorted by probability)
220
+ top_indices = np.argsort(proba)[::-1]
221
+
222
+ # Build predictions list (top 5 by default)
223
+ predictions = []
224
+ for rank, idx in enumerate(top_indices[:5], start=1):
225
+ disease_name = label_encoder.inverse_transform([idx])[0]
226
+ confidence = float(proba[idx])
227
+
228
+ predictions.append(SymptomPrediction(
229
+ rank=rank,
230
+ disease=str(disease_name),
231
+ confidence=round(confidence, 4),
232
+ confidence_percent=f"{confidence * 100:.2f}%"
233
+ ))
234
+
235
+ return SymptomCheckResponse(
236
+ success=True,
237
+ predictions=predictions,
238
+ input_symptoms=request.symptoms,
239
+ error=None
240
+ )
241
+
242
+ except Exception as e:
243
+ return SymptomCheckResponse(
244
+ success=False,
245
+ predictions=[],
246
+ input_symptoms=request.symptoms,
247
+ error=str(e)
248
+ )
249
+
250
+
251
+ # ============================================================================
252
+ # Run with: uvicorn main:app --reload --host 0.0.0.0 --port 8000
253
+ # ============================================================================
254
+
255
+ if __name__ == "__main__":
256
+ import uvicorn
257
+ uvicorn.run(app, host="0.0.0.0", port=8000)
preprocess_data.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import pandas as pd
5
+ import numpy as np
6
+
7
+
8
+ def standardize_columns(df: pd.DataFrame) -> pd.DataFrame:
9
+ """Lowercase, strip, and replace spaces with underscores in column names."""
10
+ df = df.copy()
11
+ df.columns = [c.strip().lower().replace(" ", "_") for c in df.columns]
12
+ return df
13
+
14
+
15
+ def drop_invalid_rows(df: pd.DataFrame) -> pd.DataFrame:
16
+ """Drop rows with missing target (first column) and fully empty feature rows."""
17
+ df = df.copy()
18
+ target_col = df.columns[0]
19
+ df = df[~df[target_col].isna()]
20
+ feature_df = df.iloc[:, 1:]
21
+ non_empty_mask = ~(feature_df.isna().all(axis=1) | (feature_df.sum(axis=1) == 0))
22
+ df = df.loc[non_empty_mask]
23
+ return df
24
+
25
+
26
+ def remove_constant_and_sparse_features(df: pd.DataFrame, min_positive_frac: float = 0.0005):
27
+ """Remove columns that are constant or extremely sparse (near-zero variance)."""
28
+ target = df.columns[0]
29
+ X = df.iloc[:, 1:]
30
+ keep_cols = []
31
+ for col in X.columns:
32
+ series = X[col]
33
+ if series.nunique(dropna=True) <= 1:
34
+ continue
35
+ # If binary-like, compute positive ratio
36
+ try:
37
+ pos_frac = (series.fillna(0) > 0).mean()
38
+ except Exception:
39
+ pos_frac = 1.0
40
+ if pos_frac < min_positive_frac:
41
+ continue
42
+ keep_cols.append(col)
43
+ cleaned = pd.concat([df[[target]], X[keep_cols]], axis=1)
44
+ return cleaned
45
+
46
+
47
+ def impute_missing(df: pd.DataFrame) -> pd.DataFrame:
48
+ """Impute missing values in features with 0, keep target as is."""
49
+ target = df.columns[0]
50
+ X = df.iloc[:, 1:].fillna(0)
51
+ return pd.concat([df[[target]], X], axis=1)
52
+
53
+
54
+ def limit_classes(df: pd.DataFrame, min_samples: int = 5) -> pd.DataFrame:
55
+ """Keep only classes with at least min_samples samples."""
56
+ target = df.columns[0]
57
+ counts = df[target].value_counts()
58
+ keep = counts[counts >= min_samples].index
59
+ return df[df[target].isin(keep)]
60
+
61
+
62
+ def main():
63
+ parser = argparse.ArgumentParser(description="Preprocess disease-symptom CSV for training.")
64
+ parser.add_argument("--input", required=True, help="Path to raw CSV")
65
+ parser.add_argument("--output", default="cleaned_dataset.csv", help="Path to save cleaned CSV")
66
+ args = parser.parse_args()
67
+
68
+ if not os.path.exists(args.input):
69
+ print(f"❌ Input CSV not found: {args.input}")
70
+ sys.exit(1)
71
+
72
+ print("Loading CSV...")
73
+ df = pd.read_csv(args.input)
74
+ print(f"Raw shape: {df.shape}")
75
+
76
+ print("Standardizing column names...")
77
+ df = standardize_columns(df)
78
+
79
+ print("Dropping invalid/empty rows...")
80
+ df = drop_invalid_rows(df)
81
+ print(f"After row cleanup: {df.shape}")
82
+
83
+ print("Removing constant and sparse features...")
84
+ df = remove_constant_and_sparse_features(df)
85
+ print(f"After feature cleanup: {df.shape}")
86
+
87
+ print("Imputing missing values (0 for symptoms)...")
88
+ df = impute_missing(df)
89
+
90
+ print("Limiting classes with very few samples...")
91
+ df = limit_classes(df, min_samples=5)
92
+ print(f"After class filtering: {df.shape}")
93
+
94
+ print(f"Saving cleaned CSV to: {args.output}")
95
+ df.to_csv(args.output, index=False)
96
+ print("Done.")
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
101
+
102
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ xgboost
2
+ pandas
3
+ scikit-learn
4
+ numpy
5
+ fastapi
6
+ uvicorn[standard]
symptom_checker.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import List, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import xgboost as xgb
8
+ from sklearn.model_selection import train_test_split
9
+ from sklearn.preprocessing import LabelEncoder
10
+
11
+
12
+ def load_dataset(csv_path: str) -> pd.DataFrame:
13
+ if not os.path.exists(csv_path):
14
+ raise FileNotFoundError(
15
+ f"CSV not found at '{csv_path}'. Provide a valid path with --csv <path>."
16
+ )
17
+ data = pd.read_csv(csv_path)
18
+ if data.shape[1] < 2:
19
+ raise ValueError("Dataset must have at least 2 columns: target then feature columns.")
20
+ return data
21
+
22
+
23
+ def train_model(data: pd.DataFrame):
24
+ y = data.iloc[:, 0]
25
+
26
+ # Remove diseases with only 1 record
27
+ value_counts = y.value_counts()
28
+ rare_diseases = value_counts[value_counts < 2].index
29
+ data_filtered = data[~data.iloc[:, 0].isin(rare_diseases)]
30
+
31
+ X = data_filtered.iloc[:, 1:]
32
+ y = data_filtered.iloc[:, 0]
33
+
34
+ X_train, X_test, y_train, y_test = train_test_split(
35
+ X, y, test_size=0.2, random_state=42, stratify=y
36
+ )
37
+
38
+ label_encoder = LabelEncoder()
39
+ y_train_encoded = label_encoder.fit_transform(y_train)
40
+ y_test_encoded = label_encoder.transform(y_test)
41
+
42
+ # Prefer GPU if available, but fall back to CPU if not supported
43
+ common_kwargs = dict(
44
+ objective="multi:softprob",
45
+ num_class=len(np.unique(y_train_encoded)),
46
+ eval_metric="mlogloss",
47
+ tree_method="hist",
48
+ n_estimators=400,
49
+ max_depth=6,
50
+ learning_rate=0.05,
51
+ subsample=0.8,
52
+ colsample_bytree=0.8,
53
+ random_state=42,
54
+ )
55
+
56
+ try:
57
+ model = xgb.XGBClassifier(device="cuda", **common_kwargs)
58
+ except TypeError:
59
+ # Older xgboost: no 'device' param. Try GPU via tree_method if supported, else CPU.
60
+ try:
61
+ model = xgb.XGBClassifier(tree_method="gpu_hist", **{k: v for k, v in common_kwargs.items() if k != "tree_method"})
62
+ except Exception:
63
+ model = xgb.XGBClassifier(**common_kwargs)
64
+
65
+ try:
66
+ model.fit(
67
+ X_train,
68
+ y_train_encoded,
69
+ eval_set=[(X_test, y_test_encoded)],
70
+ verbose=50,
71
+ early_stopping_rounds=50,
72
+ )
73
+ except TypeError:
74
+ # Older xgboost versions do not support early_stopping_rounds in sklearn API
75
+ model.fit(
76
+ X_train,
77
+ y_train_encoded,
78
+ eval_set=[(X_test, y_test_encoded)],
79
+ verbose=50,
80
+ )
81
+
82
+ return model, label_encoder, X.columns.tolist()
83
+
84
+
85
+ def save_artifacts(model: xgb.XGBClassifier, label_encoder: LabelEncoder, feature_names: List[str], prefix: str) -> Tuple[str, str, str]:
86
+ os.makedirs(os.path.dirname(prefix) or ".", exist_ok=True)
87
+ model_path = f"{prefix}.json"
88
+ labels_path = f"{prefix}.labels.npy"
89
+ features_path = f"{prefix}.features.txt"
90
+
91
+ try:
92
+ model.save_model(model_path)
93
+ except Exception:
94
+ model.get_booster().save_model(model_path)
95
+
96
+ # Save label encoder classes with allow_pickle=True since they contain strings
97
+ np.save(labels_path, label_encoder.classes_, allow_pickle=True)
98
+
99
+ with open(features_path, "w", encoding="utf-8") as f:
100
+ for name in feature_names:
101
+ f.write(f"{name}\n")
102
+
103
+ return model_path, labels_path, features_path
104
+
105
+
106
+ class LoadedModel:
107
+ """Wrapper for loaded XGBoost model that provides predict_proba functionality."""
108
+ def __init__(self, booster: xgb.Booster, n_classes: int, feature_names: List[str] = None):
109
+ self.booster = booster
110
+ self.n_classes = n_classes
111
+ self.feature_names = feature_names
112
+
113
+ def predict_proba(self, X: np.ndarray) -> np.ndarray:
114
+ """Return probability predictions using the booster."""
115
+ dmatrix = xgb.DMatrix(X, feature_names=self.feature_names)
116
+ preds = self.booster.predict(dmatrix)
117
+ # For multi-class, preds is already (n_samples, n_classes)
118
+ if len(preds.shape) == 1:
119
+ # Binary classification case
120
+ return np.column_stack([1 - preds, preds])
121
+ return preds
122
+
123
+ def predict(self, X: np.ndarray) -> np.ndarray:
124
+ """Return class predictions."""
125
+ proba = self.predict_proba(X)
126
+ return np.argmax(proba, axis=1)
127
+
128
+
129
+ def load_artifacts(prefix: str) -> Tuple[LoadedModel, LabelEncoder, List[str]]:
130
+ model_path = f"{prefix}.json"
131
+ labels_path = f"{prefix}.labels.npy"
132
+ features_path = f"{prefix}.features.txt"
133
+
134
+ if not (os.path.exists(model_path) and os.path.exists(labels_path) and os.path.exists(features_path)):
135
+ raise FileNotFoundError(
136
+ f"Missing artifacts. Expected: '{model_path}', '{labels_path}', '{features_path}'."
137
+ )
138
+
139
+ # Load label encoder classes
140
+ label_encoder = LabelEncoder()
141
+ classes = np.load(labels_path, allow_pickle=True)
142
+ label_encoder.classes_ = classes
143
+ n_classes = len(classes)
144
+
145
+ # Load feature names first (needed for model)
146
+ with open(features_path, "r", encoding="utf-8") as f:
147
+ feature_names = [line.strip() for line in f if line.strip()]
148
+
149
+ # Load model using Booster directly
150
+ booster = xgb.Booster()
151
+ booster.load_model(model_path)
152
+
153
+ # Wrap in our custom class that provides predict_proba (with feature names)
154
+ model = LoadedModel(booster, n_classes, feature_names)
155
+
156
+ return model, label_encoder, feature_names
157
+
158
+
159
+ def build_feature_vector(symptom_names: List[str], selected: List[str]) -> np.ndarray:
160
+ features = np.zeros(len(symptom_names), dtype=float)
161
+ name_to_index = {name.lower().strip(): idx for idx, name in enumerate(symptom_names)}
162
+ for s in selected:
163
+ key = s.lower().strip()
164
+ if key in name_to_index:
165
+ features[name_to_index[key]] = 1.0
166
+ return features.reshape(1, -1)
167
+
168
+
169
+ def interactive_loop(model, label_encoder, symptom_names: List[str]):
170
+ print("\n" + "=" * 60)
171
+ print("🩺 Symptom Checker (XGBoost)")
172
+ print("=" * 60)
173
+ print("Enter symptoms separated by commas. Example: fever, cough, headache")
174
+ print("Type 'list' to see all available symptoms, or 'quit' to exit.")
175
+ print("=" * 60)
176
+
177
+ while True:
178
+ try:
179
+ user = input("\n💬 Symptoms: ").strip()
180
+ if user.lower() in {"quit", "exit", "q"}:
181
+ print("👋 Goodbye!")
182
+ break
183
+ if user == "":
184
+ continue
185
+ if user.lower() == "list":
186
+ print("\nAvailable symptoms (features):")
187
+ print(", ".join(symptom_names))
188
+ continue
189
+
190
+ selected = [s for s in user.split(",") if s.strip()]
191
+ if not selected:
192
+ print("⚠️ Please enter at least one symptom.")
193
+ continue
194
+
195
+ x = build_feature_vector(symptom_names, selected)
196
+ proba = model.predict_proba(x)[0]
197
+ top3_idx = np.argsort(proba)[-3:][::-1]
198
+ top1 = top3_idx[0]
199
+
200
+ top1_label = label_encoder.inverse_transform([top1])[0]
201
+ top1_conf = proba[top1]
202
+
203
+ print("\n📊 Prediction Results")
204
+ print("-" * 60)
205
+ print(f"🏥 Primary Diagnosis: {top1_label}")
206
+ print(f"📈 Confidence: {top1_conf:.4f} ({top1_conf*100:.2f}%)")
207
+ print("\n🏆 Top 3 Possible Conditions:")
208
+ for rank, idx in enumerate(top3_idx, start=1):
209
+ label = label_encoder.inverse_transform([idx])[0]
210
+ print(f" {rank}. {label}: {proba[idx]:.4f} ({proba[idx]*100:.2f}%)")
211
+
212
+ except KeyboardInterrupt:
213
+ print("\n👋 Interrupted. Goodbye!")
214
+ break
215
+ except Exception as e:
216
+ print(f"❌ Error: {e}")
217
+
218
+
219
+ def main():
220
+ parser = argparse.ArgumentParser(description="Symptom checker using an XGBoost classifier.")
221
+ parser.add_argument(
222
+ "--csv",
223
+ type=str,
224
+ required=False,
225
+ help="Path to CSV dataset. First column must be target (disease), remaining columns symptoms.",
226
+ )
227
+ parser.add_argument(
228
+ "--save-prefix",
229
+ type=str,
230
+ default=None,
231
+ help="Prefix to save artifacts (creates .json/.labels.npy/.features.txt)",
232
+ )
233
+ parser.add_argument(
234
+ "--eval-only",
235
+ action="store_true",
236
+ help="Evaluate previously saved artifacts on --csv and exit (no training).",
237
+ )
238
+ parser.add_argument(
239
+ "--artifacts-prefix",
240
+ type=str,
241
+ default="symptom_checker/symptom_model",
242
+ help="Prefix path to load artifacts (default: symptom_checker/symptom_model)",
243
+ )
244
+ parser.add_argument(
245
+ "--interactive-only",
246
+ action="store_true",
247
+ help="Start interactive mode using saved artifacts only (no training).",
248
+ )
249
+ args = parser.parse_args()
250
+
251
+ if args.interactive_only:
252
+ try:
253
+ model, label_encoder, feature_names = load_artifacts(args.artifacts_prefix)
254
+ except FileNotFoundError as e:
255
+ print(str(e))
256
+ print("Train and save first, e.g.:\n python symptom_checker/symtom_checker.py --csv cleaned_dataset.csv --save-prefix symptom_checker/symptom_model")
257
+ return
258
+ interactive_loop(model, label_encoder, feature_names)
259
+ return
260
+
261
+ if args.eval_only:
262
+ if not args.csv:
263
+ print("Provide CSV for evaluation. Example:\n python symptom_checker/symtom_checker.py --eval-only --csv cleaned_dataset.csv --artifacts-prefix symptom_checker/symptom_model")
264
+ return
265
+ data = load_dataset(args.csv)
266
+ try:
267
+ model, label_encoder, feature_names = load_artifacts(args.artifacts_prefix)
268
+ except FileNotFoundError as e:
269
+ print(str(e))
270
+ return
271
+ target_col = data.columns[0]
272
+ missing = [c for c in feature_names if c not in data.columns]
273
+ if missing:
274
+ print(f"CSV missing {len(missing)} feature columns from training. Example missing: {missing[:10]}")
275
+ return
276
+ X = data.loc[:, feature_names].fillna(0).values
277
+ y = data[target_col].values
278
+ y_enc = label_encoder.transform(y)
279
+ proba = model.predict_proba(X)
280
+ y_pred = np.argmax(proba, axis=1)
281
+ acc = (y_pred == y_enc).mean()
282
+ print(f"Accuracy on provided CSV: {acc:.4f} ({acc*100:.2f}%)")
283
+ return
284
+
285
+ if not args.csv:
286
+ print("❗ No CSV provided. Run: python symptom_checker/symtom_checker.py --csv path/to/dataset.csv")
287
+ return
288
+
289
+ data = load_dataset(args.csv)
290
+ print("Shape of dataset:", data.shape)
291
+ model, label_encoder, symptom_names = train_model(data)
292
+
293
+ if args.save_prefix:
294
+ print("Saving artifacts...")
295
+ paths = save_artifacts(model, label_encoder, symptom_names, args.save_prefix)
296
+ for p in paths:
297
+ print(f" - {p}")
298
+
299
+ interactive_loop(model, label_encoder, symptom_names)
300
+
301
+
302
+ if __name__ == "__main__":
303
+ main()
304
+
symptom_model.features.txt ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anxiety_and_nervousness
2
+ depression
3
+ shortness_of_breath
4
+ depressive_or_psychotic_symptoms
5
+ sharp_chest_pain
6
+ dizziness
7
+ insomnia
8
+ abnormal_involuntary_movements
9
+ chest_tightness
10
+ palpitations
11
+ irregular_heartbeat
12
+ breathing_fast
13
+ hoarse_voice
14
+ sore_throat
15
+ difficulty_speaking
16
+ cough
17
+ nasal_congestion
18
+ throat_swelling
19
+ diminished_hearing
20
+ lump_in_throat
21
+ throat_feels_tight
22
+ difficulty_in_swallowing
23
+ skin_swelling
24
+ retention_of_urine
25
+ groin_mass
26
+ leg_pain
27
+ hip_pain
28
+ suprapubic_pain
29
+ blood_in_stool
30
+ lack_of_growth
31
+ emotional_symptoms
32
+ elbow_weakness
33
+ back_weakness
34
+ symptoms_of_the_scrotum_and_testes
35
+ swelling_of_scrotum
36
+ pain_in_testicles
37
+ flatulence
38
+ pus_draining_from_ear
39
+ jaundice
40
+ mass_in_scrotum
41
+ white_discharge_from_eye
42
+ irritable_infant
43
+ abusing_alcohol
44
+ fainting
45
+ hostile_behavior
46
+ drug_abuse
47
+ sharp_abdominal_pain
48
+ feeling_ill
49
+ vomiting
50
+ headache
51
+ nausea
52
+ diarrhea
53
+ vaginal_itching
54
+ vaginal_dryness
55
+ painful_urination
56
+ involuntary_urination
57
+ pain_during_intercourse
58
+ frequent_urination
59
+ lower_abdominal_pain
60
+ vaginal_discharge
61
+ blood_in_urine
62
+ hot_flashes
63
+ intermenstrual_bleeding
64
+ hand_or_finger_pain
65
+ wrist_pain
66
+ hand_or_finger_swelling
67
+ arm_pain
68
+ wrist_swelling
69
+ arm_stiffness_or_tightness
70
+ arm_swelling
71
+ hand_or_finger_stiffness_or_tightness
72
+ wrist_stiffness_or_tightness
73
+ lip_swelling
74
+ toothache
75
+ abnormal_appearing_skin
76
+ skin_lesion
77
+ acne_or_pimples
78
+ dry_lips
79
+ facial_pain
80
+ mouth_ulcer
81
+ skin_growth
82
+ eye_deviation
83
+ diminished_vision
84
+ double_vision
85
+ cross-eyed
86
+ symptoms_of_eye
87
+ pain_in_eye
88
+ eye_moves_abnormally
89
+ abnormal_movement_of_eyelid
90
+ foreign_body_sensation_in_eye
91
+ irregular_appearing_scalp
92
+ swollen_lymph_nodes
93
+ back_pain
94
+ neck_pain
95
+ low_back_pain
96
+ pain_of_the_anus
97
+ pain_during_pregnancy
98
+ pelvic_pain
99
+ impotence
100
+ vomiting_blood
101
+ regurgitation
102
+ burning_abdominal_pain
103
+ restlessness
104
+ symptoms_of_infants
105
+ wheezing
106
+ peripheral_edema
107
+ neck_mass
108
+ ear_pain
109
+ jaw_swelling
110
+ mouth_dryness
111
+ neck_swelling
112
+ knee_pain
113
+ foot_or_toe_pain
114
+ ankle_pain
115
+ bones_are_painful
116
+ knee_weakness
117
+ elbow_pain
118
+ knee_swelling
119
+ skin_moles
120
+ knee_lump_or_mass
121
+ weight_gain
122
+ problems_with_movement
123
+ knee_stiffness_or_tightness
124
+ leg_swelling
125
+ foot_or_toe_swelling
126
+ heartburn
127
+ smoking_problems
128
+ muscle_pain
129
+ infant_feeding_problem
130
+ recent_weight_loss
131
+ difficulty_eating
132
+ vaginal_pain
133
+ vaginal_redness
134
+ vulvar_irritation
135
+ weakness
136
+ decreased_heart_rate
137
+ increased_heart_rate
138
+ bleeding_or_discharge_from_nipple
139
+ ringing_in_ear
140
+ plugged_feeling_in_ear
141
+ itchy_ear(s)
142
+ frontal_headache
143
+ fluid_in_ear
144
+ neck_stiffness_or_tightness
145
+ spots_or_clouds_in_vision
146
+ eye_redness
147
+ lacrimation
148
+ itchiness_of_eye
149
+ blindness
150
+ eye_burns_or_stings
151
+ itchy_eyelid
152
+ decreased_appetite
153
+ excessive_appetite
154
+ excessive_anger
155
+ loss_of_sensation
156
+ focal_weakness
157
+ slurring_words
158
+ symptoms_of_the_face
159
+ disturbance_of_memory
160
+ paresthesia
161
+ side_pain
162
+ fever
163
+ shoulder_pain
164
+ shoulder_stiffness_or_tightness
165
+ shoulder_weakness
166
+ shoulder_swelling
167
+ tongue_lesions
168
+ leg_cramps_or_spasms
169
+ ache_all_over
170
+ lower_body_pain
171
+ problems_during_pregnancy
172
+ spotting_or_bleeding_during_pregnancy
173
+ cramps_and_spasms
174
+ upper_abdominal_pain
175
+ stomach_bloating
176
+ changes_in_stool_appearance
177
+ unusual_color_or_odor_to_urine
178
+ kidney_mass
179
+ swollen_abdomen
180
+ symptoms_of_prostate
181
+ leg_stiffness_or_tightness
182
+ difficulty_breathing
183
+ rib_pain
184
+ joint_pain
185
+ muscle_stiffness_or_tightness
186
+ hand_or_finger_lump_or_mass
187
+ chills
188
+ groin_pain
189
+ fatigue
190
+ abdominal_distention
191
+ regurgitation.1
192
+ symptoms_of_the_kidneys
193
+ melena
194
+ coughing_up_sputum
195
+ seizures
196
+ delusions_or_hallucinations
197
+ pain_or_soreness_of_breast
198
+ excessive_urination_at_night
199
+ bleeding_from_eye
200
+ rectal_bleeding
201
+ constipation
202
+ temper_problems
203
+ coryza
204
+ hemoptysis
205
+ lymphedema
206
+ skin_on_leg_or_foot_looks_infected
207
+ allergic_reaction
208
+ congestion_in_chest
209
+ muscle_swelling
210
+ sleepiness
211
+ apnea
212
+ abnormal_breathing_sounds
213
+ blood_clots_during_menstrual_periods
214
+ absence_of_menstruation
215
+ pulling_at_ears
216
+ gum_pain
217
+ redness_in_ear
218
+ fluid_retention
219
+ flu-like_syndrome
220
+ sinus_congestion
221
+ painful_sinuses
222
+ fears_and_phobias
223
+ recent_pregnancy
224
+ uterine_contractions
225
+ burning_chest_pain
226
+ back_cramps_or_spasms
227
+ stiffness_all_over
228
+ muscle_cramps,_contractures,_or_spasms
229
+ back_mass_or_lump
230
+ nosebleed
231
+ long_menstrual_periods
232
+ heavy_menstrual_flow
233
+ unpredictable_menstruation
234
+ painful_menstruation
235
+ infertility
236
+ frequent_menstruation
237
+ sweating
238
+ mass_on_eyelid
239
+ swollen_eye
240
+ eyelid_swelling
241
+ eyelid_lesion_or_rash
242
+ symptoms_of_bladder
243
+ irregular_appearing_nails
244
+ itching_of_skin
245
+ hurts_to_breath
246
+ skin_dryness,_peeling,_scaliness,_or_roughness
247
+ skin_on_arm_or_hand_looks_infected
248
+ skin_irritation
249
+ itchy_scalp
250
+ warts
251
+ bumps_on_penis
252
+ too_little_hair
253
+ skin_rash
254
+ mass_or_swelling_around_the_anus
255
+ ankle_swelling
256
+ dry_or_flaky_scalp
257
+ foot_or_toe_stiffness_or_tightness
258
+ elbow_swelling
259
+ early_or_late_onset_of_menopause
260
+ bleeding_from_ear
261
+ hand_or_finger_weakness
262
+ low_self-esteem
263
+ itching_of_the_anus
264
+ swollen_or_red_tonsils
265
+ irregular_belly_button
266
+ hip_stiffness_or_tightness
267
+ mouth_pain
268
+ arm_weakness
269
+ penis_pain
270
+ loss_of_sex_drive
271
+ obsessions_and_compulsions
272
+ antisocial_behavior
273
+ neck_cramps_or_spasms
274
+ sneezing
275
+ leg_weakness
276
+ penis_redness
277
+ penile_discharge
278
+ shoulder_lump_or_mass
279
+ cloudy_eye
280
+ hysterical_behavior
281
+ arm_lump_or_mass
282
+ nightmares
283
+ bleeding_gums
284
+ pain_in_gums
285
+ bedwetting
286
+ diaper_rash
287
+ lump_or_mass_of_breast
288
+ postpartum_problems_of_the_breast
289
+ hesitancy
290
+ throat_redness
291
+ joint_swelling
292
+ redness_in_or_around_nose
293
+ wrinkles_on_skin
294
+ back_stiffness_or_tightness
295
+ wrist_lump_or_mass
296
+ low_urine_output
297
+ sore_in_nose
symptom_model.labels.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a75f3ff2598b2868f33d82be336e32dee0793e863e426887ceb7aa62f36d813
3
+ size 15480