Spaces:
Build error
Build error
| from torch.utils.data import TensorDataset | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import glob | |
| import pickle | |
| import random | |
| import os | |
| import cv2 | |
| import tqdm | |
| import sys | |
| sys.path.append('..') | |
| # from utils.cap_aug import CAP_AUG | |
| class FaceEmbed(TensorDataset): | |
| def __init__(self, data_path_list, same_prob=0.8): | |
| datasets = [] | |
| # embeds = [] | |
| self.N = [] | |
| self.same_prob = same_prob | |
| for data_path in data_path_list: | |
| image_list = glob.glob(f'{data_path}/*.*g') | |
| datasets.append(image_list) | |
| self.N.append(len(image_list)) | |
| # with open(f'{data_path}/embed.pkl', 'rb') as f: | |
| # embed = pickle.load(f) | |
| # embeds.append(embed) | |
| self.datasets = datasets | |
| # self.embeds = embeds | |
| self.transforms_arcface = transforms.Compose([ | |
| transforms.ColorJitter(0.2, 0.2, 0.2, 0.01), | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| self.transforms_base = transforms.Compose([ | |
| transforms.ColorJitter(0.2, 0.2, 0.2, 0.01), | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| def __getitem__(self, item): | |
| idx = 0 | |
| while item >= self.N[idx]: | |
| item -= self.N[idx] | |
| idx += 1 | |
| image_path = self.datasets[idx][item] | |
| # name = os.path.split(image_path)[1] | |
| # embed = self.embeds[idx][name] | |
| Xs = cv2.imread(image_path)[:, :, ::-1] | |
| Xs = Image.fromarray(Xs) | |
| if random.random() > self.same_prob: | |
| image_path = random.choice(self.datasets[random.randint(0, len(self.datasets)-1)]) | |
| Xt = cv2.imread(image_path)[:, :, ::-1] | |
| Xt = Image.fromarray(Xt) | |
| same_person = 0 | |
| else: | |
| Xt = Xs.copy() | |
| same_person = 1 | |
| return self.transforms_arcface(Xs), self.transforms_base(Xs), self.transforms_base(Xt), same_person | |
| def __len__(self): | |
| return sum(self.N) | |
| class FaceEmbedVGG2(TensorDataset): | |
| def __init__(self, data_path, same_prob=0.8, same_identity=False): | |
| self.same_prob = same_prob | |
| self.same_identity = same_identity | |
| self.images_list = glob.glob(f'{data_path}/*/*.*g') | |
| self.folders_list = glob.glob(f'{data_path}/*') | |
| self.folder2imgs = {} | |
| for folder in tqdm.tqdm(self.folders_list): | |
| folder_imgs = glob.glob(f'{folder}/*') | |
| self.folder2imgs[folder] = folder_imgs | |
| self.N = len(self.images_list) | |
| self.transforms_arcface = transforms.Compose([ | |
| transforms.ColorJitter(0.2, 0.2, 0.2, 0.01), | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| self.transforms_base = transforms.Compose([ | |
| transforms.ColorJitter(0.2, 0.2, 0.2, 0.01), | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| def __getitem__(self, item): | |
| image_path = self.images_list[item] | |
| Xs = cv2.imread(image_path)[:, :, ::-1] | |
| Xs = Image.fromarray(Xs) | |
| if self.same_identity: | |
| folder_name = '/'.join(image_path.split('/')[:-1]) | |
| if random.random() > self.same_prob: | |
| image_path = random.choice(self.images_list) | |
| Xt = cv2.imread(image_path)[:, :, ::-1] | |
| Xt = Image.fromarray(Xt) | |
| same_person = 0 | |
| else: | |
| if self.same_identity: | |
| image_path = random.choice(self.folder2imgs[folder_name]) | |
| Xt = cv2.imread(image_path)[:, :, ::-1] | |
| Xt = Image.fromarray(Xt) | |
| else: | |
| Xt = Xs.copy() | |
| same_person = 1 | |
| return self.transforms_arcface(Xs), self.transforms_base(Xs), self.transforms_base(Xt), same_person | |
| def __len__(self): | |
| return self.N |