File size: 2,957 Bytes
fab18b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Generator network for Car GAN.
Architecture: DCGAN-style with transposed convolutions.
Input:  latent vector z  (batch, latent_dim)
Output: RGB image        (batch, 3, image_size, image_size)
"""

import torch
import torch.nn as nn


def _block(in_channels: int, out_channels: int, kernel: int = 4,
           stride: int = 2, padding: int = 1) -> nn.Sequential:
    """Upsampling block: ConvTranspose β†’ BatchNorm β†’ ReLU."""
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel, stride, padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )


class Generator(nn.Module):
    """
    DCGAN Generator.

    Latent vector β†’ series of upsampling blocks β†’ final tanh activation.
    Output values are in [-1, 1], normalize your targets the same way.
    """

    def __init__(self, latent_dim: int = 128, features: int = 64, channels: int = 3,
                 image_size: int = 64):
        super().__init__()
        self.latent_dim = latent_dim
        self.image_size = image_size

        # How many upsampling steps to reach image_size from 4x4 base
        # 4 β†’ 8 β†’ 16 β†’ 32 β†’ 64  (for image_size=64, n_up=4)
        import math
        self.n_up = int(math.log2(image_size)) - 2  # e.g. 4 for 64px, 5 for 128px

        # Initial projection: z β†’ (features * 2^n_up) Γ— 4 Γ— 4
        init_features = features * (2 ** self.n_up)
        self.project = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, init_features, 4, 1, 0, bias=False),
            nn.BatchNorm2d(init_features),
            nn.ReLU(inplace=True),
        )

        # Upsampling layers
        layers = []
        in_f = init_features
        for i in range(self.n_up - 1):
            out_f = in_f // 2
            layers.append(_block(in_f, out_f))
            in_f = out_f

        # Final layer: no BatchNorm, Tanh activation
        layers.append(
            nn.Sequential(
                nn.ConvTranspose2d(in_f, channels, 4, 2, 1, bias=False),
                nn.Tanh(),
            )
        )

        self.main = nn.Sequential(*layers)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.ConvTranspose2d, nn.Conv2d)):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)
                nn.init.zeros_(m.bias)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        # z: (batch, latent_dim) β†’ reshape to (batch, latent_dim, 1, 1)
        z = z.view(z.size(0), -1, 1, 1)
        x = self.project(z)
        return self.main(x)

    @torch.no_grad()
    def generate(self, n: int = 1, device: str = "cpu") -> torch.Tensor:
        """Convenience method: sample n images. Returns tensor in [-1, 1]."""
        z = torch.randn(n, self.latent_dim, device=device)
        return self(z)