ShapingHistory_Interpulation / VAE_model_tablets_class.py
DaniKaEp's picture
Adding app, data and model files
991bb01 verified
import torch
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class UnFlatten(nn.Module):
def forward(self, x):
# Adjusted to match the output of the encoder
return x.view(x.size(0), 256, 16, 16) # Adjusted dimensions
class VAE(pl.LightningModule):
def __init__(self, image_channels=1, h_dim=16*16*256, z_dim=12, lr=1e-3, beta=1, use_classification_loss=True,
num_classes=None, loss_type="standard", class_weights=None, device=None):
super(VAE, self).__init__()
self.lr = lr
self.beta = beta
self.use_classification_loss = use_classification_loss
# Adjusted encoder for 512x512 input
self.encoder = nn.Sequential(
nn.Conv2d(image_channels, 32, kernel_size=5, stride=2, padding=2), # 256x256
nn.BatchNorm2d(32),
nn.LeakyReLU(),
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2), # 128x128
nn.BatchNorm2d(64),
nn.LeakyReLU(),
nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 64x64
nn.BatchNorm2d(128),
nn.LeakyReLU(),
nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 32x32
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.Conv2d(256, 256, kernel_size=5, stride=2, padding=2), # 16x16
nn.BatchNorm2d(256),
nn.LeakyReLU(),
Flatten()
)
self.fc1 = nn.Linear(h_dim, z_dim) # For mu
self.fc2 = nn.Linear(h_dim, z_dim) # For logvar
self.fc3 = nn.Linear(z_dim, h_dim) # For reconstruction
# Adjusted decoder for reconstructing 512x512 output
self.decoder = nn.Sequential(
UnFlatten(),
nn.ConvTranspose2d(256, 256, kernel_size=5, stride=2, padding=2, output_padding=1), # 32x32
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1), # 64x64
nn.BatchNorm2d(128),
nn.LeakyReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1), # 128x128
nn.BatchNorm2d(64),
nn.LeakyReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2, padding=2, output_padding=1), # 256x256
nn.BatchNorm2d(32),
nn.LeakyReLU(),
nn.ConvTranspose2d(32, image_channels, kernel_size=5, stride=2, padding=2, output_padding=1), # 512x512
nn.BatchNorm2d(image_channels),
nn.Sigmoid(),
)
self.loss_type = loss_type
if use_classification_loss:
if loss_type == "standard":
self.criterion = nn.CrossEntropyLoss()
elif loss_type == "weighted":
# Check if class weights are provided
if class_weights is None:
raise ValueError("For weighted loss, class_weights must be provided.")
self.class_weights = torch.tensor(class_weights).to(device)
self.criterion = nn.CrossEntropyLoss(weight=self.class_weights)
elif loss_type == "focal":
self.criterion = FocalLoss()
else:
raise ValueError(f"Unknown loss_type: {loss_type}")
if self.use_classification_loss:
assert num_classes is not None, "num_classes must be provided if use_classification_loss is True."
self.fc_classify = nn.Sequential(
nn.Linear(z_dim, num_classes),
nn.Softmax(dim=1)
)
def reparameterize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
eps = torch.randn_like(std).to(std.device)
z = mu + std * eps
return z
def bottleneck(self, h):
mu, logvar = self.fc1(h), self.fc2(h)
z = self.reparameterize(mu, logvar)
if self.use_classification_loss:
class_logits = self.fc_classify(z)
return z, mu, logvar, class_logits
return z, mu, logvar
def forward(self, x):
if self.use_classification_loss:
z, mu, logvar, class_logits = self.bottleneck(self.encoder(x))
z = self.fc3(z)
return [self.decoder(z), mu, logvar, class_logits]
else:
z, mu, logvar = self.bottleneck(self.encoder(x))
z = self.fc3(z)
return [self.decoder(z), mu, logvar]
def loss_function(self,recons,x,mu,logvar):
# Account for the minibatch samples from the dataset; M_N = self.params['batch_size']/ self.num_train_imgs
recons_loss =F.mse_loss(recons, x,reduction="sum")
kld_loss = torch.sum(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0)
loss = recons_loss + self.beta * kld_loss
return loss
def classification_loss(self, logits, labels):
if self.loss_type == "standard":
return F.cross_entropy(logits, labels)
else: # For both "weighted" and "focal"
return self.criterion(logits, labels)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
outputs = self(x)
recon, mu, logvar = outputs[:3]
recon_loss = self.loss_function(recon, x, mu, logvar)
if self.use_classification_loss:
class_logits = outputs[3]
class_loss = self.classification_loss(class_logits, y)
self.log('train_class_loss', class_loss)
total_loss = 0.5 * recon_loss + 0.5 * class_loss
self.log('train_recon_loss', recon_loss)
self.log('train_total_loss', total_loss)
return total_loss
def representation(self, x):
return self.bottleneck(self.encoder(x))[0]
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
outputs = self(x)
recon, mu, logvar = outputs[:3]
recon_loss = self.loss_function(recon, x, mu, logvar)
if self.use_classification_loss:
class_logits = outputs[3]
class_loss = self.classification_loss(class_logits, y)
self.log('val_class_loss', class_loss)
total_loss = 0.5 * recon_loss + 0.5 * class_loss
self.log('val_recon_loss', recon_loss)
self.log('val_total_loss', total_loss)
return total_loss