image-classification-model-space / model_diffusion.py
PrachiY's picture
initial commit
7a59163 verified
import torch.nn as nn
import torch.nn.functional as F
from utils.diffusion_utils import *
import utils.ResNet_for_32 as resnet_s
import utils.ResNet_for_224 as resnet_l
class ConditionalLinear(nn.Module):
def __init__(self, num_in, num_out, n_steps):
super(ConditionalLinear, self).__init__()
self.num_out = num_out
self.lin = nn.Linear(num_in, num_out)
self.embed = nn.Embedding(n_steps, num_out)
self.embed.weight.data.uniform_()
def forward(self, x, t):
out = self.lin(x)
gamma = self.embed(t)
out = gamma.view(-1, self.num_out) * out
return out
class ConditionalModel(nn.Module):
def __init__(self, n_steps, y_dim=10, fp_dim=128, feature_dim=None, guidance=True):
super(ConditionalModel, self).__init__()
n_steps = n_steps + 1
self.y_dim = y_dim
self.guidance = guidance
self.norm = nn.BatchNorm1d(feature_dim)
# Unet
if self.guidance:
self.lin1 = ConditionalLinear(y_dim + fp_dim, feature_dim, n_steps)
else:
self.lin1 = ConditionalLinear(y_dim, feature_dim, n_steps)
self.unetnorm1 = nn.BatchNorm1d(feature_dim)
self.lin2 = ConditionalLinear(feature_dim, feature_dim, n_steps)
self.unetnorm2 = nn.BatchNorm1d(feature_dim)
self.lin3 = ConditionalLinear(feature_dim, feature_dim, n_steps)
self.unetnorm3 = nn.BatchNorm1d(feature_dim)
self.lin4 = nn.Linear(feature_dim, y_dim)
def forward(self, x_embed, y, t, fp_x=None):
# x_embed = self.encoder_x(x)
x_embed = self.norm(x_embed)
if self.guidance:
y = torch.cat([y, fp_x], dim=-1)
y = self.lin1(y, t)
y = self.unetnorm1(y)
y = F.softplus(y)
y = x_embed * y
y = self.lin2(y, t)
y = self.unetnorm2(y)
y = F.softplus(y)
y = self.lin3(y, t)
y = self.unetnorm3(y)
y = F.softplus(y)
return self.lin4(y)
class Diffusion(nn.Module):
def __init__(self, fp_encoder, num_timesteps=1000, n_class=10, fp_dim=512, device='cuda', beta_schedule='cosine',
feature_dim=2048, encoder_type='resnet50_l', ddim_num_steps=10):
super().__init__()
self.device = device
self.num_timesteps = num_timesteps
self.n_class = n_class
betas = make_beta_schedule(schedule=beta_schedule, num_timesteps=self.num_timesteps, start=0.0001, end=0.02)
betas = self.betas = betas.float().to(self.device)
self.betas_sqrt = torch.sqrt(betas)
alphas = 1.0 - betas
self.alphas = alphas
self.one_minus_betas_sqrt = torch.sqrt(alphas)
self.alphas_cumprod = alphas.cumprod(dim=0)
self.alphas_bar_sqrt = torch.sqrt(self.alphas_cumprod)
self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - self.alphas_cumprod)
alphas_cumprod_prev = torch.cat([torch.ones(1).to(self.device), self.alphas_cumprod[:-1]], dim=0)
self.alphas_cumprod_prev = alphas_cumprod_prev
self.posterior_mean_coeff_1 = (betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - self.alphas_cumprod))
self.posterior_mean_coeff_2 = (torch.sqrt(alphas) * (1 - alphas_cumprod_prev) / (1 - self.alphas_cumprod))
posterior_variance = (betas * (1.0 - alphas_cumprod_prev) / (1.0 - self.alphas_cumprod))
self.posterior_variance = posterior_variance
self.logvar = betas.log()
self.fp_dim = fp_dim
self.fp_encoder = fp_encoder.eval()
self.encoder_type = encoder_type
if encoder_type == 'resnet34':
self.diffusion_encoder = resnet_s.resnet34(num_input_channels=3, num_classes=feature_dim).to(self.device)
elif encoder_type == 'resnet18':
self.diffusion_encoder = resnet_s.resnet18(num_input_channels=3, num_classes=feature_dim).to(self.device)
elif encoder_type == 'resnet50':
self.diffusion_encoder = resnet_s.resnet50(num_input_channels=3, num_classes=feature_dim).to(self.device)
elif encoder_type == 'resnet18_l':
self.diffusion_encoder = resnet_l.resnet18(num_classes=feature_dim, pretrained=True).to(self.device)
elif encoder_type == 'resnet34_l':
self.diffusion_encoder = resnet_l.resnet34(num_classes=feature_dim, pretrained=True).to(self.device)
elif encoder_type == 'resnet50_l':
self.diffusion_encoder = resnet_l.resnet50(num_classes=feature_dim, pretrained=True).to(self.device)
else:
raise Exception("ResNet type should be one of [18, 34, 50]")
self.model = ConditionalModel(self.num_timesteps, y_dim=self.n_class, fp_dim=fp_dim,
feature_dim=feature_dim, guidance=True).to(self.device)
self.ddim_num_steps = ddim_num_steps
self.make_ddim_schedule(ddim_num_steps)
def make_ddim_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.num_timesteps)
assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
self.register_buffer('sqrt_alphas_cumprod', to_torch(torch.sqrt(self.alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(torch.sqrt(1. - self.alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(torch.log(1. - self.alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(torch.sqrt(1. / self.alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(torch.sqrt(1. / self.alphas_cumprod - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod,
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', torch.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
def load_diffusion_net(self, net_state_dicts):
self.model.load_state_dict(net_state_dicts[0])
self.diffusion_encoder.load_state_dict(net_state_dicts[1])
if len(net_state_dicts) == 3:
self.fp_encoder.load_state_dict(net_state_dicts[2])
def forward_t(self, y_0_batch, x_batch, t, fp_x, fq_x=None):
x_batch = x_batch.to(self.device)
e = torch.randn_like(y_0_batch).to(y_0_batch.device)
y_t_batch = q_sample(y_0_batch, self.alphas_bar_sqrt,
self.one_minus_alphas_bar_sqrt, t, noise=e, fq_x=fq_x)
x_embed_batch = self.diffusion_encoder(x_batch)
output = self.model(x_embed_batch, y_t_batch, t, fp_x)
return output, e
def reverse(self, images, only_last_sample=True, stochastic=True, fp_x=None, fq_x=None):
images = images.to(self.device)
with torch.no_grad():
if fp_x is None:
fp_x = self.fp_encoder(images)
label_t_0 = p_sample_loop(self.model, images, fp_x,
self.num_timesteps, self.alphas,
self.one_minus_alphas_bar_sqrt,
only_last_sample=only_last_sample, stochastic=stochastic, fq_x=fq_x)
return label_t_0
def reverse_ddim(self, x_batch, stochastic=True, fp_x=None, fq_x=None):
x_batch = x_batch.to(self.device)
with torch.no_grad():
if fp_x is None:
fp_x = self.fp_encoder(x_batch)
x_embed_batch = self.diffusion_encoder(x_batch)
label_t_0 = ddim_sample_loop(self.model, x_embed_batch, fp_x, self.ddim_timesteps, self.n_class, self.ddim_alphas,
self.ddim_alphas_prev, self.ddim_sigmas, stochastic=stochastic,
fq_x=fq_x)
return label_t_0