gendigit / src /discriminator.py
marlonsousa's picture
Upload 16 files
5be6b48 verified
raw
history blame contribute delete
974 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
import torchvision.transforms as transforms
class Discriminador(nn.Module):
def __init__(self):
super().__init__()
self.input_dim = 784 + 10 # Imagem (784) + One-hot do número (10)
self.dense0 = nn.Linear(self.input_dim, 128)
self.dense1 = nn.Linear(128, 64)
self.dense2 = nn.Linear(64, 32)
self.dense3 = nn.Linear(32, 1)
self.dropout = nn.Dropout(0.3)
def forward(self, X, rotulo):
X = X.view(X.shape[0], 28 * 28)
X = torch.cat((X, rotulo), dim=1) # Concatenando imagem e rótulo one-hot
X = self.dropout(F.leaky_relu(self.dense0(X), 0.2))
X = self.dropout(F.leaky_relu(self.dense1(X), 0.2))
X = self.dropout(F.leaky_relu(self.dense2(X), 0.2))
X = self.dense3(X)
return X