""" Pixel Decoder: ViT-MAE style decoder following RAE architecture. Takes 576×embed_dim ViT features and reconstructs 384×384×3 images. Architecture: ViT-L decoder (24 layers, hidden=1024, heads=16, intermediate=4096). """ import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # ─── Sincos Positional Embeddings ─────────────────────────────────────────── def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size]) emb_h = get_1d_sincos_pos_embed(embed_dim // 2, grid[0].reshape(-1)) emb_w = get_1d_sincos_pos_embed(embed_dim // 2, grid[1].reshape(-1)) emb = np.concatenate([emb_h, emb_w], axis=1) if add_cls_token: emb = np.concatenate([np.zeros([1, embed_dim]), emb], axis=0) return emb def get_1d_sincos_pos_embed(embed_dim, pos): omega = np.arange(embed_dim // 2, dtype=float) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega pos = pos.reshape(-1) out = np.einsum("m,d->md", pos, omega) return np.concatenate([np.sin(out), np.cos(out)], axis=1) # ─── Transformer Components ──────────────────────────────────────────────── class MAESelfAttention(nn.Module): def __init__(self, hidden_size, num_heads, qkv_bias=True, attn_drop=0.0, proj_drop=0.0): super().__init__() self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.query = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.key = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.value = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.out_proj = nn.Linear(hidden_size, hidden_size) self.attn_drop = attn_drop def forward(self, x): B, N, C = x.shape q = self.query(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) k = self.key(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) v = self.value(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop if self.training else 0.0) x = x.permute(0, 2, 1, 3).reshape(B, N, C) return self.out_proj(x) class MAEBlock(nn.Module): """Standard ViT block: pre-norm self-attention + pre-norm FFN.""" def __init__(self, hidden_size, num_heads, intermediate_size, hidden_act="gelu", qkv_bias=True, layer_norm_eps=1e-6): super().__init__() self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.attention = MAESelfAttention(hidden_size, num_heads, qkv_bias=qkv_bias) self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.intermediate = nn.Linear(hidden_size, intermediate_size) self.output_proj = nn.Linear(intermediate_size, hidden_size) self.act_fn = nn.GELU() def forward(self, x): # Self-attention with residual x = x + self.attention(self.layernorm_before(x)) # FFN with residual h = self.layernorm_after(x) h = self.act_fn(self.intermediate(h)) x = x + self.output_proj(h) return x # ─── Main Pixel Decoder ──────────────────────────────────────────────────── class PixelDecoderMAE(nn.Module): """ ViT-MAE style pixel decoder following RAE. Input: [B, 576, input_dim] ViT features (or FAE-reconstructed features) Output: [B, 3, 384, 384] reconstructed images Architecture (ViT-L): - Linear projection: input_dim → decoder_hidden_size - Trainable CLS token + sincos positional embeddings - 24 Transformer blocks - LayerNorm + linear head → patch_size² × 3 per token - Unpatchify → full image """ def __init__(self, input_dim=1152, decoder_hidden_size=1024, decoder_num_layers=24, decoder_num_heads=16, decoder_intermediate_size=4096, patch_size=16, img_size=384, num_channels=3, layer_norm_eps=1e-6): super().__init__() self.img_size = img_size self.patch_size = patch_size self.num_channels = num_channels self.grid_size = img_size // patch_size # 24 self.num_patches = self.grid_size ** 2 # 576 # Project encoder features to decoder dimension + normalize self.decoder_embed = nn.Linear(input_dim, decoder_hidden_size) self.embed_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps) # Trainable CLS token self.cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) # Fixed sincos positional embeddings (576 patches + 1 CLS) pos_embed = get_2d_sincos_pos_embed(decoder_hidden_size, self.grid_size, add_cls_token=True) self.decoder_pos_embed = nn.Parameter( torch.from_numpy(pos_embed).float().unsqueeze(0), requires_grad=False ) # Transformer decoder blocks self.decoder_layers = nn.ModuleList([ MAEBlock( hidden_size=decoder_hidden_size, num_heads=decoder_num_heads, intermediate_size=decoder_intermediate_size, layer_norm_eps=layer_norm_eps, ) for _ in range(decoder_num_layers) ]) self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps) # Prediction head: project to pixel patches self.decoder_pred = nn.Linear( decoder_hidden_size, patch_size ** 2 * num_channels ) self._init_weights() def _init_weights(self): nn.init.normal_(self.cls_token, std=0.02) # Initialize decoder_embed like a linear layer nn.init.xavier_uniform_(self.decoder_embed.weight) if self.decoder_embed.bias is not None: nn.init.zeros_(self.decoder_embed.bias) # Initialize decoder_pred nn.init.xavier_uniform_(self.decoder_pred.weight) if self.decoder_pred.bias is not None: nn.init.zeros_(self.decoder_pred.bias) def unpatchify(self, x): """ x: [B, num_patches, patch_size²×3] Returns: [B, 3, H, W] """ p = self.patch_size h = w = self.grid_size c = self.num_channels x = x.reshape(-1, h, w, p, p, c) x = torch.einsum("nhwpqc->nchpwq", x) return x.reshape(-1, c, h * p, w * p) def forward(self, features, noise_tau=0.0): """ Args: features: [B, 576, input_dim] ViT features noise_tau: max noise level applied AFTER normalization (where std≈1) Returns: images: [B, 3, 384, 384] reconstructed images in [-1, 1] """ # Project to decoder dimension and normalize x = self.embed_norm(self.decoder_embed(features)) # [B, 576, decoder_hidden] # Add noise after normalization (features now have std≈1, so tau=0.8 is meaningful) if noise_tau > 0 and self.training: noise_sigma = noise_tau * torch.rand( (x.size(0),) + (1,) * (len(x.shape) - 1), device=x.device ) x = x + noise_sigma * torch.randn_like(x) # Prepend CLS token cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat([cls_tokens, x], dim=1) # [B, 577, decoder_hidden] # Add positional embeddings x = x + self.decoder_pos_embed # Transformer blocks for layer in self.decoder_layers: x = layer(x) x = self.decoder_norm(x) # Predict pixel patches (remove CLS token) x = self.decoder_pred(x[:, 1:, :]) # [B, 576, patch_size²×3] # Unpatchify to full image img = self.unpatchify(x) # [B, 3, 384, 384] return img class PatchGANDiscriminator(nn.Module): """PatchGAN discriminator for adversarial loss.""" def __init__(self, in_channels=3, ndf=64): super().__init__() self.model = nn.Sequential( nn.Conv2d(in_channels, ndf, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1), nn.InstanceNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, 4, stride=2, padding=1), nn.InstanceNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, 4, stride=1, padding=1), nn.InstanceNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 8, 1, 4, stride=1, padding=1), ) def forward(self, x): return self.model(x)