| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: |
| """ |
| Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) |
| |
| Args: |
| pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates |
| embed_dim: Output channel dimension for embeddings |
| |
| Returns: |
| Tensor of shape (H, W, embed_dim) with positional embeddings |
| """ |
| H, W, grid_dim = pos_grid.shape |
| assert grid_dim == 2 |
| pos_flat = pos_grid.reshape(-1, grid_dim) |
|
|
| |
| emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) |
| emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) |
|
|
| |
| emb = torch.cat([emb_x, emb_y], dim=-1) |
|
|
| return emb.view(H, W, embed_dim) |
|
|
|
|
| def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: |
| """ |
| This function generates a 1D positional embedding from a given grid using sine and cosine functions. |
| |
| Args: |
| - embed_dim: The embedding dimension. |
| - pos: The position to generate the embedding from. |
| |
| Returns: |
| - emb: The generated 1D positional embedding. |
| """ |
| assert embed_dim % 2 == 0 |
| omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device) |
| omega /= embed_dim / 2.0 |
| omega = 1.0 / omega_0**omega |
|
|
| pos = pos.reshape(-1) |
| out = torch.einsum("m,d->md", pos, omega) |
|
|
| emb_sin = torch.sin(out) |
| emb_cos = torch.cos(out) |
|
|
| emb = torch.cat([emb_sin, emb_cos], dim=1) |
| return emb.float() |
|
|
|
|
| |
|
|
|
|
| def create_uv_grid( |
| width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None |
| ) -> torch.Tensor: |
| """ |
| Create a normalized UV grid of shape (width, height, 2). |
| |
| The grid spans horizontally and vertically according to an aspect ratio, |
| ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right |
| corner is at (x_span, y_span), normalized by the diagonal of the plane. |
| |
| Args: |
| width (int): Number of points horizontally. |
| height (int): Number of points vertically. |
| aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. |
| dtype (torch.dtype, optional): Data type of the resulting tensor. |
| device (torch.device, optional): Device on which the tensor is created. |
| |
| Returns: |
| torch.Tensor: A (width, height, 2) tensor of UV coordinates. |
| """ |
| |
| if aspect_ratio is None: |
| aspect_ratio = float(width) / float(height) |
|
|
| |
| diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 |
| span_x = aspect_ratio / diag_factor |
| span_y = 1.0 / diag_factor |
|
|
| |
| left_x = -span_x * (width - 1) / width |
| right_x = span_x * (width - 1) / width |
| top_y = -span_y * (height - 1) / height |
| bottom_y = span_y * (height - 1) / height |
|
|
| |
| x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) |
| y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) |
|
|
| |
| uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") |
| uv_grid = torch.stack((uu, vv), dim=-1) |
|
|
| return uv_grid |