File size: 5,363 Bytes
78a947a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """
Contains an implementation of a dataset for preprocessed images. It expects the images
to be stored as quantized uint16 NumPy arrays. It also loads masks and reference values.
Author: Ole-Christian Galbo Engstrøm
E-mail: ocge@foss.dk
"""
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision import tv_tensors
from .utils import dequantize_16_bit
class PreprocessedImageDataset(Dataset):
"""
This is the dataset class that we used to load the preprocessed images
(already cropped or padded to the input size of 2360 x 1272 pixels)
"""
def __init__(
self,
reference_values,
images_path,
quant_biases_path,
quant_scales_path,
masks_path,
csv_path,
split,
split_column,
transform,
):
self.reference_values = reference_values
self.csv_path = Path(csv_path)
self.split = split
self.split_column = split_column
if images_path is not None:
self.images_path = Path(images_path)
try:
self.quant_scales_path = Path(quant_scales_path)
self.quant_biases_path = Path(quant_biases_path)
except TypeError:
self.quant_scales_path = None
self.quant_biases_path = None
else:
self.images_path = None
if masks_path is not None:
self.masks_path = Path(masks_path)
else:
self.masks_path = None
self.transform = transform
self.df = self._load_csv()
if self.images_path is not None:
self.image_file_names = self._load_file_names(self.images_path, "npy")
if self.quant_scales_path is not None:
self.bias_file_names = self._load_file_names(
self.quant_biases_path, "npy"
)
self.scale_file_names = self._load_file_names(
self.quant_scales_path, "npy"
)
else:
self.bias_file_names = None
self.scale_file_names = None
if self.masks_path is not None:
self.mask_file_names = self._load_file_names(self.masks_path, "npy")
self.mask_files = self._load_mask_files()
def __len__(self):
if self.images_path is not None:
return len(self.image_file_names)
elif self.masks_path is not None:
return len(self.mask_file_names)
else:
raise ValueError("Either images_path or masks_path must be provided.")
def _load_csv(self):
df = pd.read_csv(self.csv_path)
# Use the split column to filter the dataframe. self.split is a list of integers
df = df[df[self.split_column].isin(self.split)]
df.set_index("meat_id", inplace=True)
return df
def _load_file_names(self, path: Path, file_extension: str):
file_names = []
for f in sorted(list(path.glob(f"*.{file_extension}"))):
meat_id = int(f.stem.split("-")[0])
if meat_id not in self.df.index.values:
continue
file_names.append(f)
return file_names
def _load_mask_files(self):
# Masks are around 2 MB each, so we can load them all at once
mask_files = {}
for f in self.mask_file_names:
key = f
mask = np.load(f).astype(np.float32)
mask_files[key] = mask
return mask_files
def __getitem__(self, idx):
if self.images_path is not None:
file_id = self.image_file_names[idx].stem
meat_id = int(file_id.split("-")[0])
img = np.load(self.image_file_names[idx])
if self.quant_scales_path is not None:
bias = np.load(self.bias_file_names[idx])
scale = np.load(self.scale_file_names[idx])
img = dequantize_16_bit(img, bias, scale)
img = tv_tensors.Image(img)
else:
img = None
if self.masks_path is not None:
file_id = self.mask_file_names[idx].stem
meat_id = int(file_id.split("-")[0])
mask = self.mask_files[self.mask_file_names[idx]]
mask = tv_tensors.Mask(mask)
else:
mask = None
return_tuple = ()
if img is not None:
return_tuple += (img,)
if mask is not None:
return_tuple += (mask,)
if self.transform is not None:
if len(return_tuple) == 1:
return_tuple = (self.transform(*return_tuple),)
else:
return_tuple = self.transform(*return_tuple)
if self.reference_values is not None:
refs = self.df.loc[meat_id, self.reference_values].values
refs = torch.tensor(refs, dtype=torch.float32)
if self.masks_path is not None:
# Extract mask from the return tuple
mask = return_tuple[-1]
# Couple the reference values with the mask
mask_refs_tuple = ((mask, refs),)
return_tuple = return_tuple[:-1] + mask_refs_tuple
else:
return_tuple += (refs,)
return_tuple += (file_id,)
return return_tuple
|