| import torch
|
| from dataset import CycleGANDataset
|
| from torch.utils.data import DataLoader
|
| import torch.nn as nn
|
| import torch.optim as optim
|
| from models import Generator, Discriminator
|
| from tqdm import tqdm
|
| from torchvision.utils import save_image
|
| import albumentations as A
|
| from albumentations.pytorch import ToTensorV2
|
| import os
|
|
|
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| TRAIN_DIR_HORSE = "data/horse2zebra/trainA"
|
| TRAIN_DIR_ZEBRA = "data/horse2zebra/trainB"
|
| VAL_DIR_HORSE = "data/horse2zebra/testA"
|
| VAL_DIR_ZEBRA = "data/horse2zebra/testB"
|
| BATCH_SIZE = 1
|
| LEARNING_RATE = 1e-5
|
| LAMBDA_IDENTITY = 0.0
|
| LAMBDA_CYCLE = 10
|
| NUM_WORKERS = 1
|
| NUM_EPOCHS = 10
|
| LOAD_MODEL = False
|
| SAVE_MODEL = True
|
| CHECKPOINT_GEN_H = "genh.pth.tar"
|
| CHECKPOINT_GEN_Z = "genz.pth.tar"
|
| CHECKPOINT_CRITIC_H = "critich.pth.tar"
|
| CHECKPOINT_CRITIC_Z = "criticz.pth.tar"
|
|
|
| transforms = A.Compose(
|
| [
|
| A.Resize(width=256, height=256),
|
| A.HorizontalFlip(p=0.5),
|
| A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
|
| ToTensorV2(),
|
| ],
|
| additional_targets={"image0": "image"},
|
| )
|
|
|
| def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
|
| H_reals = 0
|
| H_fakes = 0
|
| loop = tqdm(loader, leave=True)
|
|
|
| for idx, (horse, zebra) in enumerate(loop):
|
| horse = horse.to(DEVICE)
|
| zebra = zebra.to(DEVICE)
|
|
|
|
|
| with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
|
| fake_horse = gen_H(zebra)
|
| D_H_real = disc_H(horse)
|
| D_H_fake = disc_H(fake_horse.detach())
|
| H_reals += D_H_real.mean().item()
|
| H_fakes += D_H_fake.mean().item()
|
| D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
|
| D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
|
| D_H_loss = D_H_real_loss + D_H_fake_loss
|
|
|
| fake_zebra = gen_Z(horse)
|
| D_Z_real = disc_Z(zebra)
|
| D_Z_fake = disc_Z(fake_zebra.detach())
|
| D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
|
| D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
|
| D_Z_loss = D_Z_real_loss + D_Z_fake_loss
|
|
|
|
|
| D_loss = (D_H_loss + D_Z_loss) / 2
|
|
|
| opt_disc.zero_grad()
|
| d_scaler.scale(D_loss).backward()
|
| d_scaler.step(opt_disc)
|
| d_scaler.update()
|
|
|
|
|
| with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
|
|
|
| D_H_fake = disc_H(fake_horse)
|
| D_Z_fake = disc_Z(fake_zebra)
|
| loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
|
| loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))
|
|
|
|
|
| cycle_zebra = gen_Z(fake_horse)
|
| cycle_horse = gen_H(fake_zebra)
|
| cycle_zebra_loss = l1(zebra, cycle_zebra)
|
| cycle_horse_loss = l1(horse, cycle_horse)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| G_loss = (
|
| loss_G_Z
|
| + loss_G_H
|
| + cycle_zebra_loss * LAMBDA_CYCLE
|
| + cycle_horse_loss * LAMBDA_CYCLE
|
|
|
|
|
| )
|
|
|
| opt_gen.zero_grad()
|
| g_scaler.scale(G_loss).backward()
|
| g_scaler.step(opt_gen)
|
| g_scaler.update()
|
|
|
| if idx % 200 == 0:
|
| torch.save(gen_H.state_dict(), f"saved_images/genh.pth.tar")
|
| torch.save(gen_Z.state_dict(), f"saved_images/genz.pth.tar")
|
| save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
|
| save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
|
|
|
| loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))
|
|
|
| def main():
|
| disc_H = Discriminator(in_channels=3).to(DEVICE)
|
| disc_Z = Discriminator(in_channels=3).to(DEVICE)
|
| gen_Z = Generator(img_channels=3, num_residuals=9).to(DEVICE)
|
| gen_H = Generator(img_channels=3, num_residuals=9).to(DEVICE)
|
| opt_disc = optim.Adam(
|
| list(disc_H.parameters()) + list(disc_Z.parameters()),
|
| lr=LEARNING_RATE,
|
| betas=(0.5, 0.999),
|
| )
|
|
|
| opt_gen = optim.Adam(
|
| list(gen_Z.parameters()) + list(gen_H.parameters()),
|
| lr=LEARNING_RATE,
|
| betas=(0.5, 0.999),
|
| )
|
|
|
| L1 = nn.L1Loss()
|
| MSE = nn.MSELoss()
|
|
|
| dataset = CycleGANDataset(
|
| root_horse=TRAIN_DIR_HORSE,
|
| root_zebra=TRAIN_DIR_ZEBRA,
|
| transform=transforms,
|
| )
|
| loader = DataLoader(
|
| dataset,
|
| batch_size=BATCH_SIZE,
|
| shuffle=True,
|
| num_workers=NUM_WORKERS,
|
| pin_memory=True,
|
| )
|
| g_scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))
|
| d_scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))
|
|
|
| os.makedirs("saved_images", exist_ok=True)
|
|
|
| for epoch in range(NUM_EPOCHS):
|
| print(f"Epoch {epoch}/{NUM_EPOCHS}")
|
| train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, MSE, d_scaler, g_scaler)
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|