emiraran commited on
Commit
feed4c5
·
verified ·
1 Parent(s): 530614b

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +411 -14
  2. app.py +347 -0
  3. best_model_final.h5 +3 -0
  4. gradcam_utils.py +187 -0
  5. label_encoder.pkl +3 -0
  6. optimal_thresholds.pkl +3 -0
  7. requirements.txt +10 -0
README.md CHANGED
@@ -1,14 +1,411 @@
1
- ---
2
- title: Chest Xray Classification
3
- emoji: 😻
4
- colorFrom: gray
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 6.0.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Multi-label classification of 15 thoracic diseases from ches
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-Label Chest X-Ray Disease Classification
2
+
3
+ **Deep learning system for automated detection of 15 thoracic diseases from chest X-ray images using EfficientNetB0 with advanced training techniques.**
4
+
5
+ [![Python](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/)
6
+ [![TensorFlow](https://img.shields.io/badge/TensorFlow-2.10-orange.svg)](https://www.tensorflow.org/)
7
+ [![License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)
8
+
9
+ ---
10
+
11
+ ## 📊 Performance
12
+
13
+ | Metric | Value | Benchmark (Wang et al. 2017) |
14
+ |--------|-------|------------------------------|
15
+ | **Mean AUC** | **0.784** | 0.740 |
16
+ | **Improvement** | **+5.9%** | Baseline |
17
+ | **Top Disease (Edema)** | **0.884 AUC** | - |
18
+ | **Recall (Medical Priority)** | **80.3%** | - |
19
+
20
+ **Real Talk:** This isn't radiologist-level (CheXNet: 0.841 AUC), but it beats the original ChestX-ray14 paper. For a 3rd-year undergrad project, this is solid work. The dataset has 10-20% label noise (NLP-extracted, not radiologist-verified), which caps performance.
21
+
22
+ ---
23
+
24
+ ## 🎯 Dataset
25
+
26
+ **ChestX-ray14 (NIH Clinical Center)**
27
+ - 112,120 frontal-view chest X-ray images
28
+ - 30,805 unique patients
29
+ - 15 disease classes (multi-label)
30
+ - **Download:** [NIH Box](https://nihcc.app.box.com/v/ChestXray-NIHCC)
31
+
32
+ **Diseases:** Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, Fibrosis, Hernia, Infiltration, Mass, Nodule, Pleural Thickening, Pneumonia, Pneumothorax, No Finding
33
+
34
+ **⚠️ Dataset Issues (Be Aware):**
35
+ - Labels extracted via NLP from radiology reports → 10-20% noise
36
+ - Extreme class imbalance (Hernia: 110 samples vs No Finding: 60K)
37
+ - Multi-label complexity (avg 1.5 diseases per image)
38
+
39
+ ---
40
+
41
+ ## 🏗️ Architecture
42
+
43
+ ```
44
+ Input (224x224x3)
45
+
46
+ EfficientNetB0 (ImageNet pretrained)
47
+ ├── All 237 layers trainable (full fine-tuning)
48
+ └── Mixed Precision (FP16) for speed
49
+
50
+ Global Average Pooling
51
+
52
+ Dense(512, ReLU) → Dropout(0.3)
53
+
54
+ Dense(256, ReLU) → Dropout(0.2)
55
+
56
+ Dense(15, Sigmoid) [Multi-label output]
57
+ ```
58
+
59
+ **Why This Works:**
60
+ - **EfficientNetB0:** SOTA efficiency (5.3M params, 0.39B FLOPs)
61
+ - **Full fine-tuning:** Medical imaging ≠ ImageNet → adapt all layers
62
+ - **Mixed precision:** 30-40% speedup, no accuracy loss
63
+
64
+ 📖 **[See detailed architecture diagrams and training pipeline →](ARCHITECTURE.md)**
65
+
66
+ ---
67
+
68
+ ## 🔧 Training Strategy
69
+
70
+ ### **1. Focal Loss (Lin et al. 2020)**
71
+ ```python
72
+ focal_loss = BinaryFocalCrossentropy(alpha=0.25, gamma=2.0)
73
+ ```
74
+ **Why:** Handles extreme class imbalance better than BCE. Focuses on hard-to-classify samples (rare diseases).
75
+
76
+ ### **2. Balanced Oversampling**
77
+ - Rare diseases (Hernia: 110 → 2000 samples) oversampled
78
+ - Prevents model from ignoring minority classes
79
+ - **Trade-off:** Increased training time (+4%), but +12% AUC on rare diseases
80
+
81
+ ### **3. Class Weights**
82
+ - Soft weighting (50% reduction factor) to avoid overfitting rare classes
83
+ - Complements Focal Loss for balanced learning
84
+
85
+ ### **4. Medical-Appropriate Augmentation**
86
+ ```python
87
+ - Horizontal flip (anatomically valid)
88
+ - Brightness ±10% (X-ray exposure variation)
89
+ - Contrast ±10% (detector sensitivity)
90
+ - Random zoom 0.9-1.0 (positioning variation)
91
+ ```
92
+ **No rotation:** Chest X-rays have fixed orientation (heart on left).
93
+
94
+ ### **5. Test-Time Augmentation (TTA)**
95
+ - 6 predictions per image (1 original + 5 augmented)
96
+ - Average predictions → +0.6% AUC boost
97
+ - **Cost:** 6x inference time (use for critical cases only)
98
+
99
+ ### **6. Threshold Optimization**
100
+ - Default 0.5 → Optimized 0.2-0.45 per disease
101
+ - Target: 80% recall (medical priority)
102
+ - **Result:** False positives increase, but missing diseases is worse
103
+
104
+ ---
105
+
106
+ ## 📈 Results Breakdown
107
+
108
+ ### **Top Performing Diseases:**
109
+ | Disease | AUC | Recall | Precision | Why Good? |
110
+ |---------|-----|--------|-----------|-----------|
111
+ | Edema | 0.884 | 80% | 43% | Clear radiological features |
112
+ | Cardiomegaly | 0.865 | 80% | 39% | Large, distinct heart silhouette |
113
+ | Effusion | 0.852 | 82% | 46% | High prevalence (2.5K samples) |
114
+
115
+ ### **Worst Performing Diseases:**
116
+ | Disease | AUC | Recall | Precision | Why Bad? |
117
+ |---------|-----|--------|-----------|----------|
118
+ | Hernia | 0.612 | 75% | 18% | Only 110 samples (extreme rarity) |
119
+ | Pneumonia | 0.698 | 79% | 22% | Overlaps with Infiltration (label noise) |
120
+ | Nodule | 0.704 | 78% | 28% | Small, subtle features |
121
+
122
+ ### **Honest Assessment:**
123
+ - **AUC 0.78** is good for noisy labels, but not clinic-ready
124
+ - **80% recall** is appropriate for screening (catch diseases early)
125
+ - **40% precision** means high false positives (radiologist review needed)
126
+ - This is a **screening tool**, not a diagnostic system
127
+
128
+ ---
129
+
130
+ ## ⚠️ Limitations (Critical)
131
+
132
+ ### **1. False Positive Rate (The Elephant in the Room)**
133
+ - **Precision: 40-45%** → 55-60% false positives
134
+ - **Why:** Low thresholds (0.2-0.4) to maximize recall
135
+ - **Clinical impact:** Radiologist must review all positives (intended use)
136
+
137
+ ### **2. Dataset Label Noise**
138
+ - ChestX-ray14 uses NLP extraction (not radiologist-verified)
139
+ - Estimated 10-20% mislabeling rate
140
+ - Some "diseases" are actually descriptions (e.g., "No Finding")
141
+
142
+ ### **3. Class Imbalance Persists**
143
+ - Even with oversampling, rare diseases underperform
144
+ - Hernia (110 samples) vs No Finding (60K) → 500x difference
145
+ - Model biased toward common diseases
146
+
147
+ ### **4. No External Validation**
148
+ - Trained and tested on same hospital (NIH Clinical Center)
149
+ - Performance will drop on external datasets (domain shift)
150
+ - Real-world deployment requires multi-site validation
151
+
152
+ ### **5. Not Radiologist-Level**
153
+ - CheXNet (2017): 0.841 AUC with DenseNet-121
154
+ - This model: 0.784 AUC with EfficientNetB0
155
+ - **Gap:** 5.7% AUC → Needs more data, better labels, or ensemble
156
+
157
+ ---
158
+
159
+ ## 🚀 Live Demo
160
+
161
+ **Try it online:** [🤗 Hugging Face Space](https://huggingface.co/spaces/emiraran/chest-xray-classification)
162
+
163
+ Upload a chest X-ray and get instant predictions! No setup required.
164
+
165
+ ---
166
+
167
+ ## 💻 Local Usage
168
+
169
+ ### **Installation**
170
+ ```bash
171
+ pip install -r requirements.txt
172
+ ```
173
+
174
+ ### **Quick Inference (No Grad-CAM)**
175
+ ```bash
176
+ python demo.py images/00000001_000.png
177
+ ```
178
+
179
+ ### **Full Inference (With Grad-CAM)**
180
+ ```bash
181
+ python demo_with_gradcam.py images/00000001_000.png
182
+ # Output: Disease predictions + gradcam_*.png heatmaps
183
+ ```
184
+
185
+ ### **Programmatic Usage**
186
+ ```python
187
+ from demo import ChestXRayPredictor
188
+
189
+ # Initialize predictor
190
+ predictor = ChestXRayPredictor(
191
+ model_path='best_model_final.h5',
192
+ thresholds_path='optimal_thresholds.pkl',
193
+ label_encoder_path='label_encoder.pkl'
194
+ )
195
+
196
+ # Get predictions
197
+ results = predictor.predict('sample_xray.png', use_tta=False)
198
+ for disease, idx in label_encoder.items():
199
+ prob = probs[idx]
200
+ threshold = thresholds[disease]
201
+ if prob >= threshold:
202
+ results.append({
203
+ 'disease': disease,
204
+ 'probability': f"{prob:.1%}",
205
+ 'confidence': 'HIGH' if prob > threshold + 0.1 else 'MEDIUM'
206
+ })
207
+
208
+ return sorted(results, key=lambda x: float(x['probability'].strip('%')), reverse=True)
209
+
210
+ # Example
211
+ predictions = predict_xray('sample_xray.png')
212
+ for p in predictions:
213
+ print(f"{p['disease']:<20} {p['probability']:>6} [{p['confidence']}]")
214
+ ```
215
+
216
+ ---
217
+
218
+ ## 📁 Project Structure
219
+
220
+ ```
221
+ chest-xray-classification/
222
+ ├── chest_xray_analysis.ipynb # Main notebook (training + evaluation)
223
+ ├── README.md # This file
224
+ ├── ARCHITECTURE.md # Detailed architecture diagrams & pipeline
225
+ ├── .gitignore # Ignore large files
226
+ ├── requirements.txt # Python dependencies
227
+ ├── demo.py # Local inference script
228
+ ├── demo_with_gradcam.py # Local demo with Grad-CAM visualization
229
+ ├── gradcam_utils.py # Grad-CAM implementation
230
+ ├── app.py # Gradio web interface for HF Spaces
231
+ ├── best_model_final.h5 # Model weights (NOT in repo - download separately)
232
+ ├── optimal_thresholds.pkl # Disease-specific thresholds (NOT in repo)
233
+ ├── label_encoder.pkl # Disease name mapping (NOT in repo)
234
+ └── images/ # Dataset (NOT in repo - download from NIH)
235
+ ```
236
+
237
+ **Note:** Model files excluded due to size. Train the model using the notebook to generate weights.
238
+
239
+ ---
240
+
241
+ ## 🔬 Technical Details
242
+
243
+ ### **Training Configuration**
244
+ ```yaml
245
+ Epochs: 50 (early stopping at epoch 46)
246
+ Batch Size: 64
247
+ Learning Rate: 1e-5 (reduced to 3.1e-7 via ReduceLROnPlateau)
248
+ Optimizer: Adam
249
+ Loss: Binary Focal Crossentropy (α=0.25, γ=2.0)
250
+ Mixed Precision: FP16
251
+ Training Time: ~3 hours (NVIDIA RTX GPU)
252
+ ```
253
+
254
+ ### **Data Split**
255
+ - **Patient-level split** (not image-level) to prevent data leakage
256
+ - Train: 89,826 images (24,644 patients)
257
+ - Test: 22,294 images (6,161 patients)
258
+ - **Why patient-level?** Same patient may have multiple X-rays → prevent memorization
259
+
260
+ ### **Callbacks**
261
+ - **ModelCheckpoint:** Save best val_auc model
262
+ - **ReduceLROnPlateau:** Halve LR if val_loss plateaus (patience=5)
263
+ - **EarlyStopping:** Stop if val_auc plateaus (patience=10)
264
+
265
+ ---
266
+
267
+ ## 🎨 Grad-CAM Visualization
268
+
269
+ **NEW!** See where the model looks when making predictions:
270
+
271
+ ```bash
272
+ # Generate Grad-CAM heatmaps for top 3 predictions
273
+ python demo_with_gradcam.py images/00000001_000.png
274
+
275
+ # Output: gradcam_edema.png, gradcam_cardiomegaly.png, gradcam_effusion.png
276
+ ```
277
+
278
+ **What is Grad-CAM?**
279
+ - Gradient-weighted Class Activation Mapping
280
+ - Shows important regions for each disease prediction
281
+ - Red = model focuses here, Blue = model ignores
282
+ - **Use case:** Validate model isn't using spurious correlations (e.g., text artifacts)
283
+
284
+ **Reference:** Selvaraju et al. (2017) - [Grad-CAM: Visual Explanations from Deep Networks](https://arxiv.org/abs/1610.02391)
285
+
286
+ ---
287
+
288
+ ## 📚 References
289
+
290
+ 1. **Wang et al. (2017)** - ChestX-ray8: Hospital-scale Chest X-ray Database and Benchmarks
291
+ [Paper](https://arxiv.org/abs/1705.02315) | [Dataset](https://nihcc.app.box.com/v/ChestXray-NIHCC)
292
+
293
+ 2. **Rajpurkar et al. (2017)** - CheXNet: Radiologist-Level Pneumonia Detection on Chest X-Rays
294
+ [Paper](https://arxiv.org/abs/1711.05225)
295
+
296
+ 3. **Tan & Le (2019)** - EfficientNet: Rethinking Model Scaling for CNNs
297
+ [Paper](https://arxiv.org/abs/1905.11946)
298
+
299
+ 4. **Selvaraju et al. (2017)** - Grad-CAM: Visual Explanations from Deep Networks
300
+ [Paper](https://arxiv.org/abs/1610.02391)
301
+
302
+ 4. **Lin et al. (2020)** - Focal Loss for Dense Object Detection
303
+ [Paper](https://arxiv.org/abs/1708.02002)
304
+
305
+ ---
306
+
307
+ ## 🎓 For Recruiters / Academic Review
308
+
309
+ ### **What's Good:**
310
+ ✅ Beats published benchmark (+5.9% AUC)
311
+ ✅ SOTA techniques (Focal Loss, TTA, Mixed Precision, Full Fine-Tuning)
312
+ ✅ Medical-aware design (recall priority, patient-level split)
313
+ ✅ Comprehensive evaluation (ROC, PR curves, confusion matrices)
314
+ ✅ Honest limitation discussion (no BS marketing)
315
+
316
+ ### **What's Missing (Acknowledgment):**
317
+ ❌ External validation (single hospital data)
318
+ ❌ Radiologist comparison (no ground truth verification)
319
+ ❌ Grad-CAM visualization (explainability)
320
+ ❌ Ensemble methods (single model only)
321
+ ❌ Production deployment (no API, no containerization)
322
+
323
+ ### **Suitable For:**
324
+ - 🎓 Undergraduate/Graduate ML coursework
325
+ - 📝 Academic paper (with external validation)
326
+ - 💼 Portfolio project for ML engineer roles
327
+ - 🏥 Research prototype (NOT clinical deployment)
328
+
329
+ ### **NOT Suitable For:**
330
+ - ❌ Clinical decision-making (FDA/CE approval required)
331
+ - ❌ Standalone diagnosis (must be radiologist-assisted)
332
+ - ❌ Real-time emergency screening (inference time ~200ms per image)
333
+
334
+ ---
335
+
336
+ ## 🤝 Contributing
337
+
338
+ This is an academic project. If you find issues or have improvements:
339
+ 1. Fork the repo
340
+ 2. Create feature branch (`git checkout -b feature/improvement`)
341
+ 3. Commit changes (`git commit -m 'Add improvement'`)
342
+ 4. Push to branch (`git push origin feature/improvement`)
343
+ 5. Open Pull Request
344
+
345
+ ---
346
+
347
+ ## 📄 License
348
+
349
+ MIT License - See [LICENSE](LICENSE) file for details.
350
+
351
+ **Dataset License:** NIH ChestX-ray14 dataset is public domain (U.S. Government work). Please cite the original paper if you use this work.
352
+
353
+ ---
354
+
355
+ ## 🙏 Acknowledgments
356
+
357
+ - NIH Clinical Center for ChestX-ray14 dataset
358
+ - Original paper authors (Wang et al., 2017)
359
+ - TensorFlow team for EfficientNet implementation
360
+ - Medical imaging community for open research
361
+
362
+ ---
363
+
364
+ ## 📧 Contact
365
+
366
+ **Author:** Emir Muhammet Aran
367
+ **Institution:** Computer Engineering Student
368
+ **GitHub:** [github.com/emirmuhammmetaran](https://github.com/emirmuhammmetaran)
369
+
370
+ ---
371
+
372
+ ## ⚡ Quick Start
373
+
374
+ ```bash
375
+ # 1. Clone repo
376
+ git clone https://github.com/emirmuhammmetaran/chest-xray-classification.git
377
+ cd chest-xray-classification
378
+
379
+ # 2. Install dependencies
380
+ pip install -r requirements.txt
381
+
382
+ # 3. Download dataset from NIH
383
+ # https://nihcc.app.box.com/v/ChestXray-NIHCC
384
+
385
+ # 4. Run notebook
386
+ jupyter notebook chest_xray_analysis.ipynb
387
+
388
+ # 5. Train model (or use pre-trained weights)
389
+ # Training takes ~3 hours on GPU
390
+ ```
391
+
392
+ ---
393
+
394
+ **Last Updated:** December 2025
395
+ **Status:** ✅ Training complete | 📊 AUC 0.784 | 🎓 Academic project
396
+
397
+ ---
398
+
399
+ ## 🔥 Honest Takeaway
400
+
401
+ **This model works, but it's not magic.**
402
+
403
+ - It beats the 2017 baseline → Good engineering
404
+ - It has 60% false positives → Needs radiologist review
405
+ - It costs $0.50/1000 images (GPU inference) → Economical screening
406
+ - It's NOT FDA-approved → Research only
407
+
408
+ **Use case:** Pre-screen X-rays → flag suspicious cases → radiologist reviews positives.
409
+ **Don't use for:** Standalone diagnosis, emergency triage, legal liability scenarios.
410
+
411
+ **Bottom line:** Solid ML engineering with realistic expectations. That's how you build trust in AI.
app.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chest X-Ray Disease Classification - Hugging Face Demo
3
+ =======================================================
4
+
5
+ Multi-label classification of 15 thoracic diseases from chest X-rays.
6
+
7
+ Author: Emir Muhammet Aran
8
+ Model: EfficientNetB0 (AUC 0.784)
9
+ Dataset: NIH ChestX-ray14
10
+ """
11
+
12
+ import gradio as gr
13
+ import tensorflow as tf
14
+ import numpy as np
15
+ import pickle
16
+ from PIL import Image
17
+ import warnings
18
+ warnings.filterwarnings('ignore')
19
+ from gradcam_utils import generate_gradcam_for_top_predictions, get_last_conv_layer_name
20
+
21
+
22
+ # ============================================================================
23
+ # MODEL LOADING
24
+ # ============================================================================
25
+
26
+ def build_model(num_classes=15):
27
+ """Rebuild EfficientNetB0 architecture"""
28
+ from tensorflow.keras import layers
29
+ from tensorflow.keras.applications import EfficientNetB0
30
+
31
+ IMG_SIZE = 224
32
+
33
+ inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
34
+
35
+ base_model = EfficientNetB0(
36
+ include_top=False,
37
+ weights=None,
38
+ input_tensor=inputs,
39
+ pooling='avg'
40
+ )
41
+
42
+ x = base_model.output
43
+ x = layers.Dense(512, activation='relu')(x)
44
+ x = layers.Dropout(0.3)(x)
45
+ x = layers.Dense(256, activation='relu')(x)
46
+ x = layers.Dropout(0.2)(x)
47
+ outputs = layers.Dense(num_classes, activation='sigmoid', dtype='float32')(x)
48
+
49
+ model = tf.keras.Model(inputs=inputs, outputs=outputs)
50
+ return model
51
+
52
+
53
+ # Load model components
54
+ print("Loading model...")
55
+ model = build_model(num_classes=15)
56
+ model.load_weights('best_model_final.h5')
57
+
58
+ with open('optimal_thresholds.pkl', 'rb') as f:
59
+ optimal_thresholds = pickle.load(f)
60
+
61
+ with open('label_encoder.pkl', 'rb') as f:
62
+ label_encoder = pickle.load(f)
63
+
64
+ print("✅ Model loaded successfully!")
65
+
66
+
67
+ # ============================================================================
68
+ # PREDICTION FUNCTION
69
+ # ============================================================================
70
+
71
+ def predict_xray(image, use_tta=False):
72
+ """
73
+ Predict diseases from chest X-ray image.
74
+
75
+ Args:
76
+ image: PIL Image or numpy array
77
+ use_tta: Use Test-Time Augmentation (slower but more accurate)
78
+
79
+ Returns:
80
+ HTML formatted results
81
+ """
82
+ try:
83
+ # Preprocess image
84
+ if isinstance(image, np.ndarray):
85
+ image = Image.fromarray(image)
86
+
87
+ # Resize and normalize
88
+ image = image.convert('RGB')
89
+ image = image.resize((224, 224))
90
+ img_array = np.array(image) / 255.0
91
+ img_array = np.expand_dims(img_array, axis=0).astype(np.float32)
92
+
93
+ # Predict
94
+ if use_tta:
95
+ # Test-Time Augmentation (5 predictions)
96
+ predictions = []
97
+ predictions.append(model.predict(img_array, verbose=0)[0])
98
+
99
+ for _ in range(4):
100
+ # Random horizontal flip
101
+ aug_img = tf.image.random_flip_left_right(img_array)
102
+ aug_img = tf.image.random_brightness(aug_img, max_delta=0.1)
103
+ aug_img = tf.clip_by_value(aug_img, 0.0, 1.0)
104
+ predictions.append(model.predict(aug_img.numpy(), verbose=0)[0])
105
+
106
+ probs = np.mean(predictions, axis=0)
107
+ else:
108
+ probs = model.predict(img_array, verbose=0)[0]
109
+
110
+ # Apply thresholds and format results
111
+ results = []
112
+ for disease, idx in label_encoder.items():
113
+ prob = float(probs[idx])
114
+ threshold = optimal_thresholds[disease]
115
+
116
+ if prob >= threshold:
117
+ confidence_score = min((prob - threshold) / (1 - threshold), 1.0)
118
+ confidence = 'HIGH' if confidence_score > 0.5 else 'MEDIUM'
119
+
120
+ results.append({
121
+ 'disease': disease,
122
+ 'probability': prob,
123
+ 'confidence': confidence
124
+ })
125
+
126
+ # Sort by probability
127
+ results = sorted(results, key=lambda x: x['probability'], reverse=True)
128
+
129
+ # Generate Grad-CAM for top 3 predictions if enabled
130
+ gradcam_images = None
131
+ if use_tta and results: # Use TTA checkbox to toggle Grad-CAM
132
+ try:
133
+ last_conv_layer = get_last_conv_layer_name(model)
134
+ gradcam_images = generate_gradcam_for_top_predictions(
135
+ image, model, results, label_encoder, top_k=min(3, len(results)),
136
+ last_conv_layer_name=last_conv_layer
137
+ )
138
+ except Exception as e:
139
+ print(f"Grad-CAM generation failed: {e}")
140
+ gradcam_images = None
141
+
142
+ # Format output
143
+ if not results:
144
+ html_output = """
145
+ <div style="padding: 20px; background: #d4edda; border: 2px solid #28a745; border-radius: 10px;">
146
+ <h2 style="color: #155724; margin-top: 0;">✅ NO ABNORMALITIES DETECTED</h2>
147
+ <p style="color: #155724;">All disease probabilities are below the optimized thresholds.</p>
148
+ <p style="color: #666; font-size: 0.9em; margin-bottom: 0;">
149
+ <strong>Note:</strong> This model prioritizes recall (80%), so low-probability findings are filtered out.
150
+ </p>
151
+ </div>
152
+ """
153
+ else:
154
+ html_output = f"""
155
+ <div style="padding: 20px; background: #fff3cd; border: 2px solid #ffc107; border-radius: 10px;">
156
+ <h2 style="color: #856404; margin-top: 0;">⚠️ {len(results)} POTENTIAL FINDING(S) DETECTED</h2>
157
+ <div style="margin: 15px 0;">
158
+ """
159
+
160
+ for i, r in enumerate(results, 1):
161
+ prob_pct = f"{r['probability'] * 100:.1f}%"
162
+ conf_color = '#28a745' if r['confidence'] == 'HIGH' else '#ffc107'
163
+
164
+ html_output += f"""
165
+ <div style="padding: 12px; margin: 8px 0; background: white; border-left: 4px solid {conf_color}; border-radius: 5px;">
166
+ <div style="display: flex; justify-content: space-between; align-items: center;">
167
+ <span style="font-weight: bold; font-size: 1.1em;">{i}. {r['disease']}</span>
168
+ <span style="background: {conf_color}; color: white; padding: 4px 12px; border-radius: 12px; font-size: 0.85em;">
169
+ {r['confidence']}
170
+ </span>
171
+ </div>
172
+ <div style="margin-top: 8px;">
173
+ <span style="color: #666;">Probability: </span>
174
+ <span style="font-weight: bold; color: #333;">{prob_pct}</span>
175
+ </div>
176
+ </div>
177
+ """
178
+
179
+ html_output += """
180
+ </div>
181
+ </div>
182
+ """
183
+
184
+ # Add disclaimer
185
+ html_output += """
186
+ <div style="margin-top: 20px; padding: 15px; background: #f8d7da; border: 2px solid #f5c6cb; border-radius: 10px;">
187
+ <h3 style="color: #721c24; margin-top: 0; font-size: 1em;">⚠️ IMPORTANT DISCLAIMER</h3>
188
+ <p style="color: #721c24; margin: 8px 0; font-size: 0.9em;">
189
+ <strong>This is a research prototype. NOT for clinical diagnosis.</strong>
190
+ </p>
191
+ <ul style="color: #721c24; margin: 8px 0; font-size: 0.85em; padding-left: 20px;">
192
+ <li>Model achieves 0.784 AUC (80% recall, 40% precision)</li>
193
+ <li>High false positive rate by design (prioritizes catching diseases)</li>
194
+ <li>Dataset has 10-20% label noise (NLP-extracted labels)</li>
195
+ <li>Always consult a qualified radiologist for medical diagnosis</li>
196
+ </ul>
197
+ </div>
198
+ """
199
+
200
+ # Return both HTML and Grad-CAM images
201
+ if gradcam_images:
202
+ return html_output, gradcam_images[0][1], gradcam_images[1][1] if len(gradcam_images) > 1 else None, gradcam_images[2][1] if len(gradcam_images) > 2 else None
203
+ else:
204
+ return html_output, None, None, None
205
+
206
+ except Exception as e:
207
+ error_html = f"""
208
+ <div style="padding: 20px; background: #f8d7da; border: 2px solid #f5c6cb; border-radius: 10px;">
209
+ <h2 style="color: #721c24; margin-top: 0;">❌ ERROR</h2>
210
+ <p style="color: #721c24;">Failed to process image: {str(e)}</p>
211
+ <p style="color: #666; font-size: 0.9em;">
212
+ Please ensure the image is a valid chest X-ray (PNG/JPEG format).
213
+ </p>
214
+ </div>
215
+ """
216
+ return error_html, None, None, None
217
+
218
+
219
+ # ============================================================================
220
+ # GRADIO INTERFACE
221
+ # ============================================================================
222
+
223
+ # Custom CSS
224
+ custom_css = """
225
+ #component-0 {
226
+ max-width: 900px;
227
+ margin: auto;
228
+ }
229
+ .output-html {
230
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
231
+ }
232
+ """
233
+
234
+ # Example images (optional - add if you have sample X-rays)
235
+ examples = [
236
+ # ["examples/normal.png"],
237
+ # ["examples/pneumonia.png"],
238
+ ]
239
+
240
+ # Create Gradio interface
241
+ with gr.Blocks(css=custom_css, title="Chest X-Ray Disease Classifier") as demo:
242
+ gr.Markdown(
243
+ """
244
+ # 🏥 Chest X-Ray Disease Classification
245
+
246
+ **Multi-label detection of 15 thoracic diseases using EfficientNetB0**
247
+
248
+ Upload a frontal chest X-ray image to detect potential abnormalities.
249
+
250
+ **Performance:** Mean AUC 0.784 | 80% Recall | Trained on 112K X-rays (NIH ChestX-ray14)
251
+
252
+ ---
253
+ """
254
+ )
255
+
256
+ with gr.Row():
257
+ with gr.Column(scale=1):
258
+ image_input = gr.Image(
259
+ label="Upload Chest X-Ray",
260
+ type="pil",
261
+ height=400
262
+ )
263
+
264
+ tta_checkbox = gr.Checkbox(
265
+ label="Enable Grad-CAM Visualization",
266
+ value=False,
267
+ info="Show where the model looks (enables TTA for better accuracy)"
268
+ )
269
+
270
+ predict_btn = gr.Button(
271
+ "🔍 Analyze X-Ray",
272
+ variant="primary",
273
+ size="lg"
274
+ )
275
+
276
+ with gr.Column(scale=1):
277
+ output_html = gr.HTML(
278
+ label="Results",
279
+ elem_classes="output-html"
280
+ )
281
+
282
+ # Grad-CAM visualizations
283
+ with gr.Row(visible=True):
284
+ gradcam_1 = gr.Image(label="🔥 Grad-CAM #1 (Top Prediction)", type="pil")
285
+ gradcam_2 = gr.Image(label="🔥 Grad-CAM #2", type="pil")
286
+ gradcam_3 = gr.Image(label="🔥 Grad-CAM #3", type="pil")
287
+
288
+ # Examples section (if you have sample images)
289
+ if examples:
290
+ gr.Examples(
291
+ examples=examples,
292
+ inputs=image_input,
293
+ outputs=output_html,
294
+ fn=predict_xray,
295
+ cache_examples=False
296
+ )
297
+
298
+ gr.Markdown(
299
+ """
300
+ ---
301
+
302
+ ## 📊 About This Model
303
+
304
+ **Architecture:** EfficientNetB0 with full fine-tuning (237 layers)
305
+ **Training:** Focal Loss + Balanced Sampling + Mixed Precision (FP16)
306
+ **Dataset:** NIH ChestX-ray14 (112,120 images from 30,805 patients)
307
+
308
+ **Detected Diseases (15 classes):**
309
+ - Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion
310
+ - Emphysema, Fibrosis, Hernia, Infiltration, Mass
311
+ - Nodule, Pleural Thickening, Pneumonia, Pneumothorax, No Finding
312
+
313
+ **Performance by Disease:**
314
+ - Best: Edema (0.884 AUC), Cardiomegaly (0.865 AUC), Effusion (0.852 AUC)
315
+ - Worst: Hernia (0.612 AUC - only 110 training samples)
316
+
317
+ **Limitations:**
318
+ - High false positive rate (60%) by design to maximize recall
319
+ - Dataset has label noise (NLP-extracted from reports)
320
+ - Single-site training (NIH) - may not generalize to other hospitals
321
+ - NOT FDA-approved or clinically validated
322
+
323
+ ---
324
+
325
+ ## 🔗 Links
326
+
327
+ - **Dataset:** [NIH ChestX-ray14 on Kaggle](https://www.kaggle.com/datasets/nih-chest-xrays/data)
328
+ - **Code:** [GitHub Repository](https://github.com/emirmuhammmetaran/chest-xray-classification)
329
+ - **Paper:** [Wang et al. 2017](https://arxiv.org/abs/1705.02315)
330
+
331
+ ---
332
+
333
+ **Built by:** Emir Muhammet Aran | **Institution:** Computer Engineering Student
334
+ **Last Updated:** December 2025
335
+ """
336
+ )
337
+
338
+ # Connect button to prediction function
339
+ predict_btn.click(
340
+ fn=predict_xray,
341
+ inputs=[image_input, tta_checkbox],
342
+ outputs=[output_html, gradcam_1, gradcam_2, gradcam_3]
343
+ )
344
+
345
+ # Launch app
346
+ if __name__ == "__main__":
347
+ demo.launch()
best_model_final.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c786618c34a3bb1afa575f26d2f7b814d2ec5fd7a70d354308cd593f8c5ab913
3
+ size 19900048
gradcam_utils.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grad-CAM Implementation for Chest X-Ray Classification
3
+ ========================================================
4
+
5
+ Visualizes which regions of the X-ray the model focuses on when making predictions.
6
+
7
+ Reference: Selvaraju et al. (2017) - Grad-CAM: Visual Explanations from Deep Networks
8
+ """
9
+
10
+ import tensorflow as tf
11
+ import numpy as np
12
+ import cv2
13
+ from PIL import Image
14
+
15
+
16
+ def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
17
+ """
18
+ Generate Grad-CAM heatmap for a given image and prediction.
19
+
20
+ Args:
21
+ img_array: Preprocessed image (batch_size, height, width, channels)
22
+ model: Trained Keras model
23
+ last_conv_layer_name: Name of last convolutional layer
24
+ pred_index: Target class index (if None, uses predicted class)
25
+
26
+ Returns:
27
+ heatmap: Normalized heatmap (0-1 range)
28
+ """
29
+ # Create a model that maps the input image to the activations of the last conv layer
30
+ # as well as the output predictions
31
+ grad_model = tf.keras.models.Model(
32
+ [model.inputs],
33
+ [model.get_layer(last_conv_layer_name).output, model.output]
34
+ )
35
+
36
+ # Compute the gradient of the top predicted class for our input image
37
+ # with respect to the activations of the last conv layer
38
+ with tf.GradientTape() as tape:
39
+ last_conv_layer_output, preds = grad_model(img_array)
40
+ if pred_index is None:
41
+ pred_index = tf.argmax(preds[0])
42
+ class_channel = preds[:, pred_index]
43
+
44
+ # Gradient of the output neuron with regard to the output feature map of the last conv layer
45
+ grads = tape.gradient(class_channel, last_conv_layer_output)
46
+
47
+ # Vector where each entry is the mean intensity of the gradient over a specific feature map channel
48
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
49
+
50
+ # Multiply each channel in the feature map array by "how important this channel is"
51
+ last_conv_layer_output = last_conv_layer_output[0]
52
+ heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
53
+ heatmap = tf.squeeze(heatmap)
54
+
55
+ # Normalize the heatmap between 0 & 1 for visualization
56
+ heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
57
+ return heatmap.numpy()
58
+
59
+
60
+ def overlay_heatmap_on_image(img, heatmap, alpha=0.4, colormap=cv2.COLORMAP_JET):
61
+ """
62
+ Overlay Grad-CAM heatmap on original image.
63
+
64
+ Args:
65
+ img: Original PIL Image or numpy array
66
+ heatmap: Grad-CAM heatmap (0-1 range)
67
+ alpha: Transparency of heatmap overlay (0-1)
68
+ colormap: OpenCV colormap (default: JET - red=hot, blue=cold)
69
+
70
+ Returns:
71
+ superimposed_img: PIL Image with heatmap overlay
72
+ """
73
+ # Convert PIL to numpy if needed
74
+ if isinstance(img, Image.Image):
75
+ img = np.array(img)
76
+
77
+ # Resize heatmap to match image size
78
+ heatmap_resized = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
79
+
80
+ # Convert heatmap to RGB
81
+ heatmap_colored = np.uint8(255 * heatmap_resized)
82
+ heatmap_colored = cv2.applyColorMap(heatmap_colored, colormap)
83
+ heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
84
+
85
+ # Superimpose the heatmap on original image
86
+ superimposed_img = heatmap_colored * alpha + img * (1 - alpha)
87
+ superimposed_img = np.uint8(superimposed_img)
88
+
89
+ return Image.fromarray(superimposed_img)
90
+
91
+
92
+ def generate_gradcam_for_disease(image, model, disease_name, label_encoder,
93
+ last_conv_layer_name='top_conv', img_size=224):
94
+ """
95
+ Generate Grad-CAM visualization for a specific disease prediction.
96
+
97
+ Args:
98
+ image: PIL Image
99
+ model: Trained model
100
+ disease_name: Name of disease to visualize
101
+ label_encoder: Disease name -> index mapping
102
+ last_conv_layer_name: Name of last conv layer in EfficientNetB0
103
+ img_size: Input image size
104
+
105
+ Returns:
106
+ overlaid_image: PIL Image with Grad-CAM overlay
107
+ heatmap: Raw heatmap array
108
+ """
109
+ # Preprocess image
110
+ img_resized = image.convert('RGB').resize((img_size, img_size))
111
+ img_array = np.array(img_resized) / 255.0
112
+ img_array = np.expand_dims(img_array, axis=0).astype(np.float32)
113
+
114
+ # Get disease index
115
+ disease_idx = label_encoder[disease_name]
116
+
117
+ # Generate heatmap
118
+ heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name, disease_idx)
119
+
120
+ # Overlay on original image
121
+ overlaid_image = overlay_heatmap_on_image(img_resized, heatmap, alpha=0.4)
122
+
123
+ return overlaid_image, heatmap
124
+
125
+
126
+ def generate_gradcam_for_top_predictions(image, model, predictions, label_encoder,
127
+ top_k=3, last_conv_layer_name='top_conv'):
128
+ """
129
+ Generate Grad-CAM for top K predicted diseases.
130
+
131
+ Args:
132
+ image: PIL Image
133
+ model: Trained model
134
+ predictions: List of prediction dicts from main app
135
+ label_encoder: Disease name -> index mapping
136
+ top_k: Number of top predictions to visualize
137
+ last_conv_layer_name: Name of last conv layer
138
+
139
+ Returns:
140
+ gradcam_images: List of (disease_name, overlaid_image, probability) tuples
141
+ """
142
+ gradcam_images = []
143
+
144
+ # Sort predictions by probability
145
+ sorted_preds = sorted(predictions, key=lambda x: x['probability'], reverse=True)[:top_k]
146
+
147
+ for pred in sorted_preds:
148
+ disease_name = pred['disease']
149
+ probability = pred['probability']
150
+
151
+ # Generate Grad-CAM
152
+ overlaid_img, _ = generate_gradcam_for_disease(
153
+ image, model, disease_name, label_encoder, last_conv_layer_name
154
+ )
155
+
156
+ gradcam_images.append((disease_name, overlaid_img, probability))
157
+
158
+ return gradcam_images
159
+
160
+
161
+ def get_last_conv_layer_name(model):
162
+ """
163
+ Automatically find the last convolutional layer in the model.
164
+
165
+ For EfficientNetB0, it's typically 'top_conv' or the last Conv2D layer.
166
+
167
+ Args:
168
+ model: Keras model
169
+
170
+ Returns:
171
+ layer_name: Name of last conv layer
172
+ """
173
+ # Try common names first
174
+ common_names = ['top_conv', 'block7a_project_conv', 'conv_head']
175
+ for name in common_names:
176
+ try:
177
+ model.get_layer(name)
178
+ return name
179
+ except:
180
+ pass
181
+
182
+ # Search backwards for Conv2D layer
183
+ for layer in reversed(model.layers):
184
+ if isinstance(layer, tf.keras.layers.Conv2D):
185
+ return layer.name
186
+
187
+ raise ValueError("No convolutional layer found in model!")
label_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a741eb91d4f54ad79acc4df4cb6f2ab3b91de9bae12e1639b84037a56c5d008
3
+ size 234
optimal_thresholds.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b72ad6220549ae86dcb8800ffe62fc9362cbeeb75d32db2e484bae77aab25338
3
+ size 514
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ tensorflow>=2.10.0,<2.11.0
2
+ numpy>=1.23.0,<1.24.0
3
+ pandas>=2.0.0
4
+ matplotlib>=3.7.0
5
+ seaborn>=0.12.0
6
+ scikit-learn>=1.3.0
7
+ jupyter>=1.0.0
8
+ Pillow>=9.5.0
9
+ opencv-python>=4.7.0
10
+ gradio>=4.0.0