mk322 commited on
Commit
a4e88c8
·
verified ·
1 Parent(s): 8e9dca4

Upload fae_spatial.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. fae_spatial.py +205 -0
fae_spatial.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FAE with CNN spatial pooling for token reduction.
2
+
3
+ Encoder: CNN downsample (24×24 → H'×W') + self-attention + project to latent_dim
4
+ Decoder: project up + ViT layers at compressed resolution + CNN upsample (H'×W' → 24×24)
5
+
6
+ pool_factor=2: 576 → 144 tokens (s2)
7
+ pool_factor=4: 576 → 36 tokens (s4)
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import math
14
+ import sys, os
15
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16
+ from utils import RMSNorm
17
+ from models.feature_decoder import RotaryPositionalEmbedding2D, ViTDecoderBlock
18
+
19
+
20
+ class CNNDownsample(nn.Module):
21
+ """Spatial downsampling with strided convolutions.
22
+ Each layer does 2x downsample. Stacks log2(pool_factor) layers.
23
+ """
24
+
25
+ def __init__(self, dim, pool_factor):
26
+ super().__init__()
27
+ assert pool_factor in (2, 4), f"pool_factor must be 2 or 4, got {pool_factor}"
28
+ num_layers = int(math.log2(pool_factor))
29
+ layers = []
30
+ for _ in range(num_layers):
31
+ layers.extend([
32
+ nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1),
33
+ nn.GELU(),
34
+ ])
35
+ self.net = nn.Sequential(*layers)
36
+
37
+ def forward(self, x):
38
+ """x: [B, C, H, W] → [B, C, H/pf, W/pf]"""
39
+ return self.net(x)
40
+
41
+
42
+ class CNNUpsample(nn.Module):
43
+ """Spatial upsampling with transposed convolutions.
44
+ Each layer does 2x upsample. Stacks log2(pool_factor) layers.
45
+ """
46
+
47
+ def __init__(self, dim, pool_factor):
48
+ super().__init__()
49
+ assert pool_factor in (2, 4), f"pool_factor must be 2 or 4, got {pool_factor}"
50
+ num_layers = int(math.log2(pool_factor))
51
+ layers = []
52
+ for _ in range(num_layers):
53
+ layers.extend([
54
+ nn.ConvTranspose2d(dim, dim, kernel_size=4, stride=2, padding=1),
55
+ nn.GELU(),
56
+ ])
57
+ self.net = nn.Sequential(*layers)
58
+
59
+ def forward(self, x):
60
+ """x: [B, C, H', W'] → [B, C, H'*pf, W'*pf]"""
61
+ return self.net(x)
62
+
63
+
64
+ class FAESpatialEncoder(nn.Module):
65
+ """FAE Encoder with CNN spatial pooling.
66
+
67
+ Input: [B, 576, embed_dim]
68
+ Output: [B, N_compressed, latent_dim]
69
+ where N_compressed = (24/pool_factor)^2
70
+ """
71
+
72
+ def __init__(self, embed_dim=1152, latent_dim=32, num_heads=16,
73
+ pool_factor=2, grid_size=24, use_vae=True):
74
+ super().__init__()
75
+ self.embed_dim = embed_dim
76
+ self.latent_dim = latent_dim
77
+ self.pool_factor = pool_factor
78
+ self.grid_size = grid_size
79
+ self.compressed_grid = grid_size // pool_factor
80
+ self.use_vae = use_vae
81
+
82
+ # CNN spatial downsampling
83
+ self.downsample = CNNDownsample(embed_dim, pool_factor)
84
+
85
+ # Self-attention at compressed resolution (pre-norm)
86
+ self.norm1 = RMSNorm(embed_dim)
87
+ self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
88
+
89
+ # SwiGLU FFN
90
+ self.norm2 = RMSNorm(embed_dim)
91
+ ffn_dim = int(embed_dim * 2.7)
92
+ self.w1 = nn.Linear(embed_dim, ffn_dim, bias=False)
93
+ self.w2 = nn.Linear(ffn_dim, embed_dim, bias=False)
94
+ self.w3 = nn.Linear(embed_dim, ffn_dim, bias=False)
95
+
96
+ # Per-token projection to latent dim
97
+ self.proj = nn.Linear(embed_dim, latent_dim)
98
+
99
+ # VAE heads
100
+ if use_vae:
101
+ self.mu_head = nn.Linear(latent_dim, latent_dim)
102
+ self.logvar_head = nn.Linear(latent_dim, latent_dim)
103
+
104
+ def forward(self, x):
105
+ """
106
+ Args:
107
+ x: [B, N, embed_dim] where N = grid_size^2 = 576
108
+ Returns:
109
+ z_sample: [B, N_compressed, latent_dim]
110
+ mu, logvar: same shape
111
+ """
112
+ B, N, D = x.shape
113
+
114
+ # Reshape to 2D and downsample
115
+ x = x.transpose(1, 2).reshape(B, D, self.grid_size, self.grid_size)
116
+ x = self.downsample(x) # [B, D, H', W']
117
+ x = x.flatten(2).transpose(1, 2) # [B, N_compressed, D]
118
+
119
+ # Self-attention
120
+ normed = self.norm1(x)
121
+ x = x + self.self_attn(normed, normed, normed)[0]
122
+
123
+ # SwiGLU FFN
124
+ h = self.norm2(x)
125
+ x = x + self.w2(F.silu(self.w1(h)) * self.w3(h))
126
+
127
+ # Project to latent
128
+ z = self.proj(x)
129
+
130
+ if not self.use_vae:
131
+ return z, z, torch.zeros_like(z)
132
+
133
+ mu = self.mu_head(z)
134
+ logvar = self.logvar_head(z)
135
+
136
+ if self.training:
137
+ std = torch.exp(0.5 * logvar)
138
+ z_sample = mu + std * torch.randn_like(std)
139
+ else:
140
+ z_sample = mu
141
+
142
+ return z_sample, mu, logvar
143
+
144
+
145
+ class FAESpatialDecoder(nn.Module):
146
+ """FAE Decoder with CNN spatial upsampling.
147
+
148
+ Input: [B, N_compressed, latent_dim]
149
+ Output: [B, 576, output_dim]
150
+
151
+ ViT layers operate at compressed resolution, then CNN upsamples.
152
+ """
153
+
154
+ def __init__(self, latent_dim=32, output_dim=1152, num_layers=6,
155
+ num_heads=16, ffn_mult=2.7, pool_factor=2, grid_size=24):
156
+ super().__init__()
157
+ self.output_dim = output_dim
158
+ self.pool_factor = pool_factor
159
+ self.grid_size = grid_size
160
+ self.compressed_grid = grid_size // pool_factor
161
+
162
+ # Project latent up to full dim
163
+ self.input_proj = nn.Linear(latent_dim, output_dim)
164
+
165
+ # RoPE at compressed grid resolution
166
+ head_dim = output_dim // num_heads
167
+ self.rope = RotaryPositionalEmbedding2D(head_dim, grid_size=self.compressed_grid)
168
+
169
+ # Transformer layers at compressed resolution
170
+ self.layers = nn.ModuleList([
171
+ ViTDecoderBlock(output_dim, num_heads, ffn_mult)
172
+ for _ in range(num_layers)
173
+ ])
174
+ self.pre_upsample_norm = RMSNorm(output_dim)
175
+
176
+ # CNN spatial upsampling
177
+ self.upsample = CNNUpsample(output_dim, pool_factor)
178
+
179
+ # Final projection after upsample (refine features)
180
+ self.final_norm = RMSNorm(output_dim)
181
+
182
+ def forward(self, z):
183
+ """
184
+ Args:
185
+ z: [B, N_compressed, latent_dim]
186
+ Returns:
187
+ x_hat: [B, N_full, output_dim] where N_full = grid_size^2
188
+ """
189
+ B = z.shape[0]
190
+ x = self.input_proj(z) # [B, N_compressed, output_dim]
191
+
192
+ rope_cos, rope_sin = self.rope(x.shape[1], x.device)
193
+
194
+ for layer in self.layers:
195
+ x = layer(x, rope_cos, rope_sin)
196
+
197
+ x = self.pre_upsample_norm(x)
198
+
199
+ # Reshape to 2D and upsample
200
+ x = x.transpose(1, 2).reshape(B, self.output_dim,
201
+ self.compressed_grid, self.compressed_grid)
202
+ x = self.upsample(x) # [B, output_dim, grid_size, grid_size]
203
+ x = x.flatten(2).transpose(1, 2) # [B, N_full, output_dim]
204
+
205
+ return self.final_norm(x)