Spaces:
Sleeping
Sleeping
| import json | |
| import torchvision.transforms as transforms | |
| import torchvision | |
| import PIL.Image as Image | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from pathlib import Path | |
| import PIL.Image as Image | |
| import numpy as np | |
| from pycocotools.coco import COCO | |
| import skimage.io as io | |
| import pandas as pd | |
| class real(torch.utils.data.Dataset): | |
| def __init__(self,realroot,size,transform=None): | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((size,size)), | |
| #RandAugment(2, 14), | |
| #transforms.CenterCrop((size,size)), | |
| transforms.ToTensor() | |
| ]) | |
| dataDir='your dir' | |
| dataType='val2014' | |
| self.annFile = '{}/annotations/captions_{}.json'.format(dataDir,dataType) | |
| self.coco=COCO(self.annFile) | |
| self.imgIds_list=sorted(self.coco.getImgIds()) | |
| def __getitem__(self,item): | |
| imgIds = self.coco.getImgIds(imgIds = [self.imgIds_list[item]]) | |
| img = self.coco.loadImgs(imgIds[np.random.randint(0,len(imgIds))])[0] | |
| I = io.imread(img['coco_url']) | |
| real_image = Image.fromarray(I).convert('RGB') | |
| real_image = self.transform(real_image) | |
| annIds = self.coco.getAnnIds(imgIds=img['id']) | |
| anns = self.coco.loadAnns(annIds) | |
| label = 0 | |
| return real_image,label,anns[0]['caption'] | |
| def __len__(self): | |
| return len(self.imgIds_list) | |
| class realflickr(torch.utils.data.Dataset): | |
| def __init__(self,realroot,size,transform=None): | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((size,size)), | |
| #RandAugment(2, 14), | |
| #transforms.CenterCrop((size,size)), | |
| transforms.ToTensor() | |
| ]) | |
| annotations = pd.read_table('your dir', sep='\t', header=None, | |
| names=['image', 'caption']) | |
| self.prompt_list = np.array(annotations['caption'][::5]) | |
| self.image_list = np.array(annotations['image'][::5]) | |
| def __getitem__(self,item): | |
| real_image = Image.open('your dir') | |
| prompts = self.prompt_list[item] | |
| label = 0 | |
| real_image = self.transform(real_image) | |
| return real_image,label,prompts | |
| def __len__(self): | |
| return len(self.image_list) | |
| class fakereal(torch.utils.data.Dataset): | |
| def __init__(self,fakeroot,size,transform=None): | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((size,size)), | |
| #RandAugment(2, 14), | |
| #transforms.CenterCrop((size,size)), | |
| transforms.ToTensor() | |
| ]) | |
| fake_images_path = Path(fakeroot) | |
| fake_images_list = list(fake_images_path.glob('*.png')) | |
| fake_images_list_str = [ str(x) for x in fake_images_list ] | |
| self.fake_images = fake_images_list_str | |
| def __getitem__(self,item): | |
| fake_image_path = self.fake_images[item] | |
| fake_image = Image.open(fake_image_path).convert('RGB') | |
| fake_image = self.transform(fake_image) | |
| label = 1 | |
| prompts = fake_image_path.split('/')[-1].replace('-',' ').split('.png')[0] | |
| return fake_image,label,prompts | |
| def __len__(self): | |
| return len(self.fake_images) | |
| class fakeclip(torch.utils.data.Dataset): | |
| def __init__(self,fakeroot,size,transforms=None): | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((size,size)), | |
| #RandAugment(2, 14), | |
| #transforms.CenterCrop((size,size)), | |
| transforms.ToTensor() | |
| ]) | |
| fake_images_path = Path(fakeroot) | |
| fake_images_list = list(fake_images_path.glob('*.png')) | |
| fake_images_list_str = [ str(x) for x in fake_images_list ] | |
| self.fake_images = fake_images_list_str | |