mk322 commited on
Commit
a41cb10
Β·
verified Β·
1 Parent(s): 546159e

Upload pixel_decoder_mae.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pixel_decoder_mae.py +235 -0
pixel_decoder_mae.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pixel Decoder: ViT-MAE style decoder following RAE architecture.
3
+ Takes 576Γ—embed_dim ViT features and reconstructs 384Γ—384Γ—3 images.
4
+ Architecture: ViT-L decoder (24 layers, hidden=1024, heads=16, intermediate=4096).
5
+ """
6
+
7
+ import math
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ # ─── Sincos Positional Embeddings ───────────────────────────────────────────
15
+
16
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
17
+ grid_h = np.arange(grid_size, dtype=np.float32)
18
+ grid_w = np.arange(grid_size, dtype=np.float32)
19
+ grid = np.meshgrid(grid_w, grid_h)
20
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
21
+
22
+ emb_h = get_1d_sincos_pos_embed(embed_dim // 2, grid[0].reshape(-1))
23
+ emb_w = get_1d_sincos_pos_embed(embed_dim // 2, grid[1].reshape(-1))
24
+ emb = np.concatenate([emb_h, emb_w], axis=1)
25
+
26
+ if add_cls_token:
27
+ emb = np.concatenate([np.zeros([1, embed_dim]), emb], axis=0)
28
+ return emb
29
+
30
+
31
+ def get_1d_sincos_pos_embed(embed_dim, pos):
32
+ omega = np.arange(embed_dim // 2, dtype=float)
33
+ omega /= embed_dim / 2.0
34
+ omega = 1.0 / 10000**omega
35
+
36
+ pos = pos.reshape(-1)
37
+ out = np.einsum("m,d->md", pos, omega)
38
+ return np.concatenate([np.sin(out), np.cos(out)], axis=1)
39
+
40
+
41
+ # ─── Transformer Components ────────────────────────────────────────────────
42
+
43
+ class MAESelfAttention(nn.Module):
44
+ def __init__(self, hidden_size, num_heads, qkv_bias=True, attn_drop=0.0, proj_drop=0.0):
45
+ super().__init__()
46
+ self.num_heads = num_heads
47
+ self.head_dim = hidden_size // num_heads
48
+
49
+ self.query = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
50
+ self.key = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
51
+ self.value = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
52
+ self.out_proj = nn.Linear(hidden_size, hidden_size)
53
+ self.attn_drop = attn_drop
54
+
55
+ def forward(self, x):
56
+ B, N, C = x.shape
57
+ q = self.query(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
58
+ k = self.key(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
59
+ v = self.value(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
60
+
61
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop if self.training else 0.0)
62
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
63
+ return self.out_proj(x)
64
+
65
+
66
+ class MAEBlock(nn.Module):
67
+ """Standard ViT block: pre-norm self-attention + pre-norm FFN."""
68
+ def __init__(self, hidden_size, num_heads, intermediate_size, hidden_act="gelu",
69
+ qkv_bias=True, layer_norm_eps=1e-6):
70
+ super().__init__()
71
+ self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
72
+ self.attention = MAESelfAttention(hidden_size, num_heads, qkv_bias=qkv_bias)
73
+ self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
74
+ self.intermediate = nn.Linear(hidden_size, intermediate_size)
75
+ self.output_proj = nn.Linear(intermediate_size, hidden_size)
76
+ self.act_fn = nn.GELU()
77
+
78
+ def forward(self, x):
79
+ # Self-attention with residual
80
+ x = x + self.attention(self.layernorm_before(x))
81
+ # FFN with residual
82
+ h = self.layernorm_after(x)
83
+ h = self.act_fn(self.intermediate(h))
84
+ x = x + self.output_proj(h)
85
+ return x
86
+
87
+
88
+ # ─── Main Pixel Decoder ────────────────────────────────────────────────────
89
+
90
+ class PixelDecoderMAE(nn.Module):
91
+ """
92
+ ViT-MAE style pixel decoder following RAE.
93
+
94
+ Input: [B, 576, input_dim] ViT features (or FAE-reconstructed features)
95
+ Output: [B, 3, 384, 384] reconstructed images
96
+
97
+ Architecture (ViT-L):
98
+ - Linear projection: input_dim β†’ decoder_hidden_size
99
+ - Trainable CLS token + sincos positional embeddings
100
+ - 24 Transformer blocks
101
+ - LayerNorm + linear head β†’ patch_sizeΒ² Γ— 3 per token
102
+ - Unpatchify β†’ full image
103
+ """
104
+
105
+ def __init__(self, input_dim=1152, decoder_hidden_size=1024,
106
+ decoder_num_layers=24, decoder_num_heads=16,
107
+ decoder_intermediate_size=4096, patch_size=16,
108
+ img_size=384, num_channels=3, layer_norm_eps=1e-6):
109
+ super().__init__()
110
+ self.img_size = img_size
111
+ self.patch_size = patch_size
112
+ self.num_channels = num_channels
113
+ self.grid_size = img_size // patch_size # 24
114
+ self.num_patches = self.grid_size ** 2 # 576
115
+
116
+ # Project encoder features to decoder dimension + normalize
117
+ self.decoder_embed = nn.Linear(input_dim, decoder_hidden_size)
118
+ self.embed_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps)
119
+
120
+ # Trainable CLS token
121
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
122
+
123
+ # Fixed sincos positional embeddings (576 patches + 1 CLS)
124
+ pos_embed = get_2d_sincos_pos_embed(decoder_hidden_size, self.grid_size, add_cls_token=True)
125
+ self.decoder_pos_embed = nn.Parameter(
126
+ torch.from_numpy(pos_embed).float().unsqueeze(0),
127
+ requires_grad=False
128
+ )
129
+
130
+ # Transformer decoder blocks
131
+ self.decoder_layers = nn.ModuleList([
132
+ MAEBlock(
133
+ hidden_size=decoder_hidden_size,
134
+ num_heads=decoder_num_heads,
135
+ intermediate_size=decoder_intermediate_size,
136
+ layer_norm_eps=layer_norm_eps,
137
+ )
138
+ for _ in range(decoder_num_layers)
139
+ ])
140
+
141
+ self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps)
142
+
143
+ # Prediction head: project to pixel patches
144
+ self.decoder_pred = nn.Linear(
145
+ decoder_hidden_size, patch_size ** 2 * num_channels
146
+ )
147
+
148
+ self._init_weights()
149
+
150
+ def _init_weights(self):
151
+ nn.init.normal_(self.cls_token, std=0.02)
152
+ # Initialize decoder_embed like a linear layer
153
+ nn.init.xavier_uniform_(self.decoder_embed.weight)
154
+ if self.decoder_embed.bias is not None:
155
+ nn.init.zeros_(self.decoder_embed.bias)
156
+ # Initialize decoder_pred
157
+ nn.init.xavier_uniform_(self.decoder_pred.weight)
158
+ if self.decoder_pred.bias is not None:
159
+ nn.init.zeros_(self.decoder_pred.bias)
160
+
161
+ def unpatchify(self, x):
162
+ """
163
+ x: [B, num_patches, patch_sizeΒ²Γ—3]
164
+ Returns: [B, 3, H, W]
165
+ """
166
+ p = self.patch_size
167
+ h = w = self.grid_size
168
+ c = self.num_channels
169
+
170
+ x = x.reshape(-1, h, w, p, p, c)
171
+ x = torch.einsum("nhwpqc->nchpwq", x)
172
+ return x.reshape(-1, c, h * p, w * p)
173
+
174
+ def forward(self, features, noise_tau=0.0):
175
+ """
176
+ Args:
177
+ features: [B, 576, input_dim] ViT features
178
+ noise_tau: max noise level applied AFTER normalization (where stdβ‰ˆ1)
179
+ Returns:
180
+ images: [B, 3, 384, 384] reconstructed images in [-1, 1]
181
+ """
182
+ # Project to decoder dimension and normalize
183
+ x = self.embed_norm(self.decoder_embed(features)) # [B, 576, decoder_hidden]
184
+
185
+ # Add noise after normalization (features now have stdβ‰ˆ1, so tau=0.8 is meaningful)
186
+ if noise_tau > 0 and self.training:
187
+ noise_sigma = noise_tau * torch.rand(
188
+ (x.size(0),) + (1,) * (len(x.shape) - 1), device=x.device
189
+ )
190
+ x = x + noise_sigma * torch.randn_like(x)
191
+
192
+ # Prepend CLS token
193
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
194
+ x = torch.cat([cls_tokens, x], dim=1) # [B, 577, decoder_hidden]
195
+
196
+ # Add positional embeddings
197
+ x = x + self.decoder_pos_embed
198
+
199
+ # Transformer blocks
200
+ for layer in self.decoder_layers:
201
+ x = layer(x)
202
+
203
+ x = self.decoder_norm(x)
204
+
205
+ # Predict pixel patches (remove CLS token)
206
+ x = self.decoder_pred(x[:, 1:, :]) # [B, 576, patch_sizeΒ²Γ—3]
207
+
208
+ # Unpatchify to full image
209
+ img = self.unpatchify(x) # [B, 3, 384, 384]
210
+
211
+ return img
212
+
213
+
214
+ class PatchGANDiscriminator(nn.Module):
215
+ """PatchGAN discriminator for adversarial loss."""
216
+
217
+ def __init__(self, in_channels=3, ndf=64):
218
+ super().__init__()
219
+ self.model = nn.Sequential(
220
+ nn.Conv2d(in_channels, ndf, 4, stride=2, padding=1),
221
+ nn.LeakyReLU(0.2, inplace=True),
222
+ nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1),
223
+ nn.InstanceNorm2d(ndf * 2),
224
+ nn.LeakyReLU(0.2, inplace=True),
225
+ nn.Conv2d(ndf * 2, ndf * 4, 4, stride=2, padding=1),
226
+ nn.InstanceNorm2d(ndf * 4),
227
+ nn.LeakyReLU(0.2, inplace=True),
228
+ nn.Conv2d(ndf * 4, ndf * 8, 4, stride=1, padding=1),
229
+ nn.InstanceNorm2d(ndf * 8),
230
+ nn.LeakyReLU(0.2, inplace=True),
231
+ nn.Conv2d(ndf * 8, 1, 4, stride=1, padding=1),
232
+ )
233
+
234
+ def forward(self, x):
235
+ return self.model(x)