Model Card โ CapsNet 4-Class Lung-Disease Classifier
Model name: capsnet_4class_lung_disease_classifier
Version: 1.0
Date: 2025-09-17
Overview
A Capsule Network (CapsNet) implemented in TensorFlow/Keras to classify four lung-disease categories from masked chest X-ray images. The model uses routing-by-agreement and margin loss, and was trained with MLflow tracking. Input images are resized to 256ร256ร3.
โ ๏ธ Not a medical device. Outputs are for research/education. Clinician review is required before any clinical use.
Intended Use
- Primary use: Educational/research experiments on lung-disease image classification.
- Users: ML practitioners and students familiar with Python/TensorFlow.
- Out-of-scope: Direct clinical decision-making; deployment on patient data without formal validation and regulatory clearance.
Model Details
- For technical detalis on Capsule Networks see https://www.digitalocean.com/community/tutorials/capsule-networks (a nice one)
- Architecture: CapsNet with
PrimaryCapsandDigitCaps; routing iterations: 3; first Conv2D kernel size: tuned over [5, 7, 9, 10, 11] after an exploratory sweep (3-epoch runs over [3,5,7,9,10,11]). - Loss:
margin_loss(capsule margin). - Optimizer:
Adamwith a learning-rate scheduler (lr_scheduler). - Metrics: accuracy.
- Input shape:
(1, 256, 256, 1); #classes: 4; batch size: 32. - Training schedule: up to 50 epochs with callbacks:
EarlyStopping(mode='min'),ReduceLROnPlateau, customStopAtValAccuracy(target=0.95), andModelCheckpoint(save_best_only=True). - Framework: TensorFlow/Keras.
Data
- Training/validation: Balanced masked CXR (loaded as
train_datasetandval_dataset). Long-run training usedvalidation_split = 0.2; kernel exploration usedvalidation_split = 0.5. - Preprocessing: resize to 256ร256; masking performed upstream. Any additional normalization should match the notebook pipeline.
- Label schema: 4 disease classes (variable
disease_labels = ['COVID', 'Lung_Opacity', 'Normal', 'Viral Pneumonia']).
Licenses & provenance:
Training/Validation dataset: https://www.kaggle.com/tawsifurrahman/covid19-radiography-database (License: Data files ยฉ Original Authors) (Data size: 1.15 Gb)
Dataset for external tests: https://www.kaggle.com/datasets/pranavraikokte/covid19-image-dataset (License: https://creativecommons.org/licenses/by-sa/4.0/) and https://www.kaggle.com/datasets/omkarmanohardalvi/lungs-disease-dataset-4-types (License: Unknown)
Evaluation
- Protocol: train on the balanced training set, validate with a held-out split; select first-layer kernel size from [5,7,9,10,11] based on validation performance; evaluate best checkpoint on the unseen test set.
- Reported metrics: Metrics for each model being trained are in
capsnet_training_metrics_all_runs.csv
External test
I excluded the lung opacity class from external tests because it often co-occurs with other diseases. This makes it challenging to classify accurately. However, the reported accuracy for the remaining classes is still quite representative.
Results CSV: test_on_external_dataset_capsnet_lung_disease_classifier_krnl9.csv
Risks, Bias, and Limitations
- Domain shift: performance may degrade on images from other hospitals, scanners, or populations.
- Label noise / class imbalance: training is balanced, but real-world prevalence may differ.
- Confounders: text markers, devices, or preprocessing differences can leak non-pathology signals.
- Fairness: if patient demographics were not controlled, subgroup performance may vary.
- Regulatory: not cleared for clinical use.
Recommendations
- Always use human-in-the-loop review.
- Report per-class metrics and confidence scores; calibrate if needed.
- Perform external validation on multiple sites before any operational use.
- Track experiments with MLflow (
mlruns_capsnet) and save confusion matrices.
How to Use
import tensorflow as tf
from tensorflow.keras.utils import load_img, img_to_array
# here the model architecture inc. custom objects to
from modelbuilder import capsnet_custom_objects # "margin_loss": margin_loss,
# "PrimaryCaps": PrimaryCaps,
# "DigitCaps": DigitCaps,
# "Length": Length
# Load trained Keras model
model = tf.keras.models.load_model("path/to/model.keras",
custom_objects=custom_objects)
x = preprocess("example_cxr.png")
# x.shape -> (1, 256, 256, 1)
pred = model.predict(x)[0] # shape: (4,)
pred_label = np.argmax(pred)
print(pred, pred_label)
- Downloads last month
- 14
