| --- |
| license: cc-by-4.0 |
| tags: |
| - medical-imaging |
| - lung-nodule |
| - ct |
| - concept-bottleneck |
| - knowledge-distillation |
| - explainable-ai |
| - pytorch |
| library_name: pytorch |
| pipeline_tag: image-classification |
| --- |
| |
| # Pulmo — 2.5D Concept-Bottleneck Multi-task Model for Lung Nodule Analysis |
|
|
| Pulmo is a lightweight, **explainable** model for chest-CT lung nodule analysis. |
| From a single 64³ patch (passed as its 7 central axial slices) it jointly predicts: |
|
|
| - **Detection** — nodule vs. non-nodule |
| - **Malignancy** — benign vs. malignant, via a **concept bottleneck** |
| - **8 radiological concepts** — subtlety, internal structure, calcification, sphericity, margin, lobulation, spiculation, texture |
| - **Segmentation** — nodule mask of the central slice |
|
|
| Because malignancy is computed as `Linear(8 concepts → 2)`, **every malignancy |
| prediction is fully attributable to the 8 clinical concepts** — you can read off |
| exactly which concept (e.g. spiculation) drove the decision. |
|
|
| > ⚠️ **Research use only. Pulmo is not a medical device and must not be used for clinical diagnosis.** |
|
|
| ## How it was built |
|
|
| Pulmo is the **deployment student** of a knowledge-distillation pipeline: |
|
|
| 1. A ViT-Large encoder was self-supervised (MAE / domain-adaptive pretraining) on lung CT. |
| 2. A **3D teacher** (`UNet3D` + concept-bottleneck heads, ~CNN-only trunk) was fine-tuned on LUNA16/LIDC with focal loss, MixUp and aggressive augmentation. Teacher test: det 0.998 / mal 0.986 / Dice 0.857. |
| 3. **This model** — a 2.5D student (~2M params) — was trained from scratch by **online distillation** (`loss = 0.5·hard + 0.5·soft`, temperature 3.0) to imitate the frozen teacher, for ~5–10× faster inference at a fraction of the size. |
|
|
| Full training notebooks (data prep → labels → patch precompute → concepts → teacher → distillation → evaluation → explainability): **[link to your notebooks repo here]** |
|
|
| ## Results (held-out internal test split) |
|
|
| | Task | Metric | Pulmo (2.5D student) | Teacher (3D) | |
| |---|---|---|---| |
| | Detection | AUC | 0.997 | 0.998 | |
| | Malignancy | AUC | 0.986 | 0.986 | |
| | Segmentation | Dice | 0.859 | 0.857 | |
|
|
| Patient-level 80/10/10 split of LUNA16. Metrics are patch-level on the internal |
| test split; the model has **not** been externally validated. |
|
|
| ## Usage |
|
|
| ```python |
| import numpy as np, torch |
| from huggingface_hub import hf_hub_download |
| from modeling import Student2p5D, CONCEPT_NAMES |
| |
| ckpt = hf_hub_download("ariyul/Pulmo", "student_2p5d_best.pth") |
| model = Student2p5D(n_slices=7, n_concepts=8, base=24) |
| state = torch.load(ckpt, map_location="cpu", weights_only=False) |
| model.load_state_dict(state["model_state_dict"], strict=True) |
| model.eval() |
| |
| # patch_3d: a 64x64x64 raw-HU crop centered on a candidate (Z, Y, X) |
| p = np.clip(patch_3d.astype(np.float32), -1000, 1000) |
| p = (p + 1000) / 2000.0 |
| x = torch.from_numpy(p[28:35][None]) # 7 central axial slices -> (1, 7, 64, 64) |
| |
| with torch.no_grad(): |
| out = model(x) |
| mal_p = torch.softmax(out["malignancy"][0], 0)[1].item() |
| ``` |
|
|
| See `inference_example.py` for the full example including the concept-level explanation. |
|
|
| ## Input / preprocessing |
|
|
| - Input tensor: `(B, 7, 64, 64)`, float32 in `[0, 1]` |
| - HU clip `[-1000, 1000]`, then normalize to `[0, 1]` |
| - Take the 7 central axial slices of a 64³ patch centered on the candidate world coordinate |
|
|
| ## Files |
|
|
| - `student_2p5d_best.pth` — model weights |
| - `modeling.py` — `Student2p5D` definition (required to load the weights) |
| - `config.json` — architecture and preprocessing parameters |
| - `inference_example.py` — runnable example with concept explanation |
|
|
| ## Training data & citations |
|
|
| Trained on **LUNA16** (a curated subset of **LIDC-IDRI**). If you use Pulmo, please |
| also credit the underlying datasets: |
|
|
| - Setio et al., *Validation, comparison, and combination of algorithms for automatic |
| detection of pulmonary nodules in CT images: the LUNA16 challenge*, Medical Image |
| Analysis, 2017. |
| - Armato et al., *The Lung Image Database Consortium (LIDC) and Image Database Resource |
| Initiative (IDRI)*, Medical Physics, 2011. |
|
|
| ## Limitations |
|
|
| - Patch-level evaluation on a single internal split; no external/multi-center validation. |
| - Trained on LUNA16 preprocessing conventions (resampling, HU window); behavior on |
| other acquisition protocols is untested. |
| - Concept predictions are learned regressions of LIDC radiologist ratings, not |
| ground-truth measurements. |
|
|
| ## License |
|
|
| Model weights and code: CC BY 4.0. Underlying datasets carry their own licenses. |
|
|