File size: 8,303 Bytes
bc90483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.

from typing import Callable, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torchvision.transforms import functional as Fv

import dinov3.distributed as distributed


def precompute_forward_number_for_sliding_inference(

    test_dataloader,

    dataset_len: int,

    eval_crop_size: int,

    eval_stride: int,

):
    image_crop_nums = torch.zeros(dataset_len, device=distributed.get_rank(), dtype=torch.int8)
    print("Computing the number of forwards for sliding window evaluation")
    for batch_img, target in test_dataloader:
        # Dataset is wrapped in DatasetWithEnumeratedTargets
        # and has index information
        index, _ = target
        # Only keep samples with non-negative indices
        if index.item() < 0:
            continue
        batch_image_crops = []
        for img in batch_img:
            # Compute the number of crops to create (thus the number of forwards to do for each image)
            h_stride, w_stride = eval_stride, eval_stride  # type: ignore
            h_crop, w_crop = eval_crop_size, eval_crop_size  # type: ignore
            h_img, w_img = img.shape[-2:]
            h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1  # type: ignore
            w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1  # type: ignore
            batch_image_crops.append(h_grids * w_grids)  # number of crops
        image_crop_nums[index.item()] = max(batch_image_crops)  # add information to the global tensor
    dist.all_reduce(image_crop_nums, op=dist.ReduceOp.MAX)
    return torch.max(image_crop_nums).item()


def make_inference(

    x: torch.Tensor,

    segmentation_model: nn.Module,

    inference_mode: str = "whole",

    decoder_head_type: str = "linear",

    rescale_to=(512, 512),

    n_output_channels: int = 256,

    crop_size: Optional[Tuple[int]] = None,

    stride: Optional[Tuple[int]] = None,

    apply_horizontal_flip: bool = False,

    num_max_forward: int = 1,

    output_activation: Callable | None = None,

):
    """Make inference on a given image, and reverts horizontal flip TTA if applicable.

    If `inference_mode` = whole, one single prediction is made for the image.

    If `inference_mode` = slide, the image is cropped into multiple slices and the latter are

    used to make prediction following a sliding window method.



    Args:

        x (tensor): input image to make inference on.

        dense_predictor (nn.Module): model to use for evaluating on dense tasks.

            requires a `predict` method.

        inference_mode (str, optional): Do inference on the whole image (mode="whole"), or by

            adopting a sliding window approach to aggregate the results on

            smaller patches of the input image (mode="slide"). Defaults to "whole".

        rescale_to (tuple, optional): Resizing the output of the model prediction to the

            shape of the ground truth. Defaults to (512, 512).

        n_output_channels (int): number of output classes

        crop_size (tuple, optional): [h_crop, w_crop]

        stride (tuple, optional): [h_stride, w_stride]

        apply_horizontal_flip (bool): Determines if horizontal flip TTA was applied for

            the prediction. Defaults to False.

        output_activation (callable): Output activation to use on top of the predictions.

            - softmax is used when each pixel belongs to a single class (multiclass),

            - sigmoid is used when pixel can belong to multiple classes (multilabel). Defaults to None (identity).

    Returns:

        Tensor: The segmentation results created from the input image.

    """
    assert inference_mode in ["whole", "slide"]
    if inference_mode == "slide":
        # crop size and stride are needed for sliding inference
        assert crop_size is not None
        assert stride is not None
        pred = F.interpolate(
            slide_inference(
                x,
                segmentation_model,
                decoder_head_type,
                n_output_channels=n_output_channels,
                crop_size=crop_size,
                stride=stride,
                num_max_forward=num_max_forward,
            ),
            size=rescale_to,
            mode="bilinear",
            align_corners=False,
        )
    else:
        pred = segmentation_model.predict(
            F.interpolate(
                x,
                size=(512, 512),
                mode="bilinear",
                align_corners=False,
            ),
            rescale_to=rescale_to,
        )
        if decoder_head_type == "m2f":
            mask_pred, mask_cls = pred["pred_masks"], pred["pred_logits"]
            mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
            mask_pred = mask_pred.sigmoid()
            pred = torch.einsum("bqc,bqhw->bchw", mask_cls.to(torch.float), mask_pred.to(torch.float))
    if apply_horizontal_flip:
        pred = Fv.hflip(pred)
    if output_activation:
        pred = output_activation(pred)
    return pred


def slide_inference(

    inputs: torch.Tensor,

    segmentation_model: nn.Module,

    decoder_head_type: str = "linear",

    n_output_channels: int = 256,

    crop_size: Tuple = (512, 512),

    stride: Tuple = (341, 341),

    num_max_forward: int = 1,

):
    """Inference by sliding-window with overlap.

    If h_crop > h_img or w_crop > w_img, the small patch will be used to

    decode without padding.

    Args:

        inputs (tensor): the tensor should have a shape NxCxHxW,

            which contains all images in the batch.

        segmentation_model (nn.Module): model to use for evaluating on dense tasks.

        n_output_channels (int): number of output channels

        crop_size (tuple): (h_crop, w_crop)

        stride (tuple): (h_stride, w_stride)

    Returns:

        Tensor: The output results from model of each input image.

    """
    h_stride, w_stride = stride
    h_crop, w_crop = crop_size
    batch_size, C, h_img, w_img = inputs.shape
    if h_crop > h_img and w_crop > w_img:  # Meaning we are doing < 1.0 TTA
        h_crop, w_crop = min(h_img, w_img), min(h_img, w_img)
    assert batch_size == 1  # As of now, the code assumes that a single image is passed at a time at inference time
    h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
    w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
    preds = inputs.new_zeros((1, n_output_channels, h_img, w_img)).cpu()
    count_mat = inputs.new_zeros((1, 1, h_img, w_img)).to(torch.int8).cpu()
    for h_idx in range(h_grids):
        for w_idx in range(w_grids):
            y1 = h_idx * h_stride
            x1 = w_idx * w_stride
            y2 = min(y1 + h_crop, h_img)
            x2 = min(x1 + w_crop, w_img)
            y1 = max(y2 - h_crop, 0)
            x1 = max(x2 - w_crop, 0)
            crop_img = inputs[:, :, y1:y2, x1:x2]
            crop_pred = segmentation_model.predict(crop_img, rescale_to=crop_img.shape[2:])
            if decoder_head_type == "m2f":
                mask_pred, mask_cls = crop_pred["pred_masks"], crop_pred["pred_logits"]
                mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
                mask_pred = mask_pred.sigmoid()
                crop_pred = torch.einsum("bqc,bqhw->bchw", mask_cls.to(torch.bfloat16), mask_pred.to(torch.bfloat16))
                del mask_cls, mask_pred
            preds += F.pad(crop_pred, (int(x1), int(preds.shape[-1] - x2), int(y1), int(preds.shape[-2] - y2))).cpu()
            count_mat[:, :, y1:y2, x1:x2] += 1
            del crop_img, crop_pred
    # Optional buffer to ensure each gpu does the same number of operations for sharded models
    for _ in range(h_grids * w_grids, num_max_forward):
        dummy_input = inputs.new_zeros((1, C, h_crop, w_crop))
        _ = segmentation_model.predict(dummy_input, rescale_to=dummy_input.shape[2:])
    assert (count_mat == 0).sum() == 0
    return preds / count_mat