| import torch |
| from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential |
|
|
|
|
| class Discriminator(Module): |
|
|
| def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lambda_=10): |
|
|
| alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device) |
| alpha = alpha.repeat(1, pac, real_data.size(1)) |
| alpha = alpha.view(-1, real_data.size(1)) |
|
|
| interpolates = alpha * real_data + ((1 - alpha) * fake_data) |
|
|
| disc_interpolates = self(interpolates) |
|
|
| gradients = torch.autograd.grad( |
| outputs=disc_interpolates, inputs=interpolates, |
| grad_outputs=torch.ones(disc_interpolates.size(), device=device), |
| create_graph=True, retain_graph=True, only_inputs=True |
| )[0] |
|
|
| gradient_penalty = (( |
| gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1 |
| ) ** 2).mean() * lambda_ |
|
|
| return gradient_penalty |
|
|
| def __init__(self, input_dim, dis_dims, pack=10): |
| super(Discriminator, self).__init__() |
| dim = input_dim * pack |
| self.pack = pack |
| self.packdim = dim |
| seq = [] |
| for item in list(dis_dims): |
| seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)] |
| dim = item |
|
|
| seq += [Linear(dim, 1)] |
| self.seq = Sequential(*seq) |
|
|
| def forward(self, input): |
| if input.size()[0] % self.pack != 0: |
| raise ValueError("Batch size should be divisible to {}, but provided {}".format(self.pack, input.size()[0], )) |
| return self.seq(input.view(-1, self.packdim)) |
|
|
|
|
| class Residual(Module): |
| def __init__(self, i, o): |
| super(Residual, self).__init__() |
| self.fc = Linear(i, o) |
| self.bn = BatchNorm1d(o) |
| self.relu = ReLU() |
|
|
| def forward(self, input): |
| out = self.fc(input) |
| out = self.bn(out) |
| out = self.relu(out) |
| return torch.cat([out, input], dim=1) |
|
|
|
|
| class Generator(Module): |
| def __init__(self, embedding_dim, gen_dims, data_dim): |
| super(Generator, self).__init__() |
| dim = embedding_dim |
| seq = [] |
| for item in list(gen_dims): |
| seq += [Residual(dim, item)] |
| dim += item |
| seq.append(Linear(dim, data_dim)) |
| self.seq = Sequential(*seq) |
|
|
| def forward(self, input): |
| data = self.seq(input) |
| return data |
|
|