File size: 1,459 Bytes
6163604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#from .gtdb import GTDB_CLASSES, GTDB_ROOT, GTDBAnnotationTransform, GTDBDetection
from .gtdb_new import GTDB_CLASSES, GTDB_ROOT, GTDBAnnotationTransform, GTDBDetection
from .config import *
import torch
import cv2
import numpy as np

def detection_collate(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).

    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations

    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on
                                 0 dim
    """
    targets = []
    imgs = []
    ids = []

    for sample in batch:
        imgs.append(sample[0])
        targets.append(torch.FloatTensor(sample[1]))
        ids.append(sample[2])

    return torch.stack(imgs, 0), targets, ids


def base_transform(image, size, mean):
    #print('Image size ', image.shape)
    image = image.astype(np.float32)
    x = cv2.resize(image, (size, size), interpolation=cv2.INTER_AREA).astype(np.float32)
    x -= mean
    return x


class BaseTransform:
    def __init__(self, size, mean):
        self.size = size
        self.mean = np.array(mean, dtype=np.float32)

    def __call__(self, image, boxes=None, labels=None):
        return base_transform(image, self.size, self.mean), boxes, labels