Spaces:
Sleeping
Sleeping
File size: 3,377 Bytes
37163a6 2279ae0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, IterableDataset
from torchvision import transforms
import datasets
from datasets import register
from utils.geometry import make_coord_scale_grid
from models.ldm.dac.audiotools import AudioSignal
import numpy as np
from models.ldm.dac.audiotools.data.datasets import AudioDataset, AudioLoader
from models.ldm.dac.audiotools import transforms as tfm
class BaseWrapperCAE:
def __init__(
self,
dataset,
resize_inp,
return_gt=True,
gt_glores_lb=None,
gt_glores_ub=None,
gt_patch_size=None,
p_whole=0.0,
p_max=0.0
):
self.dataset = datasets.make(dataset)
self.resize_inp = resize_inp
self.return_gt = return_gt
self.gt_glores_lb = gt_glores_lb
self.gt_glores_ub = gt_glores_ub
self.gt_patch_size = gt_patch_size
self.p_whole = p_whole
self.p_max = p_max
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5),
])
def process(self, image):
assert image.size[0] == image.size[1]
ret = {}
inp = image.resize((self.resize_inp, self.resize_inp), Image.LANCZOS)
inp = self.transform(inp)
ret.update({'inp': inp})
if not self.return_gt:
return ret
if self.gt_glores_lb is None:
glo = self.transform(image)
else:
if random.random() < self.p_whole:
r = self.gt_patch_size
elif random.random() < self.p_max:
r = min(image.size[0], self.gt_glores_ub)
else:
r = random.randint(
self.gt_glores_lb,
max(self.gt_glores_lb, min(image.size[0], self.gt_glores_ub))
)
glo = image.resize((r, r), Image.LANCZOS)
glo = self.transform(glo)
p = self.gt_patch_size
ii = random.randint(0, glo.shape[1] - p)
jj = random.randint(0, glo.shape[2] - p)
gt_patch = glo[:, ii: ii + p, jj: jj + p]
x0, y0 = ii / glo.shape[-2], jj / glo.shape[-1]
x1, y1 = (ii + p) / glo.shape[-2], (jj + p) / glo.shape[-1]
coord, scale = make_coord_scale_grid((p, p), range=[[x0, x1], [y0, y1]])
ret['gt'] = torch.cat([
gt_patch, # 3 p p
coord.permute(2, 0, 1), # 2 p p
scale.permute(2, 0, 1), # 2 p p
], dim=0)
return ret
@register('wrapper_cae')
class WrapperCAE(BaseWrapperCAE, Dataset):
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
if isinstance(data, dict):
ret = dict()
ret.update(self.process(data.pop('image')))
ret.update(data)
return ret
else:
return self.process(data)
@register('wrapper_cae_iterable')
class WrapperCAE(BaseWrapperCAE, IterableDataset):
def __iter__(self):
for data in self.dataset:
if isinstance(data, dict):
ret = dict()
ret.update(self.process(data.pop('image')))
ret.update(data)
yield ret
else:
yield self.process(data) |