|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
Compute 2.5D FID using distributed GPU processing.
|
|
|
|
|
|
SHELL Usage Example:
|
|
|
-------------------
|
|
|
#!/bin/bash
|
|
|
|
|
|
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
|
|
|
NUM_GPUS=7
|
|
|
|
|
|
torchrun --nproc_per_node=${NUM_GPUS} compute_fid_2-5d_ct.py \
|
|
|
--model_name "radimagenet_resnet50" \
|
|
|
--real_dataset_root "path/to/datasetA" \
|
|
|
--real_filelist "path/to/filelistA.txt" \
|
|
|
--real_features_dir "datasetA" \
|
|
|
--synth_dataset_root "path/to/datasetB" \
|
|
|
--synth_filelist "path/to/filelistB.txt" \
|
|
|
--synth_features_dir "datasetB" \
|
|
|
--enable_center_slices_ratio 0.4 \
|
|
|
--enable_padding True \
|
|
|
--enable_center_cropping True \
|
|
|
--enable_resampling_spacing "1.0x1.0x1.0" \
|
|
|
--ignore_existing True \
|
|
|
--num_images 100 \
|
|
|
--output_root "./features/features-512x512x512" \
|
|
|
--target_shape "512x512x512"
|
|
|
|
|
|
This script loads two datasets (real vs. synthetic) in 3D medical format (NIfTI)
|
|
|
and extracts feature maps via a 2.5D approach. It then computes the Frechet
|
|
|
Inception Distance (FID) across three orthogonal planes. Data parallelism
|
|
|
is implemented using torch.distributed with an NCCL backend.
|
|
|
|
|
|
Function Arguments (main):
|
|
|
--------------------------
|
|
|
real_dataset_root (str):
|
|
|
Root folder for the real dataset.
|
|
|
|
|
|
real_filelist (str):
|
|
|
Text file listing 3D images for the real dataset.
|
|
|
|
|
|
real_features_dir (str):
|
|
|
Subdirectory (under `output_root`) in which to store feature files
|
|
|
extracted from the real dataset.
|
|
|
|
|
|
synth_dataset_root (str):
|
|
|
Root folder for the synthetic dataset.
|
|
|
|
|
|
synth_filelist (str):
|
|
|
Text file listing 3D images for the synthetic dataset.
|
|
|
|
|
|
synth_features_dir (str):
|
|
|
Subdirectory (under `output_root`) in which to store feature files
|
|
|
extracted from the synthetic dataset.
|
|
|
|
|
|
enable_center_slices_ratio (float or None):
|
|
|
- If not None, only slices around the specified center ratio will be used
|
|
|
(analogous to "enable_center_slices=True" with that ratio).
|
|
|
- If None, no center-slice selection is performed
|
|
|
(analogous to "enable_center_slices=False").
|
|
|
|
|
|
enable_padding (bool):
|
|
|
Whether to pad images to `target_shape`.
|
|
|
|
|
|
enable_center_cropping (bool):
|
|
|
Whether to center-crop images to `target_shape`.
|
|
|
|
|
|
enable_resampling_spacing (str or None):
|
|
|
- If not None, resample images to the specified voxel spacing (e.g. "1.0x1.0x1.0")
|
|
|
(analogous to "enable_resampling=True" with that spacing).
|
|
|
- If None, resampling is skipped
|
|
|
(analogous to "enable_resampling=False").
|
|
|
|
|
|
ignore_existing (bool):
|
|
|
If True, ignore any existing .pt feature files and force re-extraction.
|
|
|
|
|
|
model_name (str):
|
|
|
Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1".
|
|
|
|
|
|
num_images (int):
|
|
|
Max number of images to process from each dataset (truncate if more are present).
|
|
|
|
|
|
output_root (str):
|
|
|
Folder where extracted .pt feature files, logs, and results are saved.
|
|
|
|
|
|
target_shape (str):
|
|
|
Target shape as "XxYxZ" for padding, cropping, or resampling operations.
|
|
|
"""
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
import torch
|
|
|
import fire
|
|
|
import monai
|
|
|
import re
|
|
|
import torch.distributed as dist
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from datetime import timedelta
|
|
|
from pathlib import Path
|
|
|
from monai.metrics.fid import FIDMetric
|
|
|
from monai.transforms import Compose
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger("fid_2-5d_ct")
|
|
|
if not logger.handlers:
|
|
|
|
|
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
def drop_empty_slice(slices, empty_threshold: float):
|
|
|
"""
|
|
|
Decide which 2D slices to keep by checking if their maximum intensity
|
|
|
is below a certain threshold.
|
|
|
|
|
|
Args:
|
|
|
slices (tuple or list of Tensors): Each element is (B, C, H, W).
|
|
|
empty_threshold (float): If the slice's maximum value is below this threshold,
|
|
|
it is considered "empty".
|
|
|
|
|
|
Returns:
|
|
|
list[bool]: A list of booleans indicating for each slice whether to keep it.
|
|
|
"""
|
|
|
outputs = []
|
|
|
n_drop = 0
|
|
|
for s in slices:
|
|
|
largest_unique = torch.max(torch.unique(s))
|
|
|
if largest_unique < empty_threshold:
|
|
|
outputs.append(False)
|
|
|
n_drop += 1
|
|
|
else:
|
|
|
outputs.append(True)
|
|
|
|
|
|
logger.info(f"Empty slice drop rate {round((n_drop/len(slices))*100,1)}%")
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
def subtract_mean(x: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Subtract per-channel means (ImageNet-like: [0.406, 0.456, 0.485])
|
|
|
from the input 4D or 5D tensor. Expects channels in the first dimension
|
|
|
after the batch dimension: (B, C, H, W) or (B, C, H, W, D).
|
|
|
"""
|
|
|
mean = [0.406, 0.456, 0.485]
|
|
|
x[:, 0, ...] -= mean[0]
|
|
|
x[:, 1, ...] -= mean[1]
|
|
|
x[:, 2, ...] -= mean[2]
|
|
|
return x
|
|
|
|
|
|
|
|
|
def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
|
|
|
"""
|
|
|
Average out the spatial dimensions of a tensor, preserving or removing them
|
|
|
according to `keepdim`. This is used to produce a 1D feature vector
|
|
|
out of a feature map.
|
|
|
|
|
|
Args:
|
|
|
x (torch.Tensor): Input tensor (B, C, H, W, ...) or (B, C, H, W).
|
|
|
keepdim (bool): Whether to keep dimension or not after averaging.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: Tensor with reduced spatial dimensions.
|
|
|
"""
|
|
|
dim = len(x.shape)
|
|
|
|
|
|
if dim == 2:
|
|
|
return x
|
|
|
|
|
|
if dim == 3:
|
|
|
return x.mean([2], keepdim=keepdim)
|
|
|
|
|
|
if dim == 4:
|
|
|
return x.mean([2, 3], keepdim=keepdim)
|
|
|
|
|
|
if dim == 5:
|
|
|
return x.mean([2, 3, 4], keepdim=keepdim)
|
|
|
return x
|
|
|
|
|
|
|
|
|
def medicalnet_intensity_normalisation(volume: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Intensity normalization approach from MedicalNet:
|
|
|
(volume - mean) / (std + 1e-5) across spatial dims.
|
|
|
Expects (B, C, H, W) or (B, C, H, W, D).
|
|
|
"""
|
|
|
dim = len(volume.shape)
|
|
|
if dim == 4:
|
|
|
mean = volume.mean([2, 3], keepdim=True)
|
|
|
std = volume.std([2, 3], keepdim=True)
|
|
|
elif dim == 5:
|
|
|
mean = volume.mean([2, 3, 4], keepdim=True)
|
|
|
std = volume.std([2, 3, 4], keepdim=True)
|
|
|
else:
|
|
|
return volume
|
|
|
return (volume - mean) / (std + 1e-5)
|
|
|
|
|
|
|
|
|
def radimagenet_intensity_normalisation(volume: torch.Tensor, norm2d: bool = False) -> torch.Tensor:
|
|
|
"""
|
|
|
Intensity normalization for radimagenet_resnet. Optionally normalizes each 2D slice individually.
|
|
|
|
|
|
Args:
|
|
|
volume (torch.Tensor): Input (B, C, H, W) or (B, C, H, W, D).
|
|
|
norm2d (bool): If True, normalizes each (H,W) slice to [0,1], then subtracts the ImageNet mean.
|
|
|
"""
|
|
|
logger.info(f"norm2d: {norm2d}")
|
|
|
dim = len(volume.shape)
|
|
|
|
|
|
if dim == 4 and norm2d:
|
|
|
max2d, _ = torch.max(volume, dim=2, keepdim=True)
|
|
|
max2d, _ = torch.max(max2d, dim=3, keepdim=True)
|
|
|
min2d, _ = torch.min(volume, dim=2, keepdim=True)
|
|
|
min2d, _ = torch.min(min2d, dim=3, keepdim=True)
|
|
|
|
|
|
volume = (volume - min2d) / (max2d - min2d + 1e-10)
|
|
|
|
|
|
return subtract_mean(volume)
|
|
|
elif dim == 4:
|
|
|
|
|
|
max3d = torch.max(volume)
|
|
|
min3d = torch.min(volume)
|
|
|
volume = (volume - min3d) / (max3d - min3d + 1e-10)
|
|
|
return subtract_mean(volume)
|
|
|
|
|
|
if dim == 5:
|
|
|
maxval = torch.max(volume)
|
|
|
minval = torch.min(volume)
|
|
|
volume = (volume - minval) / (maxval - minval + 1e-10)
|
|
|
return subtract_mean(volume)
|
|
|
return volume
|
|
|
|
|
|
|
|
|
def get_features_2p5d(
|
|
|
image: torch.Tensor,
|
|
|
feature_network: torch.nn.Module,
|
|
|
center_slices: bool = False,
|
|
|
center_slices_ratio: float = 1.0,
|
|
|
sample_every_k: int = 1,
|
|
|
xy_only: bool = True,
|
|
|
drop_empty: bool = False,
|
|
|
empty_threshold: float = -700,
|
|
|
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
|
|
"""
|
|
|
Extract 2.5D features from a 3D image by slicing it along XY, YZ, ZX planes.
|
|
|
|
|
|
Args:
|
|
|
image (torch.Tensor): Input 5D tensor in shape (B, C, H, W, D).
|
|
|
feature_network (torch.nn.Module): Model that processes 2D slices (C,H,W).
|
|
|
center_slices (bool): Whether to slice only the center portion of each axis.
|
|
|
center_slices_ratio (float): Ratio of slices to keep in the center if `center_slices` is True.
|
|
|
sample_every_k (int): Downsampling factor along each axis when slicing.
|
|
|
xy_only (bool): If True, return only the XY-plane features.
|
|
|
drop_empty (bool): Drop slices that are deemed "empty" below `empty_threshold`.
|
|
|
empty_threshold (float): Threshold to decide emptiness of slices.
|
|
|
|
|
|
Returns:
|
|
|
tuple of torch.Tensor or None: (XY_features, YZ_features, ZX_features).
|
|
|
"""
|
|
|
logger.info(f"center_slices: {center_slices}, ratio: {center_slices_ratio}")
|
|
|
|
|
|
|
|
|
if image.shape[1] == 1:
|
|
|
image = image.repeat(1, 3, 1, 1, 1)
|
|
|
|
|
|
|
|
|
image = image[:, [2, 1, 0], ...]
|
|
|
|
|
|
B, C, H, W, D = image.size()
|
|
|
with torch.no_grad():
|
|
|
|
|
|
if center_slices:
|
|
|
start_d = int((1.0 - center_slices_ratio) / 2.0 * D)
|
|
|
end_d = int((1.0 + center_slices_ratio) / 2.0 * D)
|
|
|
slices = torch.unbind(image[:, :, :, :, start_d:end_d:sample_every_k], dim=-1)
|
|
|
else:
|
|
|
slices = torch.unbind(image, dim=-1)
|
|
|
|
|
|
if drop_empty:
|
|
|
mapping_index = drop_empty_slice(slices, empty_threshold)
|
|
|
else:
|
|
|
mapping_index = [True for _ in range(len(slices))]
|
|
|
|
|
|
images_2d = torch.cat(slices, dim=0)
|
|
|
images_2d = radimagenet_intensity_normalisation(images_2d)
|
|
|
images_2d = images_2d[mapping_index]
|
|
|
|
|
|
feature_image_xy = feature_network.forward(images_2d)
|
|
|
feature_image_xy = spatial_average(feature_image_xy, keepdim=False)
|
|
|
if xy_only:
|
|
|
return feature_image_xy, None, None
|
|
|
|
|
|
|
|
|
if center_slices:
|
|
|
start_h = int((1.0 - center_slices_ratio) / 2.0 * H)
|
|
|
end_h = int((1.0 + center_slices_ratio) / 2.0 * H)
|
|
|
slices = torch.unbind(image[:, :, start_h:end_h:sample_every_k, :, :], dim=2)
|
|
|
else:
|
|
|
slices = torch.unbind(image, dim=2)
|
|
|
|
|
|
if drop_empty:
|
|
|
mapping_index = drop_empty_slice(slices, empty_threshold)
|
|
|
else:
|
|
|
mapping_index = [True for _ in range(len(slices))]
|
|
|
|
|
|
images_2d = torch.cat(slices, dim=0)
|
|
|
images_2d = radimagenet_intensity_normalisation(images_2d)
|
|
|
images_2d = images_2d[mapping_index]
|
|
|
|
|
|
feature_image_yz = feature_network.forward(images_2d)
|
|
|
feature_image_yz = spatial_average(feature_image_yz, keepdim=False)
|
|
|
|
|
|
|
|
|
if center_slices:
|
|
|
start_w = int((1.0 - center_slices_ratio) / 2.0 * W)
|
|
|
end_w = int((1.0 + center_slices_ratio) / 2.0 * W)
|
|
|
slices = torch.unbind(image[:, :, :, start_w:end_w:sample_every_k, :], dim=3)
|
|
|
else:
|
|
|
slices = torch.unbind(image, dim=3)
|
|
|
|
|
|
if drop_empty:
|
|
|
mapping_index = drop_empty_slice(slices, empty_threshold)
|
|
|
else:
|
|
|
mapping_index = [True for _ in range(len(slices))]
|
|
|
|
|
|
images_2d = torch.cat(slices, dim=0)
|
|
|
images_2d = radimagenet_intensity_normalisation(images_2d)
|
|
|
images_2d = images_2d[mapping_index]
|
|
|
|
|
|
feature_image_zx = feature_network.forward(images_2d)
|
|
|
feature_image_zx = spatial_average(feature_image_zx, keepdim=False)
|
|
|
|
|
|
return feature_image_xy, feature_image_yz, feature_image_zx
|
|
|
|
|
|
|
|
|
def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float = 0.0) -> torch.Tensor:
|
|
|
"""
|
|
|
Zero-pad a 2D feature map or other tensor along the first dimension to match a specified size.
|
|
|
|
|
|
Args:
|
|
|
tensor (torch.Tensor): The feature tensor to pad.
|
|
|
max_size (int): Desired size along the first dimension.
|
|
|
padding_value (float): Value to fill during padding.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: Padded tensor matching `max_size` along dim=0.
|
|
|
"""
|
|
|
pad_size = [0, 0] * (len(tensor.shape) - 1) + [0, max_size - tensor.shape[0]]
|
|
|
return F.pad(tensor, pad_size, "constant", padding_value)
|
|
|
|
|
|
|
|
|
def main(
|
|
|
real_dataset_root: str = "path/to/datasetA",
|
|
|
real_filelist: str = "path/to/filelistA.txt",
|
|
|
real_features_dir: str = "datasetA",
|
|
|
synth_dataset_root: str = "path/to/datasetB",
|
|
|
synth_filelist: str = "path/to/filelistB.txt",
|
|
|
synth_features_dir: str = "datasetB",
|
|
|
enable_center_slices_ratio: float = None,
|
|
|
enable_padding: bool = True,
|
|
|
enable_center_cropping: bool = True,
|
|
|
enable_resampling_spacing: str = None,
|
|
|
ignore_existing: bool = False,
|
|
|
model_name: str = "radimagenet_resnet50",
|
|
|
num_images: int = 100,
|
|
|
output_root: str = "./features/features-512x512x512",
|
|
|
target_shape: str = "512x512x512",
|
|
|
):
|
|
|
"""
|
|
|
Compute 2.5D FID using distributed GPU processing.
|
|
|
|
|
|
This function loads two datasets (real vs. synthetic) in 3D medical format (NIfTI)
|
|
|
and extracts feature maps via a 2.5D approach, then computes the Frechet Inception
|
|
|
Distance (FID) across three orthogonal planes. Data parallelism is implemented
|
|
|
using torch.distributed with an NCCL backend.
|
|
|
|
|
|
Args:
|
|
|
real_dataset_root (str):
|
|
|
Root folder for the real dataset.
|
|
|
real_filelist (str):
|
|
|
Path to a text file listing 3D images (e.g., NIfTI files) for the real dataset.
|
|
|
Each line in this file should contain a relative path (or filename) to a NIfTI file.
|
|
|
For example, your "real_filelist.txt" could look like:
|
|
|
case001.nii.gz
|
|
|
case002.nii.gz
|
|
|
case003.nii.gz
|
|
|
...
|
|
|
These entries will be appended to `real_dataset_root`.
|
|
|
real_features_dir (str):
|
|
|
Name of the directory under `output_root` in which to store
|
|
|
extracted features for the real dataset.
|
|
|
|
|
|
synth_dataset_root (str):
|
|
|
Root folder for the synthetic dataset.
|
|
|
synth_filelist (str):
|
|
|
Path to a text file listing 3D images (e.g., NIfTI files) for the synthetic dataset.
|
|
|
The format is the same as the real dataset file list, for example:
|
|
|
synth_case001.nii.gz
|
|
|
synth_case002.nii.gz
|
|
|
synth_case003.nii.gz
|
|
|
...
|
|
|
These entries will be appended to `synth_dataset_root`.
|
|
|
synth_features_dir (str):
|
|
|
Name of the directory under `output_root` in which to store
|
|
|
extracted features for the synthetic dataset.
|
|
|
|
|
|
enable_center_slices_ratio (float or None):
|
|
|
- If not None, only slices around the specified center ratio are used.
|
|
|
(similar to "enable_center_slices=True" with that ratio in an earlier script).
|
|
|
- If None, no center-slice selection is performed
|
|
|
(similar to "enable_center_slices=False").
|
|
|
|
|
|
enable_padding (bool):
|
|
|
Whether to pad images to `target_shape`.
|
|
|
|
|
|
enable_center_cropping (bool):
|
|
|
Whether to center-crop images to `target_shape`.
|
|
|
|
|
|
enable_resampling_spacing (str or None):
|
|
|
- If not None, resample images to this voxel spacing (e.g. "1.0x1.0x1.0")
|
|
|
(similar to "enable_resampling=True" with that spacing).
|
|
|
- If None, skip resampling (similar to "enable_resampling=False").
|
|
|
|
|
|
ignore_existing (bool):
|
|
|
If True, ignore any existing .pt feature files and force re-computation.
|
|
|
|
|
|
model_name (str):
|
|
|
Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1".
|
|
|
|
|
|
num_images (int):
|
|
|
Maximum number of images to load from each dataset (truncate if more are present).
|
|
|
|
|
|
output_root (str):
|
|
|
Parent folder where extracted .pt files and logs will be saved.
|
|
|
|
|
|
target_shape (str):
|
|
|
Target shape, e.g. "512x512x512", for padding, cropping, or resampling operations.
|
|
|
|
|
|
Returns:
|
|
|
None
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
dist.init_process_group(backend="nccl", init_method="env://", timeout=timedelta(seconds=7200))
|
|
|
|
|
|
local_rank = int(os.environ["LOCAL_RANK"])
|
|
|
world_size = int(dist.get_world_size())
|
|
|
device = torch.device("cuda", local_rank)
|
|
|
torch.cuda.set_device(device)
|
|
|
logger.info(f"[INFO] Running process on {device} of total {world_size} ranks.")
|
|
|
|
|
|
|
|
|
if not isinstance(enable_padding, bool):
|
|
|
enable_padding = enable_padding.lower() == "true"
|
|
|
if not isinstance(enable_center_cropping, bool):
|
|
|
enable_center_cropping = enable_center_cropping.lower() == "true"
|
|
|
if not isinstance(ignore_existing, bool):
|
|
|
ignore_existing = ignore_existing.lower() == "true"
|
|
|
|
|
|
|
|
|
enable_center_slices = enable_center_slices_ratio is not None
|
|
|
|
|
|
|
|
|
enable_resampling = enable_resampling_spacing is not None
|
|
|
|
|
|
|
|
|
if local_rank == 0:
|
|
|
logger.info(f"Real dataset root: {real_dataset_root}")
|
|
|
logger.info(f"Synth dataset root: {synth_dataset_root}")
|
|
|
logger.info(f"enable_center_slices_ratio: {enable_center_slices_ratio}")
|
|
|
logger.info(f"enable_center_slices: {enable_center_slices}")
|
|
|
logger.info(f"enable_padding: {enable_padding}")
|
|
|
logger.info(f"enable_center_cropping: {enable_center_cropping}")
|
|
|
logger.info(f"enable_resampling_spacing: {enable_resampling_spacing}")
|
|
|
logger.info(f"enable_resampling: {enable_resampling}")
|
|
|
logger.info(f"ignore_existing: {ignore_existing}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_name == "radimagenet_resnet50":
|
|
|
feature_network = torch.hub.load(
|
|
|
"Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True, trust_repo=True
|
|
|
)
|
|
|
suffix = "radimagenet_resnet50"
|
|
|
else:
|
|
|
import torchvision
|
|
|
|
|
|
feature_network = torchvision.models.squeezenet1_1(pretrained=True)
|
|
|
suffix = "squeezenet1_1"
|
|
|
|
|
|
feature_network.to(device)
|
|
|
feature_network.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_shape = [int(x) for x in target_shape.split("x")]
|
|
|
target_shape_tuple = tuple(t_shape)
|
|
|
|
|
|
|
|
|
if enable_resampling:
|
|
|
rs_spacing = [float(x) for x in enable_resampling_spacing.split("x")]
|
|
|
rs_spacing_tuple = tuple(rs_spacing)
|
|
|
if local_rank == 0:
|
|
|
logger.info(f"Resampling spacing: {rs_spacing_tuple}")
|
|
|
else:
|
|
|
rs_spacing_tuple = (1.0, 1.0, 1.0)
|
|
|
|
|
|
|
|
|
center_slices_ratio_final = enable_center_slices_ratio if enable_center_slices else 1.0
|
|
|
if local_rank == 0:
|
|
|
logger.info(f"center_slices_ratio: {center_slices_ratio_final}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_root_real = os.path.join(output_root, real_features_dir)
|
|
|
with open(real_filelist, "r") as rf:
|
|
|
real_lines = [l.strip() for l in rf.readlines()]
|
|
|
real_lines.sort()
|
|
|
real_lines = real_lines[:num_images]
|
|
|
|
|
|
real_filenames = [{"image": os.path.join(real_dataset_root, f)} for f in real_lines]
|
|
|
real_filenames = monai.data.partition_dataset(
|
|
|
data=real_filenames, shuffle=False, num_partitions=world_size, even_divisible=False
|
|
|
)[local_rank]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_root_synth = os.path.join(output_root, synth_features_dir)
|
|
|
with open(synth_filelist, "r") as sf:
|
|
|
synth_lines = [l.strip() for l in sf.readlines()]
|
|
|
synth_lines.sort()
|
|
|
synth_lines = synth_lines[:num_images]
|
|
|
|
|
|
synth_filenames = [{"image": os.path.join(synth_dataset_root, f)} for f in synth_lines]
|
|
|
synth_filenames = monai.data.partition_dataset(
|
|
|
data=synth_filenames, shuffle=False, num_partitions=world_size, even_divisible=False
|
|
|
)[local_rank]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transform_list = [
|
|
|
monai.transforms.LoadImaged(keys=["image"]),
|
|
|
monai.transforms.EnsureChannelFirstd(keys=["image"]),
|
|
|
monai.transforms.Orientationd(keys=["image"], axcodes="RAS"),
|
|
|
]
|
|
|
|
|
|
if enable_resampling:
|
|
|
transform_list.append(monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"]))
|
|
|
|
|
|
if enable_padding:
|
|
|
transform_list.append(
|
|
|
monai.transforms.SpatialPadd(keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000)
|
|
|
)
|
|
|
|
|
|
if enable_center_cropping:
|
|
|
transform_list.append(monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple))
|
|
|
|
|
|
transform_list.append(
|
|
|
monai.transforms.ScaleIntensityRanged(
|
|
|
keys=["image"], a_min=-1000, a_max=1000, b_min=-1000, b_max=1000, clip=True
|
|
|
)
|
|
|
)
|
|
|
transforms = Compose(transform_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
real_ds = monai.data.Dataset(data=real_filenames, transform=transforms)
|
|
|
real_loader = monai.data.DataLoader(real_ds, num_workers=6, batch_size=1, shuffle=False)
|
|
|
|
|
|
synth_ds = monai.data.Dataset(data=synth_filenames, transform=transforms)
|
|
|
synth_loader = monai.data.DataLoader(synth_ds, num_workers=6, batch_size=1, shuffle=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
real_features_xy, real_features_yz, real_features_zx = [], [], []
|
|
|
for idx, batch_data in enumerate(real_loader, start=1):
|
|
|
img = batch_data["image"].to(device)
|
|
|
fn = img.meta["filename_or_obj"][0]
|
|
|
logger.info(f"[Rank {local_rank}] Real data {idx}/{len(real_filenames)}: {fn}")
|
|
|
|
|
|
out_fp = fn.replace(real_dataset_root, output_root_real).replace(".nii.gz", ".pt")
|
|
|
out_fp = Path(out_fp)
|
|
|
out_fp.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
if (not ignore_existing) and os.path.isfile(out_fp):
|
|
|
feats = torch.load(out_fp, weights_only=True)
|
|
|
else:
|
|
|
img_t = img.as_tensor()
|
|
|
logger.info(f"image shape: {tuple(img_t.shape)}")
|
|
|
|
|
|
feats = get_features_2p5d(
|
|
|
img_t,
|
|
|
feature_network,
|
|
|
center_slices=enable_center_slices,
|
|
|
center_slices_ratio=center_slices_ratio_final,
|
|
|
xy_only=False,
|
|
|
)
|
|
|
logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}")
|
|
|
torch.save(feats, out_fp)
|
|
|
|
|
|
real_features_xy.append(feats[0])
|
|
|
real_features_yz.append(feats[1])
|
|
|
real_features_zx.append(feats[2])
|
|
|
|
|
|
real_features_xy = torch.vstack(real_features_xy)
|
|
|
real_features_yz = torch.vstack(real_features_yz)
|
|
|
real_features_zx = torch.vstack(real_features_zx)
|
|
|
logger.info(
|
|
|
f"Real feature shapes: {real_features_xy.shape}, " f"{real_features_yz.shape}, {real_features_zx.shape}"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
synth_features_xy, synth_features_yz, synth_features_zx = [], [], []
|
|
|
for idx, batch_data in enumerate(synth_loader, start=1):
|
|
|
img = batch_data["image"].to(device)
|
|
|
fn = img.meta["filename_or_obj"][0]
|
|
|
logger.info(f"[Rank {local_rank}] Synth data {idx}/{len(synth_filenames)}: {fn}")
|
|
|
|
|
|
out_fp = fn.replace(synth_dataset_root, output_root_synth).replace(".nii.gz", ".pt")
|
|
|
out_fp = Path(out_fp)
|
|
|
out_fp.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
if (not ignore_existing) and os.path.isfile(out_fp):
|
|
|
feats = torch.load(out_fp, weights_only=True)
|
|
|
else:
|
|
|
img_t = img.as_tensor()
|
|
|
logger.info(f"image shape: {tuple(img_t.shape)}")
|
|
|
|
|
|
feats = get_features_2p5d(
|
|
|
img_t,
|
|
|
feature_network,
|
|
|
center_slices=enable_center_slices,
|
|
|
center_slices_ratio=center_slices_ratio_final,
|
|
|
xy_only=False,
|
|
|
)
|
|
|
logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}")
|
|
|
torch.save(feats, out_fp)
|
|
|
|
|
|
synth_features_xy.append(feats[0])
|
|
|
synth_features_yz.append(feats[1])
|
|
|
synth_features_zx.append(feats[2])
|
|
|
|
|
|
synth_features_xy = torch.vstack(synth_features_xy)
|
|
|
synth_features_yz = torch.vstack(synth_features_yz)
|
|
|
synth_features_zx = torch.vstack(synth_features_zx)
|
|
|
logger.info(
|
|
|
f"Synth feature shapes: {synth_features_xy.shape}, " f"{synth_features_yz.shape}, {synth_features_zx.shape}"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features = [
|
|
|
real_features_xy,
|
|
|
real_features_yz,
|
|
|
real_features_zx,
|
|
|
synth_features_xy,
|
|
|
synth_features_yz,
|
|
|
synth_features_zx,
|
|
|
]
|
|
|
|
|
|
|
|
|
local_sizes = []
|
|
|
for ft_idx in range(len(features)):
|
|
|
local_size = torch.tensor([features[ft_idx].shape[0]], dtype=torch.int64, device=device)
|
|
|
local_sizes.append(local_size)
|
|
|
|
|
|
all_sizes = []
|
|
|
for ft_idx in range(len(features)):
|
|
|
rank_sizes = [torch.tensor([0], dtype=torch.int64, device=device) for _ in range(world_size)]
|
|
|
dist.all_gather(rank_sizes, local_sizes[ft_idx])
|
|
|
all_sizes.append(rank_sizes)
|
|
|
|
|
|
|
|
|
all_tensors_list = []
|
|
|
for ft_idx, ft in enumerate(features):
|
|
|
max_size = max(all_sizes[ft_idx]).item()
|
|
|
ft_padded = pad_to_max_size(ft, max_size)
|
|
|
|
|
|
gather_list = [torch.empty_like(ft_padded) for _ in range(world_size)]
|
|
|
dist.all_gather(gather_list, ft_padded)
|
|
|
|
|
|
|
|
|
for rk in range(world_size):
|
|
|
gather_list[rk] = gather_list[rk][: all_sizes[ft_idx][rk], :]
|
|
|
|
|
|
all_tensors_list.append(gather_list)
|
|
|
|
|
|
|
|
|
if local_rank == 0:
|
|
|
real_xy = torch.vstack(all_tensors_list[0])
|
|
|
real_yz = torch.vstack(all_tensors_list[1])
|
|
|
real_zx = torch.vstack(all_tensors_list[2])
|
|
|
|
|
|
synth_xy = torch.vstack(all_tensors_list[3])
|
|
|
synth_yz = torch.vstack(all_tensors_list[4])
|
|
|
synth_zx = torch.vstack(all_tensors_list[5])
|
|
|
|
|
|
logger.info(f"Final Real shapes: {real_xy.shape}, {real_yz.shape}, {real_zx.shape}")
|
|
|
logger.info(f"Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}")
|
|
|
|
|
|
fid = FIDMetric()
|
|
|
logger.info(f"Computing FID for: {output_root_real} | {output_root_synth}")
|
|
|
fid_res_xy = fid(synth_xy, real_xy)
|
|
|
fid_res_yz = fid(synth_yz, real_yz)
|
|
|
fid_res_zx = fid(synth_zx, real_zx)
|
|
|
|
|
|
logger.info(f"FID XY: {fid_res_xy}")
|
|
|
logger.info(f"FID YZ: {fid_res_yz}")
|
|
|
logger.info(f"FID ZX: {fid_res_zx}")
|
|
|
fid_avg = (fid_res_xy + fid_res_yz + fid_res_zx) / 3.0
|
|
|
logger.info(f"FID Avg: {fid_avg}")
|
|
|
|
|
|
dist.destroy_process_group()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
fire.Fire(main)
|
|
|
|