Spaces:
Paused
Paused
File size: 2,957 Bytes
fab18b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | """
Generator network for Car GAN.
Architecture: DCGAN-style with transposed convolutions.
Input: latent vector z (batch, latent_dim)
Output: RGB image (batch, 3, image_size, image_size)
"""
import torch
import torch.nn as nn
def _block(in_channels: int, out_channels: int, kernel: int = 4,
stride: int = 2, padding: int = 1) -> nn.Sequential:
"""Upsampling block: ConvTranspose β BatchNorm β ReLU."""
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
class Generator(nn.Module):
"""
DCGAN Generator.
Latent vector β series of upsampling blocks β final tanh activation.
Output values are in [-1, 1], normalize your targets the same way.
"""
def __init__(self, latent_dim: int = 128, features: int = 64, channels: int = 3,
image_size: int = 64):
super().__init__()
self.latent_dim = latent_dim
self.image_size = image_size
# How many upsampling steps to reach image_size from 4x4 base
# 4 β 8 β 16 β 32 β 64 (for image_size=64, n_up=4)
import math
self.n_up = int(math.log2(image_size)) - 2 # e.g. 4 for 64px, 5 for 128px
# Initial projection: z β (features * 2^n_up) Γ 4 Γ 4
init_features = features * (2 ** self.n_up)
self.project = nn.Sequential(
nn.ConvTranspose2d(latent_dim, init_features, 4, 1, 0, bias=False),
nn.BatchNorm2d(init_features),
nn.ReLU(inplace=True),
)
# Upsampling layers
layers = []
in_f = init_features
for i in range(self.n_up - 1):
out_f = in_f // 2
layers.append(_block(in_f, out_f))
in_f = out_f
# Final layer: no BatchNorm, Tanh activation
layers.append(
nn.Sequential(
nn.ConvTranspose2d(in_f, channels, 4, 2, 1, bias=False),
nn.Tanh(),
)
)
self.main = nn.Sequential(*layers)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, (nn.ConvTranspose2d, nn.Conv2d)):
nn.init.normal_(m.weight, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
def forward(self, z: torch.Tensor) -> torch.Tensor:
# z: (batch, latent_dim) β reshape to (batch, latent_dim, 1, 1)
z = z.view(z.size(0), -1, 1, 1)
x = self.project(z)
return self.main(x)
@torch.no_grad()
def generate(self, n: int = 1, device: str = "cpu") -> torch.Tensor:
"""Convenience method: sample n images. Returns tensor in [-1, 1]."""
z = torch.randn(n, self.latent_dim, device=device)
return self(z)
|