""" Janus — standalone inference for the lavsendahal/janus HuggingFace repo. GAP variant: fully self-contained. pip install torch transformers safetensors nibabel scipy huggingface_hub masked-attn / gated-fusion / scalar-fusion variants require organ segmentation masks (and radiomics features for fusion variants) produced by the full Janus preprocessing pipeline. See: https://github.com/lavsendahal/janus Usage: python inference.py ct.nii.gz python inference.py ct.nii.gz --variant gap --device cuda --top 10 """ import argparse import json import sys from pathlib import Path import nibabel as nib import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import hf_hub_download from safetensors.torch import load_file from scipy.ndimage import zoom as scipy_zoom from transformers import AutoModel # ── constants ───────────────────────────────────────────────────────────────── HF_REPO = "lavsendahal/janus" DINOV3_IDS = { "S": "facebook/dinov3-vits16-pretrain-lvd1689m", "B": "facebook/dinov3-vitb16-pretrain-lvd1689m", "L": "facebook/dinov3-vitl16-pretrain-lvd1689m", } # Must match slurm_prepack.sub exactly HU_MIN, HU_MAX = -1000.0, 1000.0 TARGET_SPACING = np.array([1.5, 1.5, 3.0]) # mm, X/Y/Z in RAS space TARGET_SHAPE_XYZ = np.array([224, 224, 160]) # voxels, X/Y/Z before axis permute VARIANTS_NEEDING_MASKS = {"masked-attn", "gated-fusion", "scalar-fusion"} VARIANTS_NEEDING_SCALARS = {"gated-fusion", "scalar-fusion"} IMN_MEAN = [0.485, 0.456, 0.406] IMN_STD = [0.229, 0.224, 0.225] # ── preprocessing ───────────────────────────────────────────────────────────── def load_ct(path: str) -> torch.Tensor: """ Load any NIfTI CT, resample to training resolution, return [1, 1, D, H, W]. Pipeline (matches packer.py + dataset.py exactly): 1. Reorient to RAS canonical 2. Resample to 1.5 × 1.5 × 3.0 mm (bilinear, order=1) 3. Resize to 224 × 224 × 160 voxels (bilinear, order=1) 4. Permute [X, Y, Z] → [Z, Y, X] (= [D, H, W], matches dataset permute) 5. Clip HU to [-1000, 1000] and normalise to [0, 1] """ nii = nib.load(path) nii = nib.as_closest_canonical(nii) # reorient to RAS vol = nii.get_fdata(dtype=np.float32) # [X, Y, Z] affine = nii.affine spacing = np.abs(np.diag(affine)[:3]) # mm per voxel [sx, sy, sz] # step 1 → target spacing z1 = spacing / TARGET_SPACING if not np.allclose(z1, 1.0, atol=0.01): vol = scipy_zoom(vol, z1, order=1, mode="nearest") # step 2 → target shape z2 = TARGET_SHAPE_XYZ / np.array(vol.shape) if not np.allclose(z2, 1.0, atol=0.01): vol = scipy_zoom(vol, z2, order=1, mode="nearest") # step 3 → permute [X, Y, Z] → [Z, Y, X] (= [D, H, W]) vol = np.transpose(vol, (2, 1, 0)) # step 4 → HU clip + normalise vol = np.clip(vol, HU_MIN, HU_MAX) vol = (vol - HU_MIN) / (HU_MAX - HU_MIN) return torch.from_numpy(vol).unsqueeze(0).unsqueeze(0) # [1, 1, D, H, W] def _make_trislices(vol: torch.Tensor, stride: int) -> torch.Tensor: """[B, 1, D, H, W] → [B, T, 3, H, W]""" B, _, D, H, W = vol.shape centers = list(range(1, max(2, D - 1), max(1, stride))) if not centers: centers = [D // 2] T = len(centers) out = torch.empty(B, T, 3, H, W, device=vol.device, dtype=vol.dtype) for t, c in enumerate(centers): out[:, t, 0] = vol[:, 0, max(0, c - 1)] out[:, t, 1] = vol[:, 0, c] out[:, t, 2] = vol[:, 0, min(D - 1, c + 1)] return out # ── self-contained GAP model ────────────────────────────────────────────────── class _JanusGAP(nn.Module): """Minimal reproduction of JanusGAP for standalone inference.""" def __init__(self, n_labels: int, backbone_id: str, image_size: int = 224, tri_stride: int = 1): super().__init__() self.image_size = image_size self.tri_stride = tri_stride self.backbone = AutoModel.from_pretrained(backbone_id, trust_remote_code=True) hidden_dim = self.backbone.config.hidden_size self.num_reg = getattr(self.backbone.config, "num_register_tokens", 0) self.head = nn.Linear(hidden_dim, n_labels) mean = torch.tensor(IMN_MEAN).view(1, 1, 3, 1, 1) std = torch.tensor(IMN_STD).view(1, 1, 3, 1, 1) self.register_buffer("_mean", mean) self.register_buffer("_std", std) def forward(self, vol: torch.Tensor) -> torch.Tensor: B, _, D, H, W = vol.shape frames = _make_trislices(vol, self.tri_stride) # [B, T, 3, H, W] T = frames.size(1) frames = F.interpolate( frames.view(B * T, 3, H, W), size=(self.image_size, self.image_size), mode="bilinear", align_corners=False, ).view(B, T, 3, self.image_size, self.image_size) frames = (frames - self._mean) / self._std flat = frames.view(B * T, 3, self.image_size, self.image_size) out = self.backbone(pixel_values=flat) tokens = out.last_hidden_state[:, 1:, :] # drop CLS if self.num_reg > 0: tokens = tokens[:, :-self.num_reg, :] # drop register tokens pooled = tokens.mean(dim=1).view(B, T, -1).mean(dim=1) # [B, D] return self.head(pooled) # ── inference ───────────────────────────────────────────────────────────────── def _load_state_dict(weights_path: str) -> dict: """ Load safetensors and remap backbone transformer-layer keys if needed. The DINOv3 custom HF model changed its internal structure between versions: checkpoint (training): backbone.layer.X.* current AutoModel: backbone.model.layer.X.* embeddings and norm keys (backbone.embeddings.X, backbone.norm.X) are the same in both versions and must NOT be remapped. """ sd = load_file(weights_path) if not any(k.startswith("backbone.layer.") for k in sd): return sd # already in current format, no remapping needed remapped = {} for k, v in sd.items(): if k.startswith("backbone.layer."): remapped["backbone.model." + k[len("backbone."):]] = v else: remapped[k] = v return remapped def predict(ct_path: str, variant: str = "gap", device: str = "cpu") -> dict: """ Run Janus inference on a preprocessed CT NIfTI file. Args: ct_path: Path to a preprocessed NIfTI CT (.nii or .nii.gz). variant: One of 'gap' | 'masked-attn' | 'gated-fusion' | 'scalar-fusion'. Only 'gap' is supported in this standalone script. device: 'cpu' or 'cuda'. Returns: Dict mapping disease name → predicted probability [0, 1]. """ if variant in VARIANTS_NEEDING_MASKS: _unsupported(variant) print(f"Downloading {variant} weights from {HF_REPO} ...") weights_path = hf_hub_download(HF_REPO, filename=f"{variant}/model.safetensors") config_path = hf_hub_download(HF_REPO, filename=f"{variant}/config.json") cfg = json.loads(Path(config_path).read_text()) variant_key = cfg.get("backbone_variant", "B") backbone_id = DINOV3_IDS.get(variant_key, DINOV3_IDS["B"]) print(f"Loading Janus-GAP (n_labels={cfg['n_labels']}, backbone=DINOv3-{variant_key}) ...") model = _JanusGAP( n_labels = cfg["n_labels"], backbone_id = backbone_id, image_size = cfg["image_size"], tri_stride = cfg["tri_stride"], ) model.load_state_dict(_load_state_dict(weights_path)) model.eval().to(device) vol = load_ct(ct_path).to(device) print(f"CT volume: {tuple(vol.shape)} device={device}") with torch.no_grad(): logits = model(vol) probs = torch.sigmoid(logits)[0].cpu().tolist() return dict(zip(cfg["labels"], probs)) def _unsupported(variant: str) -> None: needs_scalars = variant in VARIANTS_NEEDING_SCALARS print(f"\nVariant '{variant}' requires:") print(" • 20-channel organ segmentation masks aligned to the CT volume") if needs_scalars: print(" • Macro-radiomics scalar features (organ volumes, HU statistics,") print(" diameter measurements) extracted from those masks") print() print("These are produced by the full Janus preprocessing pipeline.") print("Source code and instructions: https://github.com/lavsendahal/janus") print(" (Repository is currently private — request access if needed.)") print() print("Once set up, run inference via the janus package directly:") print(f" python -m janus.inference --variant {variant} --ct ") sys.exit(1) # ── CLI ─────────────────────────────────────────────────────────────────────── def _default_device() -> str: if torch.cuda.is_available(): try: torch.cuda.init() return "cuda" except RuntimeError: print("Warning: CUDA detected but initialisation failed (driver too old?). Falling back to CPU.") return "cpu" def _collect_cases(input_dir: str, csv_path: str | None) -> list[tuple[str, Path]]: """ Return (case_id, nifti_path) pairs to process. If csv_path is given: read case IDs from it (one per line, or first column if .csv), then look up matching files in input_dir. If csv_path is not given: use every .nii / .nii.gz file in input_dir. """ root = Path(input_dir) if not root.is_dir(): raise ValueError(f"--input_dir '{input_dir}' is not a directory") # index all NIfTI files in input_dir: stem → path index: dict[str, Path] = {} for p in sorted(root.iterdir()): if p.name.endswith(".nii.gz"): index[p.name[: -len(".nii.gz")]] = p elif p.name.endswith(".nii"): index[p.name[: -len(".nii")]] = p if csv_path is None: return list(index.items()) # read case IDs from csv/txt csv_file = Path(csv_path) if not csv_file.exists(): raise ValueError(f"--csv '{csv_path}' not found") if csv_file.suffix.lower() == ".csv": import csv with open(csv_file) as f: reader = csv.reader(f) header = next(reader, None) ids = [row[0].strip() for row in reader if row] else: ids = [l.strip() for l in csv_file.read_text().splitlines() if l.strip()] cases, missing = [], [] for cid in ids: if cid in index: cases.append((cid, index[cid])) else: missing.append(cid) if missing: print(f"Warning: {len(missing)} IDs from CSV not found in input_dir: {missing[:5]}{'...' if len(missing) > 5 else ''}") return cases def main(): parser = argparse.ArgumentParser( description="Janus: 30-label abdominal CT disease classifier.", formatter_class=argparse.RawTextHelpFormatter, ) # ── single-file mode ────────────────────────────────────────────────────── parser.add_argument( "ct", nargs="?", default=None, help="Path to a single CT NIfTI file (.nii or .nii.gz).", ) # ── batch mode ──────────────────────────────────────────────────────────── parser.add_argument( "--input_dir", default=None, help="Directory of NIfTI files to process in batch.", ) parser.add_argument( "--csv", default=None, help="Text/CSV file listing case IDs to process (one per line / first column).\n" "Only used together with --input_dir. If omitted, all files in\n" "--input_dir are processed.", ) parser.add_argument( "--output", default="predictions.csv", help="Output CSV path for batch mode (default: predictions.csv).", ) # ── shared ──────────────────────────────────────────────────────────────── parser.add_argument( "--variant", default="gap", choices=["gap", "masked-attn", "gated-fusion", "scalar-fusion"], help=( "Model variant (default: gap).\n" " gap — CT image only (self-contained)\n" " masked-attn — requires organ segmentation masks\n" " gated-fusion — requires masks + radiomics features\n" " scalar-fusion— requires masks + radiomics features" ), ) parser.add_argument( "--device", default=_default_device(), help="Inference device (default: cuda if available, else cpu)", ) parser.add_argument( "--top", type=int, default=0, help="(single-file) Show only top-N predictions (0 = all).", ) parser.add_argument( "--threshold", type=float, default=None, help="(single-file) Flag diseases above this probability.", ) args = parser.parse_args() # ── validate mode ───────────────────────────────────────────────────────── if args.input_dir is None and args.ct is None: parser.error("Provide either a CT path (positional) or --input_dir for batch mode.") if args.input_dir and args.ct: parser.error("Provide either a CT path or --input_dir, not both.") # ── batch mode ──────────────────────────────────────────────────────────── if args.input_dir: import pandas as pd cases = _collect_cases(args.input_dir, args.csv) if not cases: print("No cases to process.") return print(f"Found {len(cases)} cases to process → {args.output}") # load model once if args.variant in VARIANTS_NEEDING_MASKS: _unsupported(args.variant) weights_path = hf_hub_download(HF_REPO, filename=f"{args.variant}/model.safetensors") config_path = hf_hub_download(HF_REPO, filename=f"{args.variant}/config.json") cfg = json.loads(Path(config_path).read_text()) backbone_id = DINOV3_IDS.get(cfg.get("backbone_variant", "B"), DINOV3_IDS["B"]) model = _JanusGAP( n_labels = cfg["n_labels"], backbone_id = backbone_id, image_size = cfg["image_size"], tri_stride = cfg["tri_stride"], ) model.load_state_dict(_load_state_dict(weights_path)) model.eval().to(args.device) rows = [] failed = [] for i, (case_id, nifti_path) in enumerate(cases): try: vol = load_ct(str(nifti_path)).to(args.device) with torch.no_grad(): logits = model(vol) probs = torch.sigmoid(logits)[0].cpu().tolist() rows.append({"case_id": case_id, **dict(zip(cfg["labels"], probs))}) if (i + 1) % 50 == 0 or (i + 1) == len(cases): print(f" {i+1}/{len(cases)}") except Exception as e: print(f" FAILED {case_id}: {e}") failed.append(case_id) df = pd.DataFrame(rows) df.to_csv(args.output, index=False) print(f"\nSaved {len(rows)} predictions → {args.output}") if failed: print(f"Failed cases ({len(failed)}): {failed}") return # ── single-file mode ────────────────────────────────────────────────────── preds = predict(args.ct, variant=args.variant, device=args.device) ranked = sorted(preds.items(), key=lambda kv: -kv[1]) if args.top: ranked = ranked[: args.top] print(f"\n{'Disease':<40} {'Prob':>6}") print("─" * 50) for disease, prob in ranked: flag = " ◄" if args.threshold and prob >= args.threshold else "" bar = "█" * int(prob * 20) print(f" {disease:<38} {prob:.3f} {bar}{flag}") if args.threshold: flagged = [d for d, p in preds.items() if p >= args.threshold] print(f"\nFindings above {args.threshold:.2f}: {flagged or 'none'}") if __name__ == "__main__": main()