janus / inference.py
lavsendahal's picture
Upload inference.py with huggingface_hub
3620972 verified
"""
Janus β€” standalone inference for the lavsendahal/janus HuggingFace repo.
GAP variant: fully self-contained.
pip install torch transformers safetensors nibabel scipy huggingface_hub
masked-attn / gated-fusion / scalar-fusion variants require organ segmentation
masks (and radiomics features for fusion variants) produced by the full Janus
preprocessing pipeline. See: https://github.com/lavsendahal/janus
Usage:
python inference.py ct.nii.gz
python inference.py ct.nii.gz --variant gap --device cuda --top 10
"""
import argparse
import json
import sys
from pathlib import Path
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from scipy.ndimage import zoom as scipy_zoom
from transformers import AutoModel
# ── constants ─────────────────────────────────────────────────────────────────
HF_REPO = "lavsendahal/janus"
DINOV3_IDS = {
"S": "facebook/dinov3-vits16-pretrain-lvd1689m",
"B": "facebook/dinov3-vitb16-pretrain-lvd1689m",
"L": "facebook/dinov3-vitl16-pretrain-lvd1689m",
}
# Must match slurm_prepack.sub exactly
HU_MIN, HU_MAX = -1000.0, 1000.0
TARGET_SPACING = np.array([1.5, 1.5, 3.0]) # mm, X/Y/Z in RAS space
TARGET_SHAPE_XYZ = np.array([224, 224, 160]) # voxels, X/Y/Z before axis permute
VARIANTS_NEEDING_MASKS = {"masked-attn", "gated-fusion", "scalar-fusion"}
VARIANTS_NEEDING_SCALARS = {"gated-fusion", "scalar-fusion"}
IMN_MEAN = [0.485, 0.456, 0.406]
IMN_STD = [0.229, 0.224, 0.225]
# ── preprocessing ─────────────────────────────────────────────────────────────
def load_ct(path: str) -> torch.Tensor:
"""
Load any NIfTI CT, resample to training resolution, return [1, 1, D, H, W].
Pipeline (matches packer.py + dataset.py exactly):
1. Reorient to RAS canonical
2. Resample to 1.5 Γ— 1.5 Γ— 3.0 mm (bilinear, order=1)
3. Resize to 224 Γ— 224 Γ— 160 voxels (bilinear, order=1)
4. Permute [X, Y, Z] β†’ [Z, Y, X] (= [D, H, W], matches dataset permute)
5. Clip HU to [-1000, 1000] and normalise to [0, 1]
"""
nii = nib.load(path)
nii = nib.as_closest_canonical(nii) # reorient to RAS
vol = nii.get_fdata(dtype=np.float32) # [X, Y, Z]
affine = nii.affine
spacing = np.abs(np.diag(affine)[:3]) # mm per voxel [sx, sy, sz]
# step 1 β†’ target spacing
z1 = spacing / TARGET_SPACING
if not np.allclose(z1, 1.0, atol=0.01):
vol = scipy_zoom(vol, z1, order=1, mode="nearest")
# step 2 β†’ target shape
z2 = TARGET_SHAPE_XYZ / np.array(vol.shape)
if not np.allclose(z2, 1.0, atol=0.01):
vol = scipy_zoom(vol, z2, order=1, mode="nearest")
# step 3 β†’ permute [X, Y, Z] β†’ [Z, Y, X] (= [D, H, W])
vol = np.transpose(vol, (2, 1, 0))
# step 4 β†’ HU clip + normalise
vol = np.clip(vol, HU_MIN, HU_MAX)
vol = (vol - HU_MIN) / (HU_MAX - HU_MIN)
return torch.from_numpy(vol).unsqueeze(0).unsqueeze(0) # [1, 1, D, H, W]
def _make_trislices(vol: torch.Tensor, stride: int) -> torch.Tensor:
"""[B, 1, D, H, W] β†’ [B, T, 3, H, W]"""
B, _, D, H, W = vol.shape
centers = list(range(1, max(2, D - 1), max(1, stride)))
if not centers:
centers = [D // 2]
T = len(centers)
out = torch.empty(B, T, 3, H, W, device=vol.device, dtype=vol.dtype)
for t, c in enumerate(centers):
out[:, t, 0] = vol[:, 0, max(0, c - 1)]
out[:, t, 1] = vol[:, 0, c]
out[:, t, 2] = vol[:, 0, min(D - 1, c + 1)]
return out
# ── self-contained GAP model ──────────────────────────────────────────────────
class _JanusGAP(nn.Module):
"""Minimal reproduction of JanusGAP for standalone inference."""
def __init__(self, n_labels: int, backbone_id: str, image_size: int = 224, tri_stride: int = 1):
super().__init__()
self.image_size = image_size
self.tri_stride = tri_stride
self.backbone = AutoModel.from_pretrained(backbone_id, trust_remote_code=True)
hidden_dim = self.backbone.config.hidden_size
self.num_reg = getattr(self.backbone.config, "num_register_tokens", 0)
self.head = nn.Linear(hidden_dim, n_labels)
mean = torch.tensor(IMN_MEAN).view(1, 1, 3, 1, 1)
std = torch.tensor(IMN_STD).view(1, 1, 3, 1, 1)
self.register_buffer("_mean", mean)
self.register_buffer("_std", std)
def forward(self, vol: torch.Tensor) -> torch.Tensor:
B, _, D, H, W = vol.shape
frames = _make_trislices(vol, self.tri_stride) # [B, T, 3, H, W]
T = frames.size(1)
frames = F.interpolate(
frames.view(B * T, 3, H, W),
size=(self.image_size, self.image_size),
mode="bilinear", align_corners=False,
).view(B, T, 3, self.image_size, self.image_size)
frames = (frames - self._mean) / self._std
flat = frames.view(B * T, 3, self.image_size, self.image_size)
out = self.backbone(pixel_values=flat)
tokens = out.last_hidden_state[:, 1:, :] # drop CLS
if self.num_reg > 0:
tokens = tokens[:, :-self.num_reg, :] # drop register tokens
pooled = tokens.mean(dim=1).view(B, T, -1).mean(dim=1) # [B, D]
return self.head(pooled)
# ── inference ─────────────────────────────────────────────────────────────────
def _load_state_dict(weights_path: str) -> dict:
"""
Load safetensors and remap backbone transformer-layer keys if needed.
The DINOv3 custom HF model changed its internal structure between versions:
checkpoint (training): backbone.layer.X.*
current AutoModel: backbone.model.layer.X.*
embeddings and norm keys (backbone.embeddings.X, backbone.norm.X) are the
same in both versions and must NOT be remapped.
"""
sd = load_file(weights_path)
if not any(k.startswith("backbone.layer.") for k in sd):
return sd # already in current format, no remapping needed
remapped = {}
for k, v in sd.items():
if k.startswith("backbone.layer."):
remapped["backbone.model." + k[len("backbone."):]] = v
else:
remapped[k] = v
return remapped
def predict(ct_path: str, variant: str = "gap", device: str = "cpu") -> dict:
"""
Run Janus inference on a preprocessed CT NIfTI file.
Args:
ct_path: Path to a preprocessed NIfTI CT (.nii or .nii.gz).
variant: One of 'gap' | 'masked-attn' | 'gated-fusion' | 'scalar-fusion'.
Only 'gap' is supported in this standalone script.
device: 'cpu' or 'cuda'.
Returns:
Dict mapping disease name β†’ predicted probability [0, 1].
"""
if variant in VARIANTS_NEEDING_MASKS:
_unsupported(variant)
print(f"Downloading {variant} weights from {HF_REPO} ...")
weights_path = hf_hub_download(HF_REPO, filename=f"{variant}/model.safetensors")
config_path = hf_hub_download(HF_REPO, filename=f"{variant}/config.json")
cfg = json.loads(Path(config_path).read_text())
variant_key = cfg.get("backbone_variant", "B")
backbone_id = DINOV3_IDS.get(variant_key, DINOV3_IDS["B"])
print(f"Loading Janus-GAP (n_labels={cfg['n_labels']}, backbone=DINOv3-{variant_key}) ...")
model = _JanusGAP(
n_labels = cfg["n_labels"],
backbone_id = backbone_id,
image_size = cfg["image_size"],
tri_stride = cfg["tri_stride"],
)
model.load_state_dict(_load_state_dict(weights_path))
model.eval().to(device)
vol = load_ct(ct_path).to(device)
print(f"CT volume: {tuple(vol.shape)} device={device}")
with torch.no_grad():
logits = model(vol)
probs = torch.sigmoid(logits)[0].cpu().tolist()
return dict(zip(cfg["labels"], probs))
def _unsupported(variant: str) -> None:
needs_scalars = variant in VARIANTS_NEEDING_SCALARS
print(f"\nVariant '{variant}' requires:")
print(" β€’ 20-channel organ segmentation masks aligned to the CT volume")
if needs_scalars:
print(" β€’ Macro-radiomics scalar features (organ volumes, HU statistics,")
print(" diameter measurements) extracted from those masks")
print()
print("These are produced by the full Janus preprocessing pipeline.")
print("Source code and instructions: https://github.com/lavsendahal/janus")
print(" (Repository is currently private β€” request access if needed.)")
print()
print("Once set up, run inference via the janus package directly:")
print(f" python -m janus.inference --variant {variant} --ct <path>")
sys.exit(1)
# ── CLI ───────────────────────────────────────────────────────────────────────
def _default_device() -> str:
if torch.cuda.is_available():
try:
torch.cuda.init()
return "cuda"
except RuntimeError:
print("Warning: CUDA detected but initialisation failed (driver too old?). Falling back to CPU.")
return "cpu"
def _collect_cases(input_dir: str, csv_path: str | None) -> list[tuple[str, Path]]:
"""
Return (case_id, nifti_path) pairs to process.
If csv_path is given: read case IDs from it (one per line, or first column
if .csv), then look up matching files in input_dir.
If csv_path is not given: use every .nii / .nii.gz file in input_dir.
"""
root = Path(input_dir)
if not root.is_dir():
raise ValueError(f"--input_dir '{input_dir}' is not a directory")
# index all NIfTI files in input_dir: stem β†’ path
index: dict[str, Path] = {}
for p in sorted(root.iterdir()):
if p.name.endswith(".nii.gz"):
index[p.name[: -len(".nii.gz")]] = p
elif p.name.endswith(".nii"):
index[p.name[: -len(".nii")]] = p
if csv_path is None:
return list(index.items())
# read case IDs from csv/txt
csv_file = Path(csv_path)
if not csv_file.exists():
raise ValueError(f"--csv '{csv_path}' not found")
if csv_file.suffix.lower() == ".csv":
import csv
with open(csv_file) as f:
reader = csv.reader(f)
header = next(reader, None)
ids = [row[0].strip() for row in reader if row]
else:
ids = [l.strip() for l in csv_file.read_text().splitlines() if l.strip()]
cases, missing = [], []
for cid in ids:
if cid in index:
cases.append((cid, index[cid]))
else:
missing.append(cid)
if missing:
print(f"Warning: {len(missing)} IDs from CSV not found in input_dir: {missing[:5]}{'...' if len(missing) > 5 else ''}")
return cases
def main():
parser = argparse.ArgumentParser(
description="Janus: 30-label abdominal CT disease classifier.",
formatter_class=argparse.RawTextHelpFormatter,
)
# ── single-file mode ──────────────────────────────────────────────────────
parser.add_argument(
"ct", nargs="?", default=None,
help="Path to a single CT NIfTI file (.nii or .nii.gz).",
)
# ── batch mode ────────────────────────────────────────────────────────────
parser.add_argument(
"--input_dir", default=None,
help="Directory of NIfTI files to process in batch.",
)
parser.add_argument(
"--csv", default=None,
help="Text/CSV file listing case IDs to process (one per line / first column).\n"
"Only used together with --input_dir. If omitted, all files in\n"
"--input_dir are processed.",
)
parser.add_argument(
"--output", default="predictions.csv",
help="Output CSV path for batch mode (default: predictions.csv).",
)
# ── shared ────────────────────────────────────────────────────────────────
parser.add_argument(
"--variant", default="gap",
choices=["gap", "masked-attn", "gated-fusion", "scalar-fusion"],
help=(
"Model variant (default: gap).\n"
" gap β€” CT image only (self-contained)\n"
" masked-attn β€” requires organ segmentation masks\n"
" gated-fusion β€” requires masks + radiomics features\n"
" scalar-fusionβ€” requires masks + radiomics features"
),
)
parser.add_argument(
"--device", default=_default_device(),
help="Inference device (default: cuda if available, else cpu)",
)
parser.add_argument(
"--top", type=int, default=0,
help="(single-file) Show only top-N predictions (0 = all).",
)
parser.add_argument(
"--threshold", type=float, default=None,
help="(single-file) Flag diseases above this probability.",
)
args = parser.parse_args()
# ── validate mode ─────────────────────────────────────────────────────────
if args.input_dir is None and args.ct is None:
parser.error("Provide either a CT path (positional) or --input_dir for batch mode.")
if args.input_dir and args.ct:
parser.error("Provide either a CT path or --input_dir, not both.")
# ── batch mode ────────────────────────────────────────────────────────────
if args.input_dir:
import pandas as pd
cases = _collect_cases(args.input_dir, args.csv)
if not cases:
print("No cases to process.")
return
print(f"Found {len(cases)} cases to process β†’ {args.output}")
# load model once
if args.variant in VARIANTS_NEEDING_MASKS:
_unsupported(args.variant)
weights_path = hf_hub_download(HF_REPO, filename=f"{args.variant}/model.safetensors")
config_path = hf_hub_download(HF_REPO, filename=f"{args.variant}/config.json")
cfg = json.loads(Path(config_path).read_text())
backbone_id = DINOV3_IDS.get(cfg.get("backbone_variant", "B"), DINOV3_IDS["B"])
model = _JanusGAP(
n_labels = cfg["n_labels"],
backbone_id = backbone_id,
image_size = cfg["image_size"],
tri_stride = cfg["tri_stride"],
)
model.load_state_dict(_load_state_dict(weights_path))
model.eval().to(args.device)
rows = []
failed = []
for i, (case_id, nifti_path) in enumerate(cases):
try:
vol = load_ct(str(nifti_path)).to(args.device)
with torch.no_grad():
logits = model(vol)
probs = torch.sigmoid(logits)[0].cpu().tolist()
rows.append({"case_id": case_id, **dict(zip(cfg["labels"], probs))})
if (i + 1) % 50 == 0 or (i + 1) == len(cases):
print(f" {i+1}/{len(cases)}")
except Exception as e:
print(f" FAILED {case_id}: {e}")
failed.append(case_id)
df = pd.DataFrame(rows)
df.to_csv(args.output, index=False)
print(f"\nSaved {len(rows)} predictions β†’ {args.output}")
if failed:
print(f"Failed cases ({len(failed)}): {failed}")
return
# ── single-file mode ──────────────────────────────────────────────────────
preds = predict(args.ct, variant=args.variant, device=args.device)
ranked = sorted(preds.items(), key=lambda kv: -kv[1])
if args.top:
ranked = ranked[: args.top]
print(f"\n{'Disease':<40} {'Prob':>6}")
print("─" * 50)
for disease, prob in ranked:
flag = " β—„" if args.threshold and prob >= args.threshold else ""
bar = "β–ˆ" * int(prob * 20)
print(f" {disease:<38} {prob:.3f} {bar}{flag}")
if args.threshold:
flagged = [d for d, p in preds.items() if p >= args.threshold]
print(f"\nFindings above {args.threshold:.2f}: {flagged or 'none'}")
if __name__ == "__main__":
main()