Pulmo / inference_example.py
ariyul's picture
Add Pulmo: 2.5D concept-bottleneck student model distilled from 3D teacher for lung nodule detection, malignancy classification, and segmentation on LUNA16/LIDC-IDRI
8e23aec verified
"""
Pulmo inference example.
Shows how to:
1. download the weights from the Hub,
2. build the model,
3. run a forward pass on a 64^3 CT patch (or 7 axial slices),
4. read the concept-bottleneck explanation for the malignancy decision.
Requires: torch, numpy, huggingface_hub
pip install torch numpy huggingface_hub
"""
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from modeling import Student2p5D, CONCEPT_NAMES
REPO_ID = "ariyul/Pulmo"
HU_CLIP = (-1000, 1000)
ROI = 64
N_SLICES = 7
def preprocess_patch(patch_3d):
"""64^3 raw-HU patch (Z, Y, X) -> (1, 7, 64, 64) float32 tensor in [0, 1].
The 7 central axial slices are extracted along Z, matching how the model
was trained.
"""
p = np.clip(patch_3d.astype(np.float32), HU_CLIP[0], HU_CLIP[1])
p = (p - HU_CLIP[0]) / (HU_CLIP[1] - HU_CLIP[0])
c = ROI // 2
h = N_SLICES // 2
slices = p[c - h:c + h + 1] # (7, 64, 64)
return torch.from_numpy(slices[None]) # (1, 7, 64, 64)
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="student_2p5d_best.pth")
model = Student2p5D(n_slices=N_SLICES, n_concepts=8, base=24).to(device)
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
state = ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt
model.load_state_dict(state, strict=True)
model.eval()
# --- Replace this with a real 64^3 patch cropped around a candidate ---
dummy_patch = np.random.randint(-1000, 400, size=(64, 64, 64)).astype(np.int16)
x = preprocess_patch(dummy_patch).to(device)
with torch.no_grad():
out = model(x)
det_p = torch.softmax(out["detection"][0], 0)[1].item()
mal_p = torch.softmax(out["malignancy"][0], 0)[1].item()
concepts = out["concepts"][0].cpu().numpy()
seg = torch.sigmoid(out["segmentation"][0, 0]).cpu().numpy()
print(f"Nodule probability : {det_p:.3f}")
print(f"Malignancy probability: {mal_p:.3f}")
print(f"Segmented voxels (>0.5): {(seg > 0.5).sum()}")
# Concept-bottleneck explanation:
# malignancy logit = sum_i w_net[i] * concept[i], w_net = W[malign] - W[benign]
W = model.malignancy_head.weight.detach().cpu().numpy()
w_net = W[1] - W[0]
contrib = w_net * concepts
print("\nPer-concept contribution to the malignancy decision:")
for i in np.argsort(contrib)[::-1]:
print(f" {CONCEPT_NAMES[i]:18s} value={concepts[i]:+.2f} contribution={contrib[i]:+.3f}")
if __name__ == "__main__":
main()