Spaces:
Sleeping
Sleeping
util
Browse files- utils/helper.py +259 -0
- utils/logger.py +12 -0
- utils/mask_generator.py +198 -0
utils/helper.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
from ldm.util import default
|
| 6 |
+
import glob
|
| 7 |
+
import PIL
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
|
| 10 |
+
def load_file(filename):
|
| 11 |
+
with open(filename , 'rb') as file:
|
| 12 |
+
x = pickle.load(file)
|
| 13 |
+
return x
|
| 14 |
+
|
| 15 |
+
def save_file(filename, x, mode="wb"):
|
| 16 |
+
with open(filename, mode) as file:
|
| 17 |
+
pickle.dump(x, file)
|
| 18 |
+
|
| 19 |
+
def normalize_np(img):
|
| 20 |
+
""" Normalize img in arbitrary range to [0, 1] """
|
| 21 |
+
img -= np.min(img)
|
| 22 |
+
img /= np.max(img)
|
| 23 |
+
return img
|
| 24 |
+
|
| 25 |
+
def clear_color(x):
|
| 26 |
+
if torch.is_complex(x):
|
| 27 |
+
x = torch.abs(x)
|
| 28 |
+
x = x.detach().cpu().squeeze().numpy()
|
| 29 |
+
return normalize_np(np.transpose(x, (1, 2, 0)))
|
| 30 |
+
|
| 31 |
+
def to_img(sample):
|
| 32 |
+
return (sample.detach().cpu().numpy().transpose(0,2,3,1) * 127.5 + 128).clip(0, 255)
|
| 33 |
+
|
| 34 |
+
def save_plot(dir_name, tensors, labels, file_name="loss.png"):
|
| 35 |
+
t = np.linspace(0, len(tensors[0]), len(tensors[0]))
|
| 36 |
+
colours = ["r", "b", "g"]
|
| 37 |
+
plt.figure()
|
| 38 |
+
for j in range(len(tensors)):
|
| 39 |
+
plt.plot(t, tensors[j],color = colours[j], label = labels[j])
|
| 40 |
+
plt.legend()
|
| 41 |
+
plt.savefig(os.path.join(dir_name, file_name))
|
| 42 |
+
#plt.show()
|
| 43 |
+
|
| 44 |
+
def save_samples(dir_name, sample, k=None, num_to_save = 5, file_name = None):
|
| 45 |
+
if type(sample) is not np.ndarray: sample_np = to_img(sample).astype(np.uint8)
|
| 46 |
+
else: sample_np = sample.astype(np.uint8)
|
| 47 |
+
|
| 48 |
+
for j in range(num_to_save):
|
| 49 |
+
if file_name is None:
|
| 50 |
+
if k is not None: file_name_img = f'sample_{k+1}'f'{j}.png'
|
| 51 |
+
else: file_name_img = f'{j}.png'
|
| 52 |
+
else: file_name_img = file_name
|
| 53 |
+
image_path = os.path.join(dir_name,file_name_img)
|
| 54 |
+
image_np = sample_np[j]
|
| 55 |
+
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
|
| 56 |
+
file_name_img = None
|
| 57 |
+
|
| 58 |
+
def save_inpaintings(dir_name, sample, y, mask_pixel, k=None, num_to_save = 5, file_name = None):
|
| 59 |
+
recon_in = y*(mask_pixel) + ( 1-mask_pixel)*sample
|
| 60 |
+
recon_in = to_img(recon_in)
|
| 61 |
+
for j in range(num_to_save):
|
| 62 |
+
if file_name is None:
|
| 63 |
+
if k is not None: file_name_img = f'sample_{k+1}'f'{j}.png'
|
| 64 |
+
else: file_name_img = f'{j}.png'
|
| 65 |
+
else: file_name_img = file_name
|
| 66 |
+
image_path = os.path.join(dir_name, file_name_img)
|
| 67 |
+
image_np = recon_in.astype(np.uint8)[j]
|
| 68 |
+
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
|
| 69 |
+
file_name_img = None
|
| 70 |
+
|
| 71 |
+
def save_params(dir_name, mu_pos, logvar_pos, gamma,k):
|
| 72 |
+
params_to_fit = params_untrain([mu_pos.detach().cpu(), logvar_pos.detach().cpu(), gamma.detach().cpu()])
|
| 73 |
+
params_path = os.path.join(dir_name, f'{k+1}.pt')
|
| 74 |
+
torch.save(params_to_fit, params_path)
|
| 75 |
+
|
| 76 |
+
def custom_to_np(img):
|
| 77 |
+
sample = img.detach().cpu()
|
| 78 |
+
#sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
|
| 79 |
+
#sample = sample.permute(0, 2, 3, 1)
|
| 80 |
+
sample = sample.contiguous()
|
| 81 |
+
return sample
|
| 82 |
+
|
| 83 |
+
def encoder_kl(diff, img):
|
| 84 |
+
_, params = diff.encode_first_stage(img, return_all = True)
|
| 85 |
+
params = diff.scale_factor * params
|
| 86 |
+
mean, logvar = torch.chunk(params, 2, dim=1)
|
| 87 |
+
noise = default(None, lambda: torch.randn_like(mean))
|
| 88 |
+
mean = mean + diff.scale_factor*noise
|
| 89 |
+
return mean, logvar
|
| 90 |
+
|
| 91 |
+
def encoder_vq(diff, img):
|
| 92 |
+
quant = diff.encode_first_stage(img) #, diff, (_,_,ind)
|
| 93 |
+
quant = diff.scale_factor * quant
|
| 94 |
+
#mean, logvar = torch.chunk(params, 2, dim=1)
|
| 95 |
+
noise = default(None, lambda: torch.randn_like(quant))
|
| 96 |
+
mean = quant + diff.scale_factor*noise #
|
| 97 |
+
return mean
|
| 98 |
+
|
| 99 |
+
def clean_directory(dir_name):
|
| 100 |
+
files = glob.glob(dir_name)
|
| 101 |
+
for f in files:
|
| 102 |
+
os.remove(f)
|
| 103 |
+
|
| 104 |
+
def params_train( params ):
|
| 105 |
+
for item in params:
|
| 106 |
+
item.requires_grad = True
|
| 107 |
+
return params
|
| 108 |
+
|
| 109 |
+
def params_untrain(params):
|
| 110 |
+
for item in params:
|
| 111 |
+
item.requires_grad = False
|
| 112 |
+
return params
|
| 113 |
+
|
| 114 |
+
def time_descretization(sigma_min=0.002, sigma_max = 80, rho = 7, num_t_steps = 18):
|
| 115 |
+
step_indices = torch.arange(num_t_steps, dtype=torch.float64).cuda()
|
| 116 |
+
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_t_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
| 117 |
+
inv_idx = torch.arange(num_t_steps -1, -1, -1).long()
|
| 118 |
+
t_steps_fwd = t_steps[inv_idx]
|
| 119 |
+
#t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
| 120 |
+
return t_steps_fwd
|
| 121 |
+
|
| 122 |
+
def get_optimizers(means, variances, gamma_param, lr_init_gamma=0.01) :
|
| 123 |
+
[lr, step_size, gamma] = [0.1, 10, 0.99] #was 0.999 for right-half: [0.01, 10, 0.99]
|
| 124 |
+
optimizer = torch.optim.Adam([means], lr=lr, betas=(0.9, 0.99))
|
| 125 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
|
| 126 |
+
|
| 127 |
+
optimizer_2 = torch.optim.Adam([variances], lr=0.001, betas=(0.9, 0.99)) #0.001 for lsun
|
| 128 |
+
optimizer_3 = torch.optim.Adam([gamma_param], lr=lr_init_gamma, betas=(0.9, 0.99)) #0.01
|
| 129 |
+
|
| 130 |
+
scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=step_size, gamma=gamma) ##added this
|
| 131 |
+
scheduler_3 = torch.optim.lr_scheduler.StepLR(optimizer_3, step_size=step_size, gamma=gamma)
|
| 132 |
+
|
| 133 |
+
return [optimizer, optimizer_2, optimizer_3 ], [scheduler, scheduler_2, scheduler_3]
|
| 134 |
+
|
| 135 |
+
def check_directory(filename_list):
|
| 136 |
+
for filename in filename_list:
|
| 137 |
+
if not os.path.exists(filename):
|
| 138 |
+
os.mkdir(filename)
|
| 139 |
+
|
| 140 |
+
def s_file(filename, x, mode="wb"):
|
| 141 |
+
with open(filename, mode) as file:
|
| 142 |
+
pickle.dump(x, file)
|
| 143 |
+
|
| 144 |
+
def r_file(filename, mode="rb"):
|
| 145 |
+
with open(filename, mode) as file:
|
| 146 |
+
x = pickle.load(file)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
def sample_from_gaussian(mu, alpha, sigma):
|
| 150 |
+
noise = torch.randn_like(mu)
|
| 151 |
+
return alpha*mu + sigma * noise
|
| 152 |
+
|
| 153 |
+
'''
|
| 154 |
+
def make_batch(image, mask=None, device=None):
|
| 155 |
+
image = torch.permute(image, (0,3,1,2))
|
| 156 |
+
batch_size = image.shape[0]
|
| 157 |
+
if mask is None :
|
| 158 |
+
mask = torch.zeros_like(image)
|
| 159 |
+
mask[0, :, :256, :128] = 1
|
| 160 |
+
else :
|
| 161 |
+
mask = torch.tensor(mask)
|
| 162 |
+
masked_image = (mask)*image #+ mask*noise*0.2
|
| 163 |
+
mask = mask[:,0,:,:].reshape(batch_size,1,image.shape[2], image.shape[3])
|
| 164 |
+
batch = {"image": image, "mask": mask, "masked_image": masked_image}
|
| 165 |
+
for k in batch:
|
| 166 |
+
batch[k] = batch[k].to(device)
|
| 167 |
+
return batch
|
| 168 |
+
|
| 169 |
+
def get_sigma_t_steps(net, n_steps=3, kwargs=None):
|
| 170 |
+
sigma_min = kwargs["sigma_min"]
|
| 171 |
+
sigma_max = kwargs["sigma_max"]
|
| 172 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
| 173 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
| 174 |
+
|
| 175 |
+
##Get the time-steps based on iddpm discretization
|
| 176 |
+
num_steps = n_steps #11 # kwargs["num_steps"]
|
| 177 |
+
C_2 = kwargs["C_2"]
|
| 178 |
+
C_1 = kwargs["C_1"]
|
| 179 |
+
M = kwargs["M"]
|
| 180 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64).cuda()
|
| 181 |
+
u = torch.zeros(M + 1, dtype=torch.float64).cuda()
|
| 182 |
+
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
|
| 183 |
+
for j in torch.arange(M, 0, -1, device=step_indices.device): # M, ..., 1
|
| 184 |
+
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
|
| 185 |
+
u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
|
| 186 |
+
sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
|
| 187 |
+
#print(sigma_steps)
|
| 188 |
+
|
| 189 |
+
##get noise schedule
|
| 190 |
+
sigma = lambda t: t
|
| 191 |
+
sigma_deriv = lambda t: 1
|
| 192 |
+
sigma_inv = lambda sigma: sigma
|
| 193 |
+
|
| 194 |
+
##scaling schedule
|
| 195 |
+
s = lambda t: 1
|
| 196 |
+
s_deriv = lambda t: 0
|
| 197 |
+
|
| 198 |
+
##compute some final time steps based on the corresponding noise levels.
|
| 199 |
+
t_steps = sigma_inv(net.round_sigma(sigma_steps))
|
| 200 |
+
|
| 201 |
+
return t_steps, sigma_inv, sigma, s, sigma_deriv
|
| 202 |
+
|
| 203 |
+
def data_replicate(data, K):
|
| 204 |
+
if len(data.shape)==2: data_batch = torch.Tensor.repeat(data,[K,1])
|
| 205 |
+
else: data_batch = torch.Tensor.repeat(data,[K,1,1,1])
|
| 206 |
+
return data_batch
|
| 207 |
+
|
| 208 |
+
'''
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def sample_T(self, x0, eta=0.4, t_steps_hierarchy=None):
|
| 212 |
+
'''
|
| 213 |
+
sigma_discretization_edm = time_descretization(sigma_min=0.002, sigma_max = 999, rho = 7, num_t_steps = 10)/1000
|
| 214 |
+
T_max = 1000
|
| 215 |
+
beta_start = 1 # 0.0015*T_max
|
| 216 |
+
beta_end = 15 # 0.0155*T_max
|
| 217 |
+
def var(t):
|
| 218 |
+
return 1.0 - (1.0) * torch.exp(- beta_start * t - 0.5 * (beta_end - beta_start) * t * t)
|
| 219 |
+
'''
|
| 220 |
+
t_steps_hierarchy = torch.tensor(t_steps_hierarchy).cuda()
|
| 221 |
+
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[0]].reshape(1, 1 ,1 ,1))**2 # self.var(t_steps_hierarchy[0])
|
| 222 |
+
x_t = torch.sqrt(1 - var_t) * x0 + torch.sqrt(var_t) * torch.randn_like(x0)
|
| 223 |
+
|
| 224 |
+
os.makedirs("out_temp2/", exist_ok=True)
|
| 225 |
+
for i, t in enumerate(t_steps_hierarchy):
|
| 226 |
+
t_hat = torch.ones(10).cuda() * (t)
|
| 227 |
+
e_out = self.model.model(x_t, t_hat)
|
| 228 |
+
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t].reshape(1, 1 ,1 ,1))**2
|
| 229 |
+
#score_out = - e_out / torch.sqrt()
|
| 230 |
+
a_t = 1 - var_t
|
| 231 |
+
#beta_t = 1 - a_t/a_prev
|
| 232 |
+
#std_pos = ((1 - a_prev)/(1 - a_t)).sqrt()*torch.sqrt(beta_t)
|
| 233 |
+
pred_x0 = (x_t - torch.sqrt(1 - a_t) * e_out) / a_t.sqrt()
|
| 234 |
+
|
| 235 |
+
if i != len(t_steps_hierarchy) - 1:
|
| 236 |
+
var_t1 = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[i+1]].reshape(1, 1 ,1 ,1))**2
|
| 237 |
+
a_prev = 1 - var_t1 # var(t_steps_hierarchy[i+1]/1000) # torch.full((10, 1, 1, 1), alphas[t_steps_hierarchy[i+1]]).cuda()
|
| 238 |
+
sigma_t = eta * torch.sqrt((1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev))
|
| 239 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_out
|
| 240 |
+
x_t = a_prev.sqrt() * pred_x0 + dir_xt + torch.randn_like(x_t) * sigma_t + sigma_t*torch.randn_like(x_t)
|
| 241 |
+
|
| 242 |
+
#x_t= (x_t - torch.sqrt( 1 - a_t/a_prev) * e_out ) / (a_t/a_prev).sqrt() + std_pos*torch.randn_like(x_t)
|
| 243 |
+
|
| 244 |
+
'''
|
| 245 |
+
def pred_mean(pred_x0, z_t):
|
| 246 |
+
posterior_mean_coef1 = beta_t * torch.sqrt(a_prev) / (1. - a_t)
|
| 247 |
+
posterior_mean_coef2 = (1. - a_prev) * torch.sqrt(a_t/a_prev) / (1. - a_t)
|
| 248 |
+
return posterior_mean_coef1*pred_x0 + posterior_mean_coef2*z_t
|
| 249 |
+
|
| 250 |
+
x_t = torch.sqrt(a_prev) * pred_x0 # pred_mean(pred_x0, x_t) #+ 0.4*torch.sqrt(beta_t) *torch.randn_like(x_t)
|
| 251 |
+
'''
|
| 252 |
+
recon = self.model.decode_first_stage(pred_x0)
|
| 253 |
+
image_path = os.path.join("out_temp2/", f'{i}.png')
|
| 254 |
+
image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
|
| 255 |
+
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
|
| 256 |
+
|
| 257 |
+
return
|
| 258 |
+
|
| 259 |
+
|
utils/logger.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
def get_logger():
|
| 4 |
+
logger = logging.getLogger(name='DPS')
|
| 5 |
+
logger.setLevel(logging.INFO)
|
| 6 |
+
|
| 7 |
+
formatter = logging.Formatter("%(asctime)s [%(name)s] >> %(message)s")
|
| 8 |
+
stream_handler = logging.StreamHandler()
|
| 9 |
+
stream_handler.setFormatter(formatter)
|
| 10 |
+
logger.addHandler(stream_handler)
|
| 11 |
+
|
| 12 |
+
return logger
|
utils/mask_generator.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image, ImageDraw
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import torch
|
| 7 |
+
#import tensorflow as tf
|
| 8 |
+
np.random.seed(10)
|
| 9 |
+
def random_sq_bbox(img, mask_shape, image_size=256, margin=(16, 16)):
|
| 10 |
+
"""Generate a random sqaure mask for inpainting
|
| 11 |
+
"""
|
| 12 |
+
B, H, W, C = img.shape
|
| 13 |
+
h, w = mask_shape
|
| 14 |
+
margin_height, margin_width = margin
|
| 15 |
+
maxt = image_size - margin_height - h
|
| 16 |
+
maxl = image_size - margin_width - w
|
| 17 |
+
|
| 18 |
+
# bb
|
| 19 |
+
t = np.random.randint(margin_height, maxt)
|
| 20 |
+
l = np.random.randint(margin_width, maxl)
|
| 21 |
+
|
| 22 |
+
# make mask
|
| 23 |
+
mask = torch.ones([B, C, H, W], device=img.device)
|
| 24 |
+
mask[..., t:t+h, l:l+w] = 0
|
| 25 |
+
mask = 1 - mask
|
| 26 |
+
#Fixed mid box
|
| 27 |
+
#mask[..., t:t+h, l:l+w] = 0
|
| 28 |
+
return mask, t, t+h, l, l+w
|
| 29 |
+
|
| 30 |
+
def RandomBrush(
|
| 31 |
+
max_tries,
|
| 32 |
+
s,
|
| 33 |
+
min_num_vertex = 4,
|
| 34 |
+
max_num_vertex = 18,
|
| 35 |
+
mean_angle = 2*math.pi / 5,
|
| 36 |
+
angle_range = 2*math.pi / 15,
|
| 37 |
+
min_width = 12,
|
| 38 |
+
max_width = 48):
|
| 39 |
+
H, W = s, s
|
| 40 |
+
average_radius = math.sqrt(H*H+W*W) / 8
|
| 41 |
+
mask = Image.new('L', (W, H), 0)
|
| 42 |
+
for _ in range(np.random.randint(max_tries)):
|
| 43 |
+
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
| 44 |
+
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
| 45 |
+
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
| 46 |
+
angles = []
|
| 47 |
+
vertex = []
|
| 48 |
+
for i in range(num_vertex):
|
| 49 |
+
if i % 2 == 0:
|
| 50 |
+
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
|
| 51 |
+
else:
|
| 52 |
+
angles.append(np.random.uniform(angle_min, angle_max))
|
| 53 |
+
|
| 54 |
+
h, w = mask.size
|
| 55 |
+
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
|
| 56 |
+
for i in range(num_vertex):
|
| 57 |
+
r = np.clip(
|
| 58 |
+
np.random.normal(loc=average_radius, scale=average_radius//2),
|
| 59 |
+
0, 2*average_radius)
|
| 60 |
+
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
| 61 |
+
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
| 62 |
+
vertex.append((int(new_x), int(new_y)))
|
| 63 |
+
|
| 64 |
+
draw = ImageDraw.Draw(mask)
|
| 65 |
+
width = int(np.random.uniform(min_width, max_width))
|
| 66 |
+
draw.line(vertex, fill=1, width=width)
|
| 67 |
+
for v in vertex:
|
| 68 |
+
draw.ellipse((v[0] - width//2,
|
| 69 |
+
v[1] - width//2,
|
| 70 |
+
v[0] + width//2,
|
| 71 |
+
v[1] + width//2),
|
| 72 |
+
fill=1)
|
| 73 |
+
if np.random.random() > 0.5:
|
| 74 |
+
mask.transpose(Image.FLIP_LEFT_RIGHT)
|
| 75 |
+
if np.random.random() > 0.5:
|
| 76 |
+
mask.transpose(Image.FLIP_TOP_BOTTOM)
|
| 77 |
+
mask = np.asarray(mask, np.uint8)
|
| 78 |
+
if np.random.random() > 0.5:
|
| 79 |
+
mask = np.flip(mask, 0)
|
| 80 |
+
if np.random.random() > 0.5:
|
| 81 |
+
mask = np.flip(mask, 1)
|
| 82 |
+
return mask
|
| 83 |
+
|
| 84 |
+
def RandomMask(s, hole_range=[0,1]):
|
| 85 |
+
coef = min(hole_range[0] + hole_range[1], 1.0)
|
| 86 |
+
while True:
|
| 87 |
+
mask = np.ones((s, s), np.uint8)
|
| 88 |
+
def Fill(max_size):
|
| 89 |
+
w, h = np.random.randint(max_size), np.random.randint(max_size)
|
| 90 |
+
ww, hh = w // 2, h // 2
|
| 91 |
+
x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh)
|
| 92 |
+
mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0
|
| 93 |
+
def MultiFill(max_tries, max_size):
|
| 94 |
+
for _ in range(np.random.randint(max_tries)):
|
| 95 |
+
Fill(max_size)
|
| 96 |
+
MultiFill(int(10 * coef), s // 2)
|
| 97 |
+
MultiFill(int(5 * coef), s)
|
| 98 |
+
##comment the following line for lower masking ratios
|
| 99 |
+
#mask = np.logical_and(mask, 1 - RandomBrush(int(20 * coef), s))
|
| 100 |
+
hole_ratio = 1 - np.mean(mask)
|
| 101 |
+
if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]):
|
| 102 |
+
continue
|
| 103 |
+
return mask[np.newaxis, ...].astype(np.float32)
|
| 104 |
+
|
| 105 |
+
def BatchRandomMask(batch_size, s, hole_range=[0, 1]):
|
| 106 |
+
return np.stack([RandomMask(s, hole_range=hole_range) for _ in range(batch_size)], axis = 0)
|
| 107 |
+
|
| 108 |
+
def random_rotation(shape):
|
| 109 |
+
cutoff = 100 #was 30
|
| 110 |
+
(n , channels, p, q) = shape
|
| 111 |
+
mask = np.zeros((n,p,q))
|
| 112 |
+
|
| 113 |
+
for i in range(n):
|
| 114 |
+
angle = np.random.choice(360, 1)
|
| 115 |
+
mask_one = np.ones((p+cutoff,q+cutoff))
|
| 116 |
+
mask_one[int((p+cutoff)/2):,:] = 0
|
| 117 |
+
|
| 118 |
+
im = Image.fromarray(mask_one)
|
| 119 |
+
im = im.rotate(angle)
|
| 120 |
+
|
| 121 |
+
left = (p+cutoff - p)/2
|
| 122 |
+
top = (q+cutoff - q)/2
|
| 123 |
+
right = (p+cutoff + p)/2
|
| 124 |
+
bottom = (q+cutoff + q)/2
|
| 125 |
+
|
| 126 |
+
# Crop the center of the image
|
| 127 |
+
im = im.crop((left, top, right, bottom))
|
| 128 |
+
|
| 129 |
+
mask[i] = np.array(im)
|
| 130 |
+
|
| 131 |
+
#mask = np.repeat(mask.reshape([n,1,p,q]), channels, axis=1)
|
| 132 |
+
mask = mask.reshape([n,1,p,q])
|
| 133 |
+
return mask
|
| 134 |
+
|
| 135 |
+
class mask_generator:
|
| 136 |
+
def __init__(self, mask_type, mask_len_range=None, mask_prob_range=None,
|
| 137 |
+
image_size=256, margin=(16, 16)):
|
| 138 |
+
"""
|
| 139 |
+
(mask_len_range): given in (min, max) tuple.
|
| 140 |
+
Specifies the range of box size in each dimension
|
| 141 |
+
(mask_prob_range): for the case of random masking,
|
| 142 |
+
specify the probability of individual pixels being masked
|
| 143 |
+
"""
|
| 144 |
+
assert mask_type in ['box', 'random', 'half', 'extreme']
|
| 145 |
+
self.mask_type = mask_type
|
| 146 |
+
self.mask_len_range = mask_len_range
|
| 147 |
+
self.mask_prob_range = mask_prob_range
|
| 148 |
+
self.image_size = image_size
|
| 149 |
+
self.margin = margin
|
| 150 |
+
|
| 151 |
+
def _retrieve_box(self, img):
|
| 152 |
+
l, h = self.mask_len_range
|
| 153 |
+
l, h = int(l), int(h)
|
| 154 |
+
mask_h = np.random.randint(l, h)
|
| 155 |
+
mask_w = np.random.randint(l, h)
|
| 156 |
+
mask, t, tl, w, wh = random_sq_bbox(img,
|
| 157 |
+
mask_shape=(mask_h, mask_w),
|
| 158 |
+
image_size=self.image_size,
|
| 159 |
+
margin=self.margin)
|
| 160 |
+
return mask, t, tl, w, wh
|
| 161 |
+
|
| 162 |
+
def generate_center_mask(self, shape):
|
| 163 |
+
assert len(shape) == 2
|
| 164 |
+
assert shape[1] % 2 == 0
|
| 165 |
+
center = shape[0] // 2
|
| 166 |
+
center_size = shape[0] // 4
|
| 167 |
+
half_resol = center_size // 2 # for now
|
| 168 |
+
ret = torch.zeros(shape, dtype=torch.float32)
|
| 169 |
+
ret[
|
| 170 |
+
center - half_resol: center + half_resol,
|
| 171 |
+
center - half_resol: center + half_resol,
|
| 172 |
+
] = 1
|
| 173 |
+
ret = ret.unsqueeze(0).unsqueeze(0)
|
| 174 |
+
return ret
|
| 175 |
+
|
| 176 |
+
def __call__(self, img):
|
| 177 |
+
if self.mask_type == 'random':
|
| 178 |
+
mask = BatchRandomMask(1, self.image_size, hole_range=self.mask_prob_range) #self._retrieve_random(img)
|
| 179 |
+
return mask
|
| 180 |
+
elif self.mask_type == "half":
|
| 181 |
+
mask = random_rotation((1, 3, self.image_size, self.image_size))
|
| 182 |
+
elif self.mask_type == 'box':
|
| 183 |
+
#mask, t, th, w, wl = self._retrieve_box(img)
|
| 184 |
+
mask = self.generate_center_mask((self.image_size,self.image_size)) # self._retrieve_box(img)
|
| 185 |
+
return mask #.permute(0,3,1,2)
|
| 186 |
+
elif self.mask_type == 'extreme':
|
| 187 |
+
mask, t, th, w, wl = self._retrieve_box(img)
|
| 188 |
+
mask = 1. - mask
|
| 189 |
+
return mask
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
'''
|
| 193 |
+
def tf_mask_generator(s, tf_hole_range):
|
| 194 |
+
def random_mask_generator(hole_range):
|
| 195 |
+
while True:
|
| 196 |
+
yield RandomMask(s, hole_range=hole_range)
|
| 197 |
+
return tf.data.Dataset.from_generator(random_mask_generator, tf.float32, tf.TensorShape([1, s, s]), (tf_hole_range,))
|
| 198 |
+
'''
|