|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
import random |
|
|
|
|
|
def gamma_correction(x, gamma): |
|
|
minv = torch.min(x) |
|
|
x = x - minv |
|
|
|
|
|
maxv = torch.max(x) |
|
|
x = x / maxv |
|
|
|
|
|
x = x**gamma |
|
|
x = x * maxv + minv |
|
|
return x |
|
|
|
|
|
def random_aug(x): |
|
|
|
|
|
|
|
|
if random.random() <= 0.3: |
|
|
gamma = random.uniform(1.0, 1.5) |
|
|
x = gamma_correction(x, gamma) |
|
|
|
|
|
mean_v = tuple(x.view(x.size(0), -1).mean(-1)) |
|
|
re = transforms.RandomErasing(p=0.5, value=mean_v) |
|
|
x = re(x) |
|
|
|
|
|
if random.random() <= 0.3: |
|
|
l = [0,1,2] |
|
|
random.shuffle(l) |
|
|
x_c = torch.zeros_like(x) |
|
|
x_c[l] = x |
|
|
x = x_c |
|
|
|
|
|
if random.random() <= 0.5: |
|
|
if random.random() <= 0.5: |
|
|
x = torch.flip(x, [1]) |
|
|
else: |
|
|
x = torch.flip(x, [2]) |
|
|
|
|
|
if random.random() <= 0.5: |
|
|
degree = [90, 180, 270] |
|
|
d = random.choice(degree) |
|
|
x = torch.rot90(x, d//90, [1, 2]) |
|
|
|
|
|
return x |
|
|
|
|
|
class PseudoSampleGenerator(object): |
|
|
def __init__(self, n_way, n_support, n_pseudo): |
|
|
super(PseudoSampleGenerator, self).__init__() |
|
|
self.n_way = n_way |
|
|
self.n_support = n_support |
|
|
self.n_pseudo = n_pseudo |
|
|
self.n_pseudo_per_way = self.n_pseudo//self.n_way |
|
|
|
|
|
def generate(self, support_set): |
|
|
|
|
|
if(self.n_support<=5): |
|
|
times = self.n_pseudo//(self.n_way*self.n_support)+1 |
|
|
psedo_list = [] |
|
|
for i in range(support_set.size(0)): |
|
|
psedo_list.append(support_set[i]) |
|
|
for j in range(1, times): |
|
|
cur_x = support_set[i] |
|
|
cur_x = random_aug(cur_x) |
|
|
psedo_list.append(cur_x) |
|
|
psedo_set = torch.stack(psedo_list) |
|
|
|
|
|
psedo_set = psedo_set.reshape([self.n_way, self.n_pseudo_per_way+ self.n_support]+list(psedo_set.size()[1:])) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
support_set = support_set.view(self.n_way, self.n_support, 3, 224, 224) |
|
|
|
|
|
perm = torch.randperm(self.n_support) |
|
|
idx = perm[:15] |
|
|
|
|
|
selected_support_set = support_set[:, idx, :, :, :] |
|
|
|
|
|
selected_support_set = selected_support_set.view(self.n_way*15, 3, 224, 224) |
|
|
|
|
|
times =1 |
|
|
psedo_query_list = [] |
|
|
for i in range(selected_support_set.size(0)): |
|
|
for j in range(0, times): |
|
|
cur_x = selected_support_set[i] |
|
|
cur_x = random_aug(cur_x) |
|
|
psedo_query_list.append(cur_x) |
|
|
psedo_query_list = torch.stack(psedo_query_list) |
|
|
psedo_query_set = psedo_query_list.view(self.n_way, 15, 3, 224, 224) |
|
|
|
|
|
psedo_set = torch.cat((support_set, psedo_query_set), dim = 1) |
|
|
|
|
|
return psedo_set |
|
|
|