Spaces:
Sleeping
Sleeping
| import glob | |
| import os | |
| import random | |
| import torch | |
| import torchvision | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from torch.utils.data.dataset import Dataset | |
| import json | |
| class SceneTextDataset(Dataset): | |
| def __init__(self, split, root_dir, augment=False): | |
| self.split = split | |
| self.root_dir = root_dir | |
| self.im_dir = os.path.join(root_dir, 'img') | |
| self.ann_dir = os.path.join(root_dir, 'annots') | |
| classes = [ | |
| 'text' | |
| ] | |
| classes = sorted(classes) | |
| classes = ['background'] + classes | |
| self.label2idx = {classes[idx]: idx for idx in range(len(classes))} | |
| self.idx2label = {idx: classes[idx] for idx in range(len(classes))} | |
| print(self.idx2label) | |
| self.images = glob.glob(os.path.join(self.im_dir, '*.jpg')) | |
| self.annotations = [os.path.join(self.ann_dir, os.path.basename(im) + '.json') for im in self.images] | |
| if(split == 'train'): | |
| self.images = self.images[:int(0.9*len(self.images))] | |
| self.annotations = self.annotations[:int(0.9*len(self.annotations))] | |
| elif(split == 'val'): | |
| self.images = self.images[int(0.9*len(self.images)):] | |
| self.annotations = self.annotations[int(0.9*len(self.annotations)):] | |
| else: | |
| self.images = self.images[int(0.9*len(self.images)):] | |
| self.annotations = self.annotations[int(0.9*len(self.annotations)):] | |
| self.augment = augment | |
| def __len__(self): | |
| return len(self.images) | |
| def convert_xcycwh_to_xyxy(self, box): | |
| x, y, w, h = box | |
| x1 = x - w/2 | |
| y1 = y - h/2 | |
| x2 = x + w/2 | |
| y2 = y + h/2 | |
| return [x1, y1, x2, y2] | |
| def convert_xcycwhtheta_to_xyxytheta(self, box): | |
| x, y, w, h, theta = box | |
| x1 = x - w/2 | |
| y1 = y - h/2 | |
| x2 = x + w/2 | |
| y2 = y + h/2 | |
| return [x1, y1, x2, y2, theta] | |
| def horizontal_flip(self, im, boxes, thetas): | |
| # flip the image | |
| im = im.flip(-1) | |
| # flip the annoted boxes | |
| boxes[:, [0, 2]] = im.shape[1] - boxes[:, [2, 0]] | |
| # flip the theta | |
| thetas = -thetas | |
| thetas = torch.where(thetas < 0, thetas + torch.pi, thetas) | |
| return im, boxes, thetas | |
| def vertical_flip(self, im, boxes, thetas): | |
| # flip the image | |
| im = im.flip(0) | |
| # flip the annoted boxes | |
| boxes[:, [1, 3]] = im.shape[0] - boxes[:, [3, 1]] | |
| # flip the theta | |
| thetas = -thetas | |
| thetas = torch.where(thetas < 0, thetas + torch.pi, thetas) | |
| return im, boxes, thetas | |
| def __getitem__(self, index): | |
| im_path = self.images[index] | |
| im = Image.open(im_path) | |
| im_tensor = torchvision.transforms.ToTensor()(im) | |
| targets = {} | |
| ann_path = self.annotations[index] | |
| with open(ann_path, 'r') as f: | |
| im_info = json.load(f) | |
| xc = [detec['obb']['xc'] for detec in im_info['objects']] | |
| yc = [detec['obb']['yc'] for detec in im_info['objects']] | |
| w = [detec['obb']['w'] for detec in im_info['objects']] | |
| h = [detec['obb']['h'] for detec in im_info['objects']] | |
| theta = [detec['obb']['theta'] for detec in im_info['objects']] | |
| # read the angles here as well from the json file... | |
| boxes = [self.convert_xcycwh_to_xyxy([xc[i], yc[i], w[i], h[i]]) for i in range(len(xc))] | |
| thetas = [theta[i] for i in range(len(theta))] | |
| # boxes = [self.convert_xcycwhtheta_to_xyxytheta([xc[i], yc[i], w[i], h[i], theta[i]]) for i in range(len(xc))] | |
| targets['bboxes'] = torch.as_tensor(boxes).float() | |
| targets['labels'] = torch.as_tensor(torch.ones(len(im_info['objects'])).long()) | |
| targets['thetas'] = torch.deg2rad(torch.as_tensor(thetas).float()) | |
| if self.split == 'train' and self.augment: | |
| if random.random() > 0.7: | |
| im_tensor, targets['bboxes'], targets['thetas'] = self.horizontal_flip(im_tensor, targets['bboxes'], targets['thetas']) | |
| if random.random() > 0.7: | |
| im_tensor, targets['bboxes'], targets['thetas'] = self.vertical_flip(im_tensor, targets['bboxes'], targets['thetas']) | |
| return im_tensor, targets, im_path |