| from torch.utils.data import Dataset |
| import torch |
| import torchvision.transforms.functional as TF |
| from torchvision.transforms import Compose, Resize, CenterCrop |
| from torchvision.io import decode_jpeg, encode_jpeg |
| from glob import glob |
| import os.path as osp |
| from PIL import Image |
| import random |
| import os |
| TARGET_COMP = 0.1 |
|
|
| class TMDistilDireDataset(Dataset): |
| def __init__(self, root, prepared_dire=True): |
| self.root = root |
| self.__fake_img_paths = [p for p in glob(osp.join(root, 'images/fakes/', '*')) if p.split('.')[-1].lower() in ['jpg', 'jpeg', 'png', 'webp']] |
| self.__real_img_paths = [p for p in glob(osp.join(root, 'images/reals/', '*')) if p.split('.')[-1].lower() in ['jpg', 'jpeg', 'png', 'webp']] |
| self.prepared_dire = prepared_dire |
| self.transform = Compose([Resize(256, antialias='True'), CenterCrop((256, 256))]) |
| |
|
|
| |
| if prepared_dire: |
| self.fake_paths = list(map(lambda x: (x, x.replace('/images/', '/dire/'), x.replace('/images/', '/eps/').split('.')[0]+'.pt', True), self.__fake_img_paths)) |
| self.real_paths = list(map(lambda x: (x, x.replace('/images/', '/dire/'), x.replace('/images/', '/eps/').split('.')[0]+'.pt', False), self.__real_img_paths)) |
| else: |
| self.fake_paths = list(map(lambda x: (x, "", "", True), self.__fake_img_paths)) |
| self.real_paths = list(map(lambda x: (x, "", "", False), self.__real_img_paths)) |
| random.shuffle(self.fake_paths) |
| random.shuffle(self.real_paths) |
| self.img_paths = self.fake_paths + self.real_paths |
| |
| |
| |
| img_paths = [] |
| for img_path, dire_path, eps_path, isfake in self.img_paths: |
| try: |
| Image.open(img_path) |
| img_paths.append((img_path, dire_path, eps_path, isfake)) |
| except: |
| continue |
| self.img_paths = img_paths |
|
|
| def __len__(self): |
| return len(self.img_paths) |
|
|
| def __getitem__(self, idx): |
| img_path, dire_path, eps_path, isfake = self.img_paths[idx] |
| img = Image.open(img_path).convert('RGB') |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| img = TF.to_tensor(img)*2 - 1 |
|
|
| |
| if self.prepared_dire: |
| img = self.transform(img) |
| dire = Image.open(dire_path).convert('RGB') |
| dire = TF.to_tensor(dire)*2 - 1 |
| dire = self.transform(dire) |
| eps = torch.load(eps_path, weights_only=True, mmap=True) |
| |
| assert img.shape[1:] == dire.shape[1:] == eps.shape[1:], f"Shape mismatch: {img.shape[1:]}, {dire.shape[1:]}, {eps.shape[1:]}" |
| |
| else: |
| img = self.transform(img) |
| dire = torch.zeros_like(img) |
| eps = torch.zeros_like(img) |
|
|
| return (img, dire, eps, isfake), (img_path, dire_path, eps_path) |
|
|
|
|
|
|
| class TMIMGOnlyDataset(TMDistilDireDataset): |
| def __init__(self, root, istrain=True): |
| super().__init__(root, prepared_dire=True) |
| self.istrain=istrain |
|
|
| def __getitem__(self, idx): |
| |
| img_path, dire_path, eps_path, isfake = self.img_paths[idx] |
| img = Image.open(img_path).convert('RGB') |
| img = self.transform(img) |
| img = TF.to_tensor(img)*2 - 1 |
| eps = torch.zeros_like(img) |
| dire = torch.zeros_like(img) |
| |
| if torch.rand(1) < 0.3 and self.istrain: |
| img = TF.hflip(img) |
| return (img, dire, eps, isfake), (img_path, dire_path, eps_path) |
| |
| |
|
|
|
|
| class TMEPSOnlyDataset(TMDistilDireDataset): |
| def __init__(self, root, istrain=True): |
| super().__init__(root, prepared_dire=True) |
| img_paths = [] |
| for img_path, dire_path, eps_path, isfake in self.img_paths: |
| if not osp.exists(eps_path) or not osp.exists(img_path): |
| |
| continue |
| try: |
| eps = torch.load(eps_path, weights_only=True, mmap=True) |
| img_paths.append((img_path, dire_path, eps_path, isfake)) |
| except Exception as e: |
| print(e) |
| continue |
| self.img_paths = img_paths |
| self.istrain=istrain |
|
|
| def __getitem__(self, idx): |
| |
| img_path, dire_path, eps_path, isfake = self.img_paths[idx] |
| img = Image.open(img_path).convert('RGB') |
| img = TF.to_tensor(img)*2 - 1 |
| img = self.transform(img) |
| eps = torch.load(eps_path, weights_only=True, mmap=True) |
| dire = torch.zeros_like(img) |
| |
| if torch.rand(1) < 0.3 and self.istrain: |
| img = TF.hflip(img) |
| eps = TF.hflip(eps) |
| |
| return (img, dire, eps, isfake), (img_path, dire_path, eps_path) |
| |
| |
| |
| class TMDireDataset(TMDistilDireDataset): |
| def __init__(self, root): |
| super().__init__(root, prepared_dire=True) |
| |
| def __getitem__(self, idx): |
| img_path, dire_path, eps_path, isfake = self.img_paths[idx] |
| |
| dire = Image.open(dire_path).convert('RGB') |
| dire = TF.to_tensor(dire)*2 - 1 |
| |
| return (dire, isfake), (dire_path,) |
| |