gendigit / src /generator.py
marlonsousa's picture
Upload 16 files
5be6b48 verified
raw
history blame contribute delete
989 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
import torchvision.transforms as transforms
class Gerador(nn.Module):
def __init__(self):
super().__init__()
self.input_dim = 100 + 10 # Ruído (100) + One-hot do número (10)
self.dense0 = nn.Linear(self.input_dim, 32)
self.dense1 = nn.Linear(32, 64)
self.dense2 = nn.Linear(64, 128)
self.dense3 = nn.Linear(128, 784)
self.dropout = nn.Dropout(0.3)
def forward(self, ruido, rotulo):
X = torch.cat((ruido, rotulo), dim=1) # Concatenando ruído e rótulo one-hot
X = self.dropout(F.leaky_relu(self.dense0(X), 0.2))
X = self.dropout(F.leaky_relu(self.dense1(X), 0.2))
X = self.dropout(F.leaky_relu(self.dense2(X), 0.2))
X = torch.tanh(self.dense3(X))
X = X.view(X.shape[0], 28, 28)
return X