Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| # The Generator model | |
| class Generator(nn.Module): | |
| def __init__(self, channels, noise_dim=100, embed_dim=1024, embed_out_dim=128): | |
| super(Generator, self).__init__() | |
| self.channels = channels | |
| self.noise_dim = noise_dim | |
| self.embed_dim = embed_dim | |
| self.embed_out_dim = embed_out_dim | |
| # Text embedding layers | |
| self.text_embedding = nn.Sequential( | |
| nn.Linear(self.embed_dim, self.embed_out_dim), | |
| nn.BatchNorm1d(1), | |
| nn.LeakyReLU(0.2, inplace=True) | |
| ) | |
| # Generator architecture | |
| model = [] | |
| model += self._create_layer(self.noise_dim + self.embed_out_dim, 512, 4, stride=1, padding=0) | |
| model += self._create_layer(512, 256, 4, stride=2, padding=1) | |
| model += self._create_layer(256, 128, 4, stride=2, padding=1) | |
| model += self._create_layer(128, 64, 4, stride=2, padding=1) | |
| model += self._create_layer(64, 32, 4, stride=2, padding=1) | |
| model += self._create_layer(32, self.channels, 4, stride=2, padding=1, output=True) | |
| self.model = nn.Sequential(*model) | |
| def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False): | |
| layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)] | |
| if output: | |
| layers.append(nn.Tanh()) # Tanh activation for the output layer | |
| else: | |
| layers += [nn.BatchNorm2d(size_out), nn.ReLU(True)] # Batch normalization and ReLU for other layers | |
| return layers | |
| def forward(self, noise, text): | |
| # Apply text embedding to the input text | |
| text = self.text_embedding(text) | |
| text = text.view(text.shape[0], text.shape[2], 1, 1) # Reshape to match the generator input size | |
| z = torch.cat([text, noise], 1) # Concatenate text embedding with noise | |
| return self.model(z) |