|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
import einops
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
def get_mask(idx, array):
|
|
|
'''
|
|
|
array: b m, records # of elements to be masked
|
|
|
'''
|
|
|
b, m = array.shape
|
|
|
n = idx.size(-1)
|
|
|
A = torch.arange(n, dtype=idx.dtype, device=idx.device).unsqueeze(0).unsqueeze(0).expand(b, m, n)
|
|
|
mask = A < array.unsqueeze(-1)
|
|
|
return mask
|
|
|
|
|
|
|
|
|
def alloc(var, rest, budget, tp, maximum, times=0, fast=False):
|
|
|
'''
|
|
|
var: (b m) variance of each pixel POSITIVE VALUE
|
|
|
rest: (b m) list of already allocated budgets
|
|
|
budget: (b) remaining to be allocated
|
|
|
tp: mean type, plain/softmax
|
|
|
maximum: maximum budget for each pixel
|
|
|
'''
|
|
|
b, m = var.shape
|
|
|
if tp == 'plain':
|
|
|
var_p = var * (rest < maximum)
|
|
|
var_sum = var_p.sum(dim=-1, keepdim=True)
|
|
|
proportion = var_p / var_sum
|
|
|
elif tp == 'softmax':
|
|
|
var_p = var.clone()
|
|
|
var_p[rest >= maximum] = -float('inf')
|
|
|
proportion = torch.nn.functional.softmax(var_p, dim=-1)
|
|
|
allocation = torch.round(proportion * budget.unsqueeze(1))
|
|
|
new_rest = torch.clamp(rest + allocation, 0, maximum)
|
|
|
remain_budget = budget - (new_rest - rest).sum(dim=-1)
|
|
|
negative_remain = (remain_budget < 0)
|
|
|
while negative_remain.sum() > 0:
|
|
|
offset = torch.eye(m, device=rest.device)[
|
|
|
torch.randint(m, (negative_remain.sum().int().item(),), device=rest.device)]
|
|
|
new_rest[negative_remain] = torch.clamp(new_rest[negative_remain] - offset, 1, maximum)
|
|
|
|
|
|
|
|
|
remain_budget = budget - (new_rest - rest).sum(dim=-1)
|
|
|
negative_remain = (remain_budget < 0)
|
|
|
|
|
|
if (remain_budget > 0).sum() > 0:
|
|
|
if times < 3:
|
|
|
new_rest[remain_budget > 0] = alloc(var[remain_budget > 0], new_rest[remain_budget > 0],
|
|
|
remain_budget[remain_budget > 0], tp, maximum, times + 1, fast=fast)
|
|
|
elif not fast:
|
|
|
positive_remain = (remain_budget > 0)
|
|
|
while positive_remain.sum() > 0:
|
|
|
offset = torch.eye(m, device=rest.device)[
|
|
|
torch.randint(m, (positive_remain.sum().int().item(),), device=rest.device)]
|
|
|
new_rest[positive_remain] = torch.clamp(new_rest[positive_remain] + offset, 1, maximum)
|
|
|
|
|
|
remain_budget = budget - (new_rest - rest).sum(dim=-1)
|
|
|
positive_remain = (remain_budget > 0)
|
|
|
return new_rest
|
|
|
|
|
|
|
|
|
def flex(D_: torch.Tensor, X: torch.Tensor, idx: torch.Tensor, flex_type, topk_, current_iter, total_iters, X_diff,
|
|
|
fast=False, return_maskarray=False):
|
|
|
'''
|
|
|
D: (b m n) Gram matrix, sorted on last dim, descending
|
|
|
X: (b numh numw he) c (sh sw) X_data
|
|
|
idx: (b m n) sorted index of D
|
|
|
x_size: (h, w) 2-tuple tensor
|
|
|
OUT: (b m n) Binary mask
|
|
|
'''
|
|
|
b, m, n = D_.shape
|
|
|
if flex_type is None or flex_type == 'none':
|
|
|
mask_array = topk_ * torch.ones((b, m), dtype=torch.int, device=D_.device)
|
|
|
|
|
|
elif flex_type == 'gsort':
|
|
|
D = D_.clone()
|
|
|
D -= (D == D.max(dim=-1, keepdim=True)) * 100000
|
|
|
val, g_idx = torch.sort(D.view(b, -1), dim=-1, descending=True)
|
|
|
|
|
|
g_idx += m * n * torch.arange(b, dtype=g_idx.dtype, device=g_idx.device).unsqueeze(-1)
|
|
|
non_topk_idx = g_idx[:, topk_ * (m - 1):]
|
|
|
|
|
|
mask_ = torch.ones_like(D).bool()
|
|
|
mask_.view(-1)[non_topk_idx.reshape(-1)] = False
|
|
|
mask_array = mask_.sum(dim=-1)
|
|
|
mask_array += 1
|
|
|
|
|
|
elif flex_type == 'interdiff_plain':
|
|
|
|
|
|
rest = torch.ones_like(X_diff)
|
|
|
budget = torch.ones(b, dtype=torch.int, device=idx.device) * (topk_ - 1) * idx.size(1)
|
|
|
mask_array = alloc(X_diff, rest, budget, tp='plain', maximum=idx.size(-1), fast=fast)
|
|
|
else:
|
|
|
raise NotImplementedError(f'Graph type {flex_type} not implemented...')
|
|
|
|
|
|
if return_maskarray:
|
|
|
return mask_array
|
|
|
|
|
|
mask = ~get_mask(idx, mask_array)
|
|
|
|
|
|
return mask
|
|
|
|
|
|
|
|
|
def cossim(X_sample, Y_sample, graph=None):
|
|
|
if graph is not None:
|
|
|
return torch.einsum('a b m c, a b n c -> a b m n', F.normalize(X_sample, dim=-1),
|
|
|
F.normalize(Y_sample, dim=-1)) + (-100.) * (~graph)
|
|
|
return torch.einsum('a b m c, a b n c -> a b m n', F.normalize(X_sample, dim=-1), F.normalize(Y_sample, dim=-1))
|
|
|
|
|
|
|
|
|
def local_sampling(x, group_size, unfold_dict, output=0, tp='bhwc'):
|
|
|
'''
|
|
|
output:
|
|
|
x (grouped) [B, nn, c]
|
|
|
x_unfold [B, NN, C]
|
|
|
0/1/2: grouped, sampled, both
|
|
|
'''
|
|
|
if isinstance(group_size, int):
|
|
|
group_size = (group_size, group_size)
|
|
|
|
|
|
if output != 1:
|
|
|
if tp == 'bhwc':
|
|
|
x_grouped = einops.rearrange(x, 'b (numh sh) (numw sw) c-> (b numh numw) (sh sw) c', sh=group_size[0],
|
|
|
sw=group_size[1])
|
|
|
elif tp == 'bchw':
|
|
|
x_grouped = einops.rearrange(x, 'b c (numh sh) (numw sw)-> (b numh numw) (sh sw) c', sh=group_size[0],
|
|
|
sw=group_size[1])
|
|
|
|
|
|
if output == 0:
|
|
|
return x_grouped
|
|
|
|
|
|
if tp == 'bhwc':
|
|
|
x = einops.rearrange(x, 'b h w c -> b c h w')
|
|
|
|
|
|
x_sampled = einops.rearrange(F.unfold(x, **unfold_dict), 'b (c k0 k1) l -> (b l) (k0 k1) c',
|
|
|
k0=unfold_dict['kernel_size'][0], k1=unfold_dict['kernel_size'][1])
|
|
|
|
|
|
if output == 1:
|
|
|
return x_sampled
|
|
|
|
|
|
assert x_grouped.size(0) == x_sampled.size(0)
|
|
|
return x_grouped, x_sampled
|
|
|
|
|
|
|
|
|
def global_sampling(x, group_size, sample_size, output=0, tp='bhwc'):
|
|
|
'''
|
|
|
output:
|
|
|
x (grouped) [B, nn, c]
|
|
|
x_unfold [B, NN, C]
|
|
|
'''
|
|
|
if isinstance(group_size, int):
|
|
|
group_size = (group_size, group_size)
|
|
|
if isinstance(sample_size, int):
|
|
|
sample_size = (sample_size, sample_size)
|
|
|
|
|
|
if output != 1:
|
|
|
if tp == 'bchw':
|
|
|
x_grouped = einops.rearrange(x, 'b c (sh numh) (sw numw) -> (b numh numw) (sh sw) c', sh=group_size[0],
|
|
|
sw=group_size[1])
|
|
|
elif tp == 'bhwc':
|
|
|
x_grouped = einops.rearrange(x, 'b (sh numh) (sw numw) c -> (b numh numw) (sh sw) c', sh=group_size[0],
|
|
|
sw=group_size[1])
|
|
|
|
|
|
if output == 0:
|
|
|
return x_grouped
|
|
|
|
|
|
if tp == 'bchw':
|
|
|
x_sampled = einops.rearrange(x, 'b c (sh extrah numh) (sw extraw numw) -> b extrah numh extraw numw c sh sw',
|
|
|
sh=sample_size[0], sw=sample_size[1], extrah=1, extraw=1)
|
|
|
elif tp == 'bhwc':
|
|
|
x_sampled = einops.rearrange(x, 'b (sh extrah numh) (sw extraw numw) c -> b extrah numh extraw numw c sh sw',
|
|
|
sh=sample_size[0], sw=sample_size[1], extrah=1, extraw=1)
|
|
|
b_y, _, numh, _, numw, c_y, sh_y, sw_y = x_sampled.shape
|
|
|
ratio_h, ratio_w = sample_size[0] // group_size[0], sample_size[1] // group_size[1]
|
|
|
x_sampled = x_sampled.expand(b_y, ratio_h, numh, ratio_w, numw, c_y, sh_y, sw_y).reshape(-1, c_y,
|
|
|
sh_y * sw_y).permute(0, 2,
|
|
|
1)
|
|
|
|
|
|
if output == 1:
|
|
|
return x_sampled
|
|
|
|
|
|
assert x_grouped.size(0) == x_sampled.size(0)
|
|
|
return x_grouped, x_sampled |