File size: 3,377 Bytes
37163a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2279ae0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import random
from PIL import Image

import torch
from torch.utils.data import Dataset, IterableDataset
from torchvision import transforms

import datasets
from datasets import register
from utils.geometry import make_coord_scale_grid


from models.ldm.dac.audiotools import AudioSignal
import numpy as np

from models.ldm.dac.audiotools.data.datasets import AudioDataset, AudioLoader
from models.ldm.dac.audiotools import transforms as tfm


class BaseWrapperCAE:

    def __init__(
        self,
        dataset,
        resize_inp,
        return_gt=True,
        gt_glores_lb=None,
        gt_glores_ub=None,
        gt_patch_size=None,
        p_whole=0.0,
        p_max=0.0
    ):
        self.dataset = datasets.make(dataset)
        self.resize_inp = resize_inp
        self.return_gt = return_gt
        self.gt_glores_lb = gt_glores_lb
        self.gt_glores_ub = gt_glores_ub
        self.gt_patch_size = gt_patch_size
        self.p_whole = p_whole
        self.p_max = p_max
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(0.5, 0.5),
        ])

    def process(self, image):
        assert image.size[0] == image.size[1]
        ret = {}
        
        inp = image.resize((self.resize_inp, self.resize_inp), Image.LANCZOS)
        inp = self.transform(inp)
        ret.update({'inp': inp})
        if not self.return_gt:
            return ret

        if self.gt_glores_lb is None:
            glo = self.transform(image)
        else:
            if random.random() < self.p_whole:
                r = self.gt_patch_size
            elif random.random() < self.p_max:
                r = min(image.size[0], self.gt_glores_ub)
            else:
                r = random.randint(
                    self.gt_glores_lb,
                    max(self.gt_glores_lb, min(image.size[0], self.gt_glores_ub))
                )
            glo = image.resize((r, r), Image.LANCZOS)
            glo = self.transform(glo)

        p = self.gt_patch_size
        ii = random.randint(0, glo.shape[1] - p)
        jj = random.randint(0, glo.shape[2] - p)
        gt_patch = glo[:, ii: ii + p, jj: jj + p]

        x0, y0 = ii / glo.shape[-2], jj / glo.shape[-1]
        x1, y1 = (ii + p) / glo.shape[-2], (jj + p) / glo.shape[-1]
        coord, scale = make_coord_scale_grid((p, p), range=[[x0, x1], [y0, y1]])
        ret['gt'] = torch.cat([
            gt_patch, # 3 p p
            coord.permute(2, 0, 1), # 2 p p
            scale.permute(2, 0, 1), # 2 p p
        ], dim=0)

        return ret


@register('wrapper_cae')
class WrapperCAE(BaseWrapperCAE, Dataset):
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        if isinstance(data, dict):
            ret = dict()
            ret.update(self.process(data.pop('image')))
            ret.update(data)
            return ret
        else:
            return self.process(data)


@register('wrapper_cae_iterable')
class WrapperCAE(BaseWrapperCAE, IterableDataset):

    def __iter__(self):
        for data in self.dataset:
            if isinstance(data, dict):
                ret = dict()
                ret.update(self.process(data.pop('image')))
                ret.update(data)
                yield ret
            else:
                yield self.process(data)