File size: 9,032 Bytes
7bef20f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""Resolution-aware encoder for VibeToken.

Vision Transformer-based encoder with flexible patch sizes for variable-resolution inputs.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from einops import rearrange
from torch import Tensor, vmap
import numpy as np

from .blocks import ResidualAttentionBlock, _expand_token
from .embeddings import FuzzyEmbedding, to_2tuple


class ResolutionEncoder(nn.Module):
    """Vision Transformer encoder with flexible resolution support.
    
    Encodes images into latent tokens using a ViT architecture with
    support for variable input resolutions and patch sizes.
    """
    
    # Model size configurations
    MODEL_CONFIGS = {
        "small": {"width": 512, "num_layers": 8, "num_heads": 8},
        "base": {"width": 768, "num_layers": 12, "num_heads": 12},
        "large": {"width": 1024, "num_layers": 24, "num_heads": 16},
    }
    
    def __init__(self, config):
        """Initialize ResolutionEncoder.
        
        Args:
            config: OmegaConf config with model parameters.
        """
        super().__init__()
        self.config = config
        
        # Extract config values
        vq_config = config.model.vq_model if hasattr(config.model, 'vq_model') else config.model
        self.patch_size = getattr(vq_config, 'vit_enc_patch_size', 32)
        self.model_size = getattr(vq_config, 'vit_enc_model_size', 'large')
        self.num_latent_tokens = getattr(vq_config, 'num_latent_tokens', 256)
        self.token_size = getattr(vq_config, 'token_size', 256)
        self.is_legacy = getattr(vq_config, 'is_legacy', False)
        
        # Handle VAE mode (doubles token size for mean+std)
        quantize_mode = getattr(vq_config, 'quantize_mode', 'vq')
        if quantize_mode == "vae":
            self.token_size = self.token_size * 2
        
        # Get model dimensions from config
        model_cfg = self.MODEL_CONFIGS[self.model_size]
        self.width = model_cfg["width"]
        self.num_layers = model_cfg["num_layers"]
        self.num_heads = model_cfg["num_heads"]
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(
            in_channels=3, out_channels=self.width,
            kernel_size=self.patch_size, stride=self.patch_size, bias=True
        )
        
        # Embeddings
        scale = self.width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
        self.positional_embedding = FuzzyEmbedding(1024, scale, self.width)
        self.latent_token_positional_embedding = nn.Parameter(
            scale * torch.randn(self.num_latent_tokens, self.width)
        )
        self.ln_pre = nn.LayerNorm(self.width)
        
        # Transformer layers
        self.transformer = nn.ModuleList([
            ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0)
            for _ in range(self.num_layers)
        ])
        
        # Output projection
        self.ln_post = nn.LayerNorm(self.width)
        self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
        
        # Cache for pseudo-inverse matrices
        self.pinvs = {}

    def _resize(self, x: Tensor, shape: Tuple[int, int]) -> Tensor:
        """Bilinear resize of 2D tensor."""
        x_resized = F.interpolate(
            x[None, None, ...], shape, mode="bilinear", antialias=False
        )
        return x_resized[0, 0, ...]

    def _calculate_pinv(
        self, 
        old_shape: Tuple[int, int], 
        new_shape: Tuple[int, int],
        device: torch.device,
    ) -> Tensor:
        """Calculate pseudo-inverse of resize matrix for FlexiViT."""
        mat = []
        for i in range(np.prod(old_shape)):
            basis_vec = torch.zeros(old_shape, device=device)
            basis_vec[np.unravel_index(i, old_shape)] = 1.0
            mat.append(self._resize(basis_vec, new_shape).reshape(-1))
        resize_matrix = torch.stack(mat)
        return torch.linalg.pinv(resize_matrix)

    def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: Tuple[int, int]) -> Tensor:
        """Resize patch embedding kernel to new patch size (FlexiViT).
        
        Args:
            patch_embed: Original weight tensor (out_ch, in_ch, H, W).
            new_patch_size: Target (H, W) patch size.
            
        Returns:
            Resized weight tensor.
        """
        base_size = to_2tuple(self.patch_size)
        if base_size == new_patch_size:
            return patch_embed

        if new_patch_size not in self.pinvs:
            self.pinvs[new_patch_size] = self._calculate_pinv(
                base_size, new_patch_size, device=patch_embed.device
            )
        pinv = self.pinvs[new_patch_size]

        def resample_patch_embed(pe: Tensor) -> Tensor:
            h, w = new_patch_size
            original_dtype = pe.dtype
            resampled = pinv @ pe.float().reshape(-1)
            return rearrange(resampled.to(original_dtype), "(h w) -> h w", h=h, w=w)

        v_resample = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)
        return v_resample(patch_embed)

    def apply_flexivit_patch_embed(self, x: Tensor, target_patch_size: Tuple[int, int]) -> Tensor:
        """Apply patch embedding with flexible patch size.
        
        Args:
            x: Input image tensor (B, 3, H, W).
            target_patch_size: Target patch size (H, W).
            
        Returns:
            Patch embeddings (B, C, grid_H, grid_W).
        """
        patch_size = to_2tuple(target_patch_size)
        
        if patch_size == to_2tuple(self.patch_size):
            weight = self.patch_embed.weight
        else:
            weight = self.resize_patch_embed(self.patch_embed.weight, patch_size)

        return F.conv2d(x, weight, bias=self.patch_embed.bias, stride=patch_size)

    def forward(
        self,
        pixel_values: torch.Tensor,
        latent_tokens: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encode_patch_size: Optional[Tuple[int, int]] = None,
    ) -> torch.Tensor:
        """Encode images to latent tokens.
        
        Args:
            pixel_values: Input images (B, 3, H, W), values in [0, 1].
            latent_tokens: Learnable latent tokens (num_latent, width).
            attention_mask: Optional attention mask.
            encode_patch_size: Optional custom patch size for encoding.
            
        Returns:
            Encoded latent features (B, token_size, 1, num_latent).
        """
        batch_size, _, H, W = pixel_values.shape
        
        # Determine patch size
        if encode_patch_size is None:
            target_patch_size = (self.patch_size, self.patch_size)
        elif isinstance(encode_patch_size, int):
            target_patch_size = (encode_patch_size, encode_patch_size)
        else:
            target_patch_size = encode_patch_size
        
        # Apply flexible patch embedding
        x = self.apply_flexivit_patch_embed(pixel_values, target_patch_size)
        
        # Flatten spatial dimensions
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0, 2, 1)  # (B, num_patches, width)
        
        # Add class embedding
        x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
        
        # Compute grid dimensions
        grid_height = H // target_patch_size[0]
        grid_width = W // target_patch_size[1]
        
        # Add positional embeddings to latent tokens
        num_latent = latent_tokens.shape[0]
        latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
        latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)[:num_latent]
        
        # Add positional embeddings to image patches
        x = x + self.positional_embedding(grid_height, grid_width, train=False, dtype=x.dtype)
        
        # Concatenate image patches and latent tokens
        x = torch.cat([x, latent_tokens], dim=1)
        
        # Pre-norm and reshape for transformer
        x = self.ln_pre(x)
        x = x.permute(1, 0, 2)  # (seq_len, B, width)
        
        # Apply transformer layers
        for layer in self.transformer:
            x = layer(x, attention_mask=None)
        
        x = x.permute(1, 0, 2)  # (B, seq_len, width)
        
        # Extract latent tokens
        latent_tokens = x[:, 1 + grid_height * grid_width:]
        latent_tokens = self.ln_post(latent_tokens)
        
        # Reshape and project to token size
        if self.is_legacy:
            latent_tokens = latent_tokens.reshape(batch_size, self.width, num_latent, 1)
        else:
            latent_tokens = latent_tokens.reshape(batch_size, num_latent, self.width, 1).permute(0, 2, 1, 3)
        
        latent_tokens = self.conv_out(latent_tokens)
        latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, num_latent)
        
        return latent_tokens