| import os |
| import os.path |
| import sys |
| import torch |
| import torch.utils.data as data |
| import cv2 |
| import numpy as np |
|
|
| class WiderFaceDetection(data.Dataset): |
| def __init__(self, txt_path, preproc=None): |
| self.preproc = preproc |
| self.imgs_path = [] |
| self.words = [] |
| f = open(txt_path,'r') |
| lines = f.readlines() |
| isFirst = True |
| labels = [] |
| for line in lines: |
| line = line.rstrip() |
| if line.startswith('#'): |
| if isFirst==True: |
| isFirst = False |
| else: |
| labels_copy = labels.copy() |
| self.words.append(labels_copy) |
| labels.clear() |
| path = line[2:] |
| path = txt_path.replace('label.txt','images/') + path |
| self.imgs_path.append(path) |
| else: |
| line = line.split(' ') |
| label = [float(x) for x in line] |
| labels.append(label) |
|
|
| self.words.append(labels) |
|
|
| def __len__(self): |
| return len(self.imgs_path) |
|
|
| def __getitem__(self, index): |
| img = cv2.imread(self.imgs_path[index]) |
| height, width, _ = img.shape |
|
|
| labels = self.words[index] |
| annotations = np.zeros((0, 15)) |
| if len(labels) == 0: |
| return annotations |
| for idx, label in enumerate(labels): |
| annotation = np.zeros((1, 15)) |
| |
| annotation[0, 0] = label[0] |
| annotation[0, 1] = label[1] |
| annotation[0, 2] = label[0] + label[2] |
| annotation[0, 3] = label[1] + label[3] |
|
|
| |
| annotation[0, 4] = label[4] |
| annotation[0, 5] = label[5] |
| annotation[0, 6] = label[7] |
| annotation[0, 7] = label[8] |
| annotation[0, 8] = label[10] |
| annotation[0, 9] = label[11] |
| annotation[0, 10] = label[13] |
| annotation[0, 11] = label[14] |
| annotation[0, 12] = label[16] |
| annotation[0, 13] = label[17] |
| if (annotation[0, 4]<0): |
| annotation[0, 14] = -1 |
| else: |
| annotation[0, 14] = 1 |
|
|
| annotations = np.append(annotations, annotation, axis=0) |
| target = np.array(annotations) |
| if self.preproc is not None: |
| img, target = self.preproc(img, target) |
|
|
| return torch.from_numpy(img), target |
|
|
| def detection_collate(batch): |
| """Custom collate fn for dealing with batches of images that have a different |
| number of associated object annotations (bounding boxes). |
| |
| Arguments: |
| batch: (tuple) A tuple of tensor images and lists of annotations |
| |
| Return: |
| A tuple containing: |
| 1) (tensor) batch of images stacked on their 0 dim |
| 2) (list of tensors) annotations for a given image are stacked on 0 dim |
| """ |
| targets = [] |
| imgs = [] |
| for _, sample in enumerate(batch): |
| for _, tup in enumerate(sample): |
| if torch.is_tensor(tup): |
| imgs.append(tup) |
| elif isinstance(tup, type(np.empty(0))): |
| annos = torch.from_numpy(tup).float() |
| targets.append(annos) |
|
|
| return (torch.stack(imgs, 0), targets) |
|
|