# Copyright (c) Meta Platforms, Inc. and affiliates. # # This software may be used and distributed in accordance with # the terms of the DINOv3 License Agreement. from typing import Callable, Optional, Tuple import torch import torch.distributed as dist import torch.nn.functional as F from torch import nn from torchvision.transforms import functional as Fv import dinov3.distributed as distributed def precompute_forward_number_for_sliding_inference( test_dataloader, dataset_len: int, eval_crop_size: int, eval_stride: int, ): image_crop_nums = torch.zeros(dataset_len, device=distributed.get_rank(), dtype=torch.int8) print("Computing the number of forwards for sliding window evaluation") for batch_img, target in test_dataloader: # Dataset is wrapped in DatasetWithEnumeratedTargets # and has index information index, _ = target # Only keep samples with non-negative indices if index.item() < 0: continue batch_image_crops = [] for img in batch_img: # Compute the number of crops to create (thus the number of forwards to do for each image) h_stride, w_stride = eval_stride, eval_stride # type: ignore h_crop, w_crop = eval_crop_size, eval_crop_size # type: ignore h_img, w_img = img.shape[-2:] h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 # type: ignore w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 # type: ignore batch_image_crops.append(h_grids * w_grids) # number of crops image_crop_nums[index.item()] = max(batch_image_crops) # add information to the global tensor dist.all_reduce(image_crop_nums, op=dist.ReduceOp.MAX) return torch.max(image_crop_nums).item() def make_inference( x: torch.Tensor, segmentation_model: nn.Module, inference_mode: str = "whole", decoder_head_type: str = "linear", rescale_to=(512, 512), n_output_channels: int = 256, crop_size: Optional[Tuple[int]] = None, stride: Optional[Tuple[int]] = None, apply_horizontal_flip: bool = False, num_max_forward: int = 1, output_activation: Callable | None = None, ): """Make inference on a given image, and reverts horizontal flip TTA if applicable. If `inference_mode` = whole, one single prediction is made for the image. If `inference_mode` = slide, the image is cropped into multiple slices and the latter are used to make prediction following a sliding window method. Args: x (tensor): input image to make inference on. dense_predictor (nn.Module): model to use for evaluating on dense tasks. requires a `predict` method. inference_mode (str, optional): Do inference on the whole image (mode="whole"), or by adopting a sliding window approach to aggregate the results on smaller patches of the input image (mode="slide"). Defaults to "whole". rescale_to (tuple, optional): Resizing the output of the model prediction to the shape of the ground truth. Defaults to (512, 512). n_output_channels (int): number of output classes crop_size (tuple, optional): [h_crop, w_crop] stride (tuple, optional): [h_stride, w_stride] apply_horizontal_flip (bool): Determines if horizontal flip TTA was applied for the prediction. Defaults to False. output_activation (callable): Output activation to use on top of the predictions. - softmax is used when each pixel belongs to a single class (multiclass), - sigmoid is used when pixel can belong to multiple classes (multilabel). Defaults to None (identity). Returns: Tensor: The segmentation results created from the input image. """ assert inference_mode in ["whole", "slide"] if inference_mode == "slide": # crop size and stride are needed for sliding inference assert crop_size is not None assert stride is not None pred = F.interpolate( slide_inference( x, segmentation_model, decoder_head_type, n_output_channels=n_output_channels, crop_size=crop_size, stride=stride, num_max_forward=num_max_forward, ), size=rescale_to, mode="bilinear", align_corners=False, ) else: pred = segmentation_model.predict( F.interpolate( x, size=(512, 512), mode="bilinear", align_corners=False, ), rescale_to=rescale_to, ) if decoder_head_type == "m2f": mask_pred, mask_cls = pred["pred_masks"], pred["pred_logits"] mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] mask_pred = mask_pred.sigmoid() pred = torch.einsum("bqc,bqhw->bchw", mask_cls.to(torch.float), mask_pred.to(torch.float)) if apply_horizontal_flip: pred = Fv.hflip(pred) if output_activation: pred = output_activation(pred) return pred def slide_inference( inputs: torch.Tensor, segmentation_model: nn.Module, decoder_head_type: str = "linear", n_output_channels: int = 256, crop_size: Tuple = (512, 512), stride: Tuple = (341, 341), num_max_forward: int = 1, ): """Inference by sliding-window with overlap. If h_crop > h_img or w_crop > w_img, the small patch will be used to decode without padding. Args: inputs (tensor): the tensor should have a shape NxCxHxW, which contains all images in the batch. segmentation_model (nn.Module): model to use for evaluating on dense tasks. n_output_channels (int): number of output channels crop_size (tuple): (h_crop, w_crop) stride (tuple): (h_stride, w_stride) Returns: Tensor: The output results from model of each input image. """ h_stride, w_stride = stride h_crop, w_crop = crop_size batch_size, C, h_img, w_img = inputs.shape if h_crop > h_img and w_crop > w_img: # Meaning we are doing < 1.0 TTA h_crop, w_crop = min(h_img, w_img), min(h_img, w_img) assert batch_size == 1 # As of now, the code assumes that a single image is passed at a time at inference time h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 preds = inputs.new_zeros((1, n_output_channels, h_img, w_img)).cpu() count_mat = inputs.new_zeros((1, 1, h_img, w_img)).to(torch.int8).cpu() for h_idx in range(h_grids): for w_idx in range(w_grids): y1 = h_idx * h_stride x1 = w_idx * w_stride y2 = min(y1 + h_crop, h_img) x2 = min(x1 + w_crop, w_img) y1 = max(y2 - h_crop, 0) x1 = max(x2 - w_crop, 0) crop_img = inputs[:, :, y1:y2, x1:x2] crop_pred = segmentation_model.predict(crop_img, rescale_to=crop_img.shape[2:]) if decoder_head_type == "m2f": mask_pred, mask_cls = crop_pred["pred_masks"], crop_pred["pred_logits"] mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] mask_pred = mask_pred.sigmoid() crop_pred = torch.einsum("bqc,bqhw->bchw", mask_cls.to(torch.bfloat16), mask_pred.to(torch.bfloat16)) del mask_cls, mask_pred preds += F.pad(crop_pred, (int(x1), int(preds.shape[-1] - x2), int(y1), int(preds.shape[-2] - y2))).cpu() count_mat[:, :, y1:y2, x1:x2] += 1 del crop_img, crop_pred # Optional buffer to ensure each gpu does the same number of operations for sharded models for _ in range(h_grids * w_grids, num_max_forward): dummy_input = inputs.new_zeros((1, C, h_crop, w_crop)) _ = segmentation_model.predict(dummy_input, rescale_to=dummy_input.shape[2:]) assert (count_mat == 0).sum() == 0 return preds / count_mat