| import os |
| import glob |
| import argparse |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| from safetensors.torch import save_file |
|
|
| def plot_multiple_images(images, n_cols, epoch): |
| n_cols = n_cols or len(images) |
| n_rows = (len(images) - 1) // n_cols + 1 |
| |
| images = images.permute(0, 2, 3, 1).cpu().numpy() |
| if images.shape[-1] == 1: |
| images = np.squeeze(images, axis=-1) |
| plt.figure(figsize=(n_cols, n_rows)) |
| for index, image in enumerate(images): |
| image = ((image + 1) / 2) |
| plt.subplot(n_rows, n_cols, index + 1) |
| plt.imshow(image, cmap="binary") |
| plt.axis("off") |
| plt.savefig(f'{args.images_output_path}epoch_{epoch}.png') |
| plt.close() |
|
|
| class ImageDataset(Dataset): |
| def __init__(self, file_paths, image_size, image_channels): |
| self.file_paths = file_paths |
| self.image_size = image_size |
| self.image_channels = image_channels |
| self.transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5] * image_channels, [0.5] * image_channels) |
| ]) |
| |
| def __len__(self): |
| return len(self.file_paths) |
| |
| def __getitem__(self, idx): |
| img_path = self.file_paths[idx] |
| image = Image.open(img_path).convert('RGBA' if self.image_channels == 4 else 'RGB') |
| image = self.transform(image) |
| return image |
|
|
| def get_dataloader(inputs, batch_size, image_size, image_channels): |
| if type(inputs) == dict: |
| file_paths = inputs["paths"].tolist() |
| else: |
| file_paths = glob.glob(f"{inputs}/*") |
| |
| dataset = ImageDataset(file_paths, image_size, image_channels) |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2) |
| return dataloader |
|
|
| def discriminator_loss(real_output, fake_output, criterion): |
| real_loss = criterion(real_output, torch.ones_like(real_output)) |
| fake_loss = criterion(fake_output, torch.zeros_like(fake_output)) |
| total_loss = real_loss + fake_loss |
| return total_loss |
|
|
| def generator_loss(fake_output, criterion): |
| return criterion(fake_output, torch.ones_like(fake_output)) |
|
|
| def train_step(images, batch_size, codings_size, generator, discriminator, gen_optimizer, disc_optimizer, criterion, device): |
| noise = torch.randn(batch_size, codings_size, device=device) |
| |
| |
| disc_optimizer.zero_grad() |
| generated_images = generator(noise) |
| real_output = discriminator(images) |
| fake_output = discriminator(generated_images.detach()) |
| disc_loss = discriminator_loss(real_output, fake_output, criterion) |
| disc_loss.backward() |
| disc_optimizer.step() |
| |
| |
| gen_optimizer.zero_grad() |
| fake_output = discriminator(generated_images) |
| gen_loss = generator_loss(fake_output, criterion) |
| gen_loss.backward() |
| gen_optimizer.step() |
| |
| return gen_loss.item(), disc_loss.item() |
|
|
| def train(dataloader, epochs, batch_size, codings_size, generator, discriminator, gen_optimizer, disc_optimizer, criterion, device): |
| generator.train() |
| discriminator.train() |
| |
| for epoch in range(epochs): |
| for image_batch in dataloader: |
| image_batch = image_batch.to(device) |
| gen_loss, disc_loss = train_step(image_batch, batch_size, codings_size, generator, discriminator, |
| gen_optimizer, disc_optimizer, criterion, device) |
| |
| print(f"Epoch {epoch+1}/{epochs} - Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}") |
| if args.images_output_path: |
| generator.eval() |
| with torch.no_grad(): |
| noise = torch.randn(batch_size, codings_size, device=device) |
| display_images = generator(noise) |
| plot_multiple_images(display_images, 8, epoch) |
| generator.train() |
|
|
| class Generator(nn.Module): |
| def __init__(self, codings_size, image_size, image_channels): |
| super(Generator, self).__init__() |
| |
| self.fc = nn.Linear(codings_size, 6 * 6 * 256, bias=False) |
| self.bn1 = nn.BatchNorm1d(6 * 6 * 256) |
| self.leaky_relu = nn.LeakyReLU(0.2) |
| |
| self.conv_transpose1 = nn.ConvTranspose2d(256, 128, kernel_size=5, stride=1, padding=2, bias=False) |
| self.bn2 = nn.BatchNorm2d(128) |
| |
| self.conv_transpose2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False) |
| self.bn3 = nn.BatchNorm2d(64) |
| |
| self.conv_transpose3 = nn.ConvTranspose2d(64, image_channels, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False) |
| self.tanh = nn.Tanh() |
| |
| def forward(self, x): |
| x = self.fc(x) |
| x = self.bn1(x) |
| x = self.leaky_relu(x) |
| x = x.view(-1, 256, 6, 6) |
| |
| x = self.conv_transpose1(x) |
| x = self.bn2(x) |
| x = self.leaky_relu(x) |
| |
| x = self.conv_transpose2(x) |
| x = self.bn3(x) |
| x = self.leaky_relu(x) |
| |
| x = self.conv_transpose3(x) |
| x = self.tanh(x) |
| |
| return x |
|
|
| class Discriminator(nn.Module): |
| def __init__(self, image_size, image_channels): |
| super(Discriminator, self).__init__() |
| |
| self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=4, stride=2, padding=1) |
| self.leaky_relu1 = nn.LeakyReLU(0.2) |
| self.dropout1 = nn.Dropout(0.4) |
| |
| self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) |
| self.leaky_relu2 = nn.LeakyReLU(0.2) |
| self.dropout2 = nn.Dropout(0.4) |
| |
| self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) |
| self.leaky_relu3 = nn.LeakyReLU(0.2) |
| self.dropout3 = nn.Dropout(0.4) |
| |
| self.global_avg_pool = nn.AdaptiveAvgPool2d(1) |
| self.fc = nn.Linear(256, 1) |
| self.sigmoid = nn.Sigmoid() |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.leaky_relu1(x) |
| x = self.dropout1(x) |
| |
| x = self.conv2(x) |
| x = self.leaky_relu2(x) |
| x = self.dropout2(x) |
| |
| x = self.conv3(x) |
| x = self.leaky_relu3(x) |
| x = self.dropout3(x) |
| |
| x = self.global_avg_pool(x) |
| x = x.view(x.size(0), -1) |
| x = self.fc(x) |
| x = self.sigmoid(x) |
| |
| return x |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data_path", default="./data/attributes.csv", help="Path to dataset (attributes.csv)") |
| parser.add_argument("--images_path", default="./data/images/", help="Path to images") |
| parser.add_argument("--model_output_path", default="./models/", help="Path to output the generator model") |
| parser.add_argument("--images_output_path", default="./gen_images/", help="Path to output generated images during training") |
| parser.add_argument("--codings_size", type=int, default=100, help="Size of the latent z vector") |
| parser.add_argument("--image_size", type=int, default=24, help="Images size") |
| parser.add_argument("--image_channels", type=int, default=4, help="Images channels") |
| parser.add_argument("--batch_size", type=int, default=16, help="Input batch size") |
| parser.add_argument("--epochs", type=int, default=50, help="Number of epochs") |
| args = parser.parse_args() |
| print(args) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| if args.images_output_path and (os.path.exists(args.images_output_path) == False): |
| print(f"Saving generated images during training at: {args.images_output_path}") |
| os.mkdir(args.images_output_path) |
|
|
| print("Loading the dataset...") |
| df = pd.read_csv(args.data_path) |
| df.id = df.id.apply(lambda x: f"{args.images_path}punk{x:03d}.png") |
|
|
| print("Creating PyTorch DataLoader...") |
| dataloader = get_dataloader({"paths": df.id}, args.batch_size, args.image_size, args.image_channels) |
| |
| generator = Generator(args.codings_size, args.image_size, args.image_channels).to(device) |
| print("Generator architecture:") |
| print(generator) |
|
|
| discriminator = Discriminator(args.image_size, args.image_channels).to(device) |
| print("Discriminator architecture:") |
| print(discriminator) |
|
|
| gen_optimizer = optim.RMSprop(generator.parameters(), lr=0.001) |
| disc_optimizer = optim.RMSprop(discriminator.parameters(), lr=0.001) |
| criterion = nn.BCELoss() |
|
|
| print("Training model...") |
| train(dataloader, args.epochs, args.batch_size, args.codings_size, generator, discriminator, |
| gen_optimizer, disc_optimizer, criterion, device) |
|
|
| print(f"Saving model at: {args.model_output_path}...") |
| os.makedirs(args.model_output_path, exist_ok=True) |
| model_path = args.model_output_path if args.model_output_path.endswith('.safetensors') else os.path.join(args.model_output_path, 'generator_model.safetensors') |
| |
| |
| |
| metadata = { |
| 'codings_size': str(args.codings_size), |
| 'image_size': str(args.image_size), |
| 'image_channels': str(args.image_channels) |
| } |
| save_file(generator.state_dict(), model_path, metadata=metadata) |
| print(f"Model saved to: {model_path}") |