File size: 5,033 Bytes
a3f0d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Variable size position embedding utils for handling different image dimensions
# --------------------------------------------------------

import numpy as np
import torch
import torch.nn.functional as F


def get_2d_sincos_pos_embed_variable(embed_dim, grid_h, grid_w, cls_token=False):
    """
    Create 2D sine-cosine position embeddings for variable grid sizes
    
    Args:
        embed_dim: embedding dimension
        grid_h: height of the grid (number of patches in height)
        grid_w: width of the grid (number of patches in width)
        cls_token: whether to include class token
    
    Returns:
        pos_embed: [grid_h*grid_w, embed_dim] or [1+grid_h*grid_w, embed_dim] (w/ or w/o cls_token)
    """
    grid_h_coords = np.arange(grid_h, dtype=np.float32)
    grid_w_coords = np.arange(grid_w, dtype=np.float32)
    grid = np.meshgrid(grid_w_coords, grid_h_coords)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_h, grid_w])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def interpolate_pos_embed_variable(original_pos_embed, target_h, target_w, cls_token=True):
    """
    Interpolate position embeddings for arbitrary target sizes
    
    Args:
        original_pos_embed: original positional embeddings [1, N, D]
        target_h: target height in patches
        target_w: target width in patches
        cls_token: whether the first token is a class token
    
    Returns:
        interpolated_pos_embed: [1, target_h*target_w + cls_token, D]
    """
    embed_dim = original_pos_embed.shape[-1]
    
    if cls_token:
        class_pos_embed = original_pos_embed[:, 0:1]  # [1, 1, D]
        patch_pos_embed = original_pos_embed[:, 1:]   # [1, N-1, D]
        orig_num_patches = patch_pos_embed.shape[1]
    else:
        class_pos_embed = None
        patch_pos_embed = original_pos_embed
        orig_num_patches = patch_pos_embed.shape[1]
    
    # Determine original grid size (assume square for original)
    orig_h = orig_w = int(np.sqrt(orig_num_patches))
    
    if orig_h * orig_w != orig_num_patches:
        raise ValueError(f"Original number of patches {orig_num_patches} is not a perfect square")
    
    # Reshape to spatial dimensions
    patch_pos_embed = patch_pos_embed.reshape(1, orig_h, orig_w, embed_dim)
    patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)  # [1, D, orig_h, orig_w]
    
    # Interpolate to target size
    patch_pos_embed = F.interpolate(
        patch_pos_embed, 
        size=(target_h, target_w), 
        mode='bicubic', 
        align_corners=False
    )
    
    # Reshape back to token sequence
    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1)  # [1, target_h, target_w, D]
    patch_pos_embed = patch_pos_embed.flatten(1, 2)       # [1, target_h*target_w, D]
    
    if cls_token:
        new_pos_embed = torch.cat([class_pos_embed, patch_pos_embed], dim=1)
    else:
        new_pos_embed = patch_pos_embed
    
    return new_pos_embed


def create_variable_pos_embed(embed_dim, height_patches, width_patches, cls_token=True):
    """
    Create positional embeddings for specific patch grid dimensions
    
    Args:
        embed_dim: embedding dimension
        height_patches: number of patches in height
        width_patches: number of patches in width
        cls_token: whether to include class token
    
    Returns:
        pos_embed: positional embeddings tensor
    """
    pos_embed_np = get_2d_sincos_pos_embed_variable(
        embed_dim, height_patches, width_patches, cls_token=cls_token
    )
    pos_embed = torch.from_numpy(pos_embed_np).float().unsqueeze(0)
    return pos_embed