import torch from torch import nn from torch.nn import functional as F import pytorch_lightning as pl class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class UnFlatten(nn.Module): def forward(self, x): # Adjusted to match the output of the encoder return x.view(x.size(0), 256, 16, 16) # Adjusted dimensions class VAE(pl.LightningModule): def __init__(self, image_channels=1, h_dim=16*16*256, z_dim=12, lr=1e-3, beta=1, use_classification_loss=True, num_classes=None, loss_type="standard", class_weights=None, device=None): super(VAE, self).__init__() self.lr = lr self.beta = beta self.use_classification_loss = use_classification_loss # Adjusted encoder for 512x512 input self.encoder = nn.Sequential( nn.Conv2d(image_channels, 32, kernel_size=5, stride=2, padding=2), # 256x256 nn.BatchNorm2d(32), nn.LeakyReLU(), nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2), # 128x128 nn.BatchNorm2d(64), nn.LeakyReLU(), nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 64x64 nn.BatchNorm2d(128), nn.LeakyReLU(), nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 32x32 nn.BatchNorm2d(256), nn.LeakyReLU(), nn.Conv2d(256, 256, kernel_size=5, stride=2, padding=2), # 16x16 nn.BatchNorm2d(256), nn.LeakyReLU(), Flatten() ) self.fc1 = nn.Linear(h_dim, z_dim) # For mu self.fc2 = nn.Linear(h_dim, z_dim) # For logvar self.fc3 = nn.Linear(z_dim, h_dim) # For reconstruction # Adjusted decoder for reconstructing 512x512 output self.decoder = nn.Sequential( UnFlatten(), nn.ConvTranspose2d(256, 256, kernel_size=5, stride=2, padding=2, output_padding=1), # 32x32 nn.BatchNorm2d(256), nn.LeakyReLU(), nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1), # 64x64 nn.BatchNorm2d(128), nn.LeakyReLU(), nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1), # 128x128 nn.BatchNorm2d(64), nn.LeakyReLU(), nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2, padding=2, output_padding=1), # 256x256 nn.BatchNorm2d(32), nn.LeakyReLU(), nn.ConvTranspose2d(32, image_channels, kernel_size=5, stride=2, padding=2, output_padding=1), # 512x512 nn.BatchNorm2d(image_channels), nn.Sigmoid(), ) self.loss_type = loss_type if use_classification_loss: if loss_type == "standard": self.criterion = nn.CrossEntropyLoss() elif loss_type == "weighted": # Check if class weights are provided if class_weights is None: raise ValueError("For weighted loss, class_weights must be provided.") self.class_weights = torch.tensor(class_weights).to(device) self.criterion = nn.CrossEntropyLoss(weight=self.class_weights) elif loss_type == "focal": self.criterion = FocalLoss() else: raise ValueError(f"Unknown loss_type: {loss_type}") if self.use_classification_loss: assert num_classes is not None, "num_classes must be provided if use_classification_loss is True." self.fc_classify = nn.Sequential( nn.Linear(z_dim, num_classes), nn.Softmax(dim=1) ) def reparameterize(self, mu, logvar): std = logvar.mul(0.5).exp_() eps = torch.randn_like(std).to(std.device) z = mu + std * eps return z def bottleneck(self, h): mu, logvar = self.fc1(h), self.fc2(h) z = self.reparameterize(mu, logvar) if self.use_classification_loss: class_logits = self.fc_classify(z) return z, mu, logvar, class_logits return z, mu, logvar def forward(self, x): if self.use_classification_loss: z, mu, logvar, class_logits = self.bottleneck(self.encoder(x)) z = self.fc3(z) return [self.decoder(z), mu, logvar, class_logits] else: z, mu, logvar = self.bottleneck(self.encoder(x)) z = self.fc3(z) return [self.decoder(z), mu, logvar] def loss_function(self,recons,x,mu,logvar): # Account for the minibatch samples from the dataset; M_N = self.params['batch_size']/ self.num_train_imgs recons_loss =F.mse_loss(recons, x,reduction="sum") kld_loss = torch.sum(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0) loss = recons_loss + self.beta * kld_loss return loss def classification_loss(self, logits, labels): if self.loss_type == "standard": return F.cross_entropy(logits, labels) else: # For both "weighted" and "focal" return self.criterion(logits, labels) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.lr) def training_step(self, train_batch, batch_idx): x, y = train_batch outputs = self(x) recon, mu, logvar = outputs[:3] recon_loss = self.loss_function(recon, x, mu, logvar) if self.use_classification_loss: class_logits = outputs[3] class_loss = self.classification_loss(class_logits, y) self.log('train_class_loss', class_loss) total_loss = 0.5 * recon_loss + 0.5 * class_loss self.log('train_recon_loss', recon_loss) self.log('train_total_loss', total_loss) return total_loss def representation(self, x): return self.bottleneck(self.encoder(x))[0] def validation_step(self, val_batch, batch_idx): x, y = val_batch outputs = self(x) recon, mu, logvar = outputs[:3] recon_loss = self.loss_function(recon, x, mu, logvar) if self.use_classification_loss: class_logits = outputs[3] class_loss = self.classification_loss(class_logits, y) self.log('val_class_loss', class_loss) total_loss = 0.5 * recon_loss + 0.5 * class_loss self.log('val_recon_loss', recon_loss) self.log('val_total_loss', total_loss) return total_loss