Spaces:
Runtime error
Runtime error
| import torch | |
| import tops | |
| import numpy as np | |
| from sg3_torch_utils.ops import conv2d_gradfix | |
| pl_mean_total = torch.zeros([]) | |
| class PLRegularization: | |
| def __init__(self, weight: float, batch_shrink: int, pl_decay: float, scale_by_mask: bool, **kwargs): | |
| self.pl_mean = torch.zeros([], device=tops.get_device()) | |
| self.pl_weight = weight | |
| self.batch_shrink = batch_shrink | |
| self.pl_decay = pl_decay | |
| self.scale_by_mask = scale_by_mask | |
| def __call__(self, G, batch, grad_scaler): | |
| batch_size = batch["img"].shape[0] // self.batch_shrink | |
| batch = {k: v[:batch_size] for k, v in batch.items() if k != "embed_map"} | |
| if "embed_map" in batch: | |
| batch["embed_map"] = batch["embed_map"] | |
| z = G.get_z(batch["img"]) | |
| with torch.cuda.amp.autocast(tops.AMP()): | |
| gen_ws = G.style_net(z) | |
| gen_img = G(**batch, w=gen_ws)["img"].float() | |
| pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) | |
| with conv2d_gradfix.no_weight_gradients(): | |
| # Sums over HWC | |
| pl_grads = torch.autograd.grad( | |
| outputs=[grad_scaler.scale(gen_img * pl_noise)], | |
| inputs=[gen_ws], | |
| create_graph=True, | |
| grad_outputs=torch.ones_like(gen_img), | |
| only_inputs=True)[0] | |
| pl_grads = pl_grads.float() / grad_scaler.get_scale() | |
| if self.scale_by_mask: | |
| # Percentage of pixels known | |
| scaling = batch["mask"].flatten(start_dim=1).mean(dim=1).view(-1, 1) | |
| pl_grads = pl_grads / scaling | |
| pl_lengths = pl_grads.square().sum(1).sqrt() | |
| pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) | |
| if not torch.isnan(pl_mean).any(): | |
| self.pl_mean.copy_(pl_mean.detach()) | |
| pl_penalty = (pl_lengths - pl_mean).square() | |
| to_log = dict(pl_penalty=pl_penalty.mean().detach()) | |
| return pl_penalty.view(-1) * self.pl_weight, to_log | |