| import os | |
| import random | |
| import copy | |
| from PIL import Image | |
| import numpy as np | |
| import json | |
| from torch.utils.data import Dataset | |
| from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor | |
| from utils.image_utils import random_augmentation, crop_img | |
| from utils.degradation_utils import Degradation | |
| class DerainDehazeDataset(Dataset): | |
| def __init__(self, args, img, text_prompt, task="derain"): | |
| super(DerainDehazeDataset, self).__init__() | |
| self.args = args | |
| self.toTensor = ToTensor() | |
| self.img = img | |
| self.text_prompt = text_prompt | |
| def __getitem__(self, idx): | |
| degraded_inp = self.img | |
| clean_path = "" | |
| degradation = "" | |
| text_prompt = self.text_prompt | |
| degraded_img = crop_img(np.array(degraded_inp.convert('RGB')), base=16) | |
| clean_img = crop_img(np.array(degraded_inp.convert('RGB')), base=16) | |
| clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img) | |
| degraded_name = [""] | |
| return [degraded_name], degradation, degraded_img, clean_img, text_prompt | |
| def __len__(self): | |
| return 1 |