File size: 7,401 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine.dist as dist
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.logging import MMLogger
from scipy import interpolate


def all_gather_concat(data: torch.Tensor) -> torch.Tensor:
    """Gather tensors with different first-dimension size and concat to one

    tenosr.



    Note:

        Only the first dimension should be different.



    Args:

        data (Tensor): Tensor to be gathered.



    Returns:

        torch.Tensor: The concatenated tenosr.

    """
    if dist.get_world_size() == 1:
        return data

    data_size = torch.tensor(data.size(0), device=data.device)
    sizes_list = dist.all_gather(data_size)

    total_length = sum(sizes_list)
    max_length = max(sizes_list)
    size_diff = max_length.item() - data_size.item()
    if size_diff:
        padding = torch.zeros(
            size_diff, *data.size()[1:], device=data.device, dtype=data.dtype)
        data = torch.cat((data, padding))

    gather_list = dist.all_gather(data)

    # gather all data according to the default DDP sampler. For instance,
    # 8 samples on 2 GPUs, GPU0: [0,2,4,6], GPU1: [1,3,5,7], will be gathered
    # as [0,1,2,3,4,5,6,7]
    all_data = []
    for gather_batch in zip(*gather_list):
        all_data.extend(gather_batch)

    return torch.stack(all_data)[:total_length]


def interpolate_pos_embed_beit(state_dict, new_model):
    """interpolate the positional embeddings. The spatial pe is relative and

    temporal pe is absolute. additional temporal pe is padded with 0.



    Args:

        state_dict (dict): The state_dict.

        new_model (nn.Module): The created model.



    Returns: dict. The state_dict with updated positional embeddings.

    """
    state_dict = interpolate_pos_relative_bias_beit(
        state_dict_old=state_dict,
        state_dict_new=new_model.state_dict(),
        patch_shape_new=new_model.vision_encoder.embeddings.patch_embeddings.
        patch_shape,
    )
    # absolute temporal pos bias
    temporal_pe_key = 'vision_encoder.embeddings.temporal_position_embeddings'
    if temporal_pe_key in state_dict:
        logger = MMLogger.get_current_instance()
        logger.info(
            f'interpolate temporal positional embeddings: {temporal_pe_key}')
        state_dict[temporal_pe_key] = load_temp_embed_with_mismatch(
            temp_embed_old=state_dict[temporal_pe_key],
            temp_embed_new=new_model.state_dict()[temporal_pe_key],
        )
    return state_dict


def load_temp_embed_with_mismatch(temp_embed_old,

                                  temp_embed_new,

                                  add_zero=True):
    """Add/Remove extra temporal_embeddings as needed.

    https://arxiv.org/abs/2104.00650 shows adding zero paddings works.



    temp_embed_old: (1, num_frames_old, 1, d)

    temp_embed_new: (1, num_frames_new, 1, d)

    add_zero: bool, if True, add zero, else, interpolate trained embeddings.

    """
    # TODO zero pad
    num_frms_new = temp_embed_new.shape[1]
    num_frms_old = temp_embed_old.shape[1]
    logger = MMLogger.get_current_instance()
    logger.info(
        f'Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}')
    if num_frms_new > num_frms_old:
        if add_zero:
            temp_embed_new[:, :num_frms_old] \
                = temp_embed_old  # untrained embeddings are zeros.
        else:
            temp_embed_new = interpolate_temporal_pos_embed(
                temp_embed_old, num_frms_new)
    elif num_frms_new < num_frms_old:
        temp_embed_new = temp_embed_old[:, :num_frms_new]
    else:  # =
        temp_embed_new = temp_embed_old
    return temp_embed_new


def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new):
    """

    temp_embed_old: (1, num_frames_old, 1, d)

    Returns:

        temp_embed_new: (1, num_frames_new, 1, d)

    """
    temp_embed_old = temp_embed_old.squeeze(2).permute(
        0, 2, 1)  # (1, d, num_frames_old)
    temp_embed_new = F.interpolate(
        temp_embed_old, num_frames_new,
        mode='linear')  # (1, d, num_frames_new)
    temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze(
        2)  # (1, num_frames_new, 1, d)
    return temp_embed_new


def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new,

                                       patch_shape_new):
    """

    Args:

        state_dict_old: loaded state dict

        state_dict_new: state dict for model with new image size

        patch_shape_new: new model patch_shape

    ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py  # noqa: E501

    """
    all_keys = list(state_dict_old.keys())
    for key in all_keys:
        if 'relative_position_index' in key:
            state_dict_old.pop(key)

        if 'relative_position_bias_table' in key:
            rel_pos_bias = state_dict_old[key]
            src_num_pos, num_attn_heads = rel_pos_bias.size()
            dst_num_pos, _ = state_dict_new[key].size()
            dst_patch_shape = patch_shape_new
            if dst_patch_shape[0] != dst_patch_shape[1]:
                raise NotImplementedError()
            num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
                dst_patch_shape[1] * 2 - 1)
            src_size = int((src_num_pos - num_extra_tokens)**0.5)
            dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
            if src_size != dst_size:
                extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
                rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]

                def geometric_progression(a, r, n):
                    return a * (1.0 - r**n) / (1.0 - r)

                left, right = 1.01, 1.5
                while right - left > 1e-6:
                    q = (left + right) / 2.0
                    gp = geometric_progression(1, q, src_size // 2)
                    if gp > dst_size // 2:
                        right = q
                    else:
                        left = q

                dis = []
                cur = 1
                for i in range(src_size // 2):
                    dis.append(cur)
                    cur += q**(i + 1)

                r_ids = [-_ for _ in reversed(dis)]

                x = r_ids + [0] + dis
                y = r_ids + [0] + dis

                t = dst_size // 2.0
                dx = np.arange(-t, t + 0.1, 1.0)
                dy = np.arange(-t, t + 0.1, 1.0)

                all_rel_pos_bias = []

                for i in range(num_attn_heads):
                    z = rel_pos_bias[:, i].view(src_size,
                                                src_size).float().numpy()
                    f = interpolate.interp2d(x, y, z, kind='cubic')
                    all_rel_pos_bias.append(
                        torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(
                            rel_pos_bias.device))

                rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)

                new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens),
                                             dim=0)
                state_dict_old[key] = new_rel_pos_bias
    return state_dict_old