File size: 4,131 Bytes
b74998d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch


def position_grid_to_embed(

    pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100

) -> torch.Tensor:
    """

    Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)



    Args:

        pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates

        embed_dim: Output channel dimension for embeddings



    Returns:

        Tensor of shape (H, W, embed_dim) with positional embeddings

    """
    H, W, grid_dim = pos_grid.shape
    assert grid_dim == 2
    pos_flat = pos_grid.reshape(-1, grid_dim)  # Flatten to (H*W, 2)

    # Process x and y coordinates separately
    emb_x = make_sincos_pos_embed(
        embed_dim // 2, pos_flat[:, 0], omega_0=omega_0
    )  # [1, H*W, D/2]
    emb_y = make_sincos_pos_embed(
        embed_dim // 2, pos_flat[:, 1], omega_0=omega_0
    )  # [1, H*W, D/2]

    # Combine and reshape
    emb = torch.cat([emb_x, emb_y], dim=-1)  # [1, H*W, D]

    return emb.view(H, W, embed_dim)  # [H, W, D]


def make_sincos_pos_embed(

    embed_dim: int, pos: torch.Tensor, omega_0: float = 100

) -> torch.Tensor:
    """

    This function generates a 1D positional embedding from a given grid using sine and cosine functions.



    Args:

    - embed_dim: The embedding dimension.

    - pos: The position to generate the embedding from.



    Returns:

    - emb: The generated 1D positional embedding.

    """
    assert embed_dim % 2 == 0
    device = pos.device
    omega = torch.arange(
        embed_dim // 2,
        dtype=torch.float32 if device.type == "mps" else torch.double,
        device=device,
    )
    omega /= embed_dim / 2.0
    omega = 1.0 / omega_0**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    emb_sin = torch.sin(out)  # (M, D/2)
    emb_cos = torch.cos(out)  # (M, D/2)

    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
    return emb.float()


# Inspired by https://github.com/microsoft/moge


def create_uv_grid(

    width: int,

    height: int,

    aspect_ratio: float = None,

    dtype: torch.dtype = None,

    device: torch.device = None,

) -> torch.Tensor:
    """

    Create a normalized UV grid of shape (width, height, 2).



    The grid spans horizontally and vertically according to an aspect ratio,

    ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right

    corner is at (x_span, y_span), normalized by the diagonal of the plane.



    Args:

        width (int): Number of points horizontally.

        height (int): Number of points vertically.

        aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.

        dtype (torch.dtype, optional): Data type of the resulting tensor.

        device (torch.device, optional): Device on which the tensor is created.



    Returns:

        torch.Tensor: A (width, height, 2) tensor of UV coordinates.

    """
    # Derive aspect ratio if not explicitly provided
    if aspect_ratio is None:
        aspect_ratio = float(width) / float(height)

    # Compute normalized spans for X and Y
    diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
    span_x = aspect_ratio / diag_factor
    span_y = 1.0 / diag_factor

    # Establish the linspace boundaries
    left_x = -span_x * (width - 1) / width
    right_x = span_x * (width - 1) / width
    top_y = -span_y * (height - 1) / height
    bottom_y = span_y * (height - 1) / height

    # Generate 1D coordinates
    x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
    y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)

    # Create 2D meshgrid (width x height) and stack into UV
    uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
    uv_grid = torch.stack((uu, vv), dim=-1)

    return uv_grid