File size: 14,001 Bytes
bc1ad01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random

# --- Helper Modules ---

class LeakyReLU(nn.Module):
    """
    Custom LeakyReLU implementation to allow for a fixed negative slope
    and in-place operation.
    """
    def __init__(self, negative_slope=0.2, inplace=False):
        super().__init__()
        self.negative_slope = negative_slope
        self.inplace = inplace

    def forward(self, x):
        return F.leaky_relu(x, self.negative_slope, self.inplace)

class PixelNorm(nn.Module):
    """
    Pixel-wise feature vector normalization.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # Epsilon added for numerical stability
        return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)

class ModulatedConv2d(nn.Module):
    """
    This is the core building block of the StyleGAN2 synthesis network.
    It applies style modulation and demodulation.
    """
    def __init__(self, in_channels, out_channels, kernel_size, style_dim, demodulate=True, upsample=False):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.style_dim = style_dim
        self.demodulate = demodulate
        self.upsample = upsample

        # Standard convolution weights
        self.weight = nn.Parameter(
            torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
        )

        # Affine transform (A) from style vector (w)
        self.modulation = nn.Linear(style_dim, in_channels, bias=True)

        # Initialize modulation bias to 1 (identity transform)
        nn.init.constant_(self.modulation.bias, 1.0)

        # Padding for the convolution
        self.padding = (kernel_size - 1) // 2

        # Upsampling filter (if needed)
        if self.upsample:
            # Using a simple bilinear filter
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, x, style):
        # Store initial batch_size and in_channels
        batch_size, in_channels_original, _, _ = x.shape

        # 1. Modulate (Style-based feature scaling)
        # style shape: [batch_size, style_dim]
        # s shape: [batch_size, 1, in_channels, 1, 1]
        s = self.modulation(style).view(batch_size, 1, in_channels_original, 1, 1)

        # Get conv weights and combine with modulation
        # w shape: [batch_size, out_channels, in_channels, k, k]
        w = self.weight * s

        # 2. Demodulate (Normalize weights to prevent scale explosion)
        if self.demodulate:
            # Calculate per-weight normalization factor
            d = torch.rsqrt(torch.sum(w**2, dim=[2, 3, 4], keepdim=True) + 1e-8)
            w = w * d

        # 3. Upsample (if applicable)
        if self.upsample:
            x = self.up(x)

        # Get current height and width *after* potential upsampling
        current_height = x.shape[2]
        current_width = x.shape[3]

        # 4. Convolution
        # Because weights are now per-batch, we need to group convolutions
        # We reshape x and w to use a single grouped convolution operation

        x = x.view(1, batch_size * in_channels_original, current_height, current_width)
        w = w.view(batch_size * self.out_channels, in_channels_original, self.kernel_size, self.kernel_size)

        # padding='same' is not supported for strided/grouped conv, so we use manual padding
        x = F.conv2d(x, w, padding=self.padding, groups=batch_size)

        # Reshape back to [batch_size, out_channels, h, w]
        _, _, new_height, new_width = x.shape
        x = x.view(batch_size, self.out_channels, new_height, new_width)

        return x

class NoiseInjection(nn.Module):
    """
    Adds scaled noise to the feature maps.
    """
    def __init__(self, channels):
        super().__init__()
        # Learned scaling factor for the noise
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x, noise=None):
        if noise is None:
            batch, _, height, width = x.shape
            noise = torch.randn(batch, 1, height, width, device=x.device, dtype=x.dtype)

        return x + self.weight * noise

class ConstantInput(nn.Module):
    """
    A learned constant 4x4 feature map to start the synthesis process.
    """
    def __init__(self, channels, size=4):
        super().__init__()
        self.input = nn.Parameter(torch.randn(1, channels, size, size))

    def forward(self, batch_size):
        return self.input.repeat(batch_size, 1, 1, 1)

class ToRGB(nn.Module):
    """
    Projects feature maps to an RGB image.
    Uses a 1x1 modulated convolution.
    """
    def __init__(self, in_channels, out_channels, style_dim):
        super().__init__()
        # 1x1 convolution
        self.conv = ModulatedConv2d(in_channels, out_channels, 1, style_dim, demodulate=False, upsample=False)
        self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))

    def forward(self, x, style, skip=None):
        x = self.conv(x, style)
        x = x + self.bias

        if skip is not None:
            # Upsample the previous RGB output and add
            skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
            x = x + skip

        return x

# --- Main Generator Components ---

class MappingNetwork(nn.Module):
    """
    Maps the initial latent vector Z to the intermediate style vector W.
    """
    def __init__(self, z_dim, w_dim, num_layers=8):
        super().__init__()
        self.z_dim = z_dim
        self.w_dim = w_dim

        layers = [PixelNorm()]
        for i in range(num_layers):
            layers.extend([
                nn.Linear(z_dim if i == 0 else w_dim, w_dim),
                LeakyReLU(0.2, inplace=True)
            ])

        self.mapping = nn.Sequential(*layers)

    def forward(self, z):
        # z shape: [batch_size, z_dim]
        w = self.mapping(z)
        # w shape: [batch_size, w_dim]
        return w

class SynthesisBlock(nn.Module):
    """
    A single block in the Synthesis Network (e.g., 8x8 -> 16x16).
    Contains upsampling, modulated convolutions, noise, and activation.
    """
    def __init__(self, in_channels, out_channels, style_dim):
        super().__init__()
        # First modulated conv with upsampling
        self.conv1 = ModulatedConv2d(in_channels, out_channels, 3, style_dim, upsample=True)
        self.noise1 = NoiseInjection(out_channels)
        self.activate1 = LeakyReLU(0.2, inplace=True)

        # Second modulated conv
        self.conv2 = ModulatedConv2d(out_channels, out_channels, 3, style_dim, upsample=False)
        self.noise2 = NoiseInjection(out_channels)
        self.activate2 = LeakyReLU(0.2, inplace=True)

    def forward(self, x, w, noise1, noise2):
        x = self.conv1(x, w)
        x = self.noise1(x, noise1)
        x = self.activate1(x)

        x = self.conv2(x, w)
        x = self.noise2(x, noise2)
        x = self.activate2(x)

        return x

class SynthesisNetwork(nn.Module):
    """
    Builds the image from the style vector W.
    """
    def __init__(self, w_dim, img_channels, img_resolution=256, start_res=4, num_blocks=None):
        super().__init__()
        self.w_dim = w_dim
        self.img_channels = img_channels
        self.start_res = start_res

        if num_blocks is None:
            self.num_blocks = int(math.log2(img_resolution) - math.log2(start_res))
            self.img_resolution = img_resolution
        else:
            self.num_blocks = num_blocks
            self.img_resolution = start_res * (2**self.num_blocks)
            print(f"Synthesis network created with {self.num_blocks} blocks, output resolution: {self.img_resolution}x{self.img_resolution}")

        channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256,
            128: 128,
            256: 64,
            512: 32,
            1024: 16,
        }

        self.input = ConstantInput(channels[start_res])

        self.conv1 = ModulatedConv2d(channels[start_res], channels[start_res], 3, w_dim, upsample=False)
        self.noise1 = NoiseInjection(channels[start_res])
        self.activate1 = LeakyReLU(0.2, inplace=True)

        self.to_rgb1 = ToRGB(channels[start_res], img_channels, w_dim)

        self.blocks = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()

        in_c = channels[start_res]

        for i in range(self.num_blocks):
            current_res = start_res * (2**(i+1))
            out_c = channels.get(current_res, 16)
            if current_res > 1024:
                print(f"Warning: Resolution {current_res}x{current_res} not in channel map. Using {out_c} channels.")

            self.blocks.append(SynthesisBlock(in_c, out_c, w_dim))
            self.to_rgbs.append(ToRGB(out_c, img_channels, w_dim))

            in_c = out_c

        # Number of style vectors needed: 1 for initial conv1, 1 for initial to_rgb, and 3 per block (conv1, conv2, to_rgb)
        self.num_styles = self.num_blocks * 3 + 2 # Corrected num_styles

    def forward(self, w, noise=None):
        # w shape: [batch_size, num_styles, w_dim]
        if w.ndim == 2:
            w = w.unsqueeze(1).repeat(1, self.num_styles, 1)

        batch_size = w.shape[0]

        # --- Handle Noise (generate if None) ---
        if noise is None:
            noise_list = []
            # Noise for the initial 4x4 conv (self.conv1)
            noise_list.append(torch.randn(batch_size, 1, self.start_res, self.start_res, device=w.device))

            current_res = self.start_res
            # Iterate through the synthesis blocks to generate noise for each
            for i in range(self.num_blocks):
                current_res *= 2 # This is the resolution *after* the current block's upsampling
                # Noise for the first conv of the current block (after upsampling)
                noise_list.append(torch.randn(batch_size, 1, current_res, current_res, device=w.device))
                # Noise for the second conv of the current block (same resolution)
                noise_list.append(torch.randn(batch_size, 1, current_res, current_res, device=w.device))
            noise = noise_list

        # --- 4x4 Block ---
        x = self.input(batch_size)
        x = self.conv1(x, w[:, 0]) # Style for initial conv1
        x = self.noise1(x, noise[0]) # Noise for initial conv1
        x = self.activate1(x)

        skip = self.to_rgb1(x, w[:, 1]) # Style for initial ToRGB

        # --- Main blocks (8x8 to img_resolution) ---
        current_noise_idx_in_list = 1 # index for noise_list: noise[0] was used above
        current_style_idx_in_w = 2   # index for w: w[:,0] and w[:,1] were used above

        for i, (block, to_rgb) in enumerate(zip(self.blocks, self.to_rgbs)):
            # Styles for this block
            w_block_conv1 = w[:, current_style_idx_in_w]
            w_block_conv2 = w[:, current_style_idx_in_w + 1]
            w_block_to_rgb = w[:, current_style_idx_in_w + 2]

            # Noises for this block
            n_block_conv1 = noise[current_noise_idx_in_list]
            n_block_conv2 = noise[current_noise_idx_in_list + 1]

            x = block(x, w_block_conv1, n_block_conv1, n_block_conv2)

            skip = to_rgb(x, w_block_to_rgb, skip)

            # Increment indices for next block
            current_style_idx_in_w += 3
            current_noise_idx_in_list += 2

        return skip # Final RGB image

class Generator(nn.Module):
    """
    The complete StyleGAN2 Generator.
    Combines the Mapping and Synthesis networks.
    """
    def __init__(self, z_dim, w_dim, img_resolution, img_channels,
                 mapping_layers=8, num_synthesis_blocks=None):
        super().__init__()
        self.z_dim = z_dim
        self.w_dim = w_dim

        self.mapping = MappingNetwork(z_dim, w_dim, mapping_layers)

        self.synthesis = SynthesisNetwork(
            w_dim, img_channels, img_resolution, num_blocks=num_synthesis_blocks
        )

        self.num_styles = self.synthesis.num_styles
        self.img_resolution = self.synthesis.img_resolution # Get final resolution

        # For truncation trick
        self.register_buffer('w_avg', torch.zeros(w_dim))

    def update_w_avg(self, new_w, momentum=0.995):
        """Helper to update the moving average of W"""
        self.w_avg = torch.lerp(new_w.mean(0), self.w_avg, momentum)

    def forward(self, z, truncation_psi=0.7, use_truncation=True,
                style_mix_prob=0.0, noise=None):

        # --- 1. Get W vector(s) ---

        # Check if we're doing style mixing
        do_style_mix = False
        if isinstance(z, list) and len(z) == 2:
            do_style_mix = True
            z1, z2 = z
            w1 = self.mapping(z1) # [batch, w_dim]
            w2 = self.mapping(z2) # [batch, w_dim]
        else:
            w = self.mapping(z) # [batch, w_dim]
            w1 = w
            w2 = w

        # --- 2. Truncation Trick ---
        if use_truncation:
            w1 = torch.lerp(self.w_avg, w1, truncation_psi)
            w2 = torch.lerp(self.w_avg, w2, truncation_psi)

        # --- 3. Style Mixing ---
        # w_final shape: [batch, num_styles, w_dim]
        w_final = torch.empty(w.shape[0], self.num_styles, self.w_dim, device=w.device)

        if do_style_mix and random.random() < style_mix_prob:
            # Select a random crossover point
            mix_cutoff = random.randint(1, self.num_styles - 1)
            w_final[:, :mix_cutoff] = w1.unsqueeze(1) # [batch, cutoff, w_dim]
            w_final[:, mix_cutoff:] = w2.unsqueeze(1) # [batch, num_styles-cutoff, w_dim]
        else:
            # No mixing, just use w1
            w_final = w1.unsqueeze(1).repeat(1, self.num_styles, 1)

        # --- 4. Synthesis ---
        img = self.synthesis(w_final, noise)
        return img