Model Card โ€” CapsNet 4-Class Lung-Disease Classifier

Model name: capsnet_4class_lung_disease_classifier
Version: 1.0
Date: 2025-09-17

Overview

A Capsule Network (CapsNet) implemented in TensorFlow/Keras to classify four lung-disease categories from masked chest X-ray images. The model uses routing-by-agreement and margin loss, and was trained with MLflow tracking. Input images are resized to 256ร—256ร—3.

โš ๏ธ Not a medical device. Outputs are for research/education. Clinician review is required before any clinical use.

Intended Use

  • Primary use: Educational/research experiments on lung-disease image classification.
  • Users: ML practitioners and students familiar with Python/TensorFlow.
  • Out-of-scope: Direct clinical decision-making; deployment on patient data without formal validation and regulatory clearance.

Model Details

  • For technical detalis on Capsule Networks see https://www.digitalocean.com/community/tutorials/capsule-networks (a nice one)
  • Architecture: CapsNet with PrimaryCaps and DigitCaps; routing iterations: 3; first Conv2D kernel size: tuned over [5, 7, 9, 10, 11] after an exploratory sweep (3-epoch runs over [3,5,7,9,10,11]).
  • Loss: margin_loss (capsule margin).
  • Optimizer: Adam with a learning-rate scheduler (lr_scheduler).
  • Metrics: accuracy.
  • Input shape: (1, 256, 256, 1); #classes: 4; batch size: 32.
  • Training schedule: up to 50 epochs with callbacks: EarlyStopping (mode='min'), ReduceLROnPlateau, custom StopAtValAccuracy(target=0.95), and ModelCheckpoint(save_best_only=True).
  • Framework: TensorFlow/Keras.

Data

  • Training/validation: Balanced masked CXR (loaded as train_dataset and val_dataset). Long-run training used validation_split = 0.2; kernel exploration used validation_split = 0.5.
  • Preprocessing: resize to 256ร—256; masking performed upstream. Any additional normalization should match the notebook pipeline.
  • Label schema: 4 disease classes (variable disease_labels = ['COVID', 'Lung_Opacity', 'Normal', 'Viral Pneumonia']).

Licenses & provenance:

Training/Validation dataset: https://www.kaggle.com/tawsifurrahman/covid19-radiography-database (License: Data files ยฉ Original Authors) (Data size: 1.15 Gb)

Dataset for external tests: https://www.kaggle.com/datasets/pranavraikokte/covid19-image-dataset (License: https://creativecommons.org/licenses/by-sa/4.0/) and https://www.kaggle.com/datasets/omkarmanohardalvi/lungs-disease-dataset-4-types (License: Unknown)

Evaluation

  • Protocol: train on the balanced training set, validate with a held-out split; select first-layer kernel size from [5,7,9,10,11] based on validation performance; evaluate best checkpoint on the unseen test set.
  • Reported metrics: Metrics for each model being trained are in capsnet_training_metrics_all_runs.csv

External test

I excluded the lung opacity class from external tests because it often co-occurs with other diseases. This makes it challenging to classify accurately. However, the reported accuracy for the remaining classes is still quite representative.

Results CSV: test_on_external_dataset_capsnet_lung_disease_classifier_krnl9.csv

alt text

Risks, Bias, and Limitations

  • Domain shift: performance may degrade on images from other hospitals, scanners, or populations.
  • Label noise / class imbalance: training is balanced, but real-world prevalence may differ.
  • Confounders: text markers, devices, or preprocessing differences can leak non-pathology signals.
  • Fairness: if patient demographics were not controlled, subgroup performance may vary.
  • Regulatory: not cleared for clinical use.

Recommendations

  • Always use human-in-the-loop review.
  • Report per-class metrics and confidence scores; calibrate if needed.
  • Perform external validation on multiple sites before any operational use.
  • Track experiments with MLflow (mlruns_capsnet) and save confusion matrices.

How to Use

import tensorflow as tf
from tensorflow.keras.utils import load_img, img_to_array

# here the model architecture inc. custom objects to 
from modelbuilder import capsnet_custom_objects # "margin_loss": margin_loss,
                                                # "PrimaryCaps": PrimaryCaps,
                                                # "DigitCaps": DigitCaps,
                                                # "Length": Length

# Load trained Keras model
model = tf.keras.models.load_model("path/to/model.keras",
                                   custom_objects=custom_objects)

x = preprocess("example_cxr.png")
# x.shape -> (1, 256, 256, 1)
pred = model.predict(x)[0]            # shape: (4,)
pred_label = np.argmax(pred)
print(pred, pred_label)
Downloads last month
14
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Space using valste/capsnet-4class-lung-disease-classifier 1