Spaces:
Build error
Build error
File size: 5,033 Bytes
a3f0d6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Variable size position embedding utils for handling different image dimensions
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn.functional as F
def get_2d_sincos_pos_embed_variable(embed_dim, grid_h, grid_w, cls_token=False):
"""
Create 2D sine-cosine position embeddings for variable grid sizes
Args:
embed_dim: embedding dimension
grid_h: height of the grid (number of patches in height)
grid_w: width of the grid (number of patches in width)
cls_token: whether to include class token
Returns:
pos_embed: [grid_h*grid_w, embed_dim] or [1+grid_h*grid_w, embed_dim] (w/ or w/o cls_token)
"""
grid_h_coords = np.arange(grid_h, dtype=np.float32)
grid_w_coords = np.arange(grid_w, dtype=np.float32)
grid = np.meshgrid(grid_w_coords, grid_h_coords) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_h, grid_w])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def interpolate_pos_embed_variable(original_pos_embed, target_h, target_w, cls_token=True):
"""
Interpolate position embeddings for arbitrary target sizes
Args:
original_pos_embed: original positional embeddings [1, N, D]
target_h: target height in patches
target_w: target width in patches
cls_token: whether the first token is a class token
Returns:
interpolated_pos_embed: [1, target_h*target_w + cls_token, D]
"""
embed_dim = original_pos_embed.shape[-1]
if cls_token:
class_pos_embed = original_pos_embed[:, 0:1] # [1, 1, D]
patch_pos_embed = original_pos_embed[:, 1:] # [1, N-1, D]
orig_num_patches = patch_pos_embed.shape[1]
else:
class_pos_embed = None
patch_pos_embed = original_pos_embed
orig_num_patches = patch_pos_embed.shape[1]
# Determine original grid size (assume square for original)
orig_h = orig_w = int(np.sqrt(orig_num_patches))
if orig_h * orig_w != orig_num_patches:
raise ValueError(f"Original number of patches {orig_num_patches} is not a perfect square")
# Reshape to spatial dimensions
patch_pos_embed = patch_pos_embed.reshape(1, orig_h, orig_w, embed_dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) # [1, D, orig_h, orig_w]
# Interpolate to target size
patch_pos_embed = F.interpolate(
patch_pos_embed,
size=(target_h, target_w),
mode='bicubic',
align_corners=False
)
# Reshape back to token sequence
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1) # [1, target_h, target_w, D]
patch_pos_embed = patch_pos_embed.flatten(1, 2) # [1, target_h*target_w, D]
if cls_token:
new_pos_embed = torch.cat([class_pos_embed, patch_pos_embed], dim=1)
else:
new_pos_embed = patch_pos_embed
return new_pos_embed
def create_variable_pos_embed(embed_dim, height_patches, width_patches, cls_token=True):
"""
Create positional embeddings for specific patch grid dimensions
Args:
embed_dim: embedding dimension
height_patches: number of patches in height
width_patches: number of patches in width
cls_token: whether to include class token
Returns:
pos_embed: positional embeddings tensor
"""
pos_embed_np = get_2d_sincos_pos_embed_variable(
embed_dim, height_patches, width_patches, cls_token=cls_token
)
pos_embed = torch.from_numpy(pos_embed_np).float().unsqueeze(0)
return pos_embed
|