File size: 7,718 Bytes
7bef20f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""Resolution-aware decoder for VibeToken.

Vision Transformer-based decoder with flexible output resolutions.
"""

import torch
import torch.nn as nn
from typing import Optional, Tuple
from einops.layers.torch import Rearrange

from .blocks import ResidualAttentionBlock, ResizableBlur, _expand_token
from .embeddings import FuzzyEmbedding


class ResolutionDecoder(nn.Module):
    """Vision Transformer decoder with flexible resolution support.
    
    Decodes latent tokens back to images with support for variable
    output resolutions and patch sizes.
    """
    
    # Model size configurations
    MODEL_CONFIGS = {
        "small": {"width": 512, "num_layers": 8, "num_heads": 8},
        "base": {"width": 768, "num_layers": 12, "num_heads": 12},
        "large": {"width": 1024, "num_layers": 24, "num_heads": 16},
    }
    
    def __init__(self, config):
        """Initialize ResolutionDecoder.
        
        Args:
            config: OmegaConf config with model parameters.
        """
        super().__init__()
        self.config = config
        
        # Extract config values
        vq_config = config.model.vq_model if hasattr(config.model, 'vq_model') else config.model
        self.image_size = getattr(config.dataset.preprocessing, 'crop_size', 512) if hasattr(config, 'dataset') else 512
        self.patch_size = getattr(vq_config, 'vit_dec_patch_size', 32)
        self.model_size = getattr(vq_config, 'vit_dec_model_size', 'large')
        self.num_latent_tokens = getattr(vq_config, 'num_latent_tokens', 256)
        self.token_size = getattr(vq_config, 'token_size', 256)
        self.is_legacy = getattr(vq_config, 'is_legacy', False)
        
        if self.is_legacy:
            raise NotImplementedError("Legacy mode is not supported in this inference-only version")
        
        # Get model dimensions
        model_cfg = self.MODEL_CONFIGS[self.model_size]
        self.width = model_cfg["width"]
        self.num_layers = model_cfg["num_layers"]
        self.num_heads = model_cfg["num_heads"]
        
        # Input projection
        self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True)
        
        # Embeddings
        scale = self.width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
        self.positional_embedding = FuzzyEmbedding(1024, scale, self.width)
        self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
        self.latent_token_positional_embedding = nn.Parameter(
            scale * torch.randn(self.num_latent_tokens, self.width)
        )
        self.ln_pre = nn.LayerNorm(self.width)
        
        # Transformer layers
        self.transformer = nn.ModuleList([
            ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0)
            for _ in range(self.num_layers)
        ])
        
        # Output projection
        self.ln_post = nn.LayerNorm(self.width)
        self.ffn = nn.Conv2d(
            self.width, self.patch_size * self.patch_size * 3, 
            kernel_size=1, padding=0, bias=True
        )
        self.rearrange = Rearrange(
            'b (p1 p2 c) h w -> b c (h p1) (w p2)',
            p1=self.patch_size, p2=self.patch_size
        )
        self.down_scale = ResizableBlur(channels=3, max_kernel_size=9, init_type="lanczos")
        self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)

    def _select_patch_size(self, height: int, width: int) -> int:
        """Select appropriate patch size based on target resolution.
        
        Args:
            height: Target image height.
            width: Target image width.
            
        Returns:
            Selected patch size.
        """
        total_pixels = height * width
        min_patches, max_patches = 256, 1024
        
        possible_sizes = []
        for ps in [8, 16, 32]:
            grid_h = height // ps
            grid_w = width // ps
            total_patches = grid_h * grid_w
            if min_patches <= total_patches <= max_patches:
                possible_sizes.append(ps)
        
        if not possible_sizes:
            # Find closest to target range
            patch_counts = []
            for ps in [8, 16, 32]:
                grid_h = height // ps
                grid_w = width // ps
                patch_counts.append((ps, grid_h * grid_w))
            patch_counts.sort(key=lambda x: min(abs(x[1] - min_patches), abs(x[1] - max_patches)))
            return patch_counts[0][0]
        
        return possible_sizes[0]

    def forward(
        self,
        z_quantized: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        decode_patch_size: Optional[int] = None,
    ) -> torch.Tensor:
        """Decode latent tokens to images.
        
        Args:
            z_quantized: Quantized latent features (B, C, H, W).
            attention_mask: Optional attention mask.
            height: Target image height.
            width: Target image width.
            decode_patch_size: Optional custom patch size for decoding.
            
        Returns:
            Decoded images (B, 3, height, width), values in [0, 1].
        """
        N, C, H, W = z_quantized.shape
        
        # Reshape and project input
        x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1)  # (N, seq_len, C*H)
        x = self.decoder_embed(x)
        
        batchsize, seq_len, _ = x.shape
        
        # Default output size
        if height is None:
            height = self.image_size
        if width is None:
            width = self.image_size
        
        # Determine patch size
        if decode_patch_size is None:
            selected_patch_size = self._select_patch_size(height, width)
        else:
            selected_patch_size = decode_patch_size
        
        if isinstance(selected_patch_size, int):
            selected_patch_size = (selected_patch_size, selected_patch_size)
        
        grid_height = height // selected_patch_size[0]
        grid_width = width // selected_patch_size[1]
        
        # Create mask tokens for output positions
        mask_tokens = self.mask_token.repeat(batchsize, grid_height * grid_width, 1).to(x.dtype)
        mask_tokens = torch.cat([
            _expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
            mask_tokens
        ], dim=1)
        
        # Add positional embeddings
        mask_tokens = mask_tokens + self.positional_embedding(
            grid_height, grid_width, train=False
        ).to(mask_tokens.dtype)
        
        x = x + self.latent_token_positional_embedding[:seq_len]
        x = torch.cat([mask_tokens, x], dim=1)
        
        # Pre-norm and reshape for transformer
        x = self.ln_pre(x)
        x = x.permute(1, 0, 2)  # (seq_len, B, width)
        
        # Apply transformer layers
        for layer in self.transformer:
            x = layer(x, attention_mask=None)
        
        x = x.permute(1, 0, 2)  # (B, seq_len, width)
        
        # Extract output tokens (excluding class token and latent tokens)
        x = x[:, 1:1 + grid_height * grid_width]
        x = self.ln_post(x)
        
        # Reshape to spatial format and project to pixels
        x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width)
        x = self.ffn(x.contiguous())
        x = self.rearrange(x)
        
        # Downsample to target resolution
        _, _, org_h, org_w = x.shape
        x = self.down_scale(x, input_size=(org_h, org_w), target_size=(height, width))
        x = self.conv_out(x)
        
        return x