Spaces:
Sleeping
Sleeping
| import cv2 | |
| import numpy as np | |
| import torchvision.datasets as datasets | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms.functional as TF | |
| from random import random, choice | |
| from io import BytesIO | |
| from PIL import Image | |
| from PIL import ImageFile | |
| from scipy.ndimage.filters import gaussian_filter | |
| from torchvision.transforms import InterpolationMode | |
| from typing import Any, Callable, cast, Dict, List, Optional, Tuple | |
| import os | |
| from transformers import AutoTokenizer | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") | |
| def pil_loader(path: str) -> Image.Image: | |
| # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | |
| with open(path, "rb") as f: | |
| img = Image.open(f) | |
| return img.convert("RGB") | |
| class ImageFolder2(datasets.DatasetFolder): | |
| def __init__( | |
| self, | |
| root: str, | |
| opt, | |
| transform: Optional[Callable] = None, | |
| ): | |
| super().__init__( | |
| root, | |
| transform=transform, | |
| extensions=IMG_EXTENSIONS, | |
| loader = pil_loader | |
| ) | |
| self.opt = opt | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.opt.clip, model_max_length=77, padding_side="right", use_fast=False) | |
| self.tokenizer.pad_token_id = 0 | |
| def __getitem__(self, index: int) -> Tuple[Any, Any]: | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: (sample, target) where target is class_index of the target class. | |
| """ | |
| path, target = self.samples[index] | |
| textpath = path.replace(self.opt.imgroot, self.opt.textroot) | |
| textpath = os.path.splitext(textpath)[0] + '.txt' | |
| sample = self.loader(path) | |
| try: | |
| with open(textpath, 'r') as file: | |
| text = file.read() | |
| cates_len = len(self.opt.cates)//2 | |
| if target == 1: text = f'{" ".join(self.opt.cates[:cates_len])}. {text} {" ".join(self.opt.cates[:cates_len])}.' | |
| if target == 0: text = f'{" ".join(self.opt.cates[cates_len:])}. {text} {" ".join(self.opt.cates[cates_len:])}.' | |
| inputs = self.tokenizer([text], padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
| input_ids=inputs['input_ids'][0] | |
| attention_mask=inputs['attention_mask'][0] | |
| except: | |
| text, input_ids, attention_mask = ' ', ' ', ' ' | |
| if self.transform is not None: | |
| sample = self.transform(sample) | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| return path, sample, text, input_ids, attention_mask, target | |
| def dataset_folder(opt, root): | |
| if opt.mode == 'binary': | |
| return binary_dataset(opt, root) | |
| if opt.mode == 'filename': | |
| return FileNameDataset(opt, root) | |
| raise ValueError('opt.mode needs to be binary or filename.') | |
| def binary_dataset(opt, root): | |
| if opt.isTrain: | |
| crop_func = transforms.RandomCrop(opt.cropSize) | |
| elif opt.no_crop: | |
| crop_func = transforms.Lambda(lambda img: img) | |
| else: | |
| crop_func = transforms.CenterCrop(opt.cropSize) | |
| if opt.isTrain and not opt.no_flip: | |
| flip_func = transforms.RandomHorizontalFlip() | |
| else: | |
| flip_func = transforms.Lambda(lambda img: img) | |
| if not opt.isTrain and opt.no_resize: | |
| rz_func = transforms.Lambda(lambda img: img) | |
| else: | |
| rz_func = transforms.Lambda(lambda img: translate_duplicate(img, opt.cropSize)) | |
| dset = ImageFolder2( | |
| root, | |
| opt, | |
| transforms.Compose([ | |
| rz_func, | |
| crop_func, | |
| flip_func, | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), | |
| ])) | |
| return dset | |
| class FileNameDataset(datasets.ImageFolder): | |
| def name(self): | |
| return 'FileNameDataset' | |
| def __init__(self, opt, root): | |
| self.opt = opt | |
| super().__init__(root) | |
| def __getitem__(self, index): | |
| # Loading sample | |
| path, target = self.samples[index] | |
| return path | |
| import math | |
| def translate_duplicate(img, cropSize): | |
| if min(img.size) < cropSize: | |
| width, height = img.size | |
| new_width = width * math.ceil(cropSize/width) | |
| new_height = height * math.ceil(cropSize/height) | |
| new_img = Image.new('RGB', (new_width, new_height)) | |
| for i in range(0, new_width, width): | |
| for j in range(0, new_height, height): | |
| new_img.paste(img, (i, j)) | |
| return new_img | |
| else: | |
| return img | |
| def data_augment(img, opt): | |
| img = np.array(img) | |
| if random() < opt.blur_prob: | |
| sig = sample_continuous(opt.blur_sig) | |
| gaussian_blur(img, sig) | |
| if random() < opt.jpg_prob: | |
| method = sample_discrete(opt.jpg_method) | |
| qual = sample_discrete(opt.jpg_qual) | |
| img = jpeg_from_key(img, qual, method) | |
| return Image.fromarray(img) | |
| def sample_continuous(s): | |
| if len(s) == 1: | |
| return s[0] | |
| if len(s) == 2: | |
| rg = s[1] - s[0] | |
| return random() * rg + s[0] | |
| raise ValueError("Length of iterable s should be 1 or 2.") | |
| def sample_discrete(s): | |
| if len(s) == 1: | |
| return s[0] | |
| return choice(s) | |
| def gaussian_blur(img, sigma): | |
| gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma) | |
| gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma) | |
| gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma) | |
| def cv2_jpg(img, compress_val): | |
| img_cv2 = img[:,:,::-1] | |
| encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val] | |
| result, encimg = cv2.imencode('.jpg', img_cv2, encode_param) | |
| decimg = cv2.imdecode(encimg, 1) | |
| return decimg[:,:,::-1] | |
| def pil_jpg(img, compress_val): | |
| out = BytesIO() | |
| img = Image.fromarray(img) | |
| img.save(out, format='jpeg', quality=compress_val) | |
| img = Image.open(out) | |
| # load from memory before ByteIO closes | |
| img = np.array(img) | |
| out.close() | |
| return img | |
| jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg} | |
| def jpeg_from_key(img, compress_val, key): | |
| method = jpeg_dict[key] | |
| return method(img, compress_val) | |
| rz_dict = {'bilinear': InterpolationMode.BILINEAR, | |
| 'bicubic': InterpolationMode.BICUBIC, | |
| 'lanczos': InterpolationMode.LANCZOS, | |
| 'nearest': InterpolationMode.NEAREST} | |
| def custom_resize(img, opt): | |
| interp = sample_discrete(opt.rz_interp) | |
| return TF.resize(img, opt.loadSize, interpolation=rz_dict[interp]) | |