Spaces:
Build error
Build error
File size: 6,680 Bytes
991bb01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class UnFlatten(nn.Module):
def forward(self, x):
# Adjusted to match the output of the encoder
return x.view(x.size(0), 256, 16, 16) # Adjusted dimensions
class VAE(pl.LightningModule):
def __init__(self, image_channels=1, h_dim=16*16*256, z_dim=12, lr=1e-3, beta=1, use_classification_loss=True,
num_classes=None, loss_type="standard", class_weights=None, device=None):
super(VAE, self).__init__()
self.lr = lr
self.beta = beta
self.use_classification_loss = use_classification_loss
# Adjusted encoder for 512x512 input
self.encoder = nn.Sequential(
nn.Conv2d(image_channels, 32, kernel_size=5, stride=2, padding=2), # 256x256
nn.BatchNorm2d(32),
nn.LeakyReLU(),
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2), # 128x128
nn.BatchNorm2d(64),
nn.LeakyReLU(),
nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 64x64
nn.BatchNorm2d(128),
nn.LeakyReLU(),
nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 32x32
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.Conv2d(256, 256, kernel_size=5, stride=2, padding=2), # 16x16
nn.BatchNorm2d(256),
nn.LeakyReLU(),
Flatten()
)
self.fc1 = nn.Linear(h_dim, z_dim) # For mu
self.fc2 = nn.Linear(h_dim, z_dim) # For logvar
self.fc3 = nn.Linear(z_dim, h_dim) # For reconstruction
# Adjusted decoder for reconstructing 512x512 output
self.decoder = nn.Sequential(
UnFlatten(),
nn.ConvTranspose2d(256, 256, kernel_size=5, stride=2, padding=2, output_padding=1), # 32x32
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1), # 64x64
nn.BatchNorm2d(128),
nn.LeakyReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1), # 128x128
nn.BatchNorm2d(64),
nn.LeakyReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2, padding=2, output_padding=1), # 256x256
nn.BatchNorm2d(32),
nn.LeakyReLU(),
nn.ConvTranspose2d(32, image_channels, kernel_size=5, stride=2, padding=2, output_padding=1), # 512x512
nn.BatchNorm2d(image_channels),
nn.Sigmoid(),
)
self.loss_type = loss_type
if use_classification_loss:
if loss_type == "standard":
self.criterion = nn.CrossEntropyLoss()
elif loss_type == "weighted":
# Check if class weights are provided
if class_weights is None:
raise ValueError("For weighted loss, class_weights must be provided.")
self.class_weights = torch.tensor(class_weights).to(device)
self.criterion = nn.CrossEntropyLoss(weight=self.class_weights)
elif loss_type == "focal":
self.criterion = FocalLoss()
else:
raise ValueError(f"Unknown loss_type: {loss_type}")
if self.use_classification_loss:
assert num_classes is not None, "num_classes must be provided if use_classification_loss is True."
self.fc_classify = nn.Sequential(
nn.Linear(z_dim, num_classes),
nn.Softmax(dim=1)
)
def reparameterize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
eps = torch.randn_like(std).to(std.device)
z = mu + std * eps
return z
def bottleneck(self, h):
mu, logvar = self.fc1(h), self.fc2(h)
z = self.reparameterize(mu, logvar)
if self.use_classification_loss:
class_logits = self.fc_classify(z)
return z, mu, logvar, class_logits
return z, mu, logvar
def forward(self, x):
if self.use_classification_loss:
z, mu, logvar, class_logits = self.bottleneck(self.encoder(x))
z = self.fc3(z)
return [self.decoder(z), mu, logvar, class_logits]
else:
z, mu, logvar = self.bottleneck(self.encoder(x))
z = self.fc3(z)
return [self.decoder(z), mu, logvar]
def loss_function(self,recons,x,mu,logvar):
# Account for the minibatch samples from the dataset; M_N = self.params['batch_size']/ self.num_train_imgs
recons_loss =F.mse_loss(recons, x,reduction="sum")
kld_loss = torch.sum(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0)
loss = recons_loss + self.beta * kld_loss
return loss
def classification_loss(self, logits, labels):
if self.loss_type == "standard":
return F.cross_entropy(logits, labels)
else: # For both "weighted" and "focal"
return self.criterion(logits, labels)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
outputs = self(x)
recon, mu, logvar = outputs[:3]
recon_loss = self.loss_function(recon, x, mu, logvar)
if self.use_classification_loss:
class_logits = outputs[3]
class_loss = self.classification_loss(class_logits, y)
self.log('train_class_loss', class_loss)
total_loss = 0.5 * recon_loss + 0.5 * class_loss
self.log('train_recon_loss', recon_loss)
self.log('train_total_loss', total_loss)
return total_loss
def representation(self, x):
return self.bottleneck(self.encoder(x))[0]
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
outputs = self(x)
recon, mu, logvar = outputs[:3]
recon_loss = self.loss_function(recon, x, mu, logvar)
if self.use_classification_loss:
class_logits = outputs[3]
class_loss = self.classification_loss(class_logits, y)
self.log('val_class_loss', class_loss)
total_loss = 0.5 * recon_loss + 0.5 * class_loss
self.log('val_recon_loss', recon_loss)
self.log('val_total_loss', total_loss)
return total_loss |