Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| from sklearn.model_selection import train_test_split | |
| import albumentations as A | |
| from albumentations.pytorch.transforms import ToTensorV2 | |
| from PIL import Image | |
| from pathlib import Path | |
| from random import randint | |
| from utils import * | |
| """ | |
| Dataset class for storing stamps data. | |
| Arguments: | |
| data -- list of dictionaries containing file_path (path to the image), box_nb (number of boxes on the image), and boxes of shape (4,) | |
| image_folder -- path to folder containing images | |
| transforms -- transforms from albumentations package | |
| """ | |
| class StampDataset(Dataset): | |
| def __init__( | |
| self, | |
| data=read_data(), | |
| image_folder=Path(IMAGE_FOLDER), | |
| transforms=None): | |
| self.data = data | |
| self.image_folder = image_folder | |
| self.transforms = transforms | |
| def __getitem__(self, idx): | |
| item = self.data[idx] | |
| image_fn = self.image_folder / item['file_path'] | |
| boxes = item['boxes'] | |
| box_nb = item['box_nb'] | |
| labels = torch.zeros((box_nb, 2), dtype=torch.int64) | |
| labels[:, 0] = 1 | |
| img = np.array(Image.open(image_fn)) | |
| try: | |
| if self.transforms: | |
| sample = self.transforms(**{ | |
| "image":img, | |
| "bboxes": boxes, | |
| "labels": labels, | |
| }) | |
| img = sample['image'] | |
| boxes = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0) | |
| except: | |
| return self.__getitem__(randint(0, len(self.data)-1)) | |
| target_tensor = boxes_to_tensor(boxes.type(torch.float32)) | |
| return img, target_tensor | |
| def __len__(self): | |
| return len(self.data) | |
| def collate_fn(batch): | |
| return tuple(zip(*batch)) | |
| def get_datasets(data_path=ANNOTATIONS_PATH, train_transforms=None, val_transforms=None): | |
| """ | |
| Creates StampDataset objects. | |
| Arguments: | |
| data_path -- string or Path, specifying path to annotations file | |
| train_transforms -- transforms to be applied during training | |
| val_transforms -- transforms to be applied during validation | |
| Returns: | |
| (train_dataset, val_dataset) -- tuple of StampDataset for training and validation | |
| """ | |
| data = read_data(data_path) | |
| if train_transforms is None: | |
| train_transforms = A.Compose([ | |
| A.RandomCropNearBBox(max_part_shift=0.6, p=0.4), | |
| A.Resize(height=448, width=448), | |
| A.HorizontalFlip(p=0.5), | |
| A.VerticalFlip(p=0.5), | |
| # A.Affine(scale=(0.9, 1.1), translate_percent=(0.05, 0.1), rotate=(-45, 45), shear=(-30, 30), p=0.3), | |
| # A.Blur(blur_limit=4, p=0.3), | |
| A.Normalize(), | |
| ToTensorV2(p=1.0), | |
| ], | |
| bbox_params={ | |
| "format":"coco", | |
| 'label_fields': ['labels'] | |
| }) | |
| if val_transforms is None: | |
| val_transforms = A.Compose([ | |
| A.Resize(height=448, width=448), | |
| A.Normalize(), | |
| ToTensorV2(p=1.0), | |
| ], | |
| bbox_params={ | |
| "format":"coco", | |
| 'label_fields': ['labels'] | |
| }) | |
| train, test_data = train_test_split(data, test_size=0.1, shuffle=True) | |
| train_data, val_data = train_test_split(train, test_size=0.2, shuffle=True) | |
| train_dataset = StampDataset(train_data, transforms=train_transforms) | |
| val_dataset = StampDataset(val_data, transforms=val_transforms) | |
| test_dataset = StampDataset(test_data, transforms=val_transforms) | |
| return train_dataset, val_dataset, test_dataset | |
| def get_loaders(batch_size=8, data_path=ANNOTATIONS_PATH, num_workers=0, train_transforms=None, val_transforms=None): | |
| """ | |
| Creates StampDataset objects. | |
| Arguments: | |
| batch_size -- integer specifying the number of images in the batch | |
| data_path -- string or Path, specifying path to annotations file | |
| train_transforms -- transforms to be applied during training | |
| val_transforms -- transforms to be applied during validation | |
| Returns: | |
| (train_loader, val_loader) -- tuple of DataLoader for training and validation | |
| """ | |
| train_dataset, val_dataset, _ = get_datasets(data_path) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| collate_fn=collate_fn, drop_last=True) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| collate_fn=collate_fn) | |
| return train_loader, val_loader | |