Spaces:
Sleeping
Sleeping
| import torch | |
| import math | |
| import numpy as np | |
| def get_1d_sincos_pos_emb_from_grid(embed_dim, pos, device="cpu"): | |
| """ | |
| Generate 1D sinusoidal positional embeddings from grid positions. | |
| Args: | |
| embed_dim (int): The embedding dimension (must be even). | |
| pos (torch.Tensor): The grid positions (e.g., [0, 1, 2, ..., v-1]). | |
| Shape: [b * gh * gw] or [batch_size, sequence_length]. | |
| device (str): Device for the output tensor. | |
| Returns: | |
| torch.Tensor: Sinusoidal positional embeddings. | |
| Shape: [len(pos), embed_dim] | |
| """ | |
| assert embed_dim % 2 == 0, "Embedding dimension must be even for sine and cosine." | |
| # Convert positions to float | |
| pos = pos.float() | |
| # Compute the sinusoidal frequencies | |
| dim = torch.arange( | |
| embed_dim // 2, dtype=torch.float32, device=device | |
| ) # [0, 1, ..., embed_dim // 2 - 1] | |
| freq = 1.0 / (10000 ** (dim / (embed_dim // 2))) # Scale frequencies logarithmically | |
| # Calculate sine and cosine embeddings | |
| pos_emb_sin = torch.sin(pos[:, None] * freq) # Shape: [len(pos), embed_dim // 2] | |
| pos_emb_cos = torch.cos(pos[:, None] * freq) # Shape: [len(pos), embed_dim // 2] | |
| # Concatenate sine and cosine along the last dimension | |
| pos_emb = torch.cat([pos_emb_sin, pos_emb_cos], dim=-1) # Shape: [len(pos), embed_dim] | |
| return pos_emb | |
| def get_2d_sincos_pos_embed( | |
| embed_dim, | |
| grid_size, | |
| cls_token: bool = False, | |
| extra_tokens: int = 0, | |
| scale: float = 1.0, | |
| base_size=None, | |
| device: str = "cpu", | |
| ): | |
| """ | |
| Official RAYZAR 2D sine-cosine positional embeddings. | |
| Args: | |
| embed_dim: embedding dimension (even) | |
| grid_size: int or tuple (H, W) | |
| cls_token: unused here but kept for compatibility | |
| extra_tokens: if > 0, prepend zero embeddings (compat) | |
| scale: coordinate scale factor | |
| base_size: optional base size for scaling | |
| device: output device | |
| Returns: | |
| Tensor of shape [H*W (+extra_tokens), embed_dim] | |
| """ | |
| if not isinstance(grid_size, tuple): | |
| grid_size = (grid_size, grid_size) | |
| # Build numpy grids exactly like the official implementation (w first) | |
| H, W = grid_size | |
| grid_h = np.arange(H, dtype=np.float32) / scale | |
| grid_w = np.arange(W, dtype=np.float32) / scale | |
| if base_size is not None: | |
| grid_h *= base_size / H | |
| grid_w *= base_size / W | |
| # Note: meshgrid called with (grid_w, grid_h) so w goes first | |
| grid = np.meshgrid(grid_w, grid_h) | |
| grid = np.stack(grid, axis=0) | |
| # Match official reshape order [2, 1, W, H] to preserve token ordering | |
| grid = grid.reshape([2, 1, W, H]) | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, device=device) | |
| if cls_token and extra_tokens > 0: | |
| pos_embed = torch.cat( | |
| [torch.zeros([extra_tokens, embed_dim], device=device), pos_embed], dim=0 | |
| ) | |
| return pos_embed | |
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, device="cpu"): | |
| """Official helper to build 2D sin-cos embeddings from a (w-first) grid. | |
| Args: | |
| embed_dim: even | |
| grid: numpy array of shape [2, 1, W, H] (w first), as in official code | |
| device: output device | |
| Returns: | |
| Tensor [H*W, embed_dim] | |
| """ | |
| assert embed_dim % 2 == 0 | |
| # Convert grid to torch tensors on the target device | |
| if isinstance(grid, np.ndarray): | |
| grid_t = torch.from_numpy(grid).to(device=device, dtype=torch.float32) | |
| else: | |
| grid_t = grid.to(device=device, dtype=torch.float32) | |
| # In the official implementation emb_h uses grid[0] and emb_w uses grid[1] | |
| emb_h = get_1d_sincos_pos_emb_from_grid(embed_dim // 2, grid_t[0].reshape(-1), device) | |
| emb_w = get_1d_sincos_pos_emb_from_grid(embed_dim // 2, grid_t[1].reshape(-1), device) | |
| pos_embed = torch.cat([emb_h, emb_w], dim=1) | |
| return pos_embed | |
| def rope(positions: torch.Tensor, d: int, device="cpu") -> torch.Tensor: | |
| """ | |
| Given a batch of positions in [0,1], compute RoPE-style | |
| sine-cosine embeddings in dimension d (must be even). | |
| positions: (B, N) tensor of float positions in [0,1]. | |
| d: int, dimension of the embedding (should be even). | |
| Returns: | |
| embeddings: (B, N, d) tensor of float embeddings. | |
| """ | |
| # positions shape: [B, N] | |
| B, N = positions.shape | |
| half_d = d // 2 | |
| # Expand positions to shape [B, N, 1] | |
| positions_3d = positions.unsqueeze(-1) # [B, N, 1] | |
| # Prepare index and frequency tensors | |
| # idx => [1, 1, half_d] | |
| idx = torch.arange(half_d, device=device).view(1, 1, -1) | |
| # freqs => [1, 1, half_d], broadcast to [B, N, half_d] | |
| freqs = torch.pow(10000.0, -2.0 * idx / d) | |
| # angle => [B, N, half_d] | |
| angle = positions_3d.to(device) * freqs | |
| # Compute sine and cosine => each [B, N, half_d] | |
| sin_part = angle.sin() | |
| cos_part = angle.cos() | |
| # Interleave sine and cosine along the last dimension => [B, N, d] | |
| embeddings = torch.empty(B, N, d, device=device, dtype=positions.dtype) | |
| embeddings[..., 0::2] = sin_part | |
| embeddings[..., 1::2] = cos_part | |
| return embeddings | |