scy639's picture
Upload folder using huggingface_hub
2b534de verified
"""SAMPLING ONLY."""
from confs import *
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from src.Face_models.encoders.model_irse import Backbone
import torch.nn as nn
import torchvision.transforms.functional as TF
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
extract_into_tensor
def un_norm_clip(x1):
x = x1*1.0 # to avoid changing the original tensor or clone() can be used
reduce=False
if len(x.shape)==3:
x = x.unsqueeze(0)
reduce=True
x[:,0,:,:] = x[:,0,:,:] * 0.26862954 + 0.48145466
x[:,1,:,:] = x[:,1,:,:] * 0.26130258 + 0.4578275
x[:,2,:,:] = x[:,2,:,:] * 0.27577711 + 0.40821073
if reduce:
x = x.squeeze(0)
return x
class IDLoss(nn.Module):
def __init__(self,path="Other_dependencies/arcface/model_ir_se50.pth",multiscale=False):
super(IDLoss, self).__init__()
print('Loading ResNet ArcFace')
self.multiscale = multiscale
self.face_pool_1 = torch.nn.AdaptiveAvgPool2d((256, 256))
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
# self.facenet=iresnet100(pretrained=False, fp16=False) # changed by sanoojan
self.facenet.load_state_dict(torch.load(path))
self.face_pool_2 = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
self.set_requires_grad(False)
def set_requires_grad(self, flag=True):
for p in self.parameters():
p.requires_grad = flag
def extract_feats(self, x,clip_img=True):
# breakpoint()
if clip_img:
x = un_norm_clip(x)
x = TF.normalize(x, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
x = self.face_pool_1(x) if x.shape[2]!=256 else x # (1) resize to 256 if needed
x = x[:, :, 35:223, 32:220] # (2) Crop interesting region
x = self.face_pool_2(x) # (3) resize to 112 to fit pre-trained model
# breakpoint()
x_feats = self.facenet(x, multi_scale=self.multiscale )
# x_feats = self.facenet(x) # changed by sanoojan
return x_feats
def forward(self, y_hat, y,clip_img=True,return_seperate=False):
n_samples = y.shape[0]
y_feats_ms = self.extract_feats(y,clip_img=clip_img) # Otherwise use the feature from there
y_hat_feats_ms = self.extract_feats(y_hat,clip_img=clip_img)
y_feats_ms = [y_f.detach() for y_f in y_feats_ms]
loss_all = 0
sim_improvement_all = 0
seperate_losses=[]
for y_hat_feats, y_feats in zip(y_hat_feats_ms, y_feats_ms):
loss = 0
sim_improvement = 0
count = 0
for i in range(n_samples):
sim_target = y_hat_feats[i].dot(y_feats[i])
sim_views = y_feats[i].dot(y_feats[i])
seperate_losses.append(1-sim_target)
loss += 1 - sim_target # id loss
sim_improvement += float(sim_target) - float(sim_views)
count += 1
loss_all += loss / count
sim_improvement_all += sim_improvement / count
return loss_all, sim_improvement_all, None
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
# self.ID_LOSS=IDLoss()
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
**kwargs
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
z_ref=None,
**kwargs):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if z_ref is not None:
tensor_1c = torch.zeros((z_ref.shape[0], 1, z_ref.shape[2], z_ref.shape[3]), device=z_ref.device)
if REFNET.CH9:
z_ref = torch.cat([z_ref, z_ref, tensor_1c], dim=1)
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None: # None
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if z_ref is not None:
z_ref_noisy = self.model.q_sample(x_start=z_ref[:,:4], t=ts, )
if REFNET.CH9:
z_ref[:,:4] = z_ref_noisy
# img and pred_x0 both B,4,64,64; cond/unconditional_conditioning tensors are B,1,768
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
z_ref=z_ref,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,**kwargs)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
z_ref=None,
**kwargs):
"""
0. input param is: (x, c, t, [z_ref] )
1. x=concat(x,inpaint,mask)
2. apply_model(x, t, c, [z_ref] )
( similar to ddpm.py LatentDiffusion.p_losses()
"""
b, *_, device = *x.shape, x.device
if 1:
z_inpaint = kwargs['z_inpaint'] # B,4
z_inpaint_mask = kwargs['z_inpaint_mask'] # B,1
z9 = kwargs['z9'] # B,9or14
# x = torch.cat([x, z_inpaint, z_inpaint_mask],dim=1) # B,9,...
x = torch.cat([x, z9[:,4:] ],dim=1) # B,9or14,...
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c, z_ref=z_ref,)
else: # check @ sanoojan
if MERGE_CFG_in_one_batch:
# b,... -> 2b,...
x_in = torch.cat([x] * 2) #x_in: 2,9,64,64
t_in = torch.cat([t] * 2)
if z_ref is not None:
z_ref_in = torch.cat([z_ref] * 2)
else:
z_ref_in = None
c_in = torch.cat([unconditional_conditioning, c]) #c_in: 2,1,768
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, z_ref=z_ref_in,).chunk(2)
else:
# first infer unconditional then conditional (reduces peak CUDA memory)
e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, z_ref=z_ref,)
e_t = self.model.apply_model(x, t, c, z_ref=z_ref,)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) #1,4,64,64
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if x.shape[1]!=4:
pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
def sample_train(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
t=None,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for DDIM sampling is {size}, eta {eta}')
# for param in self.model.first_stage_model.parameters():
# param.requires_grad = False
samples, intermediates = self.ddim_sampling_train(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,ddim_num_steps=S,
curr_t=t,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
**kwargs
)
return samples, intermediates
def ddim_sampling_train(self, cond, shape,
x_T=None, ddim_use_original_steps=False,ddim_num_steps=None,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,curr_t=None,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
kwargs['rest']=img[:,4:,:,:]
img=img[:,:4,:,:]
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
curr_t=curr_t.cpu().numpy()
skip = (curr_t-1) // ddim_num_steps
# replace all 0s with 1s
skip[skip == 0] = 1
if type(skip)!=int:
seq=[range(1, curr_t[n]-1, skip[n]) for n in range(len(curr_t))]
min_length = min(len(sublist) for sublist in seq)
min_length=min(min_length,ddim_num_steps)
# Create a new list of sublists by truncating each sublist to the minimum length
truncated_seq = [sublist[:min_length] for sublist in seq]
seq= np.array(truncated_seq)
# seq=np.flip(seq)
#concatenate all sequences
# seq = np.concatenate(seq)
seq=torch.from_numpy(seq).to(device)
seq=torch.flip(seq,dims=[1])
intermediates = {'x_inter': [img], 'pred_x0': [img]}
intermediates = {'x_inter': [], 'pred_x0': []}
# time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
# total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps")
# time_range=np.array([1])
# iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
total_steps=seq.shape[1] # 4 (ddim 4 steps)
for i in range(seq.shape[1]):
index = total_steps - i - 1
# ts = torch.full((b,), step, device=device, dtype=torch.long)
ts=seq[:,i].long()
#make it toech long
# ts=ts.long()
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim_train(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,**kwargs)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
# if index % log_every_t == 0 or index == total_steps - 1:
if i in [ total_steps - 1, ]:
# if 1: # len_inter 4 (5 if orig rf) => OOM
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
def p_sample_ddim_train(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,return_features=False,**kwargs):
b, *_, device = *x.shape, x.device
# if 'test_model_kwargs' in kwargs:
# kwargs=kwargs['test_model_kwargs']
# x = torch.cat([x, kwargs['inpaint_image'], kwargs['inpaint_mask']],dim=1)
if 'rest' in kwargs:
x = torch.cat((x, kwargs['rest']), dim=1)
z_ref = kwargs.pop('z_ref',None)
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c,return_features=return_features,z_ref=z_ref)
else: # check @ sanoojan
assert 0
x_in = torch.cat([x] * 2) #x_in: 2,9,64,64
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c]) #c_in: 2,1,768
if return_features:
e_t_uncond, e_t,features = self.model.apply_model(x_in, t_in, c_in,return_features=return_features).chunk(3)
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) #1,4,64,64
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if x.shape[1]!=4:
pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False):
assert 0