AZIIIIIIIIZ's picture
Upload 1039 files
d670799 verified
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine.dist as dist
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.logging import MMLogger
from scipy import interpolate
def all_gather_concat(data: torch.Tensor) -> torch.Tensor:
"""Gather tensors with different first-dimension size and concat to one
tenosr.
Note:
Only the first dimension should be different.
Args:
data (Tensor): Tensor to be gathered.
Returns:
torch.Tensor: The concatenated tenosr.
"""
if dist.get_world_size() == 1:
return data
data_size = torch.tensor(data.size(0), device=data.device)
sizes_list = dist.all_gather(data_size)
total_length = sum(sizes_list)
max_length = max(sizes_list)
size_diff = max_length.item() - data_size.item()
if size_diff:
padding = torch.zeros(
size_diff, *data.size()[1:], device=data.device, dtype=data.dtype)
data = torch.cat((data, padding))
gather_list = dist.all_gather(data)
# gather all data according to the default DDP sampler. For instance,
# 8 samples on 2 GPUs, GPU0: [0,2,4,6], GPU1: [1,3,5,7], will be gathered
# as [0,1,2,3,4,5,6,7]
all_data = []
for gather_batch in zip(*gather_list):
all_data.extend(gather_batch)
return torch.stack(all_data)[:total_length]
def interpolate_pos_embed_beit(state_dict, new_model):
"""interpolate the positional embeddings. The spatial pe is relative and
temporal pe is absolute. additional temporal pe is padded with 0.
Args:
state_dict (dict): The state_dict.
new_model (nn.Module): The created model.
Returns: dict. The state_dict with updated positional embeddings.
"""
state_dict = interpolate_pos_relative_bias_beit(
state_dict_old=state_dict,
state_dict_new=new_model.state_dict(),
patch_shape_new=new_model.vision_encoder.embeddings.patch_embeddings.
patch_shape,
)
# absolute temporal pos bias
temporal_pe_key = 'vision_encoder.embeddings.temporal_position_embeddings'
if temporal_pe_key in state_dict:
logger = MMLogger.get_current_instance()
logger.info(
f'interpolate temporal positional embeddings: {temporal_pe_key}')
state_dict[temporal_pe_key] = load_temp_embed_with_mismatch(
temp_embed_old=state_dict[temporal_pe_key],
temp_embed_new=new_model.state_dict()[temporal_pe_key],
)
return state_dict
def load_temp_embed_with_mismatch(temp_embed_old,
temp_embed_new,
add_zero=True):
"""Add/Remove extra temporal_embeddings as needed.
https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
temp_embed_old: (1, num_frames_old, 1, d)
temp_embed_new: (1, num_frames_new, 1, d)
add_zero: bool, if True, add zero, else, interpolate trained embeddings.
"""
# TODO zero pad
num_frms_new = temp_embed_new.shape[1]
num_frms_old = temp_embed_old.shape[1]
logger = MMLogger.get_current_instance()
logger.info(
f'Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}')
if num_frms_new > num_frms_old:
if add_zero:
temp_embed_new[:, :num_frms_old] \
= temp_embed_old # untrained embeddings are zeros.
else:
temp_embed_new = interpolate_temporal_pos_embed(
temp_embed_old, num_frms_new)
elif num_frms_new < num_frms_old:
temp_embed_new = temp_embed_old[:, :num_frms_new]
else: # =
temp_embed_new = temp_embed_old
return temp_embed_new
def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new):
"""
temp_embed_old: (1, num_frames_old, 1, d)
Returns:
temp_embed_new: (1, num_frames_new, 1, d)
"""
temp_embed_old = temp_embed_old.squeeze(2).permute(
0, 2, 1) # (1, d, num_frames_old)
temp_embed_new = F.interpolate(
temp_embed_old, num_frames_new,
mode='linear') # (1, d, num_frames_new)
temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze(
2) # (1, num_frames_new, 1, d)
return temp_embed_new
def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new,
patch_shape_new):
"""
Args:
state_dict_old: loaded state dict
state_dict_new: state dict for model with new image size
patch_shape_new: new model patch_shape
ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py # noqa: E501
"""
all_keys = list(state_dict_old.keys())
for key in all_keys:
if 'relative_position_index' in key:
state_dict_old.pop(key)
if 'relative_position_bias_table' in key:
rel_pos_bias = state_dict_old[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = state_dict_new[key].size()
dst_patch_shape = patch_shape_new
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens)**0.5)
dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
if src_size != dst_size:
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r**n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q**(i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size,
src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(
rel_pos_bias.device))
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens),
dim=0)
state_dict_old[key] = new_rel_pos_bias
return state_dict_old