| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import torchvision |
| | import torchvision.transforms as transforms |
| | from torch.utils.data import DataLoader |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import streamlit as st |
| |
|
| | |
| | class Generator(nn.Module): |
| | def __init__(self, z_dim, img_dim): |
| | super(Generator, self).__init__() |
| | self.model = nn.Sequential( |
| | nn.Linear(z_dim, 128), |
| | nn.ReLU(), |
| | nn.Linear(128, 256), |
| | nn.ReLU(), |
| | nn.Linear(256, 512), |
| | nn.ReLU(), |
| | nn.Linear(512, img_dim), |
| | nn.Tanh() |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| | |
| | class Discriminator(nn.Module): |
| | def __init__(self, img_dim): |
| | super(Discriminator, self).__init__() |
| | self.model = nn.Sequential( |
| | nn.Linear(img_dim, 512), |
| | nn.LeakyReLU(0.2), |
| | nn.Linear(512, 256), |
| | nn.LeakyReLU(0.2), |
| | nn.Linear(256, 1), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| | |
| | def train_gan(generator, discriminator, dataloader, n_epochs, z_dim, lr): |
| | loss_fn = nn.BCELoss() |
| | gen_optimizer = optim.Adam(generator.parameters(), lr=lr) |
| | disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr) |
| |
|
| | for epoch in range(n_epochs): |
| | for real_imgs, _ in dataloader: |
| | real_imgs = real_imgs.view(-1, 784) |
| | batch_size = real_imgs.size(0) |
| |
|
| | |
| | z = torch.randn(batch_size, z_dim) |
| | fake_imgs = generator(z) |
| |
|
| | real_labels = torch.ones(batch_size, 1) |
| | fake_labels = torch.zeros(batch_size, 1) |
| |
|
| | disc_real_loss = loss_fn(discriminator(real_imgs), real_labels) |
| | disc_fake_loss = loss_fn(discriminator(fake_imgs.detach()), fake_labels) |
| | disc_loss = disc_real_loss + disc_fake_loss |
| |
|
| | disc_optimizer.zero_grad() |
| | disc_loss.backward() |
| | disc_optimizer.step() |
| |
|
| | |
| | output = discriminator(fake_imgs) |
| | gen_loss = loss_fn(output, real_labels) |
| |
|
| | gen_optimizer.zero_grad() |
| | gen_loss.backward() |
| | gen_optimizer.step() |
| |
|
| | st.write(f'Epoch [{epoch+1}/{n_epochs}], Discriminator Loss: {disc_loss.item()}, Generator Loss: {gen_loss.item()}') |
| |
|
| | |
| | def main(): |
| | st.title("GAN Image Generator") |
| |
|
| | z_dim = 100 |
| | img_dim = 784 |
| |
|
| | |
| | generator = Generator(z_dim, img_dim) |
| | discriminator = Discriminator(img_dim) |
| |
|
| | |
| | lr = 0.0002 |
| | batch_size = 32 |
| | n_epochs = 10 |
| |
|
| | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) |
| |
|
| | |
| | mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True) |
| | dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True) |
| |
|
| | |
| | train_gan(generator, discriminator, dataloader, n_epochs, z_dim, lr) |
| |
|
| | |
| | st.header("Generated Images") |
| | z = torch.randn(10, z_dim) |
| | generated_imgs = generator(z).view(-1, 1, 28, 28) |
| |
|
| | fig, axes = plt.subplots(1, 10, figsize=(20, 2)) |
| | for i, ax in enumerate(axes): |
| | ax.imshow(generated_imgs[i].squeeze().detach().numpy(), cmap='gray') |
| | ax.axis('off') |
| | st.pyplot(fig) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|