pixel / trainer.py
chatpig's picture
Upload 2 files
409857b verified
raw
history blame
9.7 kB
import os
import glob
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from safetensors.torch import save_file
def plot_multiple_images(images, n_cols, epoch):
n_cols = n_cols or len(images)
n_rows = (len(images) - 1) // n_cols + 1
# Convert from CHW to HWC format for plotting
images = images.permute(0, 2, 3, 1).cpu().numpy()
if images.shape[-1] == 1:
images = np.squeeze(images, axis=-1)
plt.figure(figsize=(n_cols, n_rows))
for index, image in enumerate(images):
image = ((image + 1) / 2) # scale back
plt.subplot(n_rows, n_cols, index + 1)
plt.imshow(image, cmap="binary")
plt.axis("off")
plt.savefig(f'{args.images_output_path}epoch_{epoch}.png')
plt.close() # Close the figure to free memory
class ImageDataset(Dataset):
def __init__(self, file_paths, image_size, image_channels):
self.file_paths = file_paths
self.image_size = image_size
self.image_channels = image_channels
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize([0.5] * image_channels, [0.5] * image_channels) # Scale to [-1, 1]
])
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
img_path = self.file_paths[idx]
image = Image.open(img_path).convert('RGBA' if self.image_channels == 4 else 'RGB')
image = self.transform(image)
return image
def get_dataloader(inputs, batch_size, image_size, image_channels):
if type(inputs) == dict:
file_paths = inputs["paths"].tolist()
else:
file_paths = glob.glob(f"{inputs}/*")
dataset = ImageDataset(file_paths, image_size, image_channels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)
return dataloader
def discriminator_loss(real_output, fake_output, criterion):
real_loss = criterion(real_output, torch.ones_like(real_output))
fake_loss = criterion(fake_output, torch.zeros_like(fake_output))
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(fake_output, criterion):
return criterion(fake_output, torch.ones_like(fake_output))
def train_step(images, batch_size, codings_size, generator, discriminator, gen_optimizer, disc_optimizer, criterion, device):
noise = torch.randn(batch_size, codings_size, device=device)
# Train Discriminator
disc_optimizer.zero_grad()
generated_images = generator(noise)
real_output = discriminator(images)
fake_output = discriminator(generated_images.detach())
disc_loss = discriminator_loss(real_output, fake_output, criterion)
disc_loss.backward()
disc_optimizer.step()
# Train Generator
gen_optimizer.zero_grad()
fake_output = discriminator(generated_images)
gen_loss = generator_loss(fake_output, criterion)
gen_loss.backward()
gen_optimizer.step()
return gen_loss.item(), disc_loss.item()
def train(dataloader, epochs, batch_size, codings_size, generator, discriminator, gen_optimizer, disc_optimizer, criterion, device):
generator.train()
discriminator.train()
for epoch in range(epochs):
for image_batch in dataloader:
image_batch = image_batch.to(device)
gen_loss, disc_loss = train_step(image_batch, batch_size, codings_size, generator, discriminator,
gen_optimizer, disc_optimizer, criterion, device)
print(f"Epoch {epoch+1}/{epochs} - Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}")
if args.images_output_path:
generator.eval()
with torch.no_grad():
noise = torch.randn(batch_size, codings_size, device=device)
display_images = generator(noise)
plot_multiple_images(display_images, 8, epoch)
generator.train()
class Generator(nn.Module):
def __init__(self, codings_size, image_size, image_channels):
super(Generator, self).__init__()
self.fc = nn.Linear(codings_size, 6 * 6 * 256, bias=False)
self.bn1 = nn.BatchNorm1d(6 * 6 * 256)
self.leaky_relu = nn.LeakyReLU(0.2)
self.conv_transpose1 = nn.ConvTranspose2d(256, 128, kernel_size=5, stride=1, padding=2, bias=False)
self.bn2 = nn.BatchNorm2d(128)
self.conv_transpose2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(64)
self.conv_transpose3 = nn.ConvTranspose2d(64, image_channels, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False)
self.tanh = nn.Tanh()
def forward(self, x):
x = self.fc(x)
x = self.bn1(x)
x = self.leaky_relu(x)
x = x.view(-1, 256, 6, 6)
x = self.conv_transpose1(x)
x = self.bn2(x)
x = self.leaky_relu(x)
x = self.conv_transpose2(x)
x = self.bn3(x)
x = self.leaky_relu(x)
x = self.conv_transpose3(x)
x = self.tanh(x)
return x
class Discriminator(nn.Module):
def __init__(self, image_size, image_channels):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=4, stride=2, padding=1)
self.leaky_relu1 = nn.LeakyReLU(0.2)
self.dropout1 = nn.Dropout(0.4)
self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
self.leaky_relu2 = nn.LeakyReLU(0.2)
self.dropout2 = nn.Dropout(0.4)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.leaky_relu3 = nn.LeakyReLU(0.2)
self.dropout3 = nn.Dropout(0.4)
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(256, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.conv1(x)
x = self.leaky_relu1(x)
x = self.dropout1(x)
x = self.conv2(x)
x = self.leaky_relu2(x)
x = self.dropout2(x)
x = self.conv3(x)
x = self.leaky_relu3(x)
x = self.dropout3(x)
x = self.global_avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
x = self.sigmoid(x)
return x
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", default="./data/attributes.csv", help="Path to dataset (attributes.csv)")
parser.add_argument("--images_path", default="./data/images/", help="Path to images")
parser.add_argument("--model_output_path", default="./models/", help="Path to output the generator model")
parser.add_argument("--images_output_path", default="./gen_images/", help="Path to output generated images during training")
parser.add_argument("--codings_size", type=int, default=100, help="Size of the latent z vector")
parser.add_argument("--image_size", type=int, default=24, help="Images size")
parser.add_argument("--image_channels", type=int, default=4, help="Images channels")
parser.add_argument("--batch_size", type=int, default=16, help="Input batch size")
parser.add_argument("--epochs", type=int, default=50, help="Number of epochs")
args = parser.parse_args()
print(args)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if args.images_output_path and (os.path.exists(args.images_output_path) == False):
print(f"Saving generated images during training at: {args.images_output_path}")
os.mkdir(args.images_output_path)
print("Loading the dataset...")
df = pd.read_csv(args.data_path)
df.id = df.id.apply(lambda x: f"{args.images_path}punk{x:03d}.png")
print("Creating PyTorch DataLoader...")
dataloader = get_dataloader({"paths": df.id}, args.batch_size, args.image_size, args.image_channels)
generator = Generator(args.codings_size, args.image_size, args.image_channels).to(device)
print("Generator architecture:")
print(generator)
discriminator = Discriminator(args.image_size, args.image_channels).to(device)
print("Discriminator architecture:")
print(discriminator)
gen_optimizer = optim.RMSprop(generator.parameters(), lr=0.001)
disc_optimizer = optim.RMSprop(discriminator.parameters(), lr=0.001)
criterion = nn.BCELoss()
print("Training model...")
train(dataloader, args.epochs, args.batch_size, args.codings_size, generator, discriminator,
gen_optimizer, disc_optimizer, criterion, device)
print(f"Saving model at: {args.model_output_path}...")
os.makedirs(args.model_output_path, exist_ok=True)
model_path = args.model_output_path if args.model_output_path.endswith('.safetensors') else os.path.join(args.model_output_path, 'generator_model.safetensors')
# Save the generator model in safetensors format
# Metadata is stored as strings in safetensors
metadata = {
'codings_size': str(args.codings_size),
'image_size': str(args.image_size),
'image_channels': str(args.image_channels)
}
save_file(generator.state_dict(), model_path, metadata=metadata)
print(f"Model saved to: {model_path}")