NoMAISI / scripts /compute_fid_2-5d_ct.py
ft42's picture
Upload 63 files
599a397 verified
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
# either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
"""
Compute 2.5D FID using distributed GPU processing.
SHELL Usage Example:
-------------------
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
NUM_GPUS=7
torchrun --nproc_per_node=${NUM_GPUS} compute_fid_2-5d_ct.py \
--model_name "radimagenet_resnet50" \
--real_dataset_root "path/to/datasetA" \
--real_filelist "path/to/filelistA.txt" \
--real_features_dir "datasetA" \
--synth_dataset_root "path/to/datasetB" \
--synth_filelist "path/to/filelistB.txt" \
--synth_features_dir "datasetB" \
--enable_center_slices_ratio 0.4 \
--enable_padding True \
--enable_center_cropping True \
--enable_resampling_spacing "1.0x1.0x1.0" \
--ignore_existing True \
--num_images 100 \
--output_root "./features/features-512x512x512" \
--target_shape "512x512x512"
This script loads two datasets (real vs. synthetic) in 3D medical format (NIfTI)
and extracts feature maps via a 2.5D approach. It then computes the Frechet
Inception Distance (FID) across three orthogonal planes. Data parallelism
is implemented using torch.distributed with an NCCL backend.
Function Arguments (main):
--------------------------
real_dataset_root (str):
Root folder for the real dataset.
real_filelist (str):
Text file listing 3D images for the real dataset.
real_features_dir (str):
Subdirectory (under `output_root`) in which to store feature files
extracted from the real dataset.
synth_dataset_root (str):
Root folder for the synthetic dataset.
synth_filelist (str):
Text file listing 3D images for the synthetic dataset.
synth_features_dir (str):
Subdirectory (under `output_root`) in which to store feature files
extracted from the synthetic dataset.
enable_center_slices_ratio (float or None):
- If not None, only slices around the specified center ratio will be used
(analogous to "enable_center_slices=True" with that ratio).
- If None, no center-slice selection is performed
(analogous to "enable_center_slices=False").
enable_padding (bool):
Whether to pad images to `target_shape`.
enable_center_cropping (bool):
Whether to center-crop images to `target_shape`.
enable_resampling_spacing (str or None):
- If not None, resample images to the specified voxel spacing (e.g. "1.0x1.0x1.0")
(analogous to "enable_resampling=True" with that spacing).
- If None, resampling is skipped
(analogous to "enable_resampling=False").
ignore_existing (bool):
If True, ignore any existing .pt feature files and force re-extraction.
model_name (str):
Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1".
num_images (int):
Max number of images to process from each dataset (truncate if more are present).
output_root (str):
Folder where extracted .pt feature files, logs, and results are saved.
target_shape (str):
Target shape as "XxYxZ" for padding, cropping, or resampling operations.
"""
from __future__ import annotations
import os
import sys
import torch
import fire
import monai
import re
import torch.distributed as dist
import torch.nn.functional as F
from datetime import timedelta
from pathlib import Path
from monai.metrics.fid import FIDMetric
from monai.transforms import Compose
import logging
# ------------------------------------------------------------------------------
# Create logger
# ------------------------------------------------------------------------------
logger = logging.getLogger("fid_2-5d_ct")
if not logger.handlers:
# Configure logger only if it has no handlers (avoid reconfiguring in multi-rank scenarios)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger.setLevel(logging.INFO)
def drop_empty_slice(slices, empty_threshold: float):
"""
Decide which 2D slices to keep by checking if their maximum intensity
is below a certain threshold.
Args:
slices (tuple or list of Tensors): Each element is (B, C, H, W).
empty_threshold (float): If the slice's maximum value is below this threshold,
it is considered "empty".
Returns:
list[bool]: A list of booleans indicating for each slice whether to keep it.
"""
outputs = []
n_drop = 0
for s in slices:
largest_unique = torch.max(torch.unique(s))
if largest_unique < empty_threshold:
outputs.append(False)
n_drop += 1
else:
outputs.append(True)
logger.info(f"Empty slice drop rate {round((n_drop/len(slices))*100,1)}%")
return outputs
def subtract_mean(x: torch.Tensor) -> torch.Tensor:
"""
Subtract per-channel means (ImageNet-like: [0.406, 0.456, 0.485])
from the input 4D or 5D tensor. Expects channels in the first dimension
after the batch dimension: (B, C, H, W) or (B, C, H, W, D).
"""
mean = [0.406, 0.456, 0.485]
x[:, 0, ...] -= mean[0]
x[:, 1, ...] -= mean[1]
x[:, 2, ...] -= mean[2]
return x
def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
"""
Average out the spatial dimensions of a tensor, preserving or removing them
according to `keepdim`. This is used to produce a 1D feature vector
out of a feature map.
Args:
x (torch.Tensor): Input tensor (B, C, H, W, ...) or (B, C, H, W).
keepdim (bool): Whether to keep dimension or not after averaging.
Returns:
torch.Tensor: Tensor with reduced spatial dimensions.
"""
dim = len(x.shape)
# 2D -> no average
if dim == 2:
return x
# 3D -> average over last dim
if dim == 3:
return x.mean([2], keepdim=keepdim)
# 4D -> average over H,W
if dim == 4:
return x.mean([2, 3], keepdim=keepdim)
# 5D -> average over H,W,D
if dim == 5:
return x.mean([2, 3, 4], keepdim=keepdim)
return x
def medicalnet_intensity_normalisation(volume: torch.Tensor) -> torch.Tensor:
"""
Intensity normalization approach from MedicalNet:
(volume - mean) / (std + 1e-5) across spatial dims.
Expects (B, C, H, W) or (B, C, H, W, D).
"""
dim = len(volume.shape)
if dim == 4:
mean = volume.mean([2, 3], keepdim=True)
std = volume.std([2, 3], keepdim=True)
elif dim == 5:
mean = volume.mean([2, 3, 4], keepdim=True)
std = volume.std([2, 3, 4], keepdim=True)
else:
return volume
return (volume - mean) / (std + 1e-5)
def radimagenet_intensity_normalisation(volume: torch.Tensor, norm2d: bool = False) -> torch.Tensor:
"""
Intensity normalization for radimagenet_resnet. Optionally normalizes each 2D slice individually.
Args:
volume (torch.Tensor): Input (B, C, H, W) or (B, C, H, W, D).
norm2d (bool): If True, normalizes each (H,W) slice to [0,1], then subtracts the ImageNet mean.
"""
logger.info(f"norm2d: {norm2d}")
dim = len(volume.shape)
# If norm2d is True, only meaningful for 4D data (B, C, H, W):
if dim == 4 and norm2d:
max2d, _ = torch.max(volume, dim=2, keepdim=True)
max2d, _ = torch.max(max2d, dim=3, keepdim=True)
min2d, _ = torch.min(volume, dim=2, keepdim=True)
min2d, _ = torch.min(min2d, dim=3, keepdim=True)
# Scale each slice to 0..1
volume = (volume - min2d) / (max2d - min2d + 1e-10)
# Subtract channel mean
return subtract_mean(volume)
elif dim == 4:
# 4D but no per-slice normalization
max3d = torch.max(volume)
min3d = torch.min(volume)
volume = (volume - min3d) / (max3d - min3d + 1e-10)
return subtract_mean(volume)
# Fallback for e.g. 5D data is simply a min-max over entire volume
if dim == 5:
maxval = torch.max(volume)
minval = torch.min(volume)
volume = (volume - minval) / (maxval - minval + 1e-10)
return subtract_mean(volume)
return volume
def get_features_2p5d(
image: torch.Tensor,
feature_network: torch.nn.Module,
center_slices: bool = False,
center_slices_ratio: float = 1.0,
sample_every_k: int = 1,
xy_only: bool = True,
drop_empty: bool = False,
empty_threshold: float = -700,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Extract 2.5D features from a 3D image by slicing it along XY, YZ, ZX planes.
Args:
image (torch.Tensor): Input 5D tensor in shape (B, C, H, W, D).
feature_network (torch.nn.Module): Model that processes 2D slices (C,H,W).
center_slices (bool): Whether to slice only the center portion of each axis.
center_slices_ratio (float): Ratio of slices to keep in the center if `center_slices` is True.
sample_every_k (int): Downsampling factor along each axis when slicing.
xy_only (bool): If True, return only the XY-plane features.
drop_empty (bool): Drop slices that are deemed "empty" below `empty_threshold`.
empty_threshold (float): Threshold to decide emptiness of slices.
Returns:
tuple of torch.Tensor or None: (XY_features, YZ_features, ZX_features).
"""
logger.info(f"center_slices: {center_slices}, ratio: {center_slices_ratio}")
# If there's only 1 channel, replicate to 3 channels
if image.shape[1] == 1:
image = image.repeat(1, 3, 1, 1, 1)
# Convert from 'RGB'→(R,G,B) to (B,G,R)
image = image[:, [2, 1, 0], ...]
B, C, H, W, D = image.size()
with torch.no_grad():
# ---------------------- XY-plane slicing along D ----------------------
if center_slices:
start_d = int((1.0 - center_slices_ratio) / 2.0 * D)
end_d = int((1.0 + center_slices_ratio) / 2.0 * D)
slices = torch.unbind(image[:, :, :, :, start_d:end_d:sample_every_k], dim=-1)
else:
slices = torch.unbind(image, dim=-1)
if drop_empty:
mapping_index = drop_empty_slice(slices, empty_threshold)
else:
mapping_index = [True for _ in range(len(slices))]
images_2d = torch.cat(slices, dim=0)
images_2d = radimagenet_intensity_normalisation(images_2d)
images_2d = images_2d[mapping_index]
feature_image_xy = feature_network.forward(images_2d)
feature_image_xy = spatial_average(feature_image_xy, keepdim=False)
if xy_only:
return feature_image_xy, None, None
# ---------------------- YZ-plane slicing along H ----------------------
if center_slices:
start_h = int((1.0 - center_slices_ratio) / 2.0 * H)
end_h = int((1.0 + center_slices_ratio) / 2.0 * H)
slices = torch.unbind(image[:, :, start_h:end_h:sample_every_k, :, :], dim=2)
else:
slices = torch.unbind(image, dim=2)
if drop_empty:
mapping_index = drop_empty_slice(slices, empty_threshold)
else:
mapping_index = [True for _ in range(len(slices))]
images_2d = torch.cat(slices, dim=0)
images_2d = radimagenet_intensity_normalisation(images_2d)
images_2d = images_2d[mapping_index]
feature_image_yz = feature_network.forward(images_2d)
feature_image_yz = spatial_average(feature_image_yz, keepdim=False)
# ---------------------- ZX-plane slicing along W ----------------------
if center_slices:
start_w = int((1.0 - center_slices_ratio) / 2.0 * W)
end_w = int((1.0 + center_slices_ratio) / 2.0 * W)
slices = torch.unbind(image[:, :, :, start_w:end_w:sample_every_k, :], dim=3)
else:
slices = torch.unbind(image, dim=3)
if drop_empty:
mapping_index = drop_empty_slice(slices, empty_threshold)
else:
mapping_index = [True for _ in range(len(slices))]
images_2d = torch.cat(slices, dim=0)
images_2d = radimagenet_intensity_normalisation(images_2d)
images_2d = images_2d[mapping_index]
feature_image_zx = feature_network.forward(images_2d)
feature_image_zx = spatial_average(feature_image_zx, keepdim=False)
return feature_image_xy, feature_image_yz, feature_image_zx
def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float = 0.0) -> torch.Tensor:
"""
Zero-pad a 2D feature map or other tensor along the first dimension to match a specified size.
Args:
tensor (torch.Tensor): The feature tensor to pad.
max_size (int): Desired size along the first dimension.
padding_value (float): Value to fill during padding.
Returns:
torch.Tensor: Padded tensor matching `max_size` along dim=0.
"""
pad_size = [0, 0] * (len(tensor.shape) - 1) + [0, max_size - tensor.shape[0]]
return F.pad(tensor, pad_size, "constant", padding_value)
def main(
real_dataset_root: str = "path/to/datasetA",
real_filelist: str = "path/to/filelistA.txt",
real_features_dir: str = "datasetA",
synth_dataset_root: str = "path/to/datasetB",
synth_filelist: str = "path/to/filelistB.txt",
synth_features_dir: str = "datasetB",
enable_center_slices_ratio: float = None,
enable_padding: bool = True,
enable_center_cropping: bool = True,
enable_resampling_spacing: str = None,
ignore_existing: bool = False,
model_name: str = "radimagenet_resnet50",
num_images: int = 100,
output_root: str = "./features/features-512x512x512",
target_shape: str = "512x512x512",
):
"""
Compute 2.5D FID using distributed GPU processing.
This function loads two datasets (real vs. synthetic) in 3D medical format (NIfTI)
and extracts feature maps via a 2.5D approach, then computes the Frechet Inception
Distance (FID) across three orthogonal planes. Data parallelism is implemented
using torch.distributed with an NCCL backend.
Args:
real_dataset_root (str):
Root folder for the real dataset.
real_filelist (str):
Path to a text file listing 3D images (e.g., NIfTI files) for the real dataset.
Each line in this file should contain a relative path (or filename) to a NIfTI file.
For example, your "real_filelist.txt" could look like:
case001.nii.gz
case002.nii.gz
case003.nii.gz
...
These entries will be appended to `real_dataset_root`.
real_features_dir (str):
Name of the directory under `output_root` in which to store
extracted features for the real dataset.
synth_dataset_root (str):
Root folder for the synthetic dataset.
synth_filelist (str):
Path to a text file listing 3D images (e.g., NIfTI files) for the synthetic dataset.
The format is the same as the real dataset file list, for example:
synth_case001.nii.gz
synth_case002.nii.gz
synth_case003.nii.gz
...
These entries will be appended to `synth_dataset_root`.
synth_features_dir (str):
Name of the directory under `output_root` in which to store
extracted features for the synthetic dataset.
enable_center_slices_ratio (float or None):
- If not None, only slices around the specified center ratio are used.
(similar to "enable_center_slices=True" with that ratio in an earlier script).
- If None, no center-slice selection is performed
(similar to "enable_center_slices=False").
enable_padding (bool):
Whether to pad images to `target_shape`.
enable_center_cropping (bool):
Whether to center-crop images to `target_shape`.
enable_resampling_spacing (str or None):
- If not None, resample images to this voxel spacing (e.g. "1.0x1.0x1.0")
(similar to "enable_resampling=True" with that spacing).
- If None, skip resampling (similar to "enable_resampling=False").
ignore_existing (bool):
If True, ignore any existing .pt feature files and force re-computation.
model_name (str):
Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1".
num_images (int):
Maximum number of images to load from each dataset (truncate if more are present).
output_root (str):
Parent folder where extracted .pt files and logs will be saved.
target_shape (str):
Target shape, e.g. "512x512x512", for padding, cropping, or resampling operations.
Returns:
None
"""
# -------------------------------------------------------------------------
# Initialize Process Group (Distributed)
# -------------------------------------------------------------------------
dist.init_process_group(backend="nccl", init_method="env://", timeout=timedelta(seconds=7200))
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(dist.get_world_size())
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)
logger.info(f"[INFO] Running process on {device} of total {world_size} ranks.")
# Convert potential string bools to actual bools (if using Fire or similar)
if not isinstance(enable_padding, bool):
enable_padding = enable_padding.lower() == "true"
if not isinstance(enable_center_cropping, bool):
enable_center_cropping = enable_center_cropping.lower() == "true"
if not isinstance(ignore_existing, bool):
ignore_existing = ignore_existing.lower() == "true"
# Merge logic for center slices
enable_center_slices = enable_center_slices_ratio is not None
# Merge logic for resampling
enable_resampling = enable_resampling_spacing is not None
# Print out some flags on rank 0
if local_rank == 0:
logger.info(f"Real dataset root: {real_dataset_root}")
logger.info(f"Synth dataset root: {synth_dataset_root}")
logger.info(f"enable_center_slices_ratio: {enable_center_slices_ratio}")
logger.info(f"enable_center_slices: {enable_center_slices}")
logger.info(f"enable_padding: {enable_padding}")
logger.info(f"enable_center_cropping: {enable_center_cropping}")
logger.info(f"enable_resampling_spacing: {enable_resampling_spacing}")
logger.info(f"enable_resampling: {enable_resampling}")
logger.info(f"ignore_existing: {ignore_existing}")
# -------------------------------------------------------------------------
# Load feature extraction model
# -------------------------------------------------------------------------
if model_name == "radimagenet_resnet50":
feature_network = torch.hub.load(
"Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True, trust_repo=True
)
suffix = "radimagenet_resnet50"
else:
import torchvision
feature_network = torchvision.models.squeezenet1_1(pretrained=True)
suffix = "squeezenet1_1"
feature_network.to(device)
feature_network.eval()
# -------------------------------------------------------------------------
# Parse shape/spacings
# -------------------------------------------------------------------------
t_shape = [int(x) for x in target_shape.split("x")]
target_shape_tuple = tuple(t_shape)
# If not None, parse the resampling spacing
if enable_resampling:
rs_spacing = [float(x) for x in enable_resampling_spacing.split("x")]
rs_spacing_tuple = tuple(rs_spacing)
if local_rank == 0:
logger.info(f"Resampling spacing: {rs_spacing_tuple}")
else:
rs_spacing_tuple = (1.0, 1.0, 1.0)
# Use the ratio if provided, otherwise 1.0
center_slices_ratio_final = enable_center_slices_ratio if enable_center_slices else 1.0
if local_rank == 0:
logger.info(f"center_slices_ratio: {center_slices_ratio_final}")
# -------------------------------------------------------------------------
# Prepare Real Dataset
# -------------------------------------------------------------------------
output_root_real = os.path.join(output_root, real_features_dir)
with open(real_filelist, "r") as rf:
real_lines = [l.strip() for l in rf.readlines()]
real_lines.sort()
real_lines = real_lines[:num_images]
real_filenames = [{"image": os.path.join(real_dataset_root, f)} for f in real_lines]
real_filenames = monai.data.partition_dataset(
data=real_filenames, shuffle=False, num_partitions=world_size, even_divisible=False
)[local_rank]
# -------------------------------------------------------------------------
# Prepare Synthetic Dataset
# -------------------------------------------------------------------------
output_root_synth = os.path.join(output_root, synth_features_dir)
with open(synth_filelist, "r") as sf:
synth_lines = [l.strip() for l in sf.readlines()]
synth_lines.sort()
synth_lines = synth_lines[:num_images]
synth_filenames = [{"image": os.path.join(synth_dataset_root, f)} for f in synth_lines]
synth_filenames = monai.data.partition_dataset(
data=synth_filenames, shuffle=False, num_partitions=world_size, even_divisible=False
)[local_rank]
# -------------------------------------------------------------------------
# Build MONAI transforms
# -------------------------------------------------------------------------
transform_list = [
monai.transforms.LoadImaged(keys=["image"]),
monai.transforms.EnsureChannelFirstd(keys=["image"]),
monai.transforms.Orientationd(keys=["image"], axcodes="RAS"),
]
if enable_resampling:
transform_list.append(monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"]))
if enable_padding:
transform_list.append(
monai.transforms.SpatialPadd(keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000)
)
if enable_center_cropping:
transform_list.append(monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple))
transform_list.append(
monai.transforms.ScaleIntensityRanged(
keys=["image"], a_min=-1000, a_max=1000, b_min=-1000, b_max=1000, clip=True
)
)
transforms = Compose(transform_list)
# -------------------------------------------------------------------------
# Create DataLoaders
# -------------------------------------------------------------------------
real_ds = monai.data.Dataset(data=real_filenames, transform=transforms)
real_loader = monai.data.DataLoader(real_ds, num_workers=6, batch_size=1, shuffle=False)
synth_ds = monai.data.Dataset(data=synth_filenames, transform=transforms)
synth_loader = monai.data.DataLoader(synth_ds, num_workers=6, batch_size=1, shuffle=False)
# -------------------------------------------------------------------------
# Extract features for Real Dataset
# -------------------------------------------------------------------------
real_features_xy, real_features_yz, real_features_zx = [], [], []
for idx, batch_data in enumerate(real_loader, start=1):
img = batch_data["image"].to(device)
fn = img.meta["filename_or_obj"][0]
logger.info(f"[Rank {local_rank}] Real data {idx}/{len(real_filenames)}: {fn}")
out_fp = fn.replace(real_dataset_root, output_root_real).replace(".nii.gz", ".pt")
out_fp = Path(out_fp)
out_fp.parent.mkdir(parents=True, exist_ok=True)
if (not ignore_existing) and os.path.isfile(out_fp):
feats = torch.load(out_fp, weights_only=True)
else:
img_t = img.as_tensor()
logger.info(f"image shape: {tuple(img_t.shape)}")
feats = get_features_2p5d(
img_t,
feature_network,
center_slices=enable_center_slices,
center_slices_ratio=center_slices_ratio_final,
xy_only=False,
)
logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}")
torch.save(feats, out_fp)
real_features_xy.append(feats[0])
real_features_yz.append(feats[1])
real_features_zx.append(feats[2])
real_features_xy = torch.vstack(real_features_xy)
real_features_yz = torch.vstack(real_features_yz)
real_features_zx = torch.vstack(real_features_zx)
logger.info(
f"Real feature shapes: {real_features_xy.shape}, " f"{real_features_yz.shape}, {real_features_zx.shape}"
)
# -------------------------------------------------------------------------
# Extract features for Synthetic Dataset
# -------------------------------------------------------------------------
synth_features_xy, synth_features_yz, synth_features_zx = [], [], []
for idx, batch_data in enumerate(synth_loader, start=1):
img = batch_data["image"].to(device)
fn = img.meta["filename_or_obj"][0]
logger.info(f"[Rank {local_rank}] Synth data {idx}/{len(synth_filenames)}: {fn}")
out_fp = fn.replace(synth_dataset_root, output_root_synth).replace(".nii.gz", ".pt")
out_fp = Path(out_fp)
out_fp.parent.mkdir(parents=True, exist_ok=True)
if (not ignore_existing) and os.path.isfile(out_fp):
feats = torch.load(out_fp, weights_only=True)
else:
img_t = img.as_tensor()
logger.info(f"image shape: {tuple(img_t.shape)}")
feats = get_features_2p5d(
img_t,
feature_network,
center_slices=enable_center_slices,
center_slices_ratio=center_slices_ratio_final,
xy_only=False,
)
logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}")
torch.save(feats, out_fp)
synth_features_xy.append(feats[0])
synth_features_yz.append(feats[1])
synth_features_zx.append(feats[2])
synth_features_xy = torch.vstack(synth_features_xy)
synth_features_yz = torch.vstack(synth_features_yz)
synth_features_zx = torch.vstack(synth_features_zx)
logger.info(
f"Synth feature shapes: {synth_features_xy.shape}, " f"{synth_features_yz.shape}, {synth_features_zx.shape}"
)
# -------------------------------------------------------------------------
# All-reduce / gather features across ranks
# -------------------------------------------------------------------------
features = [
real_features_xy,
real_features_yz,
real_features_zx,
synth_features_xy,
synth_features_yz,
synth_features_zx,
]
# 1) Gather local feature sizes across ranks
local_sizes = []
for ft_idx in range(len(features)):
local_size = torch.tensor([features[ft_idx].shape[0]], dtype=torch.int64, device=device)
local_sizes.append(local_size)
all_sizes = []
for ft_idx in range(len(features)):
rank_sizes = [torch.tensor([0], dtype=torch.int64, device=device) for _ in range(world_size)]
dist.all_gather(rank_sizes, local_sizes[ft_idx])
all_sizes.append(rank_sizes)
# 2) Pad and gather all features
all_tensors_list = []
for ft_idx, ft in enumerate(features):
max_size = max(all_sizes[ft_idx]).item()
ft_padded = pad_to_max_size(ft, max_size)
gather_list = [torch.empty_like(ft_padded) for _ in range(world_size)]
dist.all_gather(gather_list, ft_padded)
# Trim each gather back to the real size
for rk in range(world_size):
gather_list[rk] = gather_list[rk][: all_sizes[ft_idx][rk], :]
all_tensors_list.append(gather_list)
# On rank 0, compute FID
if local_rank == 0:
real_xy = torch.vstack(all_tensors_list[0])
real_yz = torch.vstack(all_tensors_list[1])
real_zx = torch.vstack(all_tensors_list[2])
synth_xy = torch.vstack(all_tensors_list[3])
synth_yz = torch.vstack(all_tensors_list[4])
synth_zx = torch.vstack(all_tensors_list[5])
logger.info(f"Final Real shapes: {real_xy.shape}, {real_yz.shape}, {real_zx.shape}")
logger.info(f"Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}")
fid = FIDMetric()
logger.info(f"Computing FID for: {output_root_real} | {output_root_synth}")
fid_res_xy = fid(synth_xy, real_xy)
fid_res_yz = fid(synth_yz, real_yz)
fid_res_zx = fid(synth_zx, real_zx)
logger.info(f"FID XY: {fid_res_xy}")
logger.info(f"FID YZ: {fid_res_yz}")
logger.info(f"FID ZX: {fid_res_zx}")
fid_avg = (fid_res_xy + fid_res_yz + fid_res_zx) / 3.0
logger.info(f"FID Avg: {fid_avg}")
dist.destroy_process_group()
if __name__ == "__main__":
fire.Fire(main)