File size: 4,520 Bytes
a7ae8c5
 
8e23aec
 
 
 
 
 
 
 
 
 
a7ae8c5
8e23aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
---
license: cc-by-4.0
tags:
  - medical-imaging
  - lung-nodule
  - ct
  - concept-bottleneck
  - knowledge-distillation
  - explainable-ai
  - pytorch
library_name: pytorch
pipeline_tag: image-classification
---

# Pulmo — 2.5D Concept-Bottleneck Multi-task Model for Lung Nodule Analysis

Pulmo is a lightweight, **explainable** model for chest-CT lung nodule analysis.
From a single 64³ patch (passed as its 7 central axial slices) it jointly predicts:

- **Detection** — nodule vs. non-nodule
- **Malignancy** — benign vs. malignant, via a **concept bottleneck**
- **8 radiological concepts** — subtlety, internal structure, calcification, sphericity, margin, lobulation, spiculation, texture
- **Segmentation** — nodule mask of the central slice

Because malignancy is computed as `Linear(8 concepts → 2)`, **every malignancy
prediction is fully attributable to the 8 clinical concepts** — you can read off
exactly which concept (e.g. spiculation) drove the decision.

> ⚠️ **Research use only. Pulmo is not a medical device and must not be used for clinical diagnosis.**

## How it was built

Pulmo is the **deployment student** of a knowledge-distillation pipeline:

1. A ViT-Large encoder was self-supervised (MAE / domain-adaptive pretraining) on lung CT.
2. A **3D teacher** (`UNet3D` + concept-bottleneck heads, ~CNN-only trunk) was fine-tuned on LUNA16/LIDC with focal loss, MixUp and aggressive augmentation. Teacher test: det 0.998 / mal 0.986 / Dice 0.857.
3. **This model** — a 2.5D student (~2M params) — was trained from scratch by **online distillation** (`loss = 0.5·hard + 0.5·soft`, temperature 3.0) to imitate the frozen teacher, for ~5–10× faster inference at a fraction of the size.

Full training notebooks (data prep → labels → patch precompute → concepts → teacher → distillation → evaluation → explainability): **[link to your notebooks repo here]**

## Results (held-out internal test split)

| Task | Metric | Pulmo (2.5D student) | Teacher (3D) |
|---|---|---|---|
| Detection | AUC | 0.997 | 0.998 |
| Malignancy | AUC | 0.986 | 0.986 |
| Segmentation | Dice | 0.859 | 0.857 |

Patient-level 80/10/10 split of LUNA16. Metrics are patch-level on the internal
test split; the model has **not** been externally validated.

## Usage

```python
import numpy as np, torch
from huggingface_hub import hf_hub_download
from modeling import Student2p5D, CONCEPT_NAMES

ckpt = hf_hub_download("ariyul/Pulmo", "student_2p5d_best.pth")
model = Student2p5D(n_slices=7, n_concepts=8, base=24)
state = torch.load(ckpt, map_location="cpu", weights_only=False)
model.load_state_dict(state["model_state_dict"], strict=True)
model.eval()

# patch_3d: a 64x64x64 raw-HU crop centered on a candidate (Z, Y, X)
p = np.clip(patch_3d.astype(np.float32), -1000, 1000)
p = (p + 1000) / 2000.0
x = torch.from_numpy(p[28:35][None])      # 7 central axial slices -> (1, 7, 64, 64)

with torch.no_grad():
    out = model(x)
mal_p = torch.softmax(out["malignancy"][0], 0)[1].item()
```

See `inference_example.py` for the full example including the concept-level explanation.

## Input / preprocessing

- Input tensor: `(B, 7, 64, 64)`, float32 in `[0, 1]`
- HU clip `[-1000, 1000]`, then normalize to `[0, 1]`
- Take the 7 central axial slices of a 64³ patch centered on the candidate world coordinate

## Files

- `student_2p5d_best.pth` — model weights
- `modeling.py``Student2p5D` definition (required to load the weights)
- `config.json` — architecture and preprocessing parameters
- `inference_example.py` — runnable example with concept explanation

## Training data & citations

Trained on **LUNA16** (a curated subset of **LIDC-IDRI**). If you use Pulmo, please
also credit the underlying datasets:

- Setio et al., *Validation, comparison, and combination of algorithms for automatic
  detection of pulmonary nodules in CT images: the LUNA16 challenge*, Medical Image
  Analysis, 2017.
- Armato et al., *The Lung Image Database Consortium (LIDC) and Image Database Resource
  Initiative (IDRI)*, Medical Physics, 2011.

## Limitations

- Patch-level evaluation on a single internal split; no external/multi-center validation.
- Trained on LUNA16 preprocessing conventions (resampling, HU window); behavior on
  other acquisition protocols is untested.
- Concept predictions are learned regressions of LIDC radiologist ratings, not
  ground-truth measurements.

## License

Model weights and code: CC BY 4.0. Underlying datasets carry their own licenses.