wildrayzer / utils /pe_utils.py
Xuweiyi's picture
Initial upload: WildRayZer Gradio Space
13f3c3a verified
import torch
import math
import numpy as np
def get_1d_sincos_pos_emb_from_grid(embed_dim, pos, device="cpu"):
"""
Generate 1D sinusoidal positional embeddings from grid positions.
Args:
embed_dim (int): The embedding dimension (must be even).
pos (torch.Tensor): The grid positions (e.g., [0, 1, 2, ..., v-1]).
Shape: [b * gh * gw] or [batch_size, sequence_length].
device (str): Device for the output tensor.
Returns:
torch.Tensor: Sinusoidal positional embeddings.
Shape: [len(pos), embed_dim]
"""
assert embed_dim % 2 == 0, "Embedding dimension must be even for sine and cosine."
# Convert positions to float
pos = pos.float()
# Compute the sinusoidal frequencies
dim = torch.arange(
embed_dim // 2, dtype=torch.float32, device=device
) # [0, 1, ..., embed_dim // 2 - 1]
freq = 1.0 / (10000 ** (dim / (embed_dim // 2))) # Scale frequencies logarithmically
# Calculate sine and cosine embeddings
pos_emb_sin = torch.sin(pos[:, None] * freq) # Shape: [len(pos), embed_dim // 2]
pos_emb_cos = torch.cos(pos[:, None] * freq) # Shape: [len(pos), embed_dim // 2]
# Concatenate sine and cosine along the last dimension
pos_emb = torch.cat([pos_emb_sin, pos_emb_cos], dim=-1) # Shape: [len(pos), embed_dim]
return pos_emb
def get_2d_sincos_pos_embed(
embed_dim,
grid_size,
cls_token: bool = False,
extra_tokens: int = 0,
scale: float = 1.0,
base_size=None,
device: str = "cpu",
):
"""
Official RAYZAR 2D sine-cosine positional embeddings.
Args:
embed_dim: embedding dimension (even)
grid_size: int or tuple (H, W)
cls_token: unused here but kept for compatibility
extra_tokens: if > 0, prepend zero embeddings (compat)
scale: coordinate scale factor
base_size: optional base size for scaling
device: output device
Returns:
Tensor of shape [H*W (+extra_tokens), embed_dim]
"""
if not isinstance(grid_size, tuple):
grid_size = (grid_size, grid_size)
# Build numpy grids exactly like the official implementation (w first)
H, W = grid_size
grid_h = np.arange(H, dtype=np.float32) / scale
grid_w = np.arange(W, dtype=np.float32) / scale
if base_size is not None:
grid_h *= base_size / H
grid_w *= base_size / W
# Note: meshgrid called with (grid_w, grid_h) so w goes first
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
# Match official reshape order [2, 1, W, H] to preserve token ordering
grid = grid.reshape([2, 1, W, H])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, device=device)
if cls_token and extra_tokens > 0:
pos_embed = torch.cat(
[torch.zeros([extra_tokens, embed_dim], device=device), pos_embed], dim=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, device="cpu"):
"""Official helper to build 2D sin-cos embeddings from a (w-first) grid.
Args:
embed_dim: even
grid: numpy array of shape [2, 1, W, H] (w first), as in official code
device: output device
Returns:
Tensor [H*W, embed_dim]
"""
assert embed_dim % 2 == 0
# Convert grid to torch tensors on the target device
if isinstance(grid, np.ndarray):
grid_t = torch.from_numpy(grid).to(device=device, dtype=torch.float32)
else:
grid_t = grid.to(device=device, dtype=torch.float32)
# In the official implementation emb_h uses grid[0] and emb_w uses grid[1]
emb_h = get_1d_sincos_pos_emb_from_grid(embed_dim // 2, grid_t[0].reshape(-1), device)
emb_w = get_1d_sincos_pos_emb_from_grid(embed_dim // 2, grid_t[1].reshape(-1), device)
pos_embed = torch.cat([emb_h, emb_w], dim=1)
return pos_embed
def rope(positions: torch.Tensor, d: int, device="cpu") -> torch.Tensor:
"""
Given a batch of positions in [0,1], compute RoPE-style
sine-cosine embeddings in dimension d (must be even).
positions: (B, N) tensor of float positions in [0,1].
d: int, dimension of the embedding (should be even).
Returns:
embeddings: (B, N, d) tensor of float embeddings.
"""
# positions shape: [B, N]
B, N = positions.shape
half_d = d // 2
# Expand positions to shape [B, N, 1]
positions_3d = positions.unsqueeze(-1) # [B, N, 1]
# Prepare index and frequency tensors
# idx => [1, 1, half_d]
idx = torch.arange(half_d, device=device).view(1, 1, -1)
# freqs => [1, 1, half_d], broadcast to [B, N, half_d]
freqs = torch.pow(10000.0, -2.0 * idx / d)
# angle => [B, N, half_d]
angle = positions_3d.to(device) * freqs
# Compute sine and cosine => each [B, N, half_d]
sin_part = angle.sin()
cos_part = angle.cos()
# Interleave sine and cosine along the last dimension => [B, N, d]
embeddings = torch.empty(B, N, d, device=device, dtype=positions.dtype)
embeddings[..., 0::2] = sin_part
embeddings[..., 1::2] = cos_part
return embeddings