Prior2DSM / src /dinov3 /eval /segmentation /inference.py
osherr's picture
Upload 222 files
bc90483 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
from typing import Callable, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torchvision.transforms import functional as Fv
import dinov3.distributed as distributed
def precompute_forward_number_for_sliding_inference(
test_dataloader,
dataset_len: int,
eval_crop_size: int,
eval_stride: int,
):
image_crop_nums = torch.zeros(dataset_len, device=distributed.get_rank(), dtype=torch.int8)
print("Computing the number of forwards for sliding window evaluation")
for batch_img, target in test_dataloader:
# Dataset is wrapped in DatasetWithEnumeratedTargets
# and has index information
index, _ = target
# Only keep samples with non-negative indices
if index.item() < 0:
continue
batch_image_crops = []
for img in batch_img:
# Compute the number of crops to create (thus the number of forwards to do for each image)
h_stride, w_stride = eval_stride, eval_stride # type: ignore
h_crop, w_crop = eval_crop_size, eval_crop_size # type: ignore
h_img, w_img = img.shape[-2:]
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 # type: ignore
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 # type: ignore
batch_image_crops.append(h_grids * w_grids) # number of crops
image_crop_nums[index.item()] = max(batch_image_crops) # add information to the global tensor
dist.all_reduce(image_crop_nums, op=dist.ReduceOp.MAX)
return torch.max(image_crop_nums).item()
def make_inference(
x: torch.Tensor,
segmentation_model: nn.Module,
inference_mode: str = "whole",
decoder_head_type: str = "linear",
rescale_to=(512, 512),
n_output_channels: int = 256,
crop_size: Optional[Tuple[int]] = None,
stride: Optional[Tuple[int]] = None,
apply_horizontal_flip: bool = False,
num_max_forward: int = 1,
output_activation: Callable | None = None,
):
"""Make inference on a given image, and reverts horizontal flip TTA if applicable.
If `inference_mode` = whole, one single prediction is made for the image.
If `inference_mode` = slide, the image is cropped into multiple slices and the latter are
used to make prediction following a sliding window method.
Args:
x (tensor): input image to make inference on.
dense_predictor (nn.Module): model to use for evaluating on dense tasks.
requires a `predict` method.
inference_mode (str, optional): Do inference on the whole image (mode="whole"), or by
adopting a sliding window approach to aggregate the results on
smaller patches of the input image (mode="slide"). Defaults to "whole".
rescale_to (tuple, optional): Resizing the output of the model prediction to the
shape of the ground truth. Defaults to (512, 512).
n_output_channels (int): number of output classes
crop_size (tuple, optional): [h_crop, w_crop]
stride (tuple, optional): [h_stride, w_stride]
apply_horizontal_flip (bool): Determines if horizontal flip TTA was applied for
the prediction. Defaults to False.
output_activation (callable): Output activation to use on top of the predictions.
- softmax is used when each pixel belongs to a single class (multiclass),
- sigmoid is used when pixel can belong to multiple classes (multilabel). Defaults to None (identity).
Returns:
Tensor: The segmentation results created from the input image.
"""
assert inference_mode in ["whole", "slide"]
if inference_mode == "slide":
# crop size and stride are needed for sliding inference
assert crop_size is not None
assert stride is not None
pred = F.interpolate(
slide_inference(
x,
segmentation_model,
decoder_head_type,
n_output_channels=n_output_channels,
crop_size=crop_size,
stride=stride,
num_max_forward=num_max_forward,
),
size=rescale_to,
mode="bilinear",
align_corners=False,
)
else:
pred = segmentation_model.predict(
F.interpolate(
x,
size=(512, 512),
mode="bilinear",
align_corners=False,
),
rescale_to=rescale_to,
)
if decoder_head_type == "m2f":
mask_pred, mask_cls = pred["pred_masks"], pred["pred_logits"]
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
pred = torch.einsum("bqc,bqhw->bchw", mask_cls.to(torch.float), mask_pred.to(torch.float))
if apply_horizontal_flip:
pred = Fv.hflip(pred)
if output_activation:
pred = output_activation(pred)
return pred
def slide_inference(
inputs: torch.Tensor,
segmentation_model: nn.Module,
decoder_head_type: str = "linear",
n_output_channels: int = 256,
crop_size: Tuple = (512, 512),
stride: Tuple = (341, 341),
num_max_forward: int = 1,
):
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
Args:
inputs (tensor): the tensor should have a shape NxCxHxW,
which contains all images in the batch.
segmentation_model (nn.Module): model to use for evaluating on dense tasks.
n_output_channels (int): number of output channels
crop_size (tuple): (h_crop, w_crop)
stride (tuple): (h_stride, w_stride)
Returns:
Tensor: The output results from model of each input image.
"""
h_stride, w_stride = stride
h_crop, w_crop = crop_size
batch_size, C, h_img, w_img = inputs.shape
if h_crop > h_img and w_crop > w_img: # Meaning we are doing < 1.0 TTA
h_crop, w_crop = min(h_img, w_img), min(h_img, w_img)
assert batch_size == 1 # As of now, the code assumes that a single image is passed at a time at inference time
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = inputs.new_zeros((1, n_output_channels, h_img, w_img)).cpu()
count_mat = inputs.new_zeros((1, 1, h_img, w_img)).to(torch.int8).cpu()
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = inputs[:, :, y1:y2, x1:x2]
crop_pred = segmentation_model.predict(crop_img, rescale_to=crop_img.shape[2:])
if decoder_head_type == "m2f":
mask_pred, mask_cls = crop_pred["pred_masks"], crop_pred["pred_logits"]
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
crop_pred = torch.einsum("bqc,bqhw->bchw", mask_cls.to(torch.bfloat16), mask_pred.to(torch.bfloat16))
del mask_cls, mask_pred
preds += F.pad(crop_pred, (int(x1), int(preds.shape[-1] - x2), int(y1), int(preds.shape[-2] - y2))).cpu()
count_mat[:, :, y1:y2, x1:x2] += 1
del crop_img, crop_pred
# Optional buffer to ensure each gpu does the same number of operations for sharded models
for _ in range(h_grids * w_grids, num_max_forward):
dummy_input = inputs.new_zeros((1, C, h_crop, w_crop))
_ = segmentation_model.predict(dummy_input, rescale_to=dummy_input.shape[2:])
assert (count_mat == 0).sum() == 0
return preds / count_mat