Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import lightning as pl | |
| import wandb | |
| import itertools | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from src.classifier import Classifier | |
| from src.dataset import CustomDataset | |
| class AttentionGate(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(AttentionGate, self).__init__() | |
| self.conv_gate = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) | |
| self.conv_x = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) | |
| self.softmax = nn.Softmax(dim=1) | |
| def forward(self, x, g): | |
| gate = self.conv_gate(g) | |
| x = self.conv_x(x) | |
| attention = self.softmax(gate) | |
| x_att = x * attention | |
| return x_att | |
| class ResUNetGenerator(nn.Module): | |
| def __init__(self, gf, channels): | |
| super(ResUNetGenerator, self).__init__() | |
| # self.img_shape = img_shape | |
| self.channels = channels | |
| # Downsampling layers | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(channels, gf, kernel_size=4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.GroupNorm(num_groups=1, num_channels=gf) | |
| ) | |
| self.conv2 = nn.Sequential( | |
| nn.Conv2d(gf, gf * 2, kernel_size=4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.GroupNorm(num_groups=1, num_channels=gf * 2) | |
| ) | |
| self.conv3 = nn.Sequential( | |
| nn.Conv2d(gf * 2, gf * 4, kernel_size=4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.GroupNorm(num_groups=1, num_channels=gf * 4) | |
| ) | |
| self.conv4 = nn.Sequential( | |
| nn.Conv2d(gf * 4, gf * 8, kernel_size=4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.GroupNorm(num_groups=1, num_channels=gf * 8) | |
| ) | |
| self.attn_layer = nn.ModuleList([ | |
| AttentionGate(gf * 2**(i), gf * 2**(i+1)) | |
| for i in range(3) | |
| ]) | |
| # Upsampling layers | |
| self.deconv1 = nn.Sequential( | |
| nn.ConvTranspose2d(gf * 8, gf * 4, kernel_size=4, stride=2, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.GroupNorm(num_groups=1, num_channels=gf * 4) | |
| ) | |
| self.deconv2 = nn.Sequential( | |
| nn.ConvTranspose2d(gf * 8, gf * 2, kernel_size=4, stride=2, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.GroupNorm(num_groups=1, num_channels=gf * 2) | |
| ) | |
| self.deconv3 = nn.Sequential( | |
| nn.ConvTranspose2d(gf * 4, gf, kernel_size=4, stride=2, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.GroupNorm(num_groups=1, num_channels=gf) | |
| ) | |
| self.deconv4 = nn.Sequential( | |
| nn.ConvTranspose2d(gf * 2, channels, kernel_size=4, stride=2, padding=1), | |
| nn.Tanh() | |
| ) | |
| def forward(self, x): | |
| # Downsampling | |
| d1 = self.conv1(x) | |
| d2 = self.conv2(d1) | |
| d3 = self.conv3(d2) | |
| d4 = self.conv4(d3) | |
| # Upsampling | |
| u1 = self.deconv1(d4) | |
| u1 = self.attn_layer[2](d3, u1) | |
| u2 = self.deconv2(u1) | |
| u2 = self.attn_layer[1](d2, u2) | |
| u3 = self.deconv3(u2) | |
| u3 = self.attn_layer[0](d1, u3) | |
| output = self.deconv4(u3) | |
| return output | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam(self.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
| return optimizer | |
| class Discriminator(pl.LightningModule): | |
| def __init__(self, df): | |
| super(Discriminator, self).__init__() | |
| self.df = df | |
| # Define the layers for the discriminator | |
| self.conv_layers = nn.ModuleList([nn.Sequential( | |
| nn.Conv2d(1 if i == 0 else df * 2**(i-1), df * 2**i, kernel_size=4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2), | |
| nn.GroupNorm(8, df * 2**i)) for i in range(4)]) | |
| self.final_conv = nn.Conv2d(df * 8, 1, kernel_size=4, stride=1, padding=1) | |
| def forward(self, x): | |
| out = x | |
| for conv_layer in self.conv_layers: | |
| out = conv_layer(out) | |
| validity = self.final_conv(out) | |
| return validity | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam(self.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
| return optimizer | |
| class CycleGAN(pl.LightningModule): | |
| def __init__(self, train_dir, val_dir, test_dataloader, classifier_path, checkpoint_dir, image_size=512, batch_size=4, channels=1, gf=32, df=64, lambda_cycle=10.0, lambda_id=0.1, classifier_weight=1): | |
| super(CycleGAN, self).__init__() | |
| self.image_size = image_size | |
| self.batch_size = batch_size | |
| self.channels = channels | |
| self.gf = gf | |
| self.df = df | |
| self.lambda_cycle = lambda_cycle | |
| self.lambda_id = lambda_id * lambda_cycle | |
| self.classifier_path = classifier_path | |
| self.classifier_weight = classifier_weight | |
| self.lowest_val_loss = float('inf') | |
| self.validation_step_outputs = [] | |
| self.train_dir = train_dir | |
| self.val_dir = val_dir | |
| self.test_dataloader = test_dataloader | |
| self.checkpoint_dir = checkpoint_dir | |
| # Initialize the generator, discriminator, and classifier models | |
| self.g_NP = ResUNetGenerator(gf, channels=self.channels) | |
| self.g_PN = ResUNetGenerator(gf, channels=self.channels) | |
| self.d_N = Discriminator(df) | |
| self.d_P = Discriminator(df) | |
| self.automatic_optimization = False | |
| self.classifier = Classifier() | |
| checkpoint = torch.load(classifier_path) | |
| self.classifier.load_state_dict(checkpoint['state_dict']) | |
| self.classifier.eval() | |
| self.freeze_classifier() | |
| def freeze_classifier(self): | |
| print("freezing Classifier...") | |
| for p in self.classifier.parameters() : | |
| p.requires_grad = False | |
| def generator_training_step(self, img_N, img_P, opt): | |
| self.toggle_optimizer(opt) | |
| # Translate images to the other domain | |
| fake_P = self.g_NP(img_N) | |
| fake_N = self.g_PN(img_P) | |
| # Translate images back to original domain | |
| reconstr_N = self.g_PN(fake_P) | |
| reconstr_P = self.g_NP(fake_N) | |
| # Identity mapping of images | |
| img_N_id = self.g_PN(img_N) | |
| img_P_id = self.g_NP(img_P) | |
| # Discriminators determine validity of translated images | |
| valid_N = self.d_N(fake_N) | |
| valid_P = self.d_P(fake_P) | |
| class_N_loss = self.classifier(fake_N) | |
| class_P_loss = self.classifier(fake_P) | |
| # Adversarial loss | |
| valid_target = torch.ones_like(valid_N) | |
| adversarial_loss = nn.MSELoss()(valid_N, valid_target) + nn.MSELoss()(valid_P, valid_target) | |
| # Cycle consistency loss | |
| cycle_loss = nn.L1Loss()(reconstr_N, img_N) + nn.L1Loss()(reconstr_P, img_P) | |
| # Identity loss | |
| identity_loss = nn.L1Loss()(img_N_id, img_N) + nn.L1Loss()(img_P_id, img_P) | |
| # Classifier loss | |
| class_loss = nn.MSELoss()(class_N_loss, torch.ones_like(class_N_loss)) + nn.MSELoss()(class_P_loss, torch.zeros_like(class_P_loss)) | |
| # Total generator loss | |
| total_loss = adversarial_loss + self.lambda_cycle * cycle_loss + self.lambda_id * identity_loss + self.classifier_weight * class_loss | |
| self.log('adversarial_loss', adversarial_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
| self.log('reconstruction_loss', cycle_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
| self.log('identity_loss', identity_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
| self.log('class_loss', class_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
| self.log('generator_loss', total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
| opt.zero_grad() | |
| self.manual_backward(total_loss) | |
| opt.step() | |
| self.untoggle_optimizer(opt) | |
| return total_loss, adversarial_loss, cycle_loss | |
| def discriminator_training_step(self, img_N, img_P, opt): | |
| # Pass real images through discriminator D_N | |
| self.toggle_optimizer(opt) | |
| pred_real_N = self.d_N(img_N) | |
| mse_real_N = nn.MSELoss()(pred_real_N, torch.ones_like(pred_real_N)) | |
| fake_P = self.g_PN(img_P) | |
| pred_fake_N = self.d_N(fake_P) | |
| mse_fake_N = nn.MSELoss()(pred_fake_N, torch.zeros_like(pred_fake_N)) | |
| pred_real_P = self.d_P(img_P) | |
| mse_real_P = nn.MSELoss()(pred_real_P, torch.ones_like(pred_real_P)) | |
| fake_N = self.g_NP(img_N) | |
| pred_fake_P = self.d_P(fake_N) | |
| mse_fake_P = nn.MSELoss()(pred_fake_P, torch.zeros_like(pred_fake_P)) | |
| # Compute total discriminator loss | |
| dis_loss = 0.5 * (mse_real_N + mse_fake_N + mse_real_P + mse_fake_P) | |
| opt.zero_grad() | |
| self.manual_backward(mse_fake_P) | |
| opt.step() | |
| self.untoggle_optimizer(opt) | |
| self.log('mse_fake_N', mse_fake_N, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
| self.log('mse_fake_P', mse_fake_P, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
| self.log('discriminator_loss', dis_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
| return dis_loss, mse_fake_N, mse_fake_P | |
| def training_step(self, batch, batch_idx): | |
| img_N, img_P = batch | |
| optD, optG = self.optimizers() | |
| total_loss, adversarial_loss, cycle_loss = self.generator_training_step(img_N, img_P, optG) | |
| dis_loss, mse_fake_N, mse_fake_P = self.discriminator_training_step(img_N, img_P, optD) | |
| return {"generator_loss": total_loss, "adversarial_loss": adversarial_loss, "reconstruction_loss": cycle_loss, "discriminator_loss": dis_loss, "mse_fake_N": mse_fake_N, "mse_fake_P": mse_fake_P} | |
| def validation_step(self, batch, batch_idx): | |
| img_N, img_P = batch | |
| # Translate images to the other domain | |
| fake_P = self.g_NP(img_N) | |
| fake_N = self.g_PN(img_P) | |
| # Translate images back to original domain | |
| reconstr_N = self.g_PN(fake_P) | |
| reconstr_P = self.g_NP(fake_N) | |
| # Identity mapping of images | |
| img_N_id = self.g_PN(img_N) | |
| img_P_id = self.g_NP(img_P) | |
| # Discriminators determine validity of translated images | |
| valid_N = self.d_N(fake_N) | |
| valid_P = self.d_P(fake_P) | |
| class_N_loss = self.classifier(fake_N) | |
| class_P_loss = self.classifier(fake_P) | |
| # Adversarial loss | |
| valid_target = torch.ones_like(valid_N) | |
| adversarial_loss = nn.MSELoss()(valid_N, valid_target) + nn.MSELoss()(valid_P, valid_target) | |
| # Cycle consistency loss | |
| cycle_loss = nn.L1Loss()(reconstr_N, img_N) + nn.L1Loss()(reconstr_P, img_P) | |
| # Identity loss | |
| identity_loss = nn.L1Loss()(img_N_id, img_N) + nn.L1Loss()(img_P_id, img_P) | |
| # Classifier loss | |
| class_loss = nn.MSELoss()(class_N_loss, torch.ones_like(class_N_loss)) + nn.MSELoss()(class_P_loss, torch.zeros_like(class_P_loss)) | |
| # Total generator loss | |
| total_loss = adversarial_loss + self.lambda_cycle * cycle_loss + self.lambda_id * identity_loss + self.classifier_weight * class_loss | |
| self.validation_step_outputs.append(total_loss) | |
| self.log('val_adversarial_loss', adversarial_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) | |
| self.log('val_cycle_loss', cycle_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) | |
| self.log('val_identity_loss', identity_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) | |
| self.log('val_class_loss', class_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) | |
| self.log('val_generator_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) | |
| return total_loss | |
| def on_validation_end(self): | |
| # Calculate average validation loss | |
| avg_val_loss = torch.stack(self.validation_step_outputs).mean() | |
| # Check if current validation loss is lower than the lowest recorded validation loss | |
| if avg_val_loss < self.lowest_val_loss: | |
| # Update lowest validation loss and corresponding epoch | |
| self.lowest_val_loss = avg_val_loss | |
| # Save the generators' state dictionaries | |
| torch.save(self.g_NP.state_dict(), f"{self.checkpoint_dir}/g_NP_best.ckpt") | |
| torch.save(self.g_PN.state_dict(), f"{self.checkpoint_dir}/g_PN_best.ckpt") | |
| print(f"Model saved! loss reduced to {self.lowest_val_loss}") | |
| def configure_optimizers(self): | |
| optG = torch.optim.Adam(itertools.chain(self.g_NP.parameters(), self.g_PN.parameters()),lr=2e-4, betas=(0.5, 0.999)) | |
| optD = torch.optim.Adam(itertools.chain(self.d_N.parameters(), self.d_P.parameters()),lr=2e-4, betas=(0.5, 0.999)) | |
| gamma = lambda epoch: 1 - max(0, epoch + 1 - 100) / 101 | |
| schD = LambdaLR(optD, lr_lambda=gamma) | |
| # Optimizer= [optD, optG] | |
| return optD, optG | |
| def train_dataloader(self): | |
| root_dir = self.train_dir | |
| train_N = "0" | |
| train_P = "1" | |
| img_res = (self.image_size, self.image_size) | |
| dataset = CustomDataset(root_dir=root_dir, train_N=train_N, train_P=train_P, img_res=img_res) | |
| # Set up DataLoader for parallel processing and GPU acceleration | |
| dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True) | |
| return dataloader | |
| def val_dataloader(self): | |
| root_dir = self.val_dir | |
| train_N = "0" | |
| train_P = "1" | |
| img_res = (self.image_size, self.image_size) | |
| dataset = CustomDataset(root_dir=root_dir, train_N=train_N, train_P=train_P, img_res=img_res) | |
| # Set up DataLoader for parallel processing and GPU acceleration | |
| dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True) | |
| return dataloader | |
| def on_train_batch_end(self, outputs, batch, batch_idx): | |
| if batch_idx % 100 == 0: | |
| # Get a random batch from the test dataloader | |
| batch = next(iter(self.test_dataloader)) | |
| img_N, img_P = batch | |
| # Pick a random image from the batch | |
| idx = np.random.randint(img_N.size(0)) | |
| img_N = img_N[idx].unsqueeze(0).to('cuda') | |
| img_P = img_P[idx].unsqueeze(0).to('cuda') | |
| # Translate images to the other domain | |
| fake_P = self.g_NP(img_N) | |
| fake_N = self.g_PN(img_P) | |
| # Translate images back to original domain | |
| reconstr_N = self.g_PN(fake_P) | |
| reconstr_P = self.g_NP(fake_N) | |
| # Plot the images | |
| fig, axes = plt.subplots(2, 3, figsize=(15, 10)) | |
| # Plot real N, translated P, and reconstructed N | |
| axes[0, 0].imshow(img_N.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') | |
| axes[0, 0].set_title("Real N") | |
| axes[0, 0].axis('off') | |
| axes[0, 1].imshow(fake_P.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') | |
| axes[0, 1].set_title("Translated P") | |
| axes[0, 1].axis('off') | |
| axes[0, 2].imshow(reconstr_N.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') | |
| axes[0, 2].set_title("Reconstructed N") | |
| axes[0, 2].axis('off') | |
| # Plot real P, translated N, and reconstructed P | |
| axes[1, 0].imshow(img_P.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') | |
| axes[1, 0].set_title("Real P") | |
| axes[1, 0].axis('off') | |
| axes[1, 1].imshow(fake_N.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') | |
| axes[1, 1].set_title("Translated N") | |
| axes[1, 1].axis('off') | |
| axes[1, 2].imshow(reconstr_P.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') | |
| axes[1, 2].set_title("Reconstructed P") | |
| axes[1, 2].axis('off') | |
| # Log the figure in WandB | |
| wandb.log({"test_images": wandb.Image(fig)}) | |
| plt.close(fig) |