""" Pulmo inference example. Shows how to: 1. download the weights from the Hub, 2. build the model, 3. run a forward pass on a 64^3 CT patch (or 7 axial slices), 4. read the concept-bottleneck explanation for the malignancy decision. Requires: torch, numpy, huggingface_hub pip install torch numpy huggingface_hub """ import numpy as np import torch from huggingface_hub import hf_hub_download from modeling import Student2p5D, CONCEPT_NAMES REPO_ID = "ariyul/Pulmo" HU_CLIP = (-1000, 1000) ROI = 64 N_SLICES = 7 def preprocess_patch(patch_3d): """64^3 raw-HU patch (Z, Y, X) -> (1, 7, 64, 64) float32 tensor in [0, 1]. The 7 central axial slices are extracted along Z, matching how the model was trained. """ p = np.clip(patch_3d.astype(np.float32), HU_CLIP[0], HU_CLIP[1]) p = (p - HU_CLIP[0]) / (HU_CLIP[1] - HU_CLIP[0]) c = ROI // 2 h = N_SLICES // 2 slices = p[c - h:c + h + 1] # (7, 64, 64) return torch.from_numpy(slices[None]) # (1, 7, 64, 64) def main(): device = "cuda" if torch.cuda.is_available() else "cpu" ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="student_2p5d_best.pth") model = Student2p5D(n_slices=N_SLICES, n_concepts=8, base=24).to(device) ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) state = ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt model.load_state_dict(state, strict=True) model.eval() # --- Replace this with a real 64^3 patch cropped around a candidate --- dummy_patch = np.random.randint(-1000, 400, size=(64, 64, 64)).astype(np.int16) x = preprocess_patch(dummy_patch).to(device) with torch.no_grad(): out = model(x) det_p = torch.softmax(out["detection"][0], 0)[1].item() mal_p = torch.softmax(out["malignancy"][0], 0)[1].item() concepts = out["concepts"][0].cpu().numpy() seg = torch.sigmoid(out["segmentation"][0, 0]).cpu().numpy() print(f"Nodule probability : {det_p:.3f}") print(f"Malignancy probability: {mal_p:.3f}") print(f"Segmented voxels (>0.5): {(seg > 0.5).sum()}") # Concept-bottleneck explanation: # malignancy logit = sum_i w_net[i] * concept[i], w_net = W[malign] - W[benign] W = model.malignancy_head.weight.detach().cpu().numpy() w_net = W[1] - W[0] contrib = w_net * concepts print("\nPer-concept contribution to the malignancy decision:") for i in np.argsort(contrib)[::-1]: print(f" {CONCEPT_NAMES[i]:18s} value={concepts[i]:+.2f} contribution={contrib[i]:+.3f}") if __name__ == "__main__": main()