import os import pathlib import sys parent_path = pathlib.Path(__file__).absolute().parent.parent parent_path = os.path.abspath(parent_path) sys.path.append(parent_path) os.chdir(parent_path) print(f'>-------------> parent path {parent_path}') print(f'>-------------> current work dir {os.getcwd()}') import argparse import glob import multiprocessing from PIL import Image from os.path import join import torch from torch.utils.data import DataLoader from torchvision.datasets import VisionDataset from vqgan.load import encode_transform from generate.img_to_token import img_to_token CPU_COUNT = multiprocessing.cpu_count() class GoProDataset(VisionDataset): def __init__( self, root: str, target_root, transform=None, target_transform=None, transforms=None, transform_name=None ) -> None: super().__init__(root, transforms, transform, target_transform) self.target_root = target_root file_list = glob.glob(join(root, '*.png')) ids = [os.path.basename(i).split('.')[0] for i in file_list] self.ids = list(sorted(ids)) self.transform_name = transform_name def _load_image(self, id: int): path = join(self.root, f'{id}.png') return Image.open(path).convert("RGB") def _load_target(self, id: int): path = join(self.target_root, f'{id}.png') return Image.open(path).convert("RGB") def __getitem__(self, index: int): id = self.ids[index] image = self._load_image(id) target_img = self._load_target(id) images = self.transform(image) target_imgs = self.transform(target_img) data_list = [] if self.transform_name == 'six_crop_encode_transform': for _img, _target_img in zip(images, target_imgs): _data = torch.stack([_img, _target_img], dim=0) data_list.append(_data) else: _data = torch.stack([images, target_imgs], dim=0) data_list.append(_data) data = torch.cat(data_list, dim=0) return data def __len__(self) -> int: return len(self.ids) def convert_img_to_token(args, device=None): dataset = GoProDataset(args.input_data, args.target_data, transform=encode_transform, transform_name='encode_transform') data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_work) img_to_token(args, data_loader, args.output_path, device=device) def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--input_data", type=str, default="Rain13K_lmdb/input.lmdb") parser.add_argument("--target_data", type=str, default="Rain13K_lmdb/target.lmdb") parser.add_argument("--output_path", type=str, default="vq_token/Rain13K") parser.add_argument("--num_work", type=int, default=64) parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--dp_mode", action='store_true', default=False) parser.add_argument("--model_name_or_path", type=str, default="weight/vqgan-f16-8192-laion") args = parser.parse_args() return args if __name__ == '__main__': args = get_args() device = f'cuda:{0}' convert_img_to_token(args, device=device)