Upload 2 files
Browse files- contperceptual.py +111 -0
- stable-diffusion-main.zip +3 -0
contperceptual.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LPIPSWithDiscriminator(nn.Module):
|
| 8 |
+
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
| 9 |
+
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
| 10 |
+
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
| 11 |
+
disc_loss="hinge"):
|
| 12 |
+
|
| 13 |
+
super().__init__()
|
| 14 |
+
assert disc_loss in ["hinge", "vanilla"]
|
| 15 |
+
self.kl_weight = kl_weight
|
| 16 |
+
self.pixel_weight = pixelloss_weight
|
| 17 |
+
self.perceptual_loss = LPIPS().eval()
|
| 18 |
+
self.perceptual_weight = perceptual_weight
|
| 19 |
+
# output log variance
|
| 20 |
+
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
| 21 |
+
|
| 22 |
+
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
| 23 |
+
n_layers=disc_num_layers,
|
| 24 |
+
use_actnorm=use_actnorm
|
| 25 |
+
).apply(weights_init)
|
| 26 |
+
self.discriminator_iter_start = disc_start
|
| 27 |
+
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
| 28 |
+
self.disc_factor = disc_factor
|
| 29 |
+
self.discriminator_weight = disc_weight
|
| 30 |
+
self.disc_conditional = disc_conditional
|
| 31 |
+
|
| 32 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
| 33 |
+
if last_layer is not None:
|
| 34 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
| 35 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
| 36 |
+
else:
|
| 37 |
+
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
| 38 |
+
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
| 39 |
+
|
| 40 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
| 41 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
| 42 |
+
d_weight = d_weight * self.discriminator_weight
|
| 43 |
+
return d_weight
|
| 44 |
+
|
| 45 |
+
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
|
| 46 |
+
global_step, last_layer=None, cond=None, split="train",
|
| 47 |
+
weights=None):
|
| 48 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
| 49 |
+
if self.perceptual_weight > 0:
|
| 50 |
+
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
| 51 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
| 52 |
+
|
| 53 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
| 54 |
+
weighted_nll_loss = nll_loss
|
| 55 |
+
if weights is not None:
|
| 56 |
+
weighted_nll_loss = weights*nll_loss
|
| 57 |
+
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
| 58 |
+
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
| 59 |
+
kl_loss = posteriors.kl()
|
| 60 |
+
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
| 61 |
+
|
| 62 |
+
# now the GAN part
|
| 63 |
+
if optimizer_idx == 0:
|
| 64 |
+
# generator update
|
| 65 |
+
if cond is None:
|
| 66 |
+
assert not self.disc_conditional
|
| 67 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
| 68 |
+
else:
|
| 69 |
+
assert self.disc_conditional
|
| 70 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
| 71 |
+
g_loss = -torch.mean(logits_fake)
|
| 72 |
+
|
| 73 |
+
if self.disc_factor > 0.0:
|
| 74 |
+
try:
|
| 75 |
+
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
| 76 |
+
except RuntimeError:
|
| 77 |
+
assert not self.training
|
| 78 |
+
d_weight = torch.tensor(0.0)
|
| 79 |
+
else:
|
| 80 |
+
d_weight = torch.tensor(0.0)
|
| 81 |
+
|
| 82 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
| 83 |
+
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
|
| 84 |
+
|
| 85 |
+
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
|
| 86 |
+
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
|
| 87 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
| 88 |
+
"{}/d_weight".format(split): d_weight.detach(),
|
| 89 |
+
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
| 90 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
| 91 |
+
}
|
| 92 |
+
return loss, log
|
| 93 |
+
|
| 94 |
+
if optimizer_idx == 1:
|
| 95 |
+
# second pass for discriminator update
|
| 96 |
+
if cond is None:
|
| 97 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
| 98 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
| 99 |
+
else:
|
| 100 |
+
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
| 101 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
| 102 |
+
|
| 103 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
| 104 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
| 105 |
+
|
| 106 |
+
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
| 107 |
+
"{}/logits_real".format(split): logits_real.detach().mean(),
|
| 108 |
+
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
| 109 |
+
}
|
| 110 |
+
return d_loss, log
|
| 111 |
+
|
stable-diffusion-main.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:af742cf392f1b274a4f348a119c6709cb3dd24385c145eee7be3e814fd48ce47
|
| 3 |
+
size 44586516
|