| |
| """Single-volume intracranial hemorrhage segmentation with SAMIHS. |
| |
| Input: a 3D brain CT NIfTI (.nii or .nii.gz) |
| Output: a binary uint8 NIfTI mask with the same shape/affine as the input. |
| |
| The wrapper intentionally keeps the original SAMIHS project unchanged and loads |
| its modules through --samihs-root. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
| from types import SimpleNamespace |
| from typing import Any |
|
|
| import nibabel as nib |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| PACKAGE_ROOT = Path(__file__).resolve().parent |
| LOCAL_SAMIHS_SRC = PACKAGE_ROOT / "samihs_src" |
| DEFAULT_SAMIHS_ROOT = LOCAL_SAMIHS_SRC if LOCAL_SAMIHS_SRC.exists() else Path("/data/wxh/Medical/to_cfff/metrics/brain_bleed/SAMIHS") |
| DEFAULT_CKPT = PACKAGE_ROOT / "weights" / "SAMIHS_09170527_2_0.483.pth" |
|
|
|
|
| def _import_samihs(samihs_root: Path) -> tuple[Any, Any, Any]: |
| samihs_root = samihs_root.resolve() |
| if not samihs_root.exists(): |
| raise FileNotFoundError(f"SAMIHS root not found: {samihs_root}") |
| sys.path.insert(0, str(samihs_root)) |
| from utils.config import get_config |
| from models.model_dict import get_model |
| from utils.generate_prompts import get_click_prompt_eval |
|
|
| return get_config, get_model, get_click_prompt_eval |
|
|
|
|
| def _select_slice_axis(shape: tuple[int, int, int], requested: str) -> int: |
| if requested != "auto": |
| axis = int(requested) |
| if axis not in (0, 1, 2): |
| raise ValueError("--slice-axis must be auto, 0, 1, or 2") |
| return axis |
| |
| |
| return int(np.argmin(shape)) |
|
|
|
|
| def _normalize_slice(slc: np.ndarray, clip_low: float, clip_high: float) -> np.ndarray: |
| slc = slc.astype(np.float32, copy=False) |
| lo_p, hi_p = np.percentile(slc, [clip_low, clip_high]) |
| slc = np.clip(slc, lo_p, hi_p) |
| lo = float(np.min(slc)) |
| hi = float(np.max(slc)) |
| den = max(hi - lo, 1e-6) |
| return np.nan_to_num((slc - lo) / den, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) |
|
|
|
|
| def load_model( |
| ckpt: str | Path = DEFAULT_CKPT, |
| samihs_root: str | Path = DEFAULT_SAMIHS_ROOT, |
| device: str = "cuda:0", |
| encoder_input_size: int = 1024, |
| low_image_size: int = 256, |
| vit_name: str = "vit_b", |
| task: str = "Unlabeled", |
| ): |
| """Load the SAMIHS model and return (model, prompt_fn, device_obj).""" |
| ckpt = Path(ckpt).resolve() |
| if not ckpt.exists(): |
| raise FileNotFoundError(f"checkpoint not found: {ckpt}") |
|
|
| if device.startswith("cuda") and not torch.cuda.is_available(): |
| raise RuntimeError("CUDA requested but torch.cuda.is_available() is false") |
| device_obj = torch.device(device) |
|
|
| get_config, get_model, get_click_prompt_eval = _import_samihs(Path(samihs_root)) |
| args = SimpleNamespace( |
| encoder_input_size=int(encoder_input_size), |
| low_image_size=int(low_image_size), |
| vit_name=vit_name, |
| sam_ckpt=str(ckpt), |
| ) |
| opt = get_config(task) |
| opt.mode = "test" |
| opt.modelname = "SAMIHS" |
| opt.load_path = str(ckpt) |
| opt.device = device |
|
|
| model = get_model("SAMIHS", args=args, opt=opt, ckpt_path=str(ckpt)).to(device_obj) |
| model.eval() |
| return model, get_click_prompt_eval, device_obj |
|
|
|
|
| @torch.no_grad() |
| def segment_array( |
| volume: np.ndarray, |
| model, |
| get_click_prompt_eval, |
| device: torch.device, |
| slice_axis: int, |
| encoder_input_size: int = 1024, |
| batch_size: int = 8, |
| threshold: float = 0.5, |
| amp: bool = True, |
| clip_low: float = 0.5, |
| clip_high: float = 99.5, |
| rotate_for_samihs: bool = True, |
| ) -> np.ndarray: |
| """Segment a 3D numpy volume and return a binary uint8 mask.""" |
| if volume.ndim != 3: |
| raise ValueError(f"Expected a 3D volume, got shape={volume.shape}") |
| if batch_size <= 0: |
| raise ValueError("batch_size must be positive") |
|
|
| vol_hwd = np.moveaxis(volume, slice_axis, -1) if slice_axis != 2 else volume |
| h, w, d = vol_hwd.shape |
| mask_hwd = np.zeros((h, w, d), dtype=np.uint8) |
|
|
| opt_lite = type("OptLite", (), {"device": device})() |
| model.eval() |
|
|
| for k0 in range(0, d, batch_size): |
| k1 = min(k0 + batch_size, d) |
| tensors = [] |
| for k in range(k0, k1): |
| slc = _normalize_slice(vol_hwd[:, :, k], clip_low=clip_low, clip_high=clip_high) |
| t = torch.from_numpy(slc).unsqueeze(0).unsqueeze(0) |
| t = F.interpolate(t, (encoder_input_size, encoder_input_size), mode="bilinear", align_corners=False) |
| if rotate_for_samihs: |
| t = torch.rot90(t, k=3, dims=(-2, -1)) |
| tensors.append(t) |
|
|
| batch = torch.cat(tensors, dim=0).to(device=device, dtype=torch.float32) |
| batch = torch.nan_to_num(batch, nan=0.0, posinf=0.0, neginf=0.0) |
| pt = get_click_prompt_eval({"image": batch}, opt_lite) |
|
|
| if amp and device.type == "cuda": |
| with torch.amp.autocast("cuda", enabled=True): |
| pred = model(batch, pt, bbox=None) |
| prob = torch.sigmoid(pred["masks"]) |
| else: |
| pred = model(batch, pt, bbox=None) |
| prob = torch.sigmoid(pred["masks"]) |
|
|
| if rotate_for_samihs: |
| prob = torch.rot90(prob, k=1, dims=(-2, -1)) |
| prob_hw = F.interpolate(prob, (h, w), mode="bilinear", align_corners=False) |
| prob_hw = torch.nan_to_num(prob_hw, nan=0.0, posinf=0.0, neginf=0.0) |
| batch_mask = (prob_hw >= threshold).to(torch.uint8).squeeze(1).cpu().numpy() |
|
|
| for bi, k in enumerate(range(k0, k1)): |
| mask_hwd[:, :, k] = batch_mask[bi] |
|
|
| return np.moveaxis(mask_hwd, -1, slice_axis) if slice_axis != 2 else mask_hwd |
|
|
|
|
| def segment_nii( |
| input_nii: str | Path, |
| output_nii: str | Path, |
| ckpt: str | Path = DEFAULT_CKPT, |
| samihs_root: str | Path = DEFAULT_SAMIHS_ROOT, |
| device: str = "cuda:0", |
| encoder_input_size: int = 1024, |
| batch_size: int = 8, |
| threshold: float = 0.5, |
| slice_axis: str = "auto", |
| amp: bool = True, |
| clip_low: float = 0.5, |
| clip_high: float = 99.5, |
| rotate_for_samihs: bool = True, |
| ) -> dict[str, Any]: |
| """Segment input_nii and save output_nii. Returns metadata.""" |
| input_nii = Path(input_nii).resolve() |
| output_nii = Path(output_nii).resolve() |
| if not input_nii.exists(): |
| raise FileNotFoundError(f"input NIfTI not found: {input_nii}") |
|
|
| img = nib.load(str(input_nii)) |
| volume = img.get_fdata(dtype=np.float32) |
| axis = _select_slice_axis(tuple(int(x) for x in volume.shape), slice_axis) |
| model, prompt_fn, device_obj = load_model( |
| ckpt=ckpt, |
| samihs_root=samihs_root, |
| device=device, |
| encoder_input_size=encoder_input_size, |
| ) |
| mask = segment_array( |
| volume=volume, |
| model=model, |
| get_click_prompt_eval=prompt_fn, |
| device=device_obj, |
| slice_axis=axis, |
| encoder_input_size=encoder_input_size, |
| batch_size=batch_size, |
| threshold=threshold, |
| amp=amp, |
| clip_low=clip_low, |
| clip_high=clip_high, |
| rotate_for_samihs=rotate_for_samihs, |
| ) |
|
|
| output_nii.parent.mkdir(parents=True, exist_ok=True) |
| header = img.header.copy() |
| header.set_data_dtype(np.uint8) |
| nib.save(nib.Nifti1Image(mask.astype(np.uint8), img.affine, header), str(output_nii)) |
|
|
| zooms = img.header.get_zooms()[:3] |
| voxel_volume_mm3 = float(np.prod(zooms)) if len(zooms) == 3 else float("nan") |
| metadata = { |
| "input_nii": str(input_nii), |
| "output_nii": str(output_nii), |
| "ckpt": str(Path(ckpt).resolve()), |
| "samihs_root": str(Path(samihs_root).resolve()), |
| "shape": list(volume.shape), |
| "slice_axis": axis, |
| "encoder_input_size": int(encoder_input_size), |
| "batch_size": int(batch_size), |
| "threshold": float(threshold), |
| "mask_nonzero_voxels": int(mask.sum()), |
| "voxel_volume_mm3_from_header": voxel_volume_mm3, |
| "mask_volume_ml_from_header": float(mask.sum()) * voxel_volume_mm3 / 1000.0, |
| } |
| return metadata |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Segment a 3D brain CT NIfTI into a binary ICH mask NIfTI using SAMIHS.") |
| parser.add_argument("--input", required=True, help="Input 3D CT NIfTI path (.nii or .nii.gz)") |
| parser.add_argument("--output", required=True, help="Output binary mask NIfTI path (.nii.gz recommended)") |
| parser.add_argument("--ckpt", default=str(DEFAULT_CKPT), help="SAMIHS checkpoint path") |
| parser.add_argument("--samihs-root", default=str(DEFAULT_SAMIHS_ROOT), help="Original SAMIHS project root") |
| parser.add_argument("--device", default="cuda:0", help="Torch device, e.g. cuda:0 or cpu") |
| parser.add_argument("--encoder-input-size", type=int, default=1024, help="2D encoder input size") |
| parser.add_argument("--batch-size", type=int, default=8, help="Slice-level batch size") |
| parser.add_argument("--threshold", type=float, default=0.5, help="Probability threshold for binary mask") |
| parser.add_argument("--slice-axis", default="auto", choices=["auto", "0", "1", "2"], help="Axis treated as slice/depth; auto uses smallest dimension") |
| parser.add_argument("--clip-low", type=float, default=0.5, help="Per-slice lower percentile for intensity clipping") |
| parser.add_argument("--clip-high", type=float, default=99.5, help="Per-slice upper percentile for intensity clipping") |
| parser.add_argument("--no-amp", action="store_true", help="Disable CUDA mixed precision") |
| parser.add_argument("--no-rotate", action="store_true", help="Disable SAMIHS rot90 pre/post correction") |
| parser.add_argument("--metadata-json", default="", help="Optional path to write run metadata JSON") |
| args = parser.parse_args() |
|
|
| metadata = segment_nii( |
| input_nii=args.input, |
| output_nii=args.output, |
| ckpt=args.ckpt, |
| samihs_root=args.samihs_root, |
| device=args.device, |
| encoder_input_size=args.encoder_input_size, |
| batch_size=args.batch_size, |
| threshold=args.threshold, |
| slice_axis=args.slice_axis, |
| amp=not args.no_amp, |
| clip_low=args.clip_low, |
| clip_high=args.clip_high, |
| rotate_for_samihs=not args.no_rotate, |
| ) |
|
|
| if args.metadata_json: |
| meta_path = Path(args.metadata_json).resolve() |
| meta_path.parent.mkdir(parents=True, exist_ok=True) |
| meta_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8") |
| print(json.dumps(metadata, indent=2, ensure_ascii=False)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|