math2tex / ScanSSD /data /__init__.py
duycse1603's picture
[Add] source
6163604
#from .gtdb import GTDB_CLASSES, GTDB_ROOT, GTDBAnnotationTransform, GTDBDetection
from .gtdb_new import GTDB_CLASSES, GTDB_ROOT, GTDBAnnotationTransform, GTDBDetection
from .config import *
import torch
import cv2
import numpy as np
def detection_collate(batch):
"""Custom collate fn for dealing with batches of images that have a different
number of associated object annotations (bounding boxes).
Arguments:
batch: (tuple) A tuple of tensor images and lists of annotations
Return:
A tuple containing:
1) (tensor) batch of images stacked on their 0 dim
2) (list of tensors) annotations for a given image are stacked on
0 dim
"""
targets = []
imgs = []
ids = []
for sample in batch:
imgs.append(sample[0])
targets.append(torch.FloatTensor(sample[1]))
ids.append(sample[2])
return torch.stack(imgs, 0), targets, ids
def base_transform(image, size, mean):
#print('Image size ', image.shape)
image = image.astype(np.float32)
x = cv2.resize(image, (size, size), interpolation=cv2.INTER_AREA).astype(np.float32)
x -= mean
return x
class BaseTransform:
def __init__(self, size, mean):
self.size = size
self.mean = np.array(mean, dtype=np.float32)
def __call__(self, image, boxes=None, labels=None):
return base_transform(image, self.size, self.mean), boxes, labels