VibeToken / vibetoken /modeling /decoder.py
APGASU's picture
scripts
7bef20f verified
"""Resolution-aware decoder for VibeToken.
Vision Transformer-based decoder with flexible output resolutions.
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple
from einops.layers.torch import Rearrange
from .blocks import ResidualAttentionBlock, ResizableBlur, _expand_token
from .embeddings import FuzzyEmbedding
class ResolutionDecoder(nn.Module):
"""Vision Transformer decoder with flexible resolution support.
Decodes latent tokens back to images with support for variable
output resolutions and patch sizes.
"""
# Model size configurations
MODEL_CONFIGS = {
"small": {"width": 512, "num_layers": 8, "num_heads": 8},
"base": {"width": 768, "num_layers": 12, "num_heads": 12},
"large": {"width": 1024, "num_layers": 24, "num_heads": 16},
}
def __init__(self, config):
"""Initialize ResolutionDecoder.
Args:
config: OmegaConf config with model parameters.
"""
super().__init__()
self.config = config
# Extract config values
vq_config = config.model.vq_model if hasattr(config.model, 'vq_model') else config.model
self.image_size = getattr(config.dataset.preprocessing, 'crop_size', 512) if hasattr(config, 'dataset') else 512
self.patch_size = getattr(vq_config, 'vit_dec_patch_size', 32)
self.model_size = getattr(vq_config, 'vit_dec_model_size', 'large')
self.num_latent_tokens = getattr(vq_config, 'num_latent_tokens', 256)
self.token_size = getattr(vq_config, 'token_size', 256)
self.is_legacy = getattr(vq_config, 'is_legacy', False)
if self.is_legacy:
raise NotImplementedError("Legacy mode is not supported in this inference-only version")
# Get model dimensions
model_cfg = self.MODEL_CONFIGS[self.model_size]
self.width = model_cfg["width"]
self.num_layers = model_cfg["num_layers"]
self.num_heads = model_cfg["num_heads"]
# Input projection
self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True)
# Embeddings
scale = self.width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
self.positional_embedding = FuzzyEmbedding(1024, scale, self.width)
self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
self.latent_token_positional_embedding = nn.Parameter(
scale * torch.randn(self.num_latent_tokens, self.width)
)
self.ln_pre = nn.LayerNorm(self.width)
# Transformer layers
self.transformer = nn.ModuleList([
ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0)
for _ in range(self.num_layers)
])
# Output projection
self.ln_post = nn.LayerNorm(self.width)
self.ffn = nn.Conv2d(
self.width, self.patch_size * self.patch_size * 3,
kernel_size=1, padding=0, bias=True
)
self.rearrange = Rearrange(
'b (p1 p2 c) h w -> b c (h p1) (w p2)',
p1=self.patch_size, p2=self.patch_size
)
self.down_scale = ResizableBlur(channels=3, max_kernel_size=9, init_type="lanczos")
self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
def _select_patch_size(self, height: int, width: int) -> int:
"""Select appropriate patch size based on target resolution.
Args:
height: Target image height.
width: Target image width.
Returns:
Selected patch size.
"""
total_pixels = height * width
min_patches, max_patches = 256, 1024
possible_sizes = []
for ps in [8, 16, 32]:
grid_h = height // ps
grid_w = width // ps
total_patches = grid_h * grid_w
if min_patches <= total_patches <= max_patches:
possible_sizes.append(ps)
if not possible_sizes:
# Find closest to target range
patch_counts = []
for ps in [8, 16, 32]:
grid_h = height // ps
grid_w = width // ps
patch_counts.append((ps, grid_h * grid_w))
patch_counts.sort(key=lambda x: min(abs(x[1] - min_patches), abs(x[1] - max_patches)))
return patch_counts[0][0]
return possible_sizes[0]
def forward(
self,
z_quantized: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
decode_patch_size: Optional[int] = None,
) -> torch.Tensor:
"""Decode latent tokens to images.
Args:
z_quantized: Quantized latent features (B, C, H, W).
attention_mask: Optional attention mask.
height: Target image height.
width: Target image width.
decode_patch_size: Optional custom patch size for decoding.
Returns:
Decoded images (B, 3, height, width), values in [0, 1].
"""
N, C, H, W = z_quantized.shape
# Reshape and project input
x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1) # (N, seq_len, C*H)
x = self.decoder_embed(x)
batchsize, seq_len, _ = x.shape
# Default output size
if height is None:
height = self.image_size
if width is None:
width = self.image_size
# Determine patch size
if decode_patch_size is None:
selected_patch_size = self._select_patch_size(height, width)
else:
selected_patch_size = decode_patch_size
if isinstance(selected_patch_size, int):
selected_patch_size = (selected_patch_size, selected_patch_size)
grid_height = height // selected_patch_size[0]
grid_width = width // selected_patch_size[1]
# Create mask tokens for output positions
mask_tokens = self.mask_token.repeat(batchsize, grid_height * grid_width, 1).to(x.dtype)
mask_tokens = torch.cat([
_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
mask_tokens
], dim=1)
# Add positional embeddings
mask_tokens = mask_tokens + self.positional_embedding(
grid_height, grid_width, train=False
).to(mask_tokens.dtype)
x = x + self.latent_token_positional_embedding[:seq_len]
x = torch.cat([mask_tokens, x], dim=1)
# Pre-norm and reshape for transformer
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # (seq_len, B, width)
# Apply transformer layers
for layer in self.transformer:
x = layer(x, attention_mask=None)
x = x.permute(1, 0, 2) # (B, seq_len, width)
# Extract output tokens (excluding class token and latent tokens)
x = x[:, 1:1 + grid_height * grid_width]
x = self.ln_post(x)
# Reshape to spatial format and project to pixels
x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width)
x = self.ffn(x.contiguous())
x = self.rearrange(x)
# Downsample to target resolution
_, _, org_h, org_w = x.shape
x = self.down_scale(x, input_size=(org_h, org_w), target_size=(height, width))
x = self.conv_out(x)
return x