import torch from dataset import CycleGANDataset from torch.utils.data import DataLoader import torch.nn as nn import torch.optim as optim from models import Generator, Discriminator from tqdm import tqdm from torchvision.utils import save_image import albumentations as A from albumentations.pytorch import ToTensorV2 import os # Hyperparameters DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TRAIN_DIR_HORSE = "data/horse2zebra/trainA" TRAIN_DIR_ZEBRA = "data/horse2zebra/trainB" VAL_DIR_HORSE = "data/horse2zebra/testA" VAL_DIR_ZEBRA = "data/horse2zebra/testB" BATCH_SIZE = 1 LEARNING_RATE = 1e-5 LAMBDA_IDENTITY = 0.0 LAMBDA_CYCLE = 10 NUM_WORKERS = 1 NUM_EPOCHS = 10 LOAD_MODEL = False SAVE_MODEL = True CHECKPOINT_GEN_H = "genh.pth.tar" CHECKPOINT_GEN_Z = "genz.pth.tar" CHECKPOINT_CRITIC_H = "critich.pth.tar" CHECKPOINT_CRITIC_Z = "criticz.pth.tar" transforms = A.Compose( [ A.Resize(width=256, height=256), A.HorizontalFlip(p=0.5), A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0), ToTensorV2(), ], additional_targets={"image0": "image"}, ) def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler): H_reals = 0 H_fakes = 0 loop = tqdm(loader, leave=True) for idx, (horse, zebra) in enumerate(loop): horse = horse.to(DEVICE) zebra = zebra.to(DEVICE) # Train Discriminators H and Z with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")): fake_horse = gen_H(zebra) D_H_real = disc_H(horse) D_H_fake = disc_H(fake_horse.detach()) H_reals += D_H_real.mean().item() H_fakes += D_H_fake.mean().item() D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real)) D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake)) D_H_loss = D_H_real_loss + D_H_fake_loss fake_zebra = gen_Z(horse) D_Z_real = disc_Z(zebra) D_Z_fake = disc_Z(fake_zebra.detach()) D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real)) D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake)) D_Z_loss = D_Z_real_loss + D_Z_fake_loss # put it together D_loss = (D_H_loss + D_Z_loss) / 2 opt_disc.zero_grad() d_scaler.scale(D_loss).backward() d_scaler.step(opt_disc) d_scaler.update() # Train Generators H and Z with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")): # adversarial loss for both generators D_H_fake = disc_H(fake_horse) D_Z_fake = disc_Z(fake_zebra) loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake)) loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake)) # cycle loss cycle_zebra = gen_Z(fake_horse) cycle_horse = gen_H(fake_zebra) cycle_zebra_loss = l1(zebra, cycle_zebra) cycle_horse_loss = l1(horse, cycle_horse) # identity loss (remove these for efficiency if you want) # identity_zebra = gen_Z(zebra) # identity_horse = gen_H(horse) # identity_zebra_loss = l1(zebra, identity_zebra) # identity_horse_loss = l1(horse, identity_horse) # add all together G_loss = ( loss_G_Z + loss_G_H + cycle_zebra_loss * LAMBDA_CYCLE + cycle_horse_loss * LAMBDA_CYCLE # + identity_horse_loss * LAMBDA_IDENTITY # + identity_zebra_loss * LAMBDA_IDENTITY ) opt_gen.zero_grad() g_scaler.scale(G_loss).backward() g_scaler.step(opt_gen) g_scaler.update() if idx % 200 == 0: torch.save(gen_H.state_dict(), f"saved_images/genh.pth.tar") torch.save(gen_Z.state_dict(), f"saved_images/genz.pth.tar") save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png") save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png") loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1)) def main(): disc_H = Discriminator(in_channels=3).to(DEVICE) disc_Z = Discriminator(in_channels=3).to(DEVICE) gen_Z = Generator(img_channels=3, num_residuals=9).to(DEVICE) gen_H = Generator(img_channels=3, num_residuals=9).to(DEVICE) opt_disc = optim.Adam( list(disc_H.parameters()) + list(disc_Z.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999), ) opt_gen = optim.Adam( list(gen_Z.parameters()) + list(gen_H.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999), ) L1 = nn.L1Loss() MSE = nn.MSELoss() dataset = CycleGANDataset( root_horse=TRAIN_DIR_HORSE, root_zebra=TRAIN_DIR_ZEBRA, transform=transforms, ) loader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, ) g_scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda")) d_scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda")) os.makedirs("saved_images", exist_ok=True) for epoch in range(NUM_EPOCHS): print(f"Epoch {epoch}/{NUM_EPOCHS}") train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, MSE, d_scaler, g_scaler) if __name__ == "__main__": main()