| """ |
| Pulmo inference example. |
| |
| Shows how to: |
| 1. download the weights from the Hub, |
| 2. build the model, |
| 3. run a forward pass on a 64^3 CT patch (or 7 axial slices), |
| 4. read the concept-bottleneck explanation for the malignancy decision. |
| |
| Requires: torch, numpy, huggingface_hub |
| pip install torch numpy huggingface_hub |
| """ |
|
|
| import numpy as np |
| import torch |
| from huggingface_hub import hf_hub_download |
|
|
| from modeling import Student2p5D, CONCEPT_NAMES |
|
|
| REPO_ID = "ariyul/Pulmo" |
| HU_CLIP = (-1000, 1000) |
| ROI = 64 |
| N_SLICES = 7 |
|
|
|
|
| def preprocess_patch(patch_3d): |
| """64^3 raw-HU patch (Z, Y, X) -> (1, 7, 64, 64) float32 tensor in [0, 1]. |
| |
| The 7 central axial slices are extracted along Z, matching how the model |
| was trained. |
| """ |
| p = np.clip(patch_3d.astype(np.float32), HU_CLIP[0], HU_CLIP[1]) |
| p = (p - HU_CLIP[0]) / (HU_CLIP[1] - HU_CLIP[0]) |
| c = ROI // 2 |
| h = N_SLICES // 2 |
| slices = p[c - h:c + h + 1] |
| return torch.from_numpy(slices[None]) |
|
|
|
|
| def main(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="student_2p5d_best.pth") |
| model = Student2p5D(n_slices=N_SLICES, n_concepts=8, base=24).to(device) |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
| state = ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt |
| model.load_state_dict(state, strict=True) |
| model.eval() |
|
|
| |
| dummy_patch = np.random.randint(-1000, 400, size=(64, 64, 64)).astype(np.int16) |
| x = preprocess_patch(dummy_patch).to(device) |
|
|
| with torch.no_grad(): |
| out = model(x) |
|
|
| det_p = torch.softmax(out["detection"][0], 0)[1].item() |
| mal_p = torch.softmax(out["malignancy"][0], 0)[1].item() |
| concepts = out["concepts"][0].cpu().numpy() |
| seg = torch.sigmoid(out["segmentation"][0, 0]).cpu().numpy() |
|
|
| print(f"Nodule probability : {det_p:.3f}") |
| print(f"Malignancy probability: {mal_p:.3f}") |
| print(f"Segmented voxels (>0.5): {(seg > 0.5).sum()}") |
|
|
| |
| |
| W = model.malignancy_head.weight.detach().cpu().numpy() |
| w_net = W[1] - W[0] |
| contrib = w_net * concepts |
| print("\nPer-concept contribution to the malignancy decision:") |
| for i in np.argsort(contrib)[::-1]: |
| print(f" {CONCEPT_NAMES[i]:18s} value={concepts[i]:+.2f} contribution={contrib[i]:+.3f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|