File size: 2,645 Bytes
b39a019 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
import pandas as pd
import numpy as np
from PIL import Image
class CrosswalkDataset:
def __init__(self, annotation_path, image_path, transform=None):
self.annotations = pd.read_csv(annotation_path)
self.image_dir = image_path
self.transform = transform
# In case later one we want to do normalisation, pre-processing etc. --> Someone else will look at this
self.unique_labels = sorted(self.annotations['class'].unique())
# There might be a more efficient method to do this -- come back to
self.type_mapping = {type_value: 1 + idx for idx, type_value in enumerate(self.unique_labels)}
# Labels have been converted to numerical class labels mapped by type mapping - for tensor conversion, 0 is
# the background object/ none type. For the current binary classifier everything other than zebra is 0
self.image_data = []
self.labels = [] # One-to-One mapping with bounding boxes by the way
self.process_annotations()
def __len__(self):
return len(self.image_data)
def __getitem__(self, index):
image = self.image_data[index]
class_label = self.labels[index]
tensor_image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
# From (H, W, C) to (C, H, W) --> This is the format that pytorch uses.
tensor_label = torch.tensor(class_label, dtype=torch.float32)
if self.transform:
tensor_image = self.transform(tensor_image)
return tensor_image, tensor_label
def process_annotations(self):
for filename, group in self.annotations.groupby('filename'):
completed_image_path = f"{self.image_dir}/{filename}"
image = Image.open(completed_image_path)
image_array = np.array(image)
entity_annotations = [1, 0]
for _, row in group.iterrows():
# We can add additional classes in here - come back to later
if row['class'] == "ZebraStyle":
# numerical_class = self.type_mapping[row['class']]
entity_annotations = [0, 1]
# print([numerical_class, (row['xmin'], row['ymin'], row['xmax'], row['ymax'])])
else:
pass
# Classified as a background object
self.image_data.append(image_array)
self.labels.append(entity_annotations)
crosswalk_dataset = CrosswalkDataset("Crosswalk.v7-crosswalk-t3.tensorflow/train/_annotations.csv",
"Crosswalk.v7-crosswalk-t3.tensorflow/train")
|