File size: 7,718 Bytes
7bef20f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | """Resolution-aware decoder for VibeToken.
Vision Transformer-based decoder with flexible output resolutions.
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple
from einops.layers.torch import Rearrange
from .blocks import ResidualAttentionBlock, ResizableBlur, _expand_token
from .embeddings import FuzzyEmbedding
class ResolutionDecoder(nn.Module):
"""Vision Transformer decoder with flexible resolution support.
Decodes latent tokens back to images with support for variable
output resolutions and patch sizes.
"""
# Model size configurations
MODEL_CONFIGS = {
"small": {"width": 512, "num_layers": 8, "num_heads": 8},
"base": {"width": 768, "num_layers": 12, "num_heads": 12},
"large": {"width": 1024, "num_layers": 24, "num_heads": 16},
}
def __init__(self, config):
"""Initialize ResolutionDecoder.
Args:
config: OmegaConf config with model parameters.
"""
super().__init__()
self.config = config
# Extract config values
vq_config = config.model.vq_model if hasattr(config.model, 'vq_model') else config.model
self.image_size = getattr(config.dataset.preprocessing, 'crop_size', 512) if hasattr(config, 'dataset') else 512
self.patch_size = getattr(vq_config, 'vit_dec_patch_size', 32)
self.model_size = getattr(vq_config, 'vit_dec_model_size', 'large')
self.num_latent_tokens = getattr(vq_config, 'num_latent_tokens', 256)
self.token_size = getattr(vq_config, 'token_size', 256)
self.is_legacy = getattr(vq_config, 'is_legacy', False)
if self.is_legacy:
raise NotImplementedError("Legacy mode is not supported in this inference-only version")
# Get model dimensions
model_cfg = self.MODEL_CONFIGS[self.model_size]
self.width = model_cfg["width"]
self.num_layers = model_cfg["num_layers"]
self.num_heads = model_cfg["num_heads"]
# Input projection
self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True)
# Embeddings
scale = self.width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
self.positional_embedding = FuzzyEmbedding(1024, scale, self.width)
self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
self.latent_token_positional_embedding = nn.Parameter(
scale * torch.randn(self.num_latent_tokens, self.width)
)
self.ln_pre = nn.LayerNorm(self.width)
# Transformer layers
self.transformer = nn.ModuleList([
ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0)
for _ in range(self.num_layers)
])
# Output projection
self.ln_post = nn.LayerNorm(self.width)
self.ffn = nn.Conv2d(
self.width, self.patch_size * self.patch_size * 3,
kernel_size=1, padding=0, bias=True
)
self.rearrange = Rearrange(
'b (p1 p2 c) h w -> b c (h p1) (w p2)',
p1=self.patch_size, p2=self.patch_size
)
self.down_scale = ResizableBlur(channels=3, max_kernel_size=9, init_type="lanczos")
self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
def _select_patch_size(self, height: int, width: int) -> int:
"""Select appropriate patch size based on target resolution.
Args:
height: Target image height.
width: Target image width.
Returns:
Selected patch size.
"""
total_pixels = height * width
min_patches, max_patches = 256, 1024
possible_sizes = []
for ps in [8, 16, 32]:
grid_h = height // ps
grid_w = width // ps
total_patches = grid_h * grid_w
if min_patches <= total_patches <= max_patches:
possible_sizes.append(ps)
if not possible_sizes:
# Find closest to target range
patch_counts = []
for ps in [8, 16, 32]:
grid_h = height // ps
grid_w = width // ps
patch_counts.append((ps, grid_h * grid_w))
patch_counts.sort(key=lambda x: min(abs(x[1] - min_patches), abs(x[1] - max_patches)))
return patch_counts[0][0]
return possible_sizes[0]
def forward(
self,
z_quantized: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
decode_patch_size: Optional[int] = None,
) -> torch.Tensor:
"""Decode latent tokens to images.
Args:
z_quantized: Quantized latent features (B, C, H, W).
attention_mask: Optional attention mask.
height: Target image height.
width: Target image width.
decode_patch_size: Optional custom patch size for decoding.
Returns:
Decoded images (B, 3, height, width), values in [0, 1].
"""
N, C, H, W = z_quantized.shape
# Reshape and project input
x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1) # (N, seq_len, C*H)
x = self.decoder_embed(x)
batchsize, seq_len, _ = x.shape
# Default output size
if height is None:
height = self.image_size
if width is None:
width = self.image_size
# Determine patch size
if decode_patch_size is None:
selected_patch_size = self._select_patch_size(height, width)
else:
selected_patch_size = decode_patch_size
if isinstance(selected_patch_size, int):
selected_patch_size = (selected_patch_size, selected_patch_size)
grid_height = height // selected_patch_size[0]
grid_width = width // selected_patch_size[1]
# Create mask tokens for output positions
mask_tokens = self.mask_token.repeat(batchsize, grid_height * grid_width, 1).to(x.dtype)
mask_tokens = torch.cat([
_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
mask_tokens
], dim=1)
# Add positional embeddings
mask_tokens = mask_tokens + self.positional_embedding(
grid_height, grid_width, train=False
).to(mask_tokens.dtype)
x = x + self.latent_token_positional_embedding[:seq_len]
x = torch.cat([mask_tokens, x], dim=1)
# Pre-norm and reshape for transformer
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # (seq_len, B, width)
# Apply transformer layers
for layer in self.transformer:
x = layer(x, attention_mask=None)
x = x.permute(1, 0, 2) # (B, seq_len, width)
# Extract output tokens (excluding class token and latent tokens)
x = x[:, 1:1 + grid_height * grid_width]
x = self.ln_post(x)
# Reshape to spatial format and project to pixels
x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width)
x = self.ffn(x.contiguous())
x = self.rearrange(x)
# Downsample to target resolution
_, _, org_h, org_w = x.shape
x = self.down_scale(x, input_size=(org_h, org_w), target_size=(height, width))
x = self.conv_out(x)
return x
|