Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision.io import read_image, ImageReadMode | |
| import numpy as np | |
| def denorm_img(img: torch.Tensor) -> torch.Tensor: | |
| std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1) | |
| mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1) | |
| return torch.clip(img * std + mean, min=0, max=1) | |
| class StyleContentDataset(Dataset): | |
| def __init__(self, style_imgs, content_imgs, transform=None, normalize=None): | |
| self.style_imgs = style_imgs | |
| self.content_imgs = content_imgs | |
| self.transform = transform | |
| self.normalize = normalize | |
| def __len__(self): | |
| if len(self.style_imgs) < len(self.content_imgs): | |
| return len(self.style_imgs) | |
| else: | |
| return len(self.content_imgs) | |
| def __getitem__(self, idx): | |
| try: | |
| style = read_image(self.style_imgs[idx], ImageReadMode.RGB).float() / 255.0 | |
| content = read_image(self.content_imgs[idx], ImageReadMode.RGB).float() / 255.0 | |
| except RuntimeError: | |
| print(self.style_imgs[idx]) | |
| print(self.content_imgs[idx]) | |
| style = read_image(self.style_imgs[0], ImageReadMode.RGB).float() / 255.0 | |
| content = read_image(self.content_imgs[0], ImageReadMode.RGB).float() / 255.0 | |
| if self.normalize: | |
| style = self.normalize(style) | |
| content = self.normalize(content) | |
| if self.transform: | |
| style = self.transform(style) | |
| content = self.transform(content) | |
| return style, content | |
| class DataStore(): | |
| def __init__(self, dataset: StyleContentDataset, batch_size, shuffle=False): | |
| self.dataset = dataset | |
| self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=shuffle, num_workers=2) | |
| self.iterator = iter(self.dataloader) | |
| def get(self): | |
| try: | |
| style, content = next(self.iterator) | |
| except (StopIteration): | |
| # print('| Repeating |') | |
| # np.random.shuffle(self.dataset.style_imgs) | |
| self.iterator = iter(self.dataloader) | |
| style, content = next(self.iterator) | |
| return style, content |