DenseLabelDev / third_parts /APE /tools /eva_interpolate_patch_14to16.py
zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
# --------------------------------------------------------
# EVA: Exploring the Limits of Masked Visual Representation Learning at Scale (https://arxiv.org/abs/2211.07636)
# Github source: https://github.com/baaivision/EVA
# Copyright (c) 2022 Beijing Academy of Artificial Intelligence (BAAI)
# Licensed under The MIT License [see LICENSE for details]
# By Yuxin Fang
# Based on timm, DINO, DeiT and BEiT codebases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------'
import argparse
import torch
def interpolate_pos_embed(checkpoint_model, new_size=16, image_size=224):
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"]
print("pos_embed_checkpoint", pos_embed_checkpoint.size(), pos_embed_checkpoint.dtype)
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = int(image_size / new_size) ** 2
num_extra_tokens = 1
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size, orig_size, new_size, new_size)
)
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(
0, 3, 1, 2
)
ori_dtype = pos_tokens.dtype
pos_tokens = pos_tokens.to(torch.float32)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
)
pos_tokens = pos_tokens.to(ori_dtype)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="interpolate patch_embed kernel")
parser.add_argument(
"--input",
default="/path/to/eva_psz14.pt",
type=str,
metavar="PATH",
required=True,
help="path to input EVA checkpoint with patch_embed kernel_size=14x14",
)
parser.add_argument(
"--output",
default="/path/to/eva_psz14to16.pt",
type=str,
metavar="PATH",
required=True,
help="path to output EVA checkpoint with patch_embed kernel_size=16x16",
)
parser.add_argument("--image_size", type=int, required=True)
args = parser.parse_args()
checkpoint = torch.load(args.input, map_location=torch.device("cpu"))
print(checkpoint.keys())
if "module" in checkpoint:
checkpoint["model"] = checkpoint.pop("module")
print(checkpoint.keys())
# interpolate patch_embed
if "model" in checkpoint:
patch_embed = checkpoint["model"]["patch_embed.proj.weight"]
else:
patch_embed = checkpoint["visual.patch_embed.proj.weight"]
C_o, C_in, H, W = patch_embed.shape
patch_embed = torch.nn.functional.interpolate(
patch_embed.float(), size=(16, 16), mode="bicubic", align_corners=False
)
if "model" in checkpoint:
checkpoint["model"]["patch_embed.proj.weight"] = patch_embed
else:
checkpoint["visual.patch_embed.proj.weight"] = patch_embed
# interpolate pos_embed too
if "model" in checkpoint:
interpolate_pos_embed(checkpoint["model"], new_size=16, image_size=args.image_size)
else:
checkpoint["pos_embed"] = checkpoint["visual.pos_embed"]
interpolate_pos_embed(checkpoint, new_size=16, image_size=args.image_size)
checkpoint["visual.pos_embed"] = checkpoint.pop("pos_embed")
print("======== new state_dict ========")
if "model" in checkpoint:
for k, v in list(checkpoint["model"].items()):
checkpoint["model"]["backbone.net." + k] = checkpoint["model"].pop(k)
print("rename", k, " ", "backbone.net." + k)
for k, v in list(checkpoint["model"].items()):
print(k, " ", v.shape)
else:
for k, v in list(checkpoint.items()):
if k.startswith("text.") or k == "logit_scale":
checkpoint.pop(k)
print("pop", k, " ", v.shape)
if k.startswith("visual."):
checkpoint["backbone.net." + k[7:]] = checkpoint.pop(k)
print("rename", k, " ", "backbone.net." + k[7:])
for k, v in list(checkpoint.items()):
print(k, " ", v.shape)
torch.save(checkpoint, args.output)