ifwi
commited on
Commit
·
3c454af
1
Parent(s):
63a2e47
add more
Browse files- ldm/models/diffusion/ddpm.py +30 -25
- ldm/modules/attention.py +37 -29
ldm/models/diffusion/ddpm.py
CHANGED
|
@@ -47,6 +47,7 @@ def disabled_train(self, mode=True):
|
|
| 47 |
def uniform_on_device(r1, r2, shape, device):
|
| 48 |
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
| 49 |
|
|
|
|
| 50 |
class DDPM(pl.LightningModule):
|
| 51 |
# classic DDPM with Gaussian diffusion, in image space
|
| 52 |
def __init__(self,
|
|
@@ -124,7 +125,8 @@ class DDPM(pl.LightningModule):
|
|
| 124 |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
| 125 |
if reset_ema:
|
| 126 |
assert self.use_ema
|
| 127 |
-
print(
|
|
|
|
| 128 |
self.model_ema = LitEma(self.model)
|
| 129 |
if reset_num_ema_updates:
|
| 130 |
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
|
|
@@ -573,7 +575,7 @@ class LatentDiffusion(DDPM):
|
|
| 573 |
self.scale_factor = scale_factor
|
| 574 |
else:
|
| 575 |
self.register_buffer('scale_factor', torch.tensor(scale_factor))
|
| 576 |
-
|
| 577 |
self.instantiate_first_stage(first_stage_config)
|
| 578 |
self.instantiate_cond_stage(cond_stage_config)
|
| 579 |
self.cond_stage_forward = cond_stage_forward
|
|
@@ -586,7 +588,7 @@ class LatentDiffusion(DDPM):
|
|
| 586 |
self.proj_out = None
|
| 587 |
if self.use_pbe_weight:
|
| 588 |
print("learnable vector gene")
|
| 589 |
-
self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
|
| 590 |
else:
|
| 591 |
self.learnable_vector = None
|
| 592 |
|
|
@@ -608,7 +610,7 @@ class LatentDiffusion(DDPM):
|
|
| 608 |
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
|
| 609 |
assert self.use_ema
|
| 610 |
self.model_ema.reset_num_updates()
|
| 611 |
-
|
| 612 |
def make_cond_schedule(self, ):
|
| 613 |
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
| 614 |
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
|
@@ -646,7 +648,7 @@ class LatentDiffusion(DDPM):
|
|
| 646 |
self.first_stage_model.train = disabled_train
|
| 647 |
for param in self.first_stage_model.parameters():
|
| 648 |
param.requires_grad = False
|
| 649 |
-
|
| 650 |
def instantiate_cond_stage(self, config):
|
| 651 |
if not self.cond_stage_trainable:
|
| 652 |
if config == "__is_first_stage__":
|
|
@@ -791,14 +793,15 @@ class LatentDiffusion(DDPM):
|
|
| 791 |
|
| 792 |
@torch.no_grad()
|
| 793 |
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
| 794 |
-
cond_key=None, return_original_cond=False, bs=None, return_x=False, no_latent=False,
|
|
|
|
| 795 |
x = super().get_input(batch, k)
|
| 796 |
if bs is not None:
|
| 797 |
x = x[:bs]
|
| 798 |
x = x.to(self.device)
|
| 799 |
if no_latent:
|
| 800 |
-
_,_,h,w = x.shape
|
| 801 |
-
x = resize(x, (h//8, w//8))
|
| 802 |
return [x, None]
|
| 803 |
encoder_posterior = self.encode_first_stage(x)
|
| 804 |
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
|
@@ -815,12 +818,12 @@ class LatentDiffusion(DDPM):
|
|
| 815 |
xc = batch
|
| 816 |
else:
|
| 817 |
xc = super().get_input(batch, cond_key).to(self.device)
|
| 818 |
-
else:
|
| 819 |
xc = x
|
| 820 |
if not self.cond_stage_trainable or force_c_encode:
|
| 821 |
if self.kwargs["use_imageCLIP"]:
|
| 822 |
-
xc = resize(xc, (224,224))
|
| 823 |
-
xc = self.imagenet_norm((xc+1)/2)
|
| 824 |
c = xc
|
| 825 |
else:
|
| 826 |
if isinstance(xc, dict) or isinstance(xc, list):
|
|
@@ -830,8 +833,8 @@ class LatentDiffusion(DDPM):
|
|
| 830 |
c = c.float()
|
| 831 |
else:
|
| 832 |
if self.kwargs["use_imageCLIP"]:
|
| 833 |
-
xc = resize(xc, (224,224))
|
| 834 |
-
xc = self.imagenet_norm((xc+1)/2)
|
| 835 |
c = xc
|
| 836 |
if bs is not None:
|
| 837 |
c = c[:bs]
|
|
@@ -847,7 +850,7 @@ class LatentDiffusion(DDPM):
|
|
| 847 |
if self.use_positional_encodings:
|
| 848 |
pos_x, pos_y = self.compute_latent_shifts(batch)
|
| 849 |
c = {'pos_x': pos_x, 'pos_y': pos_y}
|
| 850 |
-
|
| 851 |
out = [z, c]
|
| 852 |
if return_first_stage_outputs:
|
| 853 |
xrec = self.decode_first_stage(z)
|
|
@@ -872,6 +875,7 @@ class LatentDiffusion(DDPM):
|
|
| 872 |
return output
|
| 873 |
else:
|
| 874 |
return output.sample
|
|
|
|
| 875 |
def decode_first_stage_train(self, z, predict_cids=False, force_not_quantize=False):
|
| 876 |
if predict_cids:
|
| 877 |
if z.dim() == 4:
|
|
@@ -905,12 +909,11 @@ class LatentDiffusion(DDPM):
|
|
| 905 |
# pbe negative condition
|
| 906 |
else:
|
| 907 |
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
| 908 |
-
self.u_cond_prop=random.uniform(0, 1)
|
| 909 |
c["c_crossattn"] = [self.get_learned_conditioning(c["c_crossattn"])]
|
| 910 |
if self.u_cond_prop < self.u_cond_percent:
|
| 911 |
-
c["c_crossattn"] = [self.learnable_vector.repeat(x.shape[0],1,1)]
|
| 912 |
return self.p_losses(x, c, t, *args, **kwargs)
|
| 913 |
-
|
| 914 |
|
| 915 |
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
| 916 |
if isinstance(cond, dict):
|
|
@@ -931,7 +934,7 @@ class LatentDiffusion(DDPM):
|
|
| 931 |
|
| 932 |
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
| 933 |
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
|
| 934 |
-
|
| 935 |
|
| 936 |
def _prior_bpd(self, x_start):
|
| 937 |
"""
|
|
@@ -946,6 +949,7 @@ class LatentDiffusion(DDPM):
|
|
| 946 |
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
| 947 |
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
| 948 |
return mean_flat(kl_prior) / np.log(2.0)
|
|
|
|
| 949 |
def p_losses(self, x_start, cond, t, noise=None):
|
| 950 |
loss_dict = {}
|
| 951 |
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
@@ -969,11 +973,11 @@ class LatentDiffusion(DDPM):
|
|
| 969 |
if self.only_agn_simple_loss:
|
| 970 |
_, _, l_h, l_w = model_output.shape
|
| 971 |
m_agn = F.interpolate(super().get_input(self.batch, "agn_mask"), (l_h, l_w))
|
| 972 |
-
loss_simple = self.get_loss(model_output * (1-m_agn), target * (1-m_agn), mean=False).mean([1, 2, 3])
|
| 973 |
else:
|
| 974 |
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
|
| 975 |
loss_dict.update({f'simple': loss_simple.mean()})
|
| 976 |
-
|
| 977 |
logvar_t = self.logvar[t].to(self.device)
|
| 978 |
loss = loss_simple / torch.exp(logvar_t) + logvar_t
|
| 979 |
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
|
|
@@ -981,7 +985,7 @@ class LatentDiffusion(DDPM):
|
|
| 981 |
loss_dict.update({f'gamma': loss.mean()})
|
| 982 |
loss_dict.update({'logvar': self.logvar.data.mean()})
|
| 983 |
loss = self.l_simple_weight * loss.mean()
|
| 984 |
-
|
| 985 |
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
|
| 986 |
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
|
| 987 |
if self.original_elbo_weight != 0:
|
|
@@ -990,7 +994,7 @@ class LatentDiffusion(DDPM):
|
|
| 990 |
|
| 991 |
if model_loss is not None:
|
| 992 |
loss += model_loss
|
| 993 |
-
loss_dict.update({f"model loss"
|
| 994 |
loss_dict.update({f'{prefix}_loss': loss})
|
| 995 |
|
| 996 |
return loss, loss_dict
|
|
@@ -1540,7 +1544,7 @@ class LatentUpscaleDiffusion(LatentDiffusion):
|
|
| 1540 |
uc[k] = [uc_tmp]
|
| 1541 |
elif k == "c_adm": # todo: only run with text-based guidance?
|
| 1542 |
assert isinstance(c[k], torch.Tensor)
|
| 1543 |
-
#uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
|
| 1544 |
uc[k] = c[k]
|
| 1545 |
elif isinstance(c[k], list):
|
| 1546 |
uc[k] = [c[k][i] for i in range(len(c[k]))]
|
|
@@ -1807,7 +1811,7 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
|
|
| 1807 |
log = super().log_images(*args, **kwargs)
|
| 1808 |
depth = self.depth_model(args[0][self.depth_stage_key])
|
| 1809 |
depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
|
| 1810 |
-
|
| 1811 |
log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
|
| 1812 |
return log
|
| 1813 |
|
|
@@ -1816,6 +1820,7 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
|
| 1816 |
"""
|
| 1817 |
condition on low-res image (and optionally on some spatial noise augmentation)
|
| 1818 |
"""
|
|
|
|
| 1819 |
def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
|
| 1820 |
low_scale_config=None, low_scale_key=None, *args, **kwargs):
|
| 1821 |
super().__init__(concat_keys=concat_keys, *args, **kwargs)
|
|
@@ -1872,4 +1877,4 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
|
| 1872 |
def log_images(self, *args, **kwargs):
|
| 1873 |
log = super().log_images(*args, **kwargs)
|
| 1874 |
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
| 1875 |
-
return log
|
|
|
|
| 47 |
def uniform_on_device(r1, r2, shape, device):
|
| 48 |
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
| 49 |
|
| 50 |
+
|
| 51 |
class DDPM(pl.LightningModule):
|
| 52 |
# classic DDPM with Gaussian diffusion, in image space
|
| 53 |
def __init__(self,
|
|
|
|
| 125 |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
| 126 |
if reset_ema:
|
| 127 |
assert self.use_ema
|
| 128 |
+
print(
|
| 129 |
+
f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
|
| 130 |
self.model_ema = LitEma(self.model)
|
| 131 |
if reset_num_ema_updates:
|
| 132 |
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
|
|
|
|
| 575 |
self.scale_factor = scale_factor
|
| 576 |
else:
|
| 577 |
self.register_buffer('scale_factor', torch.tensor(scale_factor))
|
| 578 |
+
|
| 579 |
self.instantiate_first_stage(first_stage_config)
|
| 580 |
self.instantiate_cond_stage(cond_stage_config)
|
| 581 |
self.cond_stage_forward = cond_stage_forward
|
|
|
|
| 588 |
self.proj_out = None
|
| 589 |
if self.use_pbe_weight:
|
| 590 |
print("learnable vector gene")
|
| 591 |
+
self.learnable_vector = nn.Parameter(torch.randn((1, 1, 768)), requires_grad=True)
|
| 592 |
else:
|
| 593 |
self.learnable_vector = None
|
| 594 |
|
|
|
|
| 610 |
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
|
| 611 |
assert self.use_ema
|
| 612 |
self.model_ema.reset_num_updates()
|
| 613 |
+
|
| 614 |
def make_cond_schedule(self, ):
|
| 615 |
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
| 616 |
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
|
|
|
| 648 |
self.first_stage_model.train = disabled_train
|
| 649 |
for param in self.first_stage_model.parameters():
|
| 650 |
param.requires_grad = False
|
| 651 |
+
|
| 652 |
def instantiate_cond_stage(self, config):
|
| 653 |
if not self.cond_stage_trainable:
|
| 654 |
if config == "__is_first_stage__":
|
|
|
|
| 793 |
|
| 794 |
@torch.no_grad()
|
| 795 |
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
| 796 |
+
cond_key=None, return_original_cond=False, bs=None, return_x=False, no_latent=False,
|
| 797 |
+
is_controlnet=False):
|
| 798 |
x = super().get_input(batch, k)
|
| 799 |
if bs is not None:
|
| 800 |
x = x[:bs]
|
| 801 |
x = x.to(self.device)
|
| 802 |
if no_latent:
|
| 803 |
+
_, _, h, w = x.shape
|
| 804 |
+
x = resize(x, (h // 8, w // 8))
|
| 805 |
return [x, None]
|
| 806 |
encoder_posterior = self.encode_first_stage(x)
|
| 807 |
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
|
|
|
| 818 |
xc = batch
|
| 819 |
else:
|
| 820 |
xc = super().get_input(batch, cond_key).to(self.device)
|
| 821 |
+
else:
|
| 822 |
xc = x
|
| 823 |
if not self.cond_stage_trainable or force_c_encode:
|
| 824 |
if self.kwargs["use_imageCLIP"]:
|
| 825 |
+
xc = resize(xc, (224, 224))
|
| 826 |
+
xc = self.imagenet_norm((xc + 1) / 2)
|
| 827 |
c = xc
|
| 828 |
else:
|
| 829 |
if isinstance(xc, dict) or isinstance(xc, list):
|
|
|
|
| 833 |
c = c.float()
|
| 834 |
else:
|
| 835 |
if self.kwargs["use_imageCLIP"]:
|
| 836 |
+
xc = resize(xc, (224, 224))
|
| 837 |
+
xc = self.imagenet_norm((xc + 1) / 2)
|
| 838 |
c = xc
|
| 839 |
if bs is not None:
|
| 840 |
c = c[:bs]
|
|
|
|
| 850 |
if self.use_positional_encodings:
|
| 851 |
pos_x, pos_y = self.compute_latent_shifts(batch)
|
| 852 |
c = {'pos_x': pos_x, 'pos_y': pos_y}
|
| 853 |
+
|
| 854 |
out = [z, c]
|
| 855 |
if return_first_stage_outputs:
|
| 856 |
xrec = self.decode_first_stage(z)
|
|
|
|
| 875 |
return output
|
| 876 |
else:
|
| 877 |
return output.sample
|
| 878 |
+
|
| 879 |
def decode_first_stage_train(self, z, predict_cids=False, force_not_quantize=False):
|
| 880 |
if predict_cids:
|
| 881 |
if z.dim() == 4:
|
|
|
|
| 909 |
# pbe negative condition
|
| 910 |
else:
|
| 911 |
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
| 912 |
+
self.u_cond_prop = random.uniform(0, 1)
|
| 913 |
c["c_crossattn"] = [self.get_learned_conditioning(c["c_crossattn"])]
|
| 914 |
if self.u_cond_prop < self.u_cond_percent:
|
| 915 |
+
c["c_crossattn"] = [self.learnable_vector.repeat(x.shape[0], 1, 1)]
|
| 916 |
return self.p_losses(x, c, t, *args, **kwargs)
|
|
|
|
| 917 |
|
| 918 |
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
| 919 |
if isinstance(cond, dict):
|
|
|
|
| 934 |
|
| 935 |
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
| 936 |
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
|
| 937 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 938 |
|
| 939 |
def _prior_bpd(self, x_start):
|
| 940 |
"""
|
|
|
|
| 949 |
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
| 950 |
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
| 951 |
return mean_flat(kl_prior) / np.log(2.0)
|
| 952 |
+
|
| 953 |
def p_losses(self, x_start, cond, t, noise=None):
|
| 954 |
loss_dict = {}
|
| 955 |
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
| 973 |
if self.only_agn_simple_loss:
|
| 974 |
_, _, l_h, l_w = model_output.shape
|
| 975 |
m_agn = F.interpolate(super().get_input(self.batch, "agn_mask"), (l_h, l_w))
|
| 976 |
+
loss_simple = self.get_loss(model_output * (1 - m_agn), target * (1 - m_agn), mean=False).mean([1, 2, 3])
|
| 977 |
else:
|
| 978 |
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
|
| 979 |
loss_dict.update({f'simple': loss_simple.mean()})
|
| 980 |
+
|
| 981 |
logvar_t = self.logvar[t].to(self.device)
|
| 982 |
loss = loss_simple / torch.exp(logvar_t) + logvar_t
|
| 983 |
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
|
|
|
|
| 985 |
loss_dict.update({f'gamma': loss.mean()})
|
| 986 |
loss_dict.update({'logvar': self.logvar.data.mean()})
|
| 987 |
loss = self.l_simple_weight * loss.mean()
|
| 988 |
+
|
| 989 |
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
|
| 990 |
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
|
| 991 |
if self.original_elbo_weight != 0:
|
|
|
|
| 994 |
|
| 995 |
if model_loss is not None:
|
| 996 |
loss += model_loss
|
| 997 |
+
loss_dict.update({f"model loss": model_loss})
|
| 998 |
loss_dict.update({f'{prefix}_loss': loss})
|
| 999 |
|
| 1000 |
return loss, loss_dict
|
|
|
|
| 1544 |
uc[k] = [uc_tmp]
|
| 1545 |
elif k == "c_adm": # todo: only run with text-based guidance?
|
| 1546 |
assert isinstance(c[k], torch.Tensor)
|
| 1547 |
+
# uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
|
| 1548 |
uc[k] = c[k]
|
| 1549 |
elif isinstance(c[k], list):
|
| 1550 |
uc[k] = [c[k][i] for i in range(len(c[k]))]
|
|
|
|
| 1811 |
log = super().log_images(*args, **kwargs)
|
| 1812 |
depth = self.depth_model(args[0][self.depth_stage_key])
|
| 1813 |
depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
|
| 1814 |
+
torch.amax(depth, dim=[1, 2, 3], keepdim=True)
|
| 1815 |
log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
|
| 1816 |
return log
|
| 1817 |
|
|
|
|
| 1820 |
"""
|
| 1821 |
condition on low-res image (and optionally on some spatial noise augmentation)
|
| 1822 |
"""
|
| 1823 |
+
|
| 1824 |
def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
|
| 1825 |
low_scale_config=None, low_scale_key=None, *args, **kwargs):
|
| 1826 |
super().__init__(concat_keys=concat_keys, *args, **kwargs)
|
|
|
|
| 1877 |
def log_images(self, *args, **kwargs):
|
| 1878 |
log = super().log_images(*args, **kwargs)
|
| 1879 |
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
| 1880 |
+
return log
|
ldm/modules/attention.py
CHANGED
|
@@ -12,20 +12,23 @@ from ldm.modules.diffusionmodules.util import checkpoint
|
|
| 12 |
try:
|
| 13 |
import xformers
|
| 14 |
import xformers.ops
|
|
|
|
| 15 |
XFORMERS_IS_AVAILBLE = True
|
| 16 |
except:
|
| 17 |
XFORMERS_IS_AVAILBLE = False
|
| 18 |
|
| 19 |
# CrossAttn precision handling
|
| 20 |
import os
|
|
|
|
| 21 |
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
| 22 |
|
|
|
|
| 23 |
def exists(val):
|
| 24 |
return val is not None
|
| 25 |
|
| 26 |
|
| 27 |
def uniq(arr):
|
| 28 |
-
return{el: True for el in arr}.keys()
|
| 29 |
|
| 30 |
|
| 31 |
def default(val, d):
|
|
@@ -33,6 +36,7 @@ def default(val, d):
|
|
| 33 |
return val
|
| 34 |
return d() if isfunction(d) else d
|
| 35 |
|
|
|
|
| 36 |
class GEGLU(nn.Module):
|
| 37 |
def __init__(self, dim_in, dim_out):
|
| 38 |
super().__init__()
|
|
@@ -110,12 +114,12 @@ class SpatialSelfAttention(nn.Module):
|
|
| 110 |
k = self.k(h_)
|
| 111 |
v = self.v(h_)
|
| 112 |
|
| 113 |
-
b,c,h,w = q.shape
|
| 114 |
q = rearrange(q, 'b c h w -> b (h w) c')
|
| 115 |
k = rearrange(k, 'b c h w -> b c (h w)')
|
| 116 |
w_ = torch.einsum('bij,bjk->bik', q, k)
|
| 117 |
|
| 118 |
-
w_ = w_ * (int(c)**(-0.5))
|
| 119 |
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 120 |
|
| 121 |
v = rearrange(v, 'b c h w -> b c (h w)')
|
|
@@ -124,7 +128,8 @@ class SpatialSelfAttention(nn.Module):
|
|
| 124 |
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
| 125 |
h_ = self.proj_out(h_)
|
| 126 |
|
| 127 |
-
return x+h_
|
|
|
|
| 128 |
|
| 129 |
class CrossAttention(nn.Module):
|
| 130 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., **kwargs):
|
|
@@ -143,7 +148,6 @@ class CrossAttention(nn.Module):
|
|
| 143 |
nn.Linear(inner_dim, query_dim),
|
| 144 |
nn.Dropout(dropout)
|
| 145 |
)
|
| 146 |
-
|
| 147 |
|
| 148 |
def forward(self, x, context=None, mask=None):
|
| 149 |
h = self.heads
|
|
@@ -153,26 +157,27 @@ class CrossAttention(nn.Module):
|
|
| 153 |
v = self.to_v(context)
|
| 154 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
| 155 |
|
| 156 |
-
if _ATTN_PRECISION =="fp32":
|
| 157 |
-
with torch.autocast(enabled=False, device_type
|
| 158 |
q, k = q.float(), k.float()
|
| 159 |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
| 160 |
else:
|
| 161 |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
| 162 |
-
|
| 163 |
del q, k
|
| 164 |
if exists(mask):
|
| 165 |
mask = rearrange(mask, 'b ... -> b (...)')
|
| 166 |
max_neg_value = -torch.finfo(sim.dtype).max
|
| 167 |
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
| 168 |
sim.masked_fill_(~mask, max_neg_value)
|
| 169 |
-
|
| 170 |
-
sim = sim.softmax(dim=-1)
|
| 171 |
-
|
| 172 |
out = einsum('b i j, b j d -> b i d', sim, v)
|
| 173 |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
| 174 |
return self.to_out(out)
|
| 175 |
|
|
|
|
| 176 |
class MemoryEfficientCrossAttention(nn.Module):
|
| 177 |
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
| 178 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, zero_init=False, **kwargs):
|
|
@@ -195,7 +200,6 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|
| 195 |
|
| 196 |
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
| 197 |
self.attention_op: Optional[Any] = None
|
| 198 |
-
|
| 199 |
|
| 200 |
def forward(self, x, context=None, mask=None, **kwargs):
|
| 201 |
q = self.to_q(x)
|
|
@@ -221,23 +225,25 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|
| 221 |
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
| 222 |
)
|
| 223 |
return self.to_out(out)
|
| 224 |
-
|
|
|
|
| 225 |
class BasicTransformerBlock(nn.Module):
|
| 226 |
ATTENTION_MODES = {
|
| 227 |
"softmax": CrossAttention, # vanilla attention
|
| 228 |
"softmax-xformers": MemoryEfficientCrossAttention
|
| 229 |
}
|
|
|
|
| 230 |
def __init__(
|
| 231 |
-
self,
|
| 232 |
-
dim,
|
| 233 |
-
n_heads,
|
| 234 |
-
d_head,
|
| 235 |
-
dropout=0.,
|
| 236 |
-
context_dim=None,
|
| 237 |
-
gated_ff=True,
|
| 238 |
checkpoint=True,
|
| 239 |
disable_self_attn=False
|
| 240 |
-
|
| 241 |
super().__init__()
|
| 242 |
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
| 243 |
assert attn_mode in self.ATTENTION_MODES
|
|
@@ -247,24 +253,25 @@ class BasicTransformerBlock(nn.Module):
|
|
| 247 |
context_dim=context_dim if self.disable_self_attn else None)
|
| 248 |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
| 249 |
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
| 250 |
-
heads=n_heads, dim_head=d_head, dropout=dropout)
|
| 251 |
self.norm1 = nn.LayerNorm(dim)
|
| 252 |
self.norm2 = nn.LayerNorm(dim)
|
| 253 |
self.norm3 = nn.LayerNorm(dim)
|
| 254 |
self.checkpoint = checkpoint
|
| 255 |
|
| 256 |
-
def forward(self, x, context=None,hint=None):
|
| 257 |
if hint is None:
|
| 258 |
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
| 259 |
else:
|
| 260 |
return checkpoint(self._forward, (x, context, hint), self.parameters(), self.checkpoint)
|
| 261 |
|
| 262 |
-
def _forward(self, x, context=None,hint=None):
|
| 263 |
-
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None,hint=hint) + x
|
| 264 |
x = self.attn2(self.norm2(x), context=context) + x
|
| 265 |
x = self.ff(self.norm3(x)) + x
|
| 266 |
return x
|
| 267 |
|
|
|
|
| 268 |
class SpatialTransformer(nn.Module):
|
| 269 |
"""
|
| 270 |
Transformer block for image-like data.
|
|
@@ -274,6 +281,7 @@ class SpatialTransformer(nn.Module):
|
|
| 274 |
Finally, reshape to image
|
| 275 |
NEW: use_linear for more efficiency instead of the 1x1 convs
|
| 276 |
"""
|
|
|
|
| 277 |
def __init__(self, in_channels, n_heads, d_head,
|
| 278 |
depth=1, dropout=0., context_dim=None,
|
| 279 |
disable_self_attn=False, use_linear=False,
|
|
@@ -296,7 +304,7 @@ class SpatialTransformer(nn.Module):
|
|
| 296 |
self.transformer_blocks = nn.ModuleList(
|
| 297 |
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
| 298 |
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
|
| 299 |
-
|
| 300 |
)
|
| 301 |
if not use_linear:
|
| 302 |
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
|
@@ -308,7 +316,7 @@ class SpatialTransformer(nn.Module):
|
|
| 308 |
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
| 309 |
self.use_linear = use_linear
|
| 310 |
|
| 311 |
-
def forward(self, x, context=None,hint=None):
|
| 312 |
# note: if no context is given, cross-attention defaults to self-attention
|
| 313 |
if not isinstance(context, list):
|
| 314 |
context = [context]
|
|
@@ -321,10 +329,10 @@ class SpatialTransformer(nn.Module):
|
|
| 321 |
if self.use_linear:
|
| 322 |
x = self.proj_in(x)
|
| 323 |
for i, block in enumerate(self.transformer_blocks):
|
| 324 |
-
x = block(x, context=context[i]
|
| 325 |
if self.use_linear:
|
| 326 |
x = self.proj_out(x)
|
| 327 |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
| 328 |
if not self.use_linear:
|
| 329 |
x = self.proj_out(x)
|
| 330 |
-
return x + x_in
|
|
|
|
| 12 |
try:
|
| 13 |
import xformers
|
| 14 |
import xformers.ops
|
| 15 |
+
|
| 16 |
XFORMERS_IS_AVAILBLE = True
|
| 17 |
except:
|
| 18 |
XFORMERS_IS_AVAILBLE = False
|
| 19 |
|
| 20 |
# CrossAttn precision handling
|
| 21 |
import os
|
| 22 |
+
|
| 23 |
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
| 24 |
|
| 25 |
+
|
| 26 |
def exists(val):
|
| 27 |
return val is not None
|
| 28 |
|
| 29 |
|
| 30 |
def uniq(arr):
|
| 31 |
+
return {el: True for el in arr}.keys()
|
| 32 |
|
| 33 |
|
| 34 |
def default(val, d):
|
|
|
|
| 36 |
return val
|
| 37 |
return d() if isfunction(d) else d
|
| 38 |
|
| 39 |
+
|
| 40 |
class GEGLU(nn.Module):
|
| 41 |
def __init__(self, dim_in, dim_out):
|
| 42 |
super().__init__()
|
|
|
|
| 114 |
k = self.k(h_)
|
| 115 |
v = self.v(h_)
|
| 116 |
|
| 117 |
+
b, c, h, w = q.shape
|
| 118 |
q = rearrange(q, 'b c h w -> b (h w) c')
|
| 119 |
k = rearrange(k, 'b c h w -> b c (h w)')
|
| 120 |
w_ = torch.einsum('bij,bjk->bik', q, k)
|
| 121 |
|
| 122 |
+
w_ = w_ * (int(c) ** (-0.5))
|
| 123 |
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 124 |
|
| 125 |
v = rearrange(v, 'b c h w -> b c (h w)')
|
|
|
|
| 128 |
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
| 129 |
h_ = self.proj_out(h_)
|
| 130 |
|
| 131 |
+
return x + h_
|
| 132 |
+
|
| 133 |
|
| 134 |
class CrossAttention(nn.Module):
|
| 135 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., **kwargs):
|
|
|
|
| 148 |
nn.Linear(inner_dim, query_dim),
|
| 149 |
nn.Dropout(dropout)
|
| 150 |
)
|
|
|
|
| 151 |
|
| 152 |
def forward(self, x, context=None, mask=None):
|
| 153 |
h = self.heads
|
|
|
|
| 157 |
v = self.to_v(context)
|
| 158 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
| 159 |
|
| 160 |
+
if _ATTN_PRECISION == "fp32":
|
| 161 |
+
with torch.autocast(enabled=False, device_type='cuda'):
|
| 162 |
q, k = q.float(), k.float()
|
| 163 |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
| 164 |
else:
|
| 165 |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
| 166 |
+
|
| 167 |
del q, k
|
| 168 |
if exists(mask):
|
| 169 |
mask = rearrange(mask, 'b ... -> b (...)')
|
| 170 |
max_neg_value = -torch.finfo(sim.dtype).max
|
| 171 |
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
| 172 |
sim.masked_fill_(~mask, max_neg_value)
|
| 173 |
+
|
| 174 |
+
sim = sim.softmax(dim=-1)
|
| 175 |
+
|
| 176 |
out = einsum('b i j, b j d -> b i d', sim, v)
|
| 177 |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
| 178 |
return self.to_out(out)
|
| 179 |
|
| 180 |
+
|
| 181 |
class MemoryEfficientCrossAttention(nn.Module):
|
| 182 |
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
| 183 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, zero_init=False, **kwargs):
|
|
|
|
| 200 |
|
| 201 |
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
| 202 |
self.attention_op: Optional[Any] = None
|
|
|
|
| 203 |
|
| 204 |
def forward(self, x, context=None, mask=None, **kwargs):
|
| 205 |
q = self.to_q(x)
|
|
|
|
| 225 |
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
| 226 |
)
|
| 227 |
return self.to_out(out)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
class BasicTransformerBlock(nn.Module):
|
| 231 |
ATTENTION_MODES = {
|
| 232 |
"softmax": CrossAttention, # vanilla attention
|
| 233 |
"softmax-xformers": MemoryEfficientCrossAttention
|
| 234 |
}
|
| 235 |
+
|
| 236 |
def __init__(
|
| 237 |
+
self,
|
| 238 |
+
dim,
|
| 239 |
+
n_heads,
|
| 240 |
+
d_head,
|
| 241 |
+
dropout=0.,
|
| 242 |
+
context_dim=None,
|
| 243 |
+
gated_ff=True,
|
| 244 |
checkpoint=True,
|
| 245 |
disable_self_attn=False
|
| 246 |
+
):
|
| 247 |
super().__init__()
|
| 248 |
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
| 249 |
assert attn_mode in self.ATTENTION_MODES
|
|
|
|
| 253 |
context_dim=context_dim if self.disable_self_attn else None)
|
| 254 |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
| 255 |
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
| 256 |
+
heads=n_heads, dim_head=d_head, dropout=dropout)
|
| 257 |
self.norm1 = nn.LayerNorm(dim)
|
| 258 |
self.norm2 = nn.LayerNorm(dim)
|
| 259 |
self.norm3 = nn.LayerNorm(dim)
|
| 260 |
self.checkpoint = checkpoint
|
| 261 |
|
| 262 |
+
def forward(self, x, context=None, hint=None):
|
| 263 |
if hint is None:
|
| 264 |
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
| 265 |
else:
|
| 266 |
return checkpoint(self._forward, (x, context, hint), self.parameters(), self.checkpoint)
|
| 267 |
|
| 268 |
+
def _forward(self, x, context=None, hint=None):
|
| 269 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, hint=hint) + x
|
| 270 |
x = self.attn2(self.norm2(x), context=context) + x
|
| 271 |
x = self.ff(self.norm3(x)) + x
|
| 272 |
return x
|
| 273 |
|
| 274 |
+
|
| 275 |
class SpatialTransformer(nn.Module):
|
| 276 |
"""
|
| 277 |
Transformer block for image-like data.
|
|
|
|
| 281 |
Finally, reshape to image
|
| 282 |
NEW: use_linear for more efficiency instead of the 1x1 convs
|
| 283 |
"""
|
| 284 |
+
|
| 285 |
def __init__(self, in_channels, n_heads, d_head,
|
| 286 |
depth=1, dropout=0., context_dim=None,
|
| 287 |
disable_self_attn=False, use_linear=False,
|
|
|
|
| 304 |
self.transformer_blocks = nn.ModuleList(
|
| 305 |
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
| 306 |
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
|
| 307 |
+
for d in range(depth)]
|
| 308 |
)
|
| 309 |
if not use_linear:
|
| 310 |
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
|
|
|
| 316 |
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
| 317 |
self.use_linear = use_linear
|
| 318 |
|
| 319 |
+
def forward(self, x, context=None, hint=None):
|
| 320 |
# note: if no context is given, cross-attention defaults to self-attention
|
| 321 |
if not isinstance(context, list):
|
| 322 |
context = [context]
|
|
|
|
| 329 |
if self.use_linear:
|
| 330 |
x = self.proj_in(x)
|
| 331 |
for i, block in enumerate(self.transformer_blocks):
|
| 332 |
+
x = block(x, context=context[i])
|
| 333 |
if self.use_linear:
|
| 334 |
x = self.proj_out(x)
|
| 335 |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
| 336 |
if not self.use_linear:
|
| 337 |
x = self.proj_out(x)
|
| 338 |
+
return x + x_in
|