Spaces:
Runtime error
Runtime error
| #from .gtdb import GTDB_CLASSES, GTDB_ROOT, GTDBAnnotationTransform, GTDBDetection | |
| from .gtdb_new import GTDB_CLASSES, GTDB_ROOT, GTDBAnnotationTransform, GTDBDetection | |
| from .config import * | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| def detection_collate(batch): | |
| """Custom collate fn for dealing with batches of images that have a different | |
| number of associated object annotations (bounding boxes). | |
| Arguments: | |
| batch: (tuple) A tuple of tensor images and lists of annotations | |
| Return: | |
| A tuple containing: | |
| 1) (tensor) batch of images stacked on their 0 dim | |
| 2) (list of tensors) annotations for a given image are stacked on | |
| 0 dim | |
| """ | |
| targets = [] | |
| imgs = [] | |
| ids = [] | |
| for sample in batch: | |
| imgs.append(sample[0]) | |
| targets.append(torch.FloatTensor(sample[1])) | |
| ids.append(sample[2]) | |
| return torch.stack(imgs, 0), targets, ids | |
| def base_transform(image, size, mean): | |
| #print('Image size ', image.shape) | |
| image = image.astype(np.float32) | |
| x = cv2.resize(image, (size, size), interpolation=cv2.INTER_AREA).astype(np.float32) | |
| x -= mean | |
| return x | |
| class BaseTransform: | |
| def __init__(self, size, mean): | |
| self.size = size | |
| self.mean = np.array(mean, dtype=np.float32) | |
| def __call__(self, image, boxes=None, labels=None): | |
| return base_transform(image, self.size, self.mean), boxes, labels | |