| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| |
|
| | import torch |
| |
|
| |
|
| | def interpolate_pos_embed(checkpoint_model, new_size=16, image_size=224): |
| | if "pos_embed" in checkpoint_model: |
| | pos_embed_checkpoint = checkpoint_model["pos_embed"] |
| | print("pos_embed_checkpoint", pos_embed_checkpoint.size(), pos_embed_checkpoint.dtype) |
| | embedding_size = pos_embed_checkpoint.shape[-1] |
| | num_patches = int(image_size / new_size) ** 2 |
| | num_extra_tokens = 1 |
| | |
| | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
| | |
| | new_size = int(num_patches**0.5) |
| | |
| | if orig_size != new_size: |
| | print( |
| | "Position interpolate from %dx%d to %dx%d" |
| | % (orig_size, orig_size, new_size, new_size) |
| | ) |
| | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
| | |
| | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
| | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute( |
| | 0, 3, 1, 2 |
| | ) |
| | ori_dtype = pos_tokens.dtype |
| | pos_tokens = pos_tokens.to(torch.float32) |
| | pos_tokens = torch.nn.functional.interpolate( |
| | pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False |
| | ) |
| | pos_tokens = pos_tokens.to(ori_dtype) |
| | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
| | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
| | checkpoint_model["pos_embed"] = new_pos_embed |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="interpolate patch_embed kernel") |
| | parser.add_argument( |
| | "--input", |
| | default="/path/to/eva_psz14.pt", |
| | type=str, |
| | metavar="PATH", |
| | required=True, |
| | help="path to input EVA checkpoint with patch_embed kernel_size=14x14", |
| | ) |
| | parser.add_argument( |
| | "--output", |
| | default="/path/to/eva_psz14to16.pt", |
| | type=str, |
| | metavar="PATH", |
| | required=True, |
| | help="path to output EVA checkpoint with patch_embed kernel_size=16x16", |
| | ) |
| | parser.add_argument("--image_size", type=int, required=True) |
| | args = parser.parse_args() |
| |
|
| | checkpoint = torch.load(args.input, map_location=torch.device("cpu")) |
| |
|
| | print(checkpoint.keys()) |
| | if "module" in checkpoint: |
| | checkpoint["model"] = checkpoint.pop("module") |
| | print(checkpoint.keys()) |
| |
|
| | |
| | if "model" in checkpoint: |
| | patch_embed = checkpoint["model"]["patch_embed.proj.weight"] |
| | else: |
| | patch_embed = checkpoint["visual.patch_embed.proj.weight"] |
| | C_o, C_in, H, W = patch_embed.shape |
| | patch_embed = torch.nn.functional.interpolate( |
| | patch_embed.float(), size=(16, 16), mode="bicubic", align_corners=False |
| | ) |
| | if "model" in checkpoint: |
| | checkpoint["model"]["patch_embed.proj.weight"] = patch_embed |
| | else: |
| | checkpoint["visual.patch_embed.proj.weight"] = patch_embed |
| |
|
| | |
| | if "model" in checkpoint: |
| | interpolate_pos_embed(checkpoint["model"], new_size=16, image_size=args.image_size) |
| | else: |
| | checkpoint["pos_embed"] = checkpoint["visual.pos_embed"] |
| | interpolate_pos_embed(checkpoint, new_size=16, image_size=args.image_size) |
| | checkpoint["visual.pos_embed"] = checkpoint.pop("pos_embed") |
| |
|
| | print("======== new state_dict ========") |
| | if "model" in checkpoint: |
| | for k, v in list(checkpoint["model"].items()): |
| | checkpoint["model"]["backbone.net." + k] = checkpoint["model"].pop(k) |
| | print("rename", k, " ", "backbone.net." + k) |
| | for k, v in list(checkpoint["model"].items()): |
| | print(k, " ", v.shape) |
| | else: |
| | for k, v in list(checkpoint.items()): |
| | if k.startswith("text.") or k == "logit_scale": |
| | checkpoint.pop(k) |
| | print("pop", k, " ", v.shape) |
| | if k.startswith("visual."): |
| | checkpoint["backbone.net." + k[7:]] = checkpoint.pop(k) |
| | print("rename", k, " ", "backbone.net." + k[7:]) |
| | for k, v in list(checkpoint.items()): |
| | print(k, " ", v.shape) |
| |
|
| | torch.save(checkpoint, args.output) |
| |
|