|
|
""" |
|
|
Contains an implementation of a dataset for preprocessed images. It expects the images |
|
|
to be stored as quantized uint16 NumPy arrays. It also loads masks and reference values. |
|
|
|
|
|
Author: Ole-Christian Galbo Engstrøm |
|
|
E-mail: ocge@foss.dk |
|
|
""" |
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision import tv_tensors |
|
|
|
|
|
from .utils import dequantize_16_bit |
|
|
|
|
|
|
|
|
class PreprocessedImageDataset(Dataset): |
|
|
""" |
|
|
This is the dataset class that we used to load the preprocessed images |
|
|
(already cropped or padded to the input size of 2360 x 1272 pixels) |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
reference_values, |
|
|
images_path, |
|
|
quant_biases_path, |
|
|
quant_scales_path, |
|
|
masks_path, |
|
|
csv_path, |
|
|
split, |
|
|
split_column, |
|
|
transform, |
|
|
): |
|
|
self.reference_values = reference_values |
|
|
self.csv_path = Path(csv_path) |
|
|
self.split = split |
|
|
self.split_column = split_column |
|
|
if images_path is not None: |
|
|
self.images_path = Path(images_path) |
|
|
try: |
|
|
self.quant_scales_path = Path(quant_scales_path) |
|
|
self.quant_biases_path = Path(quant_biases_path) |
|
|
except TypeError: |
|
|
self.quant_scales_path = None |
|
|
self.quant_biases_path = None |
|
|
else: |
|
|
self.images_path = None |
|
|
if masks_path is not None: |
|
|
self.masks_path = Path(masks_path) |
|
|
else: |
|
|
self.masks_path = None |
|
|
self.transform = transform |
|
|
|
|
|
self.df = self._load_csv() |
|
|
if self.images_path is not None: |
|
|
self.image_file_names = self._load_file_names(self.images_path, "npy") |
|
|
if self.quant_scales_path is not None: |
|
|
self.bias_file_names = self._load_file_names( |
|
|
self.quant_biases_path, "npy" |
|
|
) |
|
|
self.scale_file_names = self._load_file_names( |
|
|
self.quant_scales_path, "npy" |
|
|
) |
|
|
else: |
|
|
self.bias_file_names = None |
|
|
self.scale_file_names = None |
|
|
if self.masks_path is not None: |
|
|
self.mask_file_names = self._load_file_names(self.masks_path, "npy") |
|
|
self.mask_files = self._load_mask_files() |
|
|
|
|
|
def __len__(self): |
|
|
if self.images_path is not None: |
|
|
return len(self.image_file_names) |
|
|
elif self.masks_path is not None: |
|
|
return len(self.mask_file_names) |
|
|
else: |
|
|
raise ValueError("Either images_path or masks_path must be provided.") |
|
|
|
|
|
def _load_csv(self): |
|
|
df = pd.read_csv(self.csv_path) |
|
|
|
|
|
df = df[df[self.split_column].isin(self.split)] |
|
|
df.set_index("meat_id", inplace=True) |
|
|
return df |
|
|
|
|
|
def _load_file_names(self, path: Path, file_extension: str): |
|
|
file_names = [] |
|
|
for f in sorted(list(path.glob(f"*.{file_extension}"))): |
|
|
meat_id = int(f.stem.split("-")[0]) |
|
|
if meat_id not in self.df.index.values: |
|
|
continue |
|
|
file_names.append(f) |
|
|
return file_names |
|
|
|
|
|
def _load_mask_files(self): |
|
|
|
|
|
mask_files = {} |
|
|
for f in self.mask_file_names: |
|
|
key = f |
|
|
mask = np.load(f).astype(np.float32) |
|
|
mask_files[key] = mask |
|
|
return mask_files |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
if self.images_path is not None: |
|
|
file_id = self.image_file_names[idx].stem |
|
|
meat_id = int(file_id.split("-")[0]) |
|
|
img = np.load(self.image_file_names[idx]) |
|
|
if self.quant_scales_path is not None: |
|
|
bias = np.load(self.bias_file_names[idx]) |
|
|
scale = np.load(self.scale_file_names[idx]) |
|
|
img = dequantize_16_bit(img, bias, scale) |
|
|
img = tv_tensors.Image(img) |
|
|
else: |
|
|
img = None |
|
|
|
|
|
if self.masks_path is not None: |
|
|
file_id = self.mask_file_names[idx].stem |
|
|
meat_id = int(file_id.split("-")[0]) |
|
|
mask = self.mask_files[self.mask_file_names[idx]] |
|
|
mask = tv_tensors.Mask(mask) |
|
|
else: |
|
|
mask = None |
|
|
return_tuple = () |
|
|
if img is not None: |
|
|
return_tuple += (img,) |
|
|
if mask is not None: |
|
|
return_tuple += (mask,) |
|
|
if self.transform is not None: |
|
|
if len(return_tuple) == 1: |
|
|
return_tuple = (self.transform(*return_tuple),) |
|
|
else: |
|
|
return_tuple = self.transform(*return_tuple) |
|
|
|
|
|
if self.reference_values is not None: |
|
|
refs = self.df.loc[meat_id, self.reference_values].values |
|
|
refs = torch.tensor(refs, dtype=torch.float32) |
|
|
if self.masks_path is not None: |
|
|
|
|
|
mask = return_tuple[-1] |
|
|
|
|
|
mask_refs_tuple = ((mask, refs),) |
|
|
return_tuple = return_tuple[:-1] + mask_refs_tuple |
|
|
else: |
|
|
return_tuple += (refs,) |
|
|
return_tuple += (file_id,) |
|
|
return return_tuple |
|
|
|