| |
|
|
| import torch |
| import cv2 |
| from torchvision import transforms |
| import numpy as np |
| import math |
|
|
|
|
| def visual(output, out_path): |
| output = (output + 1)/2 |
| output = torch.clamp(output, 0, 1) |
| if output.shape[1] == 1: |
| output = torch.cat([output, output, output], 1) |
| output = output[0].detach().cpu().permute(1,2,0).numpy() |
| output = (output*255).astype(np.uint8) |
| output = output[:,:,::-1] |
| cv2.imwrite(out_path, output) |
| |
| |
| def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): |
| |
| lr_ramp = min(1, (1 - t) / rampdown) |
| lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) |
| lr_ramp = lr_ramp * min(1, t / rampup) |
| return initial_lr * lr_ramp |
|
|
|
|
|
|
| def latent_noise(latent, strength): |
| noise = torch.randn_like(latent) * strength |
|
|
| return latent + noise |
|
|
| def noise_regularize_(noises): |
| loss = 0 |
|
|
| for noise in noises: |
| size = noise.shape[2] |
|
|
| while True: |
| loss = ( |
| loss |
| + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) |
| + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) |
| ) |
|
|
| if size <= 8: |
| break |
|
|
| noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) |
| noise = noise.mean([3, 5]) |
| size //= 2 |
|
|
| return loss |
|
|
|
|
| def noise_normalize_(noises): |
| for noise in noises: |
| mean = noise.mean() |
| std = noise.std() |
|
|
| noise.data.add_(-mean).div_(std) |
| |
|
|
| def tensor_to_numpy(x): |
| x = x[0].permute(1, 2, 0) |
| x = torch.clamp(x, -1 ,1) |
| x = (x+1) * 127.5 |
| x = x.cpu().detach().numpy().astype(np.uint8) |
| return x |
|
|
| def numpy_to_tensor(x): |
| x = (x / 255 - 0.5) * 2 |
| x = torch.from_numpy(x).unsqueeze(0).permute(0, 3, 1, 2) |
| x = x.cuda().float() |
| return x |
|
|
| def tensor_to_pil(x): |
| x = torch.clamp(x, -1 ,1) |
| x = (x+1) * 127.5 |
| return transforms.ToPILImage()(x.squeeze_(0)) |
|
|
|
|