mdsajjadullah's picture
Update README.md
f292e9a verified
---
license: mit
tags:
- medical-imaging
- chest-xray
- federated-learning
- differential-privacy
- pytorch
- multi-label-classification
- radiology
- fairness
- explainability
language:
- en
datasets:
- ChestX-ray14
- CheXpert
metrics:
- roc_auc
---
# FedPrivNet β€” Privacy-Aware Hybrid Deep Learning Model for Chest X-Ray Classification
<p align="center">
<img src="training_curves.png" width="800"/>
</p>
## Overview
**FedPrivNet** is a custom hybrid deep learning architecture designed
for multi-label chest X-ray disease classification across **14 thoracic
pathologies**. The model is specifically engineered as the backbone for
a privacy-preserving Federated Learning pipeline, with three core design
principles: **Differential Privacy compatibility**, **Demographic
Fairness**, and **Communication Efficiency**.
Developed by **Md. Sajjad Ullah**, this model represents a novel
architectural contribution that unifies ResNet-18 and DenseNet-121
feature extraction with purpose-built privacy-safe normalization and a
trainable spatial attention mechanism for built-in explainability.
---
## Model Architecture
FedPrivNet is a dual-branch hybrid network with the following novel components:
| Component | Description |
|---|---|
| **Branch A** | ResNet-18 pretrained backbone (layers 1–3) |
| **Branch B** | DenseNet-121 pretrained backbone (blocks 1–3) |
| **Fusion Module** | 1Γ—1 Conv fusion: 1280 β†’ 512 β†’ 256 channels |
| **DPResBlock Γ—3** | Custom residual blocks with GroupNorm (DP-safe) |
| **SEBlock** | Squeeze-and-Excitation channel attention in every block |
| **SpatialAttentionGate** | Novel trainable spatial mask for built-in XAI |
| **Classifier Head** | GAP β†’ Dropout(0.3) β†’ Linear β†’ Sigmoid (14 classes) |
> **Key Design Decision:** All BatchNorm layers are replaced with
> **GroupNorm** throughout the entire network. This makes FedPrivNet
> natively compatible with **DP-SGD** (Opacus), which requires
> per-sample gradient computation β€” a property that BatchNorm breaks.
### Architecture Diagram
Input (224Γ—224 Chest X-Ray)
β”‚
β”Œβ”€β”€β”€β”€β”΄β”€β”€β”€β”€β”
β”‚ β”‚
ResNet-18 DenseNet-121
(256ch) (1024ch)
β”‚ β”‚
β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜
Fusion Conv
(256ch, 28Γ—28)
β”‚
DPResBlock-1 + SE (256ch)
β”‚
DPResBlock-2 + SE (512ch, 14Γ—14)
β”‚
DPResBlock-3 + SE (512ch)
β”‚
SpatialAttentionGate
(novel XAI module)
β”‚
GAP β†’ Dropout
β”‚
Linear (512β†’14)
β”‚
Sigmoid
β”‚
14-class Output
---
## Performance
### Overall Metrics
| Metric | Value |
|---|---|
| **Mean AUC-ROC (14 classes)** | **0.7862** |
| Train AUC-ROC | 0.8913 |
| Gender Demographic Parity Gap | **0.0059** |
| Age Demographic Parity Gap | **0.0133** |
| Total Parameters | 17,560,527 |
| Best Epoch | 29 / 30 |
### Per-Class AUC-ROC
| Disease | AUC-ROC |
|---|---|
| Cardiomegaly | **0.8847** |
| Emphysema | **0.8698** |
| Pneumothorax | **0.8650** |
| Edema | 0.8358 |
| Effusion | 0.8297 |
| Hernia | 0.8138 |
| Mass | 0.7785 |
| Fibrosis | 0.7707 |
| Atelectasis | 0.7617 |
| Pleural Thickening | 0.7384 |
| Consolidation | 0.7360 |
| Nodule | 0.7148 |
| Pneumonia | 0.7122 |
| Infiltration | 0.6963 |
### Spatial Attention Visualization
<p align="center">
<img src="spatial_attention_visualization.png" width="800"/>
</p>
---
## Baseline Comparison
FedPrivNet outperforms both individual backbone models when evaluated
under identical federated learning conditions (3 training epochs,
same dataset partition, same hyperparameters).
| Model | Mean AUC-ROC | Parameters | Notes |
|---|---|---|---|
| ResNet-18 | 0.7060 | 11.2M | Single backbone |
| DenseNet-121 | 0.6992 | 7.0M | Single backbone |
| **FedPrivNet (Ours)** | **0.7862** | 17.6M | Dual-backbone fusion |
The dual-backbone fusion delivers **+8.0% AUC improvement** over
ResNet-18 and **+8.7%** over DenseNet-121, directly justifying
the hybrid architecture design choice.
<p align="center">
<img src="baseline_comparison.png" width="800"/>
</p>
---
## Ablation Study
Each architectural component was systematically removed to measure
its individual contribution. All variants were trained for 3 epochs
from random initialization on the same federated partition.
| Variant | Mean AUC-ROC | AUC Drop | Attn | SE |
|---|---|---|---|---|
| Full FedPrivNet | 0.6778 | β€” | βœ… | βœ… |
| Without Spatial Attention Gate | 0.6850 | +0.0072 | ❌ | βœ… |
| Without SE Blocks | 0.7018 | +0.0239 | βœ… | ❌ |
| Without Attention & SE | 0.6958 | +0.0179 | ❌ | ❌ |
**Key finding:** SE Blocks contribute most to performance (+0.024 AUC
drop when removed). The Spatial Attention Gate provides additional
diagnostic focus and serves as the built-in XAI module. Both
components together produce the strongest results.
<p align="center">
<img src="ablation_study.png" width="800"/>
</p>
---
## Training Details
| Setting | Value |
|---|---|
| Framework | PyTorch 2.0 |
| Optimizer | AdamW (lr=1e-4, weight_decay=1e-5) |
| Scheduler | CosineAnnealingLR |
| Loss Function | Binary Cross Entropy |
| Batch Size | 128 |
| Epochs | 30 |
| Image Size | 224 Γ— 224 |
| GPU | NVIDIA GeForce RTX 4090 |
| Training Samples | 100,000 (50k CXR8 + 50k CheXpert) |
| Validation Samples | 10,000 (CXR8 test set) |
---
## Datasets
| Dataset | Institution | Images | Labels |
|---|---|---|---|
| **ChestX-ray14** | NIH Clinical Center | 112,120 | 14 thoracic diseases |
| **CheXpert-small** | Stanford University | 224,316 | 14 observations |
---
## Usage
### Load Pretrained Model
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(
repo_id = "mdsajjadullah/FedPrivNet-ChestXray14-CheXpert",
filename = "FedPrivNet_best.pth"
)
model = FedPrivNet(num_classes=14, dropout=0.3, pretrained=False)
checkpoint = torch.load(weights_path, map_location='cpu',
weights_only=False)
model.load_state_dict(checkpoint['model_state'])
model.eval()
print(f"Model loaded | Val AUC: {checkpoint['val_auc']:.4f}")
```
### Inference
```python
from torchvision import transforms
from PIL import Image
DISEASE_LABELS = [
'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax',
'Consolidation', 'Edema', 'Emphysema', 'Fibrosis',
'Pleural_Thickening', 'Hernia'
]
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
image = Image.open("chest_xray.png").convert("RGB")
x = transform(image).unsqueeze(0)
with torch.no_grad():
probs, attn_map = model(x, return_attention=True)
for label, prob in zip(DISEASE_LABELS, probs[0]):
if prob > 0.5:
print(f" {label}: {prob:.4f}")
```
---
## Federated Learning Integration
FedPrivNet is the backbone of a full FL pipeline:
FedPrivNet Backbone
+
Opacus 1.4 (DP-SGD, Ρ ∈ {1, 3, 5, 10})
+
Fairness Regularizer (Demographic Parity + Equal Opportunity)
+
Top-k Gradient Compression (k ∈ {0.30, 0.40, 0.50})
+
Grad-CAM + SHAP Explainability Analysis
---
## Citation
```bibtex
@misc{ullah2026fedprivnet,
author = {Md.Sajjad Ullah},
title = {FedPrivNet: Privacy-Aware Hybrid Deep Learning Model
for Chest X-Ray Classification},
year = {2026},
publisher = {HuggingFace},
url = {https://huggingface.co/mdsajjadullah/FedPrivNet-ChestXray14-CheXpert}
}
```
---
## Author
**Md.Sajjad Ullah**
Department of Computer Science and Engineering
University of Asia Pacific, Bangladesh
---
## License
MIT License