File size: 3,534 Bytes
ee3e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import pathlib
import sys

parent_path = pathlib.Path(__file__).absolute().parent.parent
parent_path = os.path.abspath(parent_path)
sys.path.append(parent_path)
os.chdir(parent_path)
print(f'>-------------> parent path {parent_path}')
print(f'>-------------> current work dir {os.getcwd()}')

import argparse
import glob
import multiprocessing
from PIL import Image
from os.path import join

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import VisionDataset

from vqgan.load import encode_transform
from generate.img_to_token import img_to_token

CPU_COUNT = multiprocessing.cpu_count()


class GoProDataset(VisionDataset):
    def __init__(

            self,

            root: str,

            target_root,

            transform=None,

            target_transform=None,

            transforms=None,

            transform_name=None

    ) -> None:
        super().__init__(root, transforms, transform, target_transform)
        self.target_root = target_root
        
        file_list = glob.glob(join(root, '*.png'))
        ids = [os.path.basename(i).split('.')[0] for i in file_list]
        self.ids = list(sorted(ids))
        
        self.transform_name = transform_name
    
    def _load_image(self, id: int):
        path = join(self.root, f'{id}.png')
        return Image.open(path).convert("RGB")
    
    def _load_target(self, id: int):
        path = join(self.target_root, f'{id}.png')
        return Image.open(path).convert("RGB")
    
    def __getitem__(self, index: int):
        id = self.ids[index]
        image = self._load_image(id)
        target_img = self._load_target(id)
        
        images = self.transform(image)
        target_imgs = self.transform(target_img)
        
        data_list = []
        if self.transform_name == 'six_crop_encode_transform':
            for _img, _target_img in zip(images, target_imgs):
                _data = torch.stack([_img, _target_img], dim=0)
                data_list.append(_data)
        else:
            _data = torch.stack([images, target_imgs], dim=0)
            data_list.append(_data)
        
        data = torch.cat(data_list, dim=0)
        
        return data
    
    def __len__(self) -> int:
        return len(self.ids)


def convert_img_to_token(args, device=None):
    dataset = GoProDataset(args.input_data, args.target_data, transform=encode_transform,
                           transform_name='encode_transform')
    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_work)
    img_to_token(args, data_loader, args.output_path, device=device)


def get_args():
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--input_data", type=str, default="Rain13K_lmdb/input.lmdb")
    parser.add_argument("--target_data", type=str, default="Rain13K_lmdb/target.lmdb")
    parser.add_argument("--output_path", type=str, default="vq_token/Rain13K")
    
    parser.add_argument("--num_work", type=int, default=64)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--dp_mode", action='store_true', default=False)
    parser.add_argument("--model_name_or_path", type=str, default="weight/vqgan-f16-8192-laion")
    args = parser.parse_args()
    
    return args


if __name__ == '__main__':
    args = get_args()
    
    device = f'cuda:{0}'
    convert_img_to_token(args, device=device)