File size: 3,374 Bytes
1633fcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import random
from torchvision import transforms
import torch

logger = logging.getLogger("dinov2")



def collate_data_and_cast_aug(
    samples_list,
    mask_ratio,
    mask_probability,
    dtype,
    n_tokens=None,
    mask_first_n=False,
    mask_generator=None,
    clone_batch=1,
):
    # dtype = torch.half  # TODO: Remove

    n_global_crops = 1

    assert n_global_crops > 0, "global crops number should be > 0"
    collated_global_crops = torch.stack([s[i] for i in range(n_global_crops) for s in samples_list])

    labels = [s[1] for s in samples_list]
    labels = torch.LongTensor(labels)
    collated_global_labels = labels.repeat(n_global_crops)

    B = len(collated_global_crops)
    N = n_tokens
    n_samples_masked = int(B * mask_probability)

    masks_list = []
    upperbound = 0

    masks_enc = torch.full((1,), 0, dtype=torch.int32)
    masks_pred = torch.full((1,), 0, dtype=torch.int32)
    # specify the number of masks to append
    number_masks = n_samples_masked * clone_batch
    # do per-sample masking
    if isinstance(mask_ratio, (tuple, list)) and len(mask_ratio) == 2:
        probs = torch.linspace(*mask_ratio, number_masks + 1)
        for i in range(0, number_masks):
            prob_min = probs[i]
            prob_max = probs[i + 1]
            masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
            upperbound += int(N * prob_max)
    else:
        mask_ratio = mask_ratio[0]
        # apply the same mask ratio to all images
        for i in range(0, number_masks):
            masks_list.append(torch.BoolTensor(mask_generator(int(N * mask_ratio))))
            upperbound += int(N * mask_ratio)

    # append masks for unmasked samples
    for i in range(n_samples_masked, B):
        # masks_list.append(torch.BoolTensor(mask_generator(0)))
        masks_list.append(torch.BoolTensor(mask_generator.get_none_mask()))

    if not mask_first_n and mask_probability > 0.0:  # shuffle masking -- not shuffling for mae-style
        random.shuffle(masks_list)

    collated_masks = torch.stack(masks_list).flatten(1)
    mask_indices_list = collated_masks.flatten().nonzero().flatten()

    masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]

    return {
        "collated_global_crops": collated_global_crops.to(dtype),
        "collated_global_labels": collated_global_labels,
        "collated_masks": collated_masks,
        "mask_indices_list": mask_indices_list,
        "masks_weight": masks_weight,
        "upperbound": upperbound,
        "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
        "masks_enc": masks_enc,
        "masks_pred": masks_pred,
    }