| """ |
| Janus β standalone inference for the lavsendahal/janus HuggingFace repo. |
| |
| GAP variant: fully self-contained. |
| pip install torch transformers safetensors nibabel scipy huggingface_hub |
| |
| masked-attn / gated-fusion / scalar-fusion variants require organ segmentation |
| masks (and radiomics features for fusion variants) produced by the full Janus |
| preprocessing pipeline. See: https://github.com/lavsendahal/janus |
| |
| Usage: |
| python inference.py ct.nii.gz |
| python inference.py ct.nii.gz --variant gap --device cuda --top 10 |
| """ |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| import nibabel as nib |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file |
| from scipy.ndimage import zoom as scipy_zoom |
| from transformers import AutoModel |
|
|
| |
| HF_REPO = "lavsendahal/janus" |
|
|
| DINOV3_IDS = { |
| "S": "facebook/dinov3-vits16-pretrain-lvd1689m", |
| "B": "facebook/dinov3-vitb16-pretrain-lvd1689m", |
| "L": "facebook/dinov3-vitl16-pretrain-lvd1689m", |
| } |
|
|
| |
| HU_MIN, HU_MAX = -1000.0, 1000.0 |
| TARGET_SPACING = np.array([1.5, 1.5, 3.0]) |
| TARGET_SHAPE_XYZ = np.array([224, 224, 160]) |
|
|
| VARIANTS_NEEDING_MASKS = {"masked-attn", "gated-fusion", "scalar-fusion"} |
| VARIANTS_NEEDING_SCALARS = {"gated-fusion", "scalar-fusion"} |
|
|
| IMN_MEAN = [0.485, 0.456, 0.406] |
| IMN_STD = [0.229, 0.224, 0.225] |
|
|
|
|
| |
|
|
| def load_ct(path: str) -> torch.Tensor: |
| """ |
| Load any NIfTI CT, resample to training resolution, return [1, 1, D, H, W]. |
| |
| Pipeline (matches packer.py + dataset.py exactly): |
| 1. Reorient to RAS canonical |
| 2. Resample to 1.5 Γ 1.5 Γ 3.0 mm (bilinear, order=1) |
| 3. Resize to 224 Γ 224 Γ 160 voxels (bilinear, order=1) |
| 4. Permute [X, Y, Z] β [Z, Y, X] (= [D, H, W], matches dataset permute) |
| 5. Clip HU to [-1000, 1000] and normalise to [0, 1] |
| """ |
| nii = nib.load(path) |
| nii = nib.as_closest_canonical(nii) |
| vol = nii.get_fdata(dtype=np.float32) |
| affine = nii.affine |
| spacing = np.abs(np.diag(affine)[:3]) |
|
|
| |
| z1 = spacing / TARGET_SPACING |
| if not np.allclose(z1, 1.0, atol=0.01): |
| vol = scipy_zoom(vol, z1, order=1, mode="nearest") |
|
|
| |
| z2 = TARGET_SHAPE_XYZ / np.array(vol.shape) |
| if not np.allclose(z2, 1.0, atol=0.01): |
| vol = scipy_zoom(vol, z2, order=1, mode="nearest") |
|
|
| |
| vol = np.transpose(vol, (2, 1, 0)) |
|
|
| |
| vol = np.clip(vol, HU_MIN, HU_MAX) |
| vol = (vol - HU_MIN) / (HU_MAX - HU_MIN) |
|
|
| return torch.from_numpy(vol).unsqueeze(0).unsqueeze(0) |
|
|
|
|
| def _make_trislices(vol: torch.Tensor, stride: int) -> torch.Tensor: |
| """[B, 1, D, H, W] β [B, T, 3, H, W]""" |
| B, _, D, H, W = vol.shape |
| centers = list(range(1, max(2, D - 1), max(1, stride))) |
| if not centers: |
| centers = [D // 2] |
| T = len(centers) |
| out = torch.empty(B, T, 3, H, W, device=vol.device, dtype=vol.dtype) |
| for t, c in enumerate(centers): |
| out[:, t, 0] = vol[:, 0, max(0, c - 1)] |
| out[:, t, 1] = vol[:, 0, c] |
| out[:, t, 2] = vol[:, 0, min(D - 1, c + 1)] |
| return out |
|
|
|
|
| |
|
|
| class _JanusGAP(nn.Module): |
| """Minimal reproduction of JanusGAP for standalone inference.""" |
|
|
| def __init__(self, n_labels: int, backbone_id: str, image_size: int = 224, tri_stride: int = 1): |
| super().__init__() |
| self.image_size = image_size |
| self.tri_stride = tri_stride |
| self.backbone = AutoModel.from_pretrained(backbone_id, trust_remote_code=True) |
| hidden_dim = self.backbone.config.hidden_size |
| self.num_reg = getattr(self.backbone.config, "num_register_tokens", 0) |
| self.head = nn.Linear(hidden_dim, n_labels) |
|
|
| mean = torch.tensor(IMN_MEAN).view(1, 1, 3, 1, 1) |
| std = torch.tensor(IMN_STD).view(1, 1, 3, 1, 1) |
| self.register_buffer("_mean", mean) |
| self.register_buffer("_std", std) |
|
|
| def forward(self, vol: torch.Tensor) -> torch.Tensor: |
| B, _, D, H, W = vol.shape |
| frames = _make_trislices(vol, self.tri_stride) |
| T = frames.size(1) |
|
|
| frames = F.interpolate( |
| frames.view(B * T, 3, H, W), |
| size=(self.image_size, self.image_size), |
| mode="bilinear", align_corners=False, |
| ).view(B, T, 3, self.image_size, self.image_size) |
| frames = (frames - self._mean) / self._std |
|
|
| flat = frames.view(B * T, 3, self.image_size, self.image_size) |
| out = self.backbone(pixel_values=flat) |
| tokens = out.last_hidden_state[:, 1:, :] |
| if self.num_reg > 0: |
| tokens = tokens[:, :-self.num_reg, :] |
| pooled = tokens.mean(dim=1).view(B, T, -1).mean(dim=1) |
| return self.head(pooled) |
|
|
|
|
| |
|
|
| def _load_state_dict(weights_path: str) -> dict: |
| """ |
| Load safetensors and remap backbone transformer-layer keys if needed. |
| |
| The DINOv3 custom HF model changed its internal structure between versions: |
| checkpoint (training): backbone.layer.X.* |
| current AutoModel: backbone.model.layer.X.* |
| |
| embeddings and norm keys (backbone.embeddings.X, backbone.norm.X) are the |
| same in both versions and must NOT be remapped. |
| """ |
| sd = load_file(weights_path) |
| if not any(k.startswith("backbone.layer.") for k in sd): |
| return sd |
| remapped = {} |
| for k, v in sd.items(): |
| if k.startswith("backbone.layer."): |
| remapped["backbone.model." + k[len("backbone."):]] = v |
| else: |
| remapped[k] = v |
| return remapped |
|
|
|
|
| def predict(ct_path: str, variant: str = "gap", device: str = "cpu") -> dict: |
| """ |
| Run Janus inference on a preprocessed CT NIfTI file. |
| |
| Args: |
| ct_path: Path to a preprocessed NIfTI CT (.nii or .nii.gz). |
| variant: One of 'gap' | 'masked-attn' | 'gated-fusion' | 'scalar-fusion'. |
| Only 'gap' is supported in this standalone script. |
| device: 'cpu' or 'cuda'. |
| |
| Returns: |
| Dict mapping disease name β predicted probability [0, 1]. |
| """ |
| if variant in VARIANTS_NEEDING_MASKS: |
| _unsupported(variant) |
|
|
| print(f"Downloading {variant} weights from {HF_REPO} ...") |
| weights_path = hf_hub_download(HF_REPO, filename=f"{variant}/model.safetensors") |
| config_path = hf_hub_download(HF_REPO, filename=f"{variant}/config.json") |
| cfg = json.loads(Path(config_path).read_text()) |
|
|
| variant_key = cfg.get("backbone_variant", "B") |
| backbone_id = DINOV3_IDS.get(variant_key, DINOV3_IDS["B"]) |
| print(f"Loading Janus-GAP (n_labels={cfg['n_labels']}, backbone=DINOv3-{variant_key}) ...") |
| model = _JanusGAP( |
| n_labels = cfg["n_labels"], |
| backbone_id = backbone_id, |
| image_size = cfg["image_size"], |
| tri_stride = cfg["tri_stride"], |
| ) |
| model.load_state_dict(_load_state_dict(weights_path)) |
| model.eval().to(device) |
|
|
| vol = load_ct(ct_path).to(device) |
| print(f"CT volume: {tuple(vol.shape)} device={device}") |
|
|
| with torch.no_grad(): |
| logits = model(vol) |
|
|
| probs = torch.sigmoid(logits)[0].cpu().tolist() |
| return dict(zip(cfg["labels"], probs)) |
|
|
|
|
| def _unsupported(variant: str) -> None: |
| needs_scalars = variant in VARIANTS_NEEDING_SCALARS |
| print(f"\nVariant '{variant}' requires:") |
| print(" β’ 20-channel organ segmentation masks aligned to the CT volume") |
| if needs_scalars: |
| print(" β’ Macro-radiomics scalar features (organ volumes, HU statistics,") |
| print(" diameter measurements) extracted from those masks") |
| print() |
| print("These are produced by the full Janus preprocessing pipeline.") |
| print("Source code and instructions: https://github.com/lavsendahal/janus") |
| print(" (Repository is currently private β request access if needed.)") |
| print() |
| print("Once set up, run inference via the janus package directly:") |
| print(f" python -m janus.inference --variant {variant} --ct <path>") |
| sys.exit(1) |
|
|
|
|
| |
|
|
| def _default_device() -> str: |
| if torch.cuda.is_available(): |
| try: |
| torch.cuda.init() |
| return "cuda" |
| except RuntimeError: |
| print("Warning: CUDA detected but initialisation failed (driver too old?). Falling back to CPU.") |
| return "cpu" |
|
|
|
|
| def _collect_cases(input_dir: str, csv_path: str | None) -> list[tuple[str, Path]]: |
| """ |
| Return (case_id, nifti_path) pairs to process. |
| |
| If csv_path is given: read case IDs from it (one per line, or first column |
| if .csv), then look up matching files in input_dir. |
| If csv_path is not given: use every .nii / .nii.gz file in input_dir. |
| """ |
| root = Path(input_dir) |
| if not root.is_dir(): |
| raise ValueError(f"--input_dir '{input_dir}' is not a directory") |
|
|
| |
| index: dict[str, Path] = {} |
| for p in sorted(root.iterdir()): |
| if p.name.endswith(".nii.gz"): |
| index[p.name[: -len(".nii.gz")]] = p |
| elif p.name.endswith(".nii"): |
| index[p.name[: -len(".nii")]] = p |
|
|
| if csv_path is None: |
| return list(index.items()) |
|
|
| |
| csv_file = Path(csv_path) |
| if not csv_file.exists(): |
| raise ValueError(f"--csv '{csv_path}' not found") |
|
|
| if csv_file.suffix.lower() == ".csv": |
| import csv |
| with open(csv_file) as f: |
| reader = csv.reader(f) |
| header = next(reader, None) |
| ids = [row[0].strip() for row in reader if row] |
| else: |
| ids = [l.strip() for l in csv_file.read_text().splitlines() if l.strip()] |
|
|
| cases, missing = [], [] |
| for cid in ids: |
| if cid in index: |
| cases.append((cid, index[cid])) |
| else: |
| missing.append(cid) |
|
|
| if missing: |
| print(f"Warning: {len(missing)} IDs from CSV not found in input_dir: {missing[:5]}{'...' if len(missing) > 5 else ''}") |
|
|
| return cases |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Janus: 30-label abdominal CT disease classifier.", |
| formatter_class=argparse.RawTextHelpFormatter, |
| ) |
| |
| parser.add_argument( |
| "ct", nargs="?", default=None, |
| help="Path to a single CT NIfTI file (.nii or .nii.gz).", |
| ) |
| |
| parser.add_argument( |
| "--input_dir", default=None, |
| help="Directory of NIfTI files to process in batch.", |
| ) |
| parser.add_argument( |
| "--csv", default=None, |
| help="Text/CSV file listing case IDs to process (one per line / first column).\n" |
| "Only used together with --input_dir. If omitted, all files in\n" |
| "--input_dir are processed.", |
| ) |
| parser.add_argument( |
| "--output", default="predictions.csv", |
| help="Output CSV path for batch mode (default: predictions.csv).", |
| ) |
| |
| parser.add_argument( |
| "--variant", default="gap", |
| choices=["gap", "masked-attn", "gated-fusion", "scalar-fusion"], |
| help=( |
| "Model variant (default: gap).\n" |
| " gap β CT image only (self-contained)\n" |
| " masked-attn β requires organ segmentation masks\n" |
| " gated-fusion β requires masks + radiomics features\n" |
| " scalar-fusionβ requires masks + radiomics features" |
| ), |
| ) |
| parser.add_argument( |
| "--device", default=_default_device(), |
| help="Inference device (default: cuda if available, else cpu)", |
| ) |
| parser.add_argument( |
| "--top", type=int, default=0, |
| help="(single-file) Show only top-N predictions (0 = all).", |
| ) |
| parser.add_argument( |
| "--threshold", type=float, default=None, |
| help="(single-file) Flag diseases above this probability.", |
| ) |
| args = parser.parse_args() |
|
|
| |
| if args.input_dir is None and args.ct is None: |
| parser.error("Provide either a CT path (positional) or --input_dir for batch mode.") |
| if args.input_dir and args.ct: |
| parser.error("Provide either a CT path or --input_dir, not both.") |
|
|
| |
| if args.input_dir: |
| import pandas as pd |
|
|
| cases = _collect_cases(args.input_dir, args.csv) |
| if not cases: |
| print("No cases to process.") |
| return |
|
|
| print(f"Found {len(cases)} cases to process β {args.output}") |
|
|
| |
| if args.variant in VARIANTS_NEEDING_MASKS: |
| _unsupported(args.variant) |
|
|
| weights_path = hf_hub_download(HF_REPO, filename=f"{args.variant}/model.safetensors") |
| config_path = hf_hub_download(HF_REPO, filename=f"{args.variant}/config.json") |
| cfg = json.loads(Path(config_path).read_text()) |
| backbone_id = DINOV3_IDS.get(cfg.get("backbone_variant", "B"), DINOV3_IDS["B"]) |
|
|
| model = _JanusGAP( |
| n_labels = cfg["n_labels"], |
| backbone_id = backbone_id, |
| image_size = cfg["image_size"], |
| tri_stride = cfg["tri_stride"], |
| ) |
| model.load_state_dict(_load_state_dict(weights_path)) |
| model.eval().to(args.device) |
|
|
| rows = [] |
| failed = [] |
| for i, (case_id, nifti_path) in enumerate(cases): |
| try: |
| vol = load_ct(str(nifti_path)).to(args.device) |
| with torch.no_grad(): |
| logits = model(vol) |
| probs = torch.sigmoid(logits)[0].cpu().tolist() |
| rows.append({"case_id": case_id, **dict(zip(cfg["labels"], probs))}) |
| if (i + 1) % 50 == 0 or (i + 1) == len(cases): |
| print(f" {i+1}/{len(cases)}") |
| except Exception as e: |
| print(f" FAILED {case_id}: {e}") |
| failed.append(case_id) |
|
|
| df = pd.DataFrame(rows) |
| df.to_csv(args.output, index=False) |
| print(f"\nSaved {len(rows)} predictions β {args.output}") |
| if failed: |
| print(f"Failed cases ({len(failed)}): {failed}") |
| return |
|
|
| |
| preds = predict(args.ct, variant=args.variant, device=args.device) |
| ranked = sorted(preds.items(), key=lambda kv: -kv[1]) |
| if args.top: |
| ranked = ranked[: args.top] |
|
|
| print(f"\n{'Disease':<40} {'Prob':>6}") |
| print("β" * 50) |
| for disease, prob in ranked: |
| flag = " β" if args.threshold and prob >= args.threshold else "" |
| bar = "β" * int(prob * 20) |
| print(f" {disease:<38} {prob:.3f} {bar}{flag}") |
|
|
| if args.threshold: |
| flagged = [d for d, p in preds.items() if p >= args.threshold] |
| print(f"\nFindings above {args.threshold:.2f}: {flagged or 'none'}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|