unet_chemical_map / src /datasets /preprocessed_image_dataset.py
Sm00thix's picture
Initial upload of source and weights
78a947a
"""
Contains an implementation of a dataset for preprocessed images. It expects the images
to be stored as quantized uint16 NumPy arrays. It also loads masks and reference values.
Author: Ole-Christian Galbo Engstrøm
E-mail: ocge@foss.dk
"""
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision import tv_tensors
from .utils import dequantize_16_bit
class PreprocessedImageDataset(Dataset):
"""
This is the dataset class that we used to load the preprocessed images
(already cropped or padded to the input size of 2360 x 1272 pixels)
"""
def __init__(
self,
reference_values,
images_path,
quant_biases_path,
quant_scales_path,
masks_path,
csv_path,
split,
split_column,
transform,
):
self.reference_values = reference_values
self.csv_path = Path(csv_path)
self.split = split
self.split_column = split_column
if images_path is not None:
self.images_path = Path(images_path)
try:
self.quant_scales_path = Path(quant_scales_path)
self.quant_biases_path = Path(quant_biases_path)
except TypeError:
self.quant_scales_path = None
self.quant_biases_path = None
else:
self.images_path = None
if masks_path is not None:
self.masks_path = Path(masks_path)
else:
self.masks_path = None
self.transform = transform
self.df = self._load_csv()
if self.images_path is not None:
self.image_file_names = self._load_file_names(self.images_path, "npy")
if self.quant_scales_path is not None:
self.bias_file_names = self._load_file_names(
self.quant_biases_path, "npy"
)
self.scale_file_names = self._load_file_names(
self.quant_scales_path, "npy"
)
else:
self.bias_file_names = None
self.scale_file_names = None
if self.masks_path is not None:
self.mask_file_names = self._load_file_names(self.masks_path, "npy")
self.mask_files = self._load_mask_files()
def __len__(self):
if self.images_path is not None:
return len(self.image_file_names)
elif self.masks_path is not None:
return len(self.mask_file_names)
else:
raise ValueError("Either images_path or masks_path must be provided.")
def _load_csv(self):
df = pd.read_csv(self.csv_path)
# Use the split column to filter the dataframe. self.split is a list of integers
df = df[df[self.split_column].isin(self.split)]
df.set_index("meat_id", inplace=True)
return df
def _load_file_names(self, path: Path, file_extension: str):
file_names = []
for f in sorted(list(path.glob(f"*.{file_extension}"))):
meat_id = int(f.stem.split("-")[0])
if meat_id not in self.df.index.values:
continue
file_names.append(f)
return file_names
def _load_mask_files(self):
# Masks are around 2 MB each, so we can load them all at once
mask_files = {}
for f in self.mask_file_names:
key = f
mask = np.load(f).astype(np.float32)
mask_files[key] = mask
return mask_files
def __getitem__(self, idx):
if self.images_path is not None:
file_id = self.image_file_names[idx].stem
meat_id = int(file_id.split("-")[0])
img = np.load(self.image_file_names[idx])
if self.quant_scales_path is not None:
bias = np.load(self.bias_file_names[idx])
scale = np.load(self.scale_file_names[idx])
img = dequantize_16_bit(img, bias, scale)
img = tv_tensors.Image(img)
else:
img = None
if self.masks_path is not None:
file_id = self.mask_file_names[idx].stem
meat_id = int(file_id.split("-")[0])
mask = self.mask_files[self.mask_file_names[idx]]
mask = tv_tensors.Mask(mask)
else:
mask = None
return_tuple = ()
if img is not None:
return_tuple += (img,)
if mask is not None:
return_tuple += (mask,)
if self.transform is not None:
if len(return_tuple) == 1:
return_tuple = (self.transform(*return_tuple),)
else:
return_tuple = self.transform(*return_tuple)
if self.reference_values is not None:
refs = self.df.loc[meat_id, self.reference_values].values
refs = torch.tensor(refs, dtype=torch.float32)
if self.masks_path is not None:
# Extract mask from the return tuple
mask = return_tuple[-1]
# Couple the reference values with the mask
mask_refs_tuple = ((mask, refs),)
return_tuple = return_tuple[:-1] + mask_refs_tuple
else:
return_tuple += (refs,)
return_tuple += (file_id,)
return return_tuple