| import torch | |
| from torch import nn | |
| class Discriminator(nn.Module): | |
| def __init__(self, input_dim=2, hidden_dim=256, hidden_layers=6): | |
| super().__init__() | |
| layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()] | |
| for _ in range(hidden_layers - 1): | |
| layers.append(nn.Linear(hidden_dim, hidden_dim)) | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Linear(hidden_dim, 1)) | |
| self.network = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.network(x) | |