keysun89 commited on
Commit
bc1ad01
·
verified ·
1 Parent(s): 3eda9ab

Create generator.py

Browse files
Files changed (1) hide show
  1. generator.py +388 -0
generator.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import random
6
+
7
+ # --- Helper Modules ---
8
+
9
+ class LeakyReLU(nn.Module):
10
+ """
11
+ Custom LeakyReLU implementation to allow for a fixed negative slope
12
+ and in-place operation.
13
+ """
14
+ def __init__(self, negative_slope=0.2, inplace=False):
15
+ super().__init__()
16
+ self.negative_slope = negative_slope
17
+ self.inplace = inplace
18
+
19
+ def forward(self, x):
20
+ return F.leaky_relu(x, self.negative_slope, self.inplace)
21
+
22
+ class PixelNorm(nn.Module):
23
+ """
24
+ Pixel-wise feature vector normalization.
25
+ """
26
+ def __init__(self):
27
+ super().__init__()
28
+
29
+ def forward(self, x):
30
+ # Epsilon added for numerical stability
31
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
32
+
33
+ class ModulatedConv2d(nn.Module):
34
+ """
35
+ This is the core building block of the StyleGAN2 synthesis network.
36
+ It applies style modulation and demodulation.
37
+ """
38
+ def __init__(self, in_channels, out_channels, kernel_size, style_dim, demodulate=True, upsample=False):
39
+ super().__init__()
40
+ self.in_channels = in_channels
41
+ self.out_channels = out_channels
42
+ self.kernel_size = kernel_size
43
+ self.style_dim = style_dim
44
+ self.demodulate = demodulate
45
+ self.upsample = upsample
46
+
47
+ # Standard convolution weights
48
+ self.weight = nn.Parameter(
49
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
50
+ )
51
+
52
+ # Affine transform (A) from style vector (w)
53
+ self.modulation = nn.Linear(style_dim, in_channels, bias=True)
54
+
55
+ # Initialize modulation bias to 1 (identity transform)
56
+ nn.init.constant_(self.modulation.bias, 1.0)
57
+
58
+ # Padding for the convolution
59
+ self.padding = (kernel_size - 1) // 2
60
+
61
+ # Upsampling filter (if needed)
62
+ if self.upsample:
63
+ # Using a simple bilinear filter
64
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
65
+
66
+ def forward(self, x, style):
67
+ # Store initial batch_size and in_channels
68
+ batch_size, in_channels_original, _, _ = x.shape
69
+
70
+ # 1. Modulate (Style-based feature scaling)
71
+ # style shape: [batch_size, style_dim]
72
+ # s shape: [batch_size, 1, in_channels, 1, 1]
73
+ s = self.modulation(style).view(batch_size, 1, in_channels_original, 1, 1)
74
+
75
+ # Get conv weights and combine with modulation
76
+ # w shape: [batch_size, out_channels, in_channels, k, k]
77
+ w = self.weight * s
78
+
79
+ # 2. Demodulate (Normalize weights to prevent scale explosion)
80
+ if self.demodulate:
81
+ # Calculate per-weight normalization factor
82
+ d = torch.rsqrt(torch.sum(w**2, dim=[2, 3, 4], keepdim=True) + 1e-8)
83
+ w = w * d
84
+
85
+ # 3. Upsample (if applicable)
86
+ if self.upsample:
87
+ x = self.up(x)
88
+
89
+ # Get current height and width *after* potential upsampling
90
+ current_height = x.shape[2]
91
+ current_width = x.shape[3]
92
+
93
+ # 4. Convolution
94
+ # Because weights are now per-batch, we need to group convolutions
95
+ # We reshape x and w to use a single grouped convolution operation
96
+
97
+ x = x.view(1, batch_size * in_channels_original, current_height, current_width)
98
+ w = w.view(batch_size * self.out_channels, in_channels_original, self.kernel_size, self.kernel_size)
99
+
100
+ # padding='same' is not supported for strided/grouped conv, so we use manual padding
101
+ x = F.conv2d(x, w, padding=self.padding, groups=batch_size)
102
+
103
+ # Reshape back to [batch_size, out_channels, h, w]
104
+ _, _, new_height, new_width = x.shape
105
+ x = x.view(batch_size, self.out_channels, new_height, new_width)
106
+
107
+ return x
108
+
109
+ class NoiseInjection(nn.Module):
110
+ """
111
+ Adds scaled noise to the feature maps.
112
+ """
113
+ def __init__(self, channels):
114
+ super().__init__()
115
+ # Learned scaling factor for the noise
116
+ self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))
117
+
118
+ def forward(self, x, noise=None):
119
+ if noise is None:
120
+ batch, _, height, width = x.shape
121
+ noise = torch.randn(batch, 1, height, width, device=x.device, dtype=x.dtype)
122
+
123
+ return x + self.weight * noise
124
+
125
+ class ConstantInput(nn.Module):
126
+ """
127
+ A learned constant 4x4 feature map to start the synthesis process.
128
+ """
129
+ def __init__(self, channels, size=4):
130
+ super().__init__()
131
+ self.input = nn.Parameter(torch.randn(1, channels, size, size))
132
+
133
+ def forward(self, batch_size):
134
+ return self.input.repeat(batch_size, 1, 1, 1)
135
+
136
+ class ToRGB(nn.Module):
137
+ """
138
+ Projects feature maps to an RGB image.
139
+ Uses a 1x1 modulated convolution.
140
+ """
141
+ def __init__(self, in_channels, out_channels, style_dim):
142
+ super().__init__()
143
+ # 1x1 convolution
144
+ self.conv = ModulatedConv2d(in_channels, out_channels, 1, style_dim, demodulate=False, upsample=False)
145
+ self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
146
+
147
+ def forward(self, x, style, skip=None):
148
+ x = self.conv(x, style)
149
+ x = x + self.bias
150
+
151
+ if skip is not None:
152
+ # Upsample the previous RGB output and add
153
+ skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
154
+ x = x + skip
155
+
156
+ return x
157
+
158
+ # --- Main Generator Components ---
159
+
160
+ class MappingNetwork(nn.Module):
161
+ """
162
+ Maps the initial latent vector Z to the intermediate style vector W.
163
+ """
164
+ def __init__(self, z_dim, w_dim, num_layers=8):
165
+ super().__init__()
166
+ self.z_dim = z_dim
167
+ self.w_dim = w_dim
168
+
169
+ layers = [PixelNorm()]
170
+ for i in range(num_layers):
171
+ layers.extend([
172
+ nn.Linear(z_dim if i == 0 else w_dim, w_dim),
173
+ LeakyReLU(0.2, inplace=True)
174
+ ])
175
+
176
+ self.mapping = nn.Sequential(*layers)
177
+
178
+ def forward(self, z):
179
+ # z shape: [batch_size, z_dim]
180
+ w = self.mapping(z)
181
+ # w shape: [batch_size, w_dim]
182
+ return w
183
+
184
+ class SynthesisBlock(nn.Module):
185
+ """
186
+ A single block in the Synthesis Network (e.g., 8x8 -> 16x16).
187
+ Contains upsampling, modulated convolutions, noise, and activation.
188
+ """
189
+ def __init__(self, in_channels, out_channels, style_dim):
190
+ super().__init__()
191
+ # First modulated conv with upsampling
192
+ self.conv1 = ModulatedConv2d(in_channels, out_channels, 3, style_dim, upsample=True)
193
+ self.noise1 = NoiseInjection(out_channels)
194
+ self.activate1 = LeakyReLU(0.2, inplace=True)
195
+
196
+ # Second modulated conv
197
+ self.conv2 = ModulatedConv2d(out_channels, out_channels, 3, style_dim, upsample=False)
198
+ self.noise2 = NoiseInjection(out_channels)
199
+ self.activate2 = LeakyReLU(0.2, inplace=True)
200
+
201
+ def forward(self, x, w, noise1, noise2):
202
+ x = self.conv1(x, w)
203
+ x = self.noise1(x, noise1)
204
+ x = self.activate1(x)
205
+
206
+ x = self.conv2(x, w)
207
+ x = self.noise2(x, noise2)
208
+ x = self.activate2(x)
209
+
210
+ return x
211
+
212
+ class SynthesisNetwork(nn.Module):
213
+ """
214
+ Builds the image from the style vector W.
215
+ """
216
+ def __init__(self, w_dim, img_channels, img_resolution=256, start_res=4, num_blocks=None):
217
+ super().__init__()
218
+ self.w_dim = w_dim
219
+ self.img_channels = img_channels
220
+ self.start_res = start_res
221
+
222
+ if num_blocks is None:
223
+ self.num_blocks = int(math.log2(img_resolution) - math.log2(start_res))
224
+ self.img_resolution = img_resolution
225
+ else:
226
+ self.num_blocks = num_blocks
227
+ self.img_resolution = start_res * (2**self.num_blocks)
228
+ print(f"Synthesis network created with {self.num_blocks} blocks, output resolution: {self.img_resolution}x{self.img_resolution}")
229
+
230
+ channels = {
231
+ 4: 512,
232
+ 8: 512,
233
+ 16: 512,
234
+ 32: 512,
235
+ 64: 256,
236
+ 128: 128,
237
+ 256: 64,
238
+ 512: 32,
239
+ 1024: 16,
240
+ }
241
+
242
+ self.input = ConstantInput(channels[start_res])
243
+
244
+ self.conv1 = ModulatedConv2d(channels[start_res], channels[start_res], 3, w_dim, upsample=False)
245
+ self.noise1 = NoiseInjection(channels[start_res])
246
+ self.activate1 = LeakyReLU(0.2, inplace=True)
247
+
248
+ self.to_rgb1 = ToRGB(channels[start_res], img_channels, w_dim)
249
+
250
+ self.blocks = nn.ModuleList()
251
+ self.to_rgbs = nn.ModuleList()
252
+
253
+ in_c = channels[start_res]
254
+
255
+ for i in range(self.num_blocks):
256
+ current_res = start_res * (2**(i+1))
257
+ out_c = channels.get(current_res, 16)
258
+ if current_res > 1024:
259
+ print(f"Warning: Resolution {current_res}x{current_res} not in channel map. Using {out_c} channels.")
260
+
261
+ self.blocks.append(SynthesisBlock(in_c, out_c, w_dim))
262
+ self.to_rgbs.append(ToRGB(out_c, img_channels, w_dim))
263
+
264
+ in_c = out_c
265
+
266
+ # Number of style vectors needed: 1 for initial conv1, 1 for initial to_rgb, and 3 per block (conv1, conv2, to_rgb)
267
+ self.num_styles = self.num_blocks * 3 + 2 # Corrected num_styles
268
+
269
+ def forward(self, w, noise=None):
270
+ # w shape: [batch_size, num_styles, w_dim]
271
+ if w.ndim == 2:
272
+ w = w.unsqueeze(1).repeat(1, self.num_styles, 1)
273
+
274
+ batch_size = w.shape[0]
275
+
276
+ # --- Handle Noise (generate if None) ---
277
+ if noise is None:
278
+ noise_list = []
279
+ # Noise for the initial 4x4 conv (self.conv1)
280
+ noise_list.append(torch.randn(batch_size, 1, self.start_res, self.start_res, device=w.device))
281
+
282
+ current_res = self.start_res
283
+ # Iterate through the synthesis blocks to generate noise for each
284
+ for i in range(self.num_blocks):
285
+ current_res *= 2 # This is the resolution *after* the current block's upsampling
286
+ # Noise for the first conv of the current block (after upsampling)
287
+ noise_list.append(torch.randn(batch_size, 1, current_res, current_res, device=w.device))
288
+ # Noise for the second conv of the current block (same resolution)
289
+ noise_list.append(torch.randn(batch_size, 1, current_res, current_res, device=w.device))
290
+ noise = noise_list
291
+
292
+ # --- 4x4 Block ---
293
+ x = self.input(batch_size)
294
+ x = self.conv1(x, w[:, 0]) # Style for initial conv1
295
+ x = self.noise1(x, noise[0]) # Noise for initial conv1
296
+ x = self.activate1(x)
297
+
298
+ skip = self.to_rgb1(x, w[:, 1]) # Style for initial ToRGB
299
+
300
+ # --- Main blocks (8x8 to img_resolution) ---
301
+ current_noise_idx_in_list = 1 # index for noise_list: noise[0] was used above
302
+ current_style_idx_in_w = 2 # index for w: w[:,0] and w[:,1] were used above
303
+
304
+ for i, (block, to_rgb) in enumerate(zip(self.blocks, self.to_rgbs)):
305
+ # Styles for this block
306
+ w_block_conv1 = w[:, current_style_idx_in_w]
307
+ w_block_conv2 = w[:, current_style_idx_in_w + 1]
308
+ w_block_to_rgb = w[:, current_style_idx_in_w + 2]
309
+
310
+ # Noises for this block
311
+ n_block_conv1 = noise[current_noise_idx_in_list]
312
+ n_block_conv2 = noise[current_noise_idx_in_list + 1]
313
+
314
+ x = block(x, w_block_conv1, n_block_conv1, n_block_conv2)
315
+
316
+ skip = to_rgb(x, w_block_to_rgb, skip)
317
+
318
+ # Increment indices for next block
319
+ current_style_idx_in_w += 3
320
+ current_noise_idx_in_list += 2
321
+
322
+ return skip # Final RGB image
323
+
324
+ class Generator(nn.Module):
325
+ """
326
+ The complete StyleGAN2 Generator.
327
+ Combines the Mapping and Synthesis networks.
328
+ """
329
+ def __init__(self, z_dim, w_dim, img_resolution, img_channels,
330
+ mapping_layers=8, num_synthesis_blocks=None):
331
+ super().__init__()
332
+ self.z_dim = z_dim
333
+ self.w_dim = w_dim
334
+
335
+ self.mapping = MappingNetwork(z_dim, w_dim, mapping_layers)
336
+
337
+ self.synthesis = SynthesisNetwork(
338
+ w_dim, img_channels, img_resolution, num_blocks=num_synthesis_blocks
339
+ )
340
+
341
+ self.num_styles = self.synthesis.num_styles
342
+ self.img_resolution = self.synthesis.img_resolution # Get final resolution
343
+
344
+ # For truncation trick
345
+ self.register_buffer('w_avg', torch.zeros(w_dim))
346
+
347
+ def update_w_avg(self, new_w, momentum=0.995):
348
+ """Helper to update the moving average of W"""
349
+ self.w_avg = torch.lerp(new_w.mean(0), self.w_avg, momentum)
350
+
351
+ def forward(self, z, truncation_psi=0.7, use_truncation=True,
352
+ style_mix_prob=0.0, noise=None):
353
+
354
+ # --- 1. Get W vector(s) ---
355
+
356
+ # Check if we're doing style mixing
357
+ do_style_mix = False
358
+ if isinstance(z, list) and len(z) == 2:
359
+ do_style_mix = True
360
+ z1, z2 = z
361
+ w1 = self.mapping(z1) # [batch, w_dim]
362
+ w2 = self.mapping(z2) # [batch, w_dim]
363
+ else:
364
+ w = self.mapping(z) # [batch, w_dim]
365
+ w1 = w
366
+ w2 = w
367
+
368
+ # --- 2. Truncation Trick ---
369
+ if use_truncation:
370
+ w1 = torch.lerp(self.w_avg, w1, truncation_psi)
371
+ w2 = torch.lerp(self.w_avg, w2, truncation_psi)
372
+
373
+ # --- 3. Style Mixing ---
374
+ # w_final shape: [batch, num_styles, w_dim]
375
+ w_final = torch.empty(w.shape[0], self.num_styles, self.w_dim, device=w.device)
376
+
377
+ if do_style_mix and random.random() < style_mix_prob:
378
+ # Select a random crossover point
379
+ mix_cutoff = random.randint(1, self.num_styles - 1)
380
+ w_final[:, :mix_cutoff] = w1.unsqueeze(1) # [batch, cutoff, w_dim]
381
+ w_final[:, mix_cutoff:] = w2.unsqueeze(1) # [batch, num_styles-cutoff, w_dim]
382
+ else:
383
+ # No mixing, just use w1
384
+ w_final = w1.unsqueeze(1).repeat(1, self.num_styles, 1)
385
+
386
+ # --- 4. Synthesis ---
387
+ img = self.synthesis(w_final, noise)
388
+ return img