Spaces:
Build error
Build error
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os, cv2 | |
| import glob | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| import torch.optim as optim | |
| import torchvision.datasets as datasets | |
| import torchvision.transforms as transforms | |
| from torchvision.utils import make_grid, save_image | |
| from gan_losses import get_gan_losses | |
| from PIL import Image | |
| import torchvision.utils as vutils | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| """## Load Data""" | |
| # data_variance = np.var(training_data.data / 255.0) | |
| data_variance = 1 | |
| def mkdir(dir): | |
| if not os.path.exists(dir): | |
| os.makedirs(dir) | |
| def read_image(img_path): | |
| img = cv2.imread(img_path) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = img / 255.0 | |
| return img | |
| class VectorQuantizer(nn.Module): | |
| def __init__(self, num_embeddings, embedding_dim, commitment_cost): | |
| super(VectorQuantizer, self).__init__() | |
| self._embedding_dim = embedding_dim | |
| self._num_embeddings = num_embeddings | |
| #codebook | |
| self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) | |
| self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings) | |
| self._commitment_cost = commitment_cost | |
| def forward(self, inputs): | |
| # convert inputs from BCHW -> BHWC | |
| inputs = inputs.permute(0, 2, 3, 1).contiguous() | |
| input_shape = inputs.shape | |
| # Flatten input | |
| flat_input = inputs.view(-1, self._embedding_dim) | |
| # Calculate distances | |
| distances = (torch.sum(flat_input**2, dim=1, keepdim=True) | |
| + torch.sum(self._embedding.weight**2, dim=1) | |
| - 2 * torch.matmul(flat_input, self._embedding.weight.t())) | |
| # Encoding | |
| encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) | |
| encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) | |
| encodings.scatter_(1, encoding_indices, 1) | |
| # Quantize and unflatten | |
| quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) | |
| # Loss | |
| e_latent_loss = F.mse_loss(quantized.detach(), inputs) | |
| q_latent_loss = F.mse_loss(quantized, inputs.detach()) | |
| loss = q_latent_loss + self._commitment_cost * e_latent_loss | |
| quantized = inputs + (quantized - inputs).detach() | |
| avg_probs = torch.mean(encodings, dim=0) | |
| perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) | |
| # convert quantized from BHWC -> BCHW | |
| return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encoding_indices | |
| class VectorQuantizerEMA(nn.Module): | |
| def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5): | |
| super(VectorQuantizerEMA, self).__init__() | |
| self._embedding_dim = embedding_dim | |
| self._num_embeddings = num_embeddings | |
| self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) | |
| self._embedding.weight.data.normal_() | |
| self._commitment_cost = commitment_cost | |
| self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings)) | |
| self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) | |
| self._ema_w.data.normal_() | |
| self._decay = decay | |
| self._epsilon = epsilon | |
| def forward(self, inputs): | |
| # convert inputs from BCHW -> BHWC | |
| inputs = inputs.permute(0, 2, 3, 1).contiguous() | |
| input_shape = inputs.shape | |
| # Flatten input | |
| flat_input = inputs.view(-1, self._embedding_dim) | |
| # Calculate distances | |
| distances = (torch.sum(flat_input**2, dim=1, keepdim=True) | |
| + torch.sum(self._embedding.weight**2, dim=1) | |
| - 2 * torch.matmul(flat_input, self._embedding.weight.t())) | |
| # Encoding | |
| encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) | |
| # encoding_indices[encoding_indices == 3] = 4 # 1 means background, 2 means epithelial cells, 4 means connective, 3 means neutrophil, 5 means plasma, 6 lymphocytes | |
| encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) | |
| encodings.scatter_(1, encoding_indices, 1) | |
| # Quantize and unflatten | |
| quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) | |
| # Use EMA to update the embedding vectors | |
| if self.training: | |
| self._ema_cluster_size = self._ema_cluster_size * self._decay + \ | |
| (1 - self._decay) * torch.sum(encodings, 0) | |
| # Laplace smoothing of the cluster size | |
| n = torch.sum(self._ema_cluster_size.data) | |
| self._ema_cluster_size = ( | |
| (self._ema_cluster_size + self._epsilon) | |
| / (n + self._num_embeddings * self._epsilon) * n) | |
| dw = torch.matmul(encodings.t(), flat_input) | |
| self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) | |
| self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) | |
| # Loss | |
| e_latent_loss = F.mse_loss(quantized.detach(), inputs) | |
| loss = self._commitment_cost * e_latent_loss | |
| # Straight Through Estimator | |
| quantized = inputs + (quantized - inputs).detach() | |
| avg_probs = torch.mean(encodings, dim=0) | |
| perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) | |
| # convert quantized from BHWC -> BCHW | |
| return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encoding_indices | |
| class Residual(nn.Module): | |
| def __init__(self, in_channels, num_hiddens, num_residual_hiddens): | |
| super(Residual, self).__init__() | |
| self._block = nn.Sequential( | |
| nn.ReLU(True), | |
| nn.Conv2d(in_channels=in_channels, | |
| out_channels=num_residual_hiddens, | |
| kernel_size=3, stride=1, padding=1, bias=False), | |
| nn.ReLU(True), | |
| nn.Conv2d(in_channels=num_residual_hiddens, | |
| out_channels=num_hiddens, | |
| kernel_size=1, stride=1, bias=False) | |
| ) | |
| def forward(self, x): | |
| return x + self._block(x) | |
| class ResidualStack(nn.Module): | |
| def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): | |
| super(ResidualStack, self).__init__() | |
| self._num_residual_layers = num_residual_layers | |
| self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens) | |
| for _ in range(self._num_residual_layers)]) | |
| def forward(self, x): | |
| for i in range(self._num_residual_layers): | |
| x = self._layers[i](x) | |
| return F.relu(x) | |
| class Encoder(nn.Module): | |
| def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens, embedding_dim): | |
| super(Encoder, self).__init__() | |
| self._conv_1 = nn.Conv2d(in_channels=in_channels, | |
| out_channels=num_hiddens//2, | |
| kernel_size=4, | |
| stride=2, padding=1) | |
| self._conv_2 = nn.Conv2d(in_channels=num_hiddens//2, | |
| out_channels=num_hiddens, | |
| kernel_size=4, | |
| stride=2, padding=1) | |
| self._conv_3 = nn.Conv2d(in_channels=num_hiddens, | |
| out_channels=num_hiddens, | |
| kernel_size=3, | |
| stride=1, padding=1) | |
| self._residual_stack = ResidualStack(in_channels=num_hiddens, | |
| num_hiddens=num_hiddens, | |
| num_residual_layers=num_residual_layers, | |
| num_residual_hiddens=num_residual_hiddens) | |
| self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens, | |
| out_channels=embedding_dim, | |
| kernel_size=1, | |
| stride=1) | |
| self.apply_tanh = nn.Tanh() | |
| def forward(self, inputs): | |
| x = self._conv_1(inputs) | |
| x = F.relu(x) | |
| x = self._conv_2(x) | |
| x = F.relu(x) | |
| x = self._conv_3(x) | |
| x = self._residual_stack(x) | |
| x = self._pre_vq_conv(x) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): | |
| super(Decoder, self).__init__() | |
| self._conv_1 = nn.Conv2d(in_channels=in_channels, | |
| out_channels=num_hiddens, | |
| kernel_size=3, | |
| stride=1, padding=1) | |
| self._residual_stack = ResidualStack(in_channels=num_hiddens, | |
| num_hiddens=num_hiddens, | |
| num_residual_layers=num_residual_layers, | |
| num_residual_hiddens=num_residual_hiddens) | |
| self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens, | |
| out_channels=num_hiddens//2, | |
| kernel_size=4, | |
| stride=2, padding=1) | |
| self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2, | |
| out_channels=3, | |
| kernel_size=4, | |
| stride=2, padding=1) | |
| self.apply_tanh = nn.Tanh() | |
| def forward(self, inputs): | |
| x = self._conv_1(inputs) | |
| x = self._residual_stack(x) | |
| x = self._conv_trans_1(x) | |
| x = F.relu(x) | |
| x = self._conv_trans_2(x) | |
| return self.apply_tanh(x) | |
| class VQModel(nn.Module): | |
| def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, | |
| num_embeddings, embedding_dim, commitment_cost, decay=0): | |
| super(VQModel, self).__init__() | |
| self._encoder = Encoder(3, num_hiddens, | |
| num_residual_layers, | |
| num_residual_hiddens, | |
| embedding_dim) | |
| if decay > 0.0: | |
| self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, | |
| commitment_cost, decay) | |
| else: | |
| self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, | |
| commitment_cost) | |
| self._decoder = Decoder(embedding_dim, | |
| num_hiddens, | |
| num_residual_layers, | |
| num_residual_hiddens) | |
| def forward(self, x): | |
| z = self._encoder(x) | |
| loss, quantized, perplexity, _ = self._vq_vae(z) | |
| x_recon = self._decoder(quantized) | |
| return loss, x_recon, perplexity | |
| def save_generated_images(image_names, batch_images, ind, mode, type): | |
| current_output_dir = os.path.join(output_dir, mode, type) | |
| mkdir(current_output_dir) | |
| num_images = batch_images.shape[0] | |
| for i in range(0,num_images): | |
| save_image(batch_images[i], os.path.join(current_output_dir,image_names[i])) | |
| def generate_images_from_diffusion_latents(model, latents_path, output_dir): | |
| latent_paths = glob.glob(os.path.join(latents_path, "*.pt")) | |
| for latent_path in latent_paths: | |
| latent = torch.load(latent_path).cuda() | |
| latent = latent.detach() | |
| _, quantized_latent, _, _ = model._vq_vae(latent) | |
| image = model._decoder(quantized_latent) | |
| image_name = os.path.basename(latent_path).split(".")[0]+".png" | |
| save_image(image, os.path.join(output_dir, image_name)) | |
| class UNetDown(nn.Module): | |
| def __init__(self, in_size, out_size, normalize=True, dropout=0.0): | |
| super(UNetDown, self).__init__() | |
| layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)] | |
| if normalize: | |
| layers.append(nn.InstanceNorm2d(out_size)) | |
| layers.append(nn.LeakyReLU(0.2)) | |
| if dropout: | |
| layers.append(nn.Dropout(dropout)) | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.model(x) | |
| class UNetUp(nn.Module): | |
| def __init__(self, in_size, out_size, dropout=0.0): | |
| super(UNetUp, self).__init__() | |
| layers = [ | |
| nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False), | |
| nn.InstanceNorm2d(out_size), | |
| nn.ReLU(inplace=True), | |
| ] | |
| if dropout: | |
| layers.append(nn.Dropout(dropout)) | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x, skip_input): | |
| x = self.model(x) | |
| x = torch.cat((x, skip_input), 1) | |
| return x | |
| class Pix2PixGenerator(nn.Module): | |
| def __init__(self, in_channels=3, out_channels=3): | |
| super(Pix2PixGenerator, self).__init__() | |
| self.down1 = UNetDown(in_channels, 64, normalize=False) | |
| self.down2 = UNetDown(64, 128) | |
| self.down3 = UNetDown(128, 256) | |
| self.down4 = UNetDown(256, 512, dropout=0.5) | |
| self.down5 = UNetDown(512, 512, dropout=0.5) | |
| self.down6 = UNetDown(512, 512, dropout=0.5) | |
| self.down7 = UNetDown(512, 512, dropout=0.5) | |
| self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) | |
| self.up1 = UNetUp(512, 512, dropout=0.5) | |
| self.up2 = UNetUp(1024, 512, dropout=0.5) | |
| self.up3 = UNetUp(1024, 512, dropout=0.5) | |
| self.up4 = UNetUp(1024, 512, dropout=0.5) | |
| self.up5 = UNetUp(1024, 256) | |
| self.up6 = UNetUp(512, 128) | |
| self.up7 = UNetUp(256, 64) | |
| self.final = nn.Sequential( | |
| nn.Upsample(scale_factor=2), | |
| nn.ZeroPad2d((1, 0, 1, 0)), | |
| nn.Conv2d(128, out_channels, 4, padding=1), | |
| nn.Tanh(), | |
| ) | |
| def forward(self, x): | |
| # U-Net generator with skip connections from encoder to decoder | |
| d1 = self.down1(x) | |
| d2 = self.down2(d1) | |
| d3 = self.down3(d2) | |
| d4 = self.down4(d3) | |
| d5 = self.down5(d4) | |
| d6 = self.down6(d5) | |
| d7 = self.down7(d6) | |
| d8 = self.down8(d7) | |
| u1 = self.up1(d8, d7) | |
| u2 = self.up2(u1, d6) | |
| u3 = self.up3(u2, d5) | |
| u4 = self.up4(u3, d4) | |
| u5 = self.up5(u4, d3) | |
| u6 = self.up6(u5, d2) | |
| u7 = self.up7(u6, d1) | |
| return self.final(u7) | |
| batch_size = 32 #Keep 16 for good results | |
| num_training_updates = 30000 | |
| num_hiddens = 32 #Original: 128 , 32 used for masks | |
| num_residual_hiddens = 32 | |
| num_residual_layers = 2 #Original was 2 | |
| embedding_dim = 3 | |
| num_embeddings = 2 #number of codebook vectors | |
| commitment_cost = 0.25 | |
| decay = 0.99 | |
| model_name = "dp_bimask_2dim_1024size_tanhindecoder.pt" | |
| def create_mask(model_dir, latents_path, final_output_dir): | |
| model = VQModel(num_hiddens, num_residual_layers, num_residual_hiddens, | |
| num_embeddings, embedding_dim, | |
| commitment_cost, decay).to(device) | |
| model.load_state_dict(torch.load(os.path.join(model_dir,model_name))) | |
| model.eval() | |
| mkdir(final_output_dir) | |
| generate_images_from_diffusion_latents(model=model, | |
| latents_path=latents_path, | |
| output_dir=final_output_dir) |