| import json |
| import os |
| import random |
|
|
| from torch.utils.data import Dataset |
|
|
| from PIL import Image |
| from PIL import ImageFile |
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
| Image.MAX_IMAGE_PIXELS = None |
|
|
| from data.utils import pre_caption |
| import os,glob |
|
|
| class pretrain_dataset(Dataset): |
| def __init__(self, ann_file, laion_path, transform): |
|
|
| self.ann_pretrain = [] |
| for f in ann_file: |
| print('loading '+f) |
| ann = json.load(open(f,'r')) |
| self.ann_pretrain += ann |
| |
| self.laion_path = laion_path |
| if self.laion_path: |
| self.laion_files = glob.glob(os.path.join(laion_path,'*.json')) |
|
|
| print('loading '+self.laion_files[0]) |
| with open(self.laion_files[0],'r') as f: |
| self.ann_laion = json.load(f) |
|
|
| self.annotation = self.ann_pretrain + self.ann_laion |
| else: |
| self.annotation = self.ann_pretrain |
| |
| self.transform = transform |
|
|
|
|
| def reload_laion(self, epoch): |
| n = epoch%len(self.laion_files) |
| print('loading '+self.laion_files[n]) |
| with open(self.laion_files[n],'r') as f: |
| self.ann_laion = json.load(f) |
| |
| self.annotation = self.ann_pretrain + self.ann_laion |
| |
| |
| def __len__(self): |
| return len(self.annotation) |
| |
| def __getitem__(self, index): |
| |
| ann = self.annotation[index] |
| |
| image = Image.open(ann['image']).convert('RGB') |
| image = self.transform(image) |
| caption = pre_caption(ann['caption'],30) |
| |
| return image, caption |