import torch.nn as nn import torch.nn.functional as F from utils.diffusion_utils import * import utils.ResNet_for_32 as resnet_s import utils.ResNet_for_224 as resnet_l class ConditionalLinear(nn.Module): def __init__(self, num_in, num_out, n_steps): super(ConditionalLinear, self).__init__() self.num_out = num_out self.lin = nn.Linear(num_in, num_out) self.embed = nn.Embedding(n_steps, num_out) self.embed.weight.data.uniform_() def forward(self, x, t): out = self.lin(x) gamma = self.embed(t) out = gamma.view(-1, self.num_out) * out return out class ConditionalModel(nn.Module): def __init__(self, n_steps, y_dim=10, fp_dim=128, feature_dim=None, guidance=True): super(ConditionalModel, self).__init__() n_steps = n_steps + 1 self.y_dim = y_dim self.guidance = guidance self.norm = nn.BatchNorm1d(feature_dim) # Unet if self.guidance: self.lin1 = ConditionalLinear(y_dim + fp_dim, feature_dim, n_steps) else: self.lin1 = ConditionalLinear(y_dim, feature_dim, n_steps) self.unetnorm1 = nn.BatchNorm1d(feature_dim) self.lin2 = ConditionalLinear(feature_dim, feature_dim, n_steps) self.unetnorm2 = nn.BatchNorm1d(feature_dim) self.lin3 = ConditionalLinear(feature_dim, feature_dim, n_steps) self.unetnorm3 = nn.BatchNorm1d(feature_dim) self.lin4 = nn.Linear(feature_dim, y_dim) def forward(self, x_embed, y, t, fp_x=None): # x_embed = self.encoder_x(x) x_embed = self.norm(x_embed) if self.guidance: y = torch.cat([y, fp_x], dim=-1) y = self.lin1(y, t) y = self.unetnorm1(y) y = F.softplus(y) y = x_embed * y y = self.lin2(y, t) y = self.unetnorm2(y) y = F.softplus(y) y = self.lin3(y, t) y = self.unetnorm3(y) y = F.softplus(y) return self.lin4(y) class Diffusion(nn.Module): def __init__(self, fp_encoder, num_timesteps=1000, n_class=10, fp_dim=512, device='cuda', beta_schedule='cosine', feature_dim=2048, encoder_type='resnet50_l', ddim_num_steps=10): super().__init__() self.device = device self.num_timesteps = num_timesteps self.n_class = n_class betas = make_beta_schedule(schedule=beta_schedule, num_timesteps=self.num_timesteps, start=0.0001, end=0.02) betas = self.betas = betas.float().to(self.device) self.betas_sqrt = torch.sqrt(betas) alphas = 1.0 - betas self.alphas = alphas self.one_minus_betas_sqrt = torch.sqrt(alphas) self.alphas_cumprod = alphas.cumprod(dim=0) self.alphas_bar_sqrt = torch.sqrt(self.alphas_cumprod) self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - self.alphas_cumprod) alphas_cumprod_prev = torch.cat([torch.ones(1).to(self.device), self.alphas_cumprod[:-1]], dim=0) self.alphas_cumprod_prev = alphas_cumprod_prev self.posterior_mean_coeff_1 = (betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) self.posterior_mean_coeff_2 = (torch.sqrt(alphas) * (1 - alphas_cumprod_prev) / (1 - self.alphas_cumprod)) posterior_variance = (betas * (1.0 - alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) self.posterior_variance = posterior_variance self.logvar = betas.log() self.fp_dim = fp_dim self.fp_encoder = fp_encoder.eval() self.encoder_type = encoder_type if encoder_type == 'resnet34': self.diffusion_encoder = resnet_s.resnet34(num_input_channels=3, num_classes=feature_dim).to(self.device) elif encoder_type == 'resnet18': self.diffusion_encoder = resnet_s.resnet18(num_input_channels=3, num_classes=feature_dim).to(self.device) elif encoder_type == 'resnet50': self.diffusion_encoder = resnet_s.resnet50(num_input_channels=3, num_classes=feature_dim).to(self.device) elif encoder_type == 'resnet18_l': self.diffusion_encoder = resnet_l.resnet18(num_classes=feature_dim, pretrained=True).to(self.device) elif encoder_type == 'resnet34_l': self.diffusion_encoder = resnet_l.resnet34(num_classes=feature_dim, pretrained=True).to(self.device) elif encoder_type == 'resnet50_l': self.diffusion_encoder = resnet_l.resnet50(num_classes=feature_dim, pretrained=True).to(self.device) else: raise Exception("ResNet type should be one of [18, 34, 50]") self.model = ConditionalModel(self.num_timesteps, y_dim=self.n_class, fp_dim=fp_dim, feature_dim=feature_dim, guidance=True).to(self.device) self.ddim_num_steps = ddim_num_steps self.make_ddim_schedule(ddim_num_steps) def make_ddim_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.): self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.num_timesteps) assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) self.register_buffer('sqrt_alphas_cumprod', to_torch(torch.sqrt(self.alphas_cumprod))) self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(torch.sqrt(1. - self.alphas_cumprod))) self.register_buffer('log_one_minus_alphas_cumprod', to_torch(torch.log(1. - self.alphas_cumprod))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(torch.sqrt(1. / self.alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(torch.sqrt(1. / self.alphas_cumprod - 1))) # ddim sampling parameters ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod, ddim_timesteps=self.ddim_timesteps, eta=ddim_eta) self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_sqrt_one_minus_alphas', torch.sqrt(1. - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) def load_diffusion_net(self, net_state_dicts): self.model.load_state_dict(net_state_dicts[0]) self.diffusion_encoder.load_state_dict(net_state_dicts[1]) if len(net_state_dicts) == 3: self.fp_encoder.load_state_dict(net_state_dicts[2]) def forward_t(self, y_0_batch, x_batch, t, fp_x, fq_x=None): x_batch = x_batch.to(self.device) e = torch.randn_like(y_0_batch).to(y_0_batch.device) y_t_batch = q_sample(y_0_batch, self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, t, noise=e, fq_x=fq_x) x_embed_batch = self.diffusion_encoder(x_batch) output = self.model(x_embed_batch, y_t_batch, t, fp_x) return output, e def reverse(self, images, only_last_sample=True, stochastic=True, fp_x=None, fq_x=None): images = images.to(self.device) with torch.no_grad(): if fp_x is None: fp_x = self.fp_encoder(images) label_t_0 = p_sample_loop(self.model, images, fp_x, self.num_timesteps, self.alphas, self.one_minus_alphas_bar_sqrt, only_last_sample=only_last_sample, stochastic=stochastic, fq_x=fq_x) return label_t_0 def reverse_ddim(self, x_batch, stochastic=True, fp_x=None, fq_x=None): x_batch = x_batch.to(self.device) with torch.no_grad(): if fp_x is None: fp_x = self.fp_encoder(x_batch) x_embed_batch = self.diffusion_encoder(x_batch) label_t_0 = ddim_sample_loop(self.model, x_embed_batch, fp_x, self.ddim_timesteps, self.n_class, self.ddim_alphas, self.ddim_alphas_prev, self.ddim_sigmas, stochastic=stochastic, fq_x=fq_x) return label_t_0