| """Resolution-aware encoder for VibeToken. |
| |
| Vision Transformer-based encoder with flexible patch sizes for variable-resolution inputs. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple |
| from einops import rearrange |
| from torch import Tensor, vmap |
| import numpy as np |
|
|
| from .blocks import ResidualAttentionBlock, _expand_token |
| from .embeddings import FuzzyEmbedding, to_2tuple |
|
|
|
|
| class ResolutionEncoder(nn.Module): |
| """Vision Transformer encoder with flexible resolution support. |
| |
| Encodes images into latent tokens using a ViT architecture with |
| support for variable input resolutions and patch sizes. |
| """ |
| |
| |
| MODEL_CONFIGS = { |
| "small": {"width": 512, "num_layers": 8, "num_heads": 8}, |
| "base": {"width": 768, "num_layers": 12, "num_heads": 12}, |
| "large": {"width": 1024, "num_layers": 24, "num_heads": 16}, |
| } |
| |
| def __init__(self, config): |
| """Initialize ResolutionEncoder. |
| |
| Args: |
| config: OmegaConf config with model parameters. |
| """ |
| super().__init__() |
| self.config = config |
| |
| |
| vq_config = config.model.vq_model if hasattr(config.model, 'vq_model') else config.model |
| self.patch_size = getattr(vq_config, 'vit_enc_patch_size', 32) |
| self.model_size = getattr(vq_config, 'vit_enc_model_size', 'large') |
| self.num_latent_tokens = getattr(vq_config, 'num_latent_tokens', 256) |
| self.token_size = getattr(vq_config, 'token_size', 256) |
| self.is_legacy = getattr(vq_config, 'is_legacy', False) |
| |
| |
| quantize_mode = getattr(vq_config, 'quantize_mode', 'vq') |
| if quantize_mode == "vae": |
| self.token_size = self.token_size * 2 |
| |
| |
| model_cfg = self.MODEL_CONFIGS[self.model_size] |
| self.width = model_cfg["width"] |
| self.num_layers = model_cfg["num_layers"] |
| self.num_heads = model_cfg["num_heads"] |
| |
| |
| self.patch_embed = nn.Conv2d( |
| in_channels=3, out_channels=self.width, |
| kernel_size=self.patch_size, stride=self.patch_size, bias=True |
| ) |
| |
| |
| scale = self.width ** -0.5 |
| self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) |
| self.positional_embedding = FuzzyEmbedding(1024, scale, self.width) |
| self.latent_token_positional_embedding = nn.Parameter( |
| scale * torch.randn(self.num_latent_tokens, self.width) |
| ) |
| self.ln_pre = nn.LayerNorm(self.width) |
| |
| |
| self.transformer = nn.ModuleList([ |
| ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0) |
| for _ in range(self.num_layers) |
| ]) |
| |
| |
| self.ln_post = nn.LayerNorm(self.width) |
| self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True) |
| |
| |
| self.pinvs = {} |
|
|
| def _resize(self, x: Tensor, shape: Tuple[int, int]) -> Tensor: |
| """Bilinear resize of 2D tensor.""" |
| x_resized = F.interpolate( |
| x[None, None, ...], shape, mode="bilinear", antialias=False |
| ) |
| return x_resized[0, 0, ...] |
|
|
| def _calculate_pinv( |
| self, |
| old_shape: Tuple[int, int], |
| new_shape: Tuple[int, int], |
| device: torch.device, |
| ) -> Tensor: |
| """Calculate pseudo-inverse of resize matrix for FlexiViT.""" |
| mat = [] |
| for i in range(np.prod(old_shape)): |
| basis_vec = torch.zeros(old_shape, device=device) |
| basis_vec[np.unravel_index(i, old_shape)] = 1.0 |
| mat.append(self._resize(basis_vec, new_shape).reshape(-1)) |
| resize_matrix = torch.stack(mat) |
| return torch.linalg.pinv(resize_matrix) |
|
|
| def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: Tuple[int, int]) -> Tensor: |
| """Resize patch embedding kernel to new patch size (FlexiViT). |
| |
| Args: |
| patch_embed: Original weight tensor (out_ch, in_ch, H, W). |
| new_patch_size: Target (H, W) patch size. |
| |
| Returns: |
| Resized weight tensor. |
| """ |
| base_size = to_2tuple(self.patch_size) |
| if base_size == new_patch_size: |
| return patch_embed |
|
|
| if new_patch_size not in self.pinvs: |
| self.pinvs[new_patch_size] = self._calculate_pinv( |
| base_size, new_patch_size, device=patch_embed.device |
| ) |
| pinv = self.pinvs[new_patch_size] |
|
|
| def resample_patch_embed(pe: Tensor) -> Tensor: |
| h, w = new_patch_size |
| original_dtype = pe.dtype |
| resampled = pinv @ pe.float().reshape(-1) |
| return rearrange(resampled.to(original_dtype), "(h w) -> h w", h=h, w=w) |
|
|
| v_resample = vmap(vmap(resample_patch_embed, 0, 0), 1, 1) |
| return v_resample(patch_embed) |
|
|
| def apply_flexivit_patch_embed(self, x: Tensor, target_patch_size: Tuple[int, int]) -> Tensor: |
| """Apply patch embedding with flexible patch size. |
| |
| Args: |
| x: Input image tensor (B, 3, H, W). |
| target_patch_size: Target patch size (H, W). |
| |
| Returns: |
| Patch embeddings (B, C, grid_H, grid_W). |
| """ |
| patch_size = to_2tuple(target_patch_size) |
| |
| if patch_size == to_2tuple(self.patch_size): |
| weight = self.patch_embed.weight |
| else: |
| weight = self.resize_patch_embed(self.patch_embed.weight, patch_size) |
|
|
| return F.conv2d(x, weight, bias=self.patch_embed.bias, stride=patch_size) |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| latent_tokens: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| encode_patch_size: Optional[Tuple[int, int]] = None, |
| ) -> torch.Tensor: |
| """Encode images to latent tokens. |
| |
| Args: |
| pixel_values: Input images (B, 3, H, W), values in [0, 1]. |
| latent_tokens: Learnable latent tokens (num_latent, width). |
| attention_mask: Optional attention mask. |
| encode_patch_size: Optional custom patch size for encoding. |
| |
| Returns: |
| Encoded latent features (B, token_size, 1, num_latent). |
| """ |
| batch_size, _, H, W = pixel_values.shape |
| |
| |
| if encode_patch_size is None: |
| target_patch_size = (self.patch_size, self.patch_size) |
| elif isinstance(encode_patch_size, int): |
| target_patch_size = (encode_patch_size, encode_patch_size) |
| else: |
| target_patch_size = encode_patch_size |
| |
| |
| x = self.apply_flexivit_patch_embed(pixel_values, target_patch_size) |
| |
| |
| x = x.reshape(x.shape[0], x.shape[1], -1) |
| x = x.permute(0, 2, 1) |
| |
| |
| x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) |
| |
| |
| grid_height = H // target_patch_size[0] |
| grid_width = W // target_patch_size[1] |
| |
| |
| num_latent = latent_tokens.shape[0] |
| latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype) |
| latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)[:num_latent] |
| |
| |
| x = x + self.positional_embedding(grid_height, grid_width, train=False, dtype=x.dtype) |
| |
| |
| x = torch.cat([x, latent_tokens], dim=1) |
| |
| |
| x = self.ln_pre(x) |
| x = x.permute(1, 0, 2) |
| |
| |
| for layer in self.transformer: |
| x = layer(x, attention_mask=None) |
| |
| x = x.permute(1, 0, 2) |
| |
| |
| latent_tokens = x[:, 1 + grid_height * grid_width:] |
| latent_tokens = self.ln_post(latent_tokens) |
| |
| |
| if self.is_legacy: |
| latent_tokens = latent_tokens.reshape(batch_size, self.width, num_latent, 1) |
| else: |
| latent_tokens = latent_tokens.reshape(batch_size, num_latent, self.width, 1).permute(0, 2, 1, 3) |
| |
| latent_tokens = self.conv_out(latent_tokens) |
| latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, num_latent) |
| |
| return latent_tokens |
|
|