import glob import os import random import torch import torchvision from PIL import Image from tqdm import tqdm from torch.utils.data.dataset import Dataset import json class SceneTextDataset(Dataset): def __init__(self, split, root_dir, augment=False): self.split = split self.root_dir = root_dir self.im_dir = os.path.join(root_dir, 'img') self.ann_dir = os.path.join(root_dir, 'annots') classes = [ 'text' ] classes = sorted(classes) classes = ['background'] + classes self.label2idx = {classes[idx]: idx for idx in range(len(classes))} self.idx2label = {idx: classes[idx] for idx in range(len(classes))} print(self.idx2label) self.images = glob.glob(os.path.join(self.im_dir, '*.jpg')) self.annotations = [os.path.join(self.ann_dir, os.path.basename(im) + '.json') for im in self.images] if(split == 'train'): self.images = self.images[:int(0.9*len(self.images))] self.annotations = self.annotations[:int(0.9*len(self.annotations))] elif(split == 'val'): self.images = self.images[int(0.9*len(self.images)):] self.annotations = self.annotations[int(0.9*len(self.annotations)):] else: self.images = self.images[int(0.9*len(self.images)):] self.annotations = self.annotations[int(0.9*len(self.annotations)):] self.augment = augment def __len__(self): return len(self.images) def convert_xcycwh_to_xyxy(self, box): x, y, w, h = box x1 = x - w/2 y1 = y - h/2 x2 = x + w/2 y2 = y + h/2 return [x1, y1, x2, y2] def convert_xcycwhtheta_to_xyxytheta(self, box): x, y, w, h, theta = box x1 = x - w/2 y1 = y - h/2 x2 = x + w/2 y2 = y + h/2 return [x1, y1, x2, y2, theta] def horizontal_flip(self, im, boxes, thetas): # flip the image im = im.flip(-1) # flip the annoted boxes boxes[:, [0, 2]] = im.shape[1] - boxes[:, [2, 0]] # flip the theta thetas = -thetas thetas = torch.where(thetas < 0, thetas + torch.pi, thetas) return im, boxes, thetas def vertical_flip(self, im, boxes, thetas): # flip the image im = im.flip(0) # flip the annoted boxes boxes[:, [1, 3]] = im.shape[0] - boxes[:, [3, 1]] # flip the theta thetas = -thetas thetas = torch.where(thetas < 0, thetas + torch.pi, thetas) return im, boxes, thetas def __getitem__(self, index): im_path = self.images[index] im = Image.open(im_path) im_tensor = torchvision.transforms.ToTensor()(im) targets = {} ann_path = self.annotations[index] with open(ann_path, 'r') as f: im_info = json.load(f) xc = [detec['obb']['xc'] for detec in im_info['objects']] yc = [detec['obb']['yc'] for detec in im_info['objects']] w = [detec['obb']['w'] for detec in im_info['objects']] h = [detec['obb']['h'] for detec in im_info['objects']] theta = [detec['obb']['theta'] for detec in im_info['objects']] # read the angles here as well from the json file... boxes = [self.convert_xcycwh_to_xyxy([xc[i], yc[i], w[i], h[i]]) for i in range(len(xc))] thetas = [theta[i] for i in range(len(theta))] # boxes = [self.convert_xcycwhtheta_to_xyxytheta([xc[i], yc[i], w[i], h[i], theta[i]]) for i in range(len(xc))] targets['bboxes'] = torch.as_tensor(boxes).float() targets['labels'] = torch.as_tensor(torch.ones(len(im_info['objects'])).long()) targets['thetas'] = torch.deg2rad(torch.as_tensor(thetas).float()) if self.split == 'train' and self.augment: if random.random() > 0.7: im_tensor, targets['bboxes'], targets['thetas'] = self.horizontal_flip(im_tensor, targets['bboxes'], targets['thetas']) if random.random() > 0.7: im_tensor, targets['bboxes'], targets['thetas'] = self.vertical_flip(im_tensor, targets['bboxes'], targets['thetas']) return im_tensor, targets, im_path