master_thesis_models / src /datamodules /focus_datamodule.py
Hannes Kuchelmeister
readd cod modifications
91867af
import os
from typing import List, Optional, Tuple
import pandas as pd
from skimage import io
import numpy as np
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.transforms import transforms
class FocusDataSet(Dataset):
"""Dataset for z-stacked images of neglected tropical diseaeses."""
def __init__(
self, csv_file, root_dir, transform=None, in_memory=True, additional_col_list=[]
):
"""Initialize focus satck dataset.
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.metadata = pd.read_csv(csv_file)
self.in_memory = in_memory
self.additional_col_index = {}
_col_list = list(additional_col_list) # clone list to avoid modifying default
for attribute in _col_list:
self.additional_col_index[attribute] = self.metadata.columns.get_loc(
attribute
)
self.col_index_path = self.metadata.columns.get_loc("image_path")
self.col_index_focus = self.metadata.columns.get_loc("focus_height")
self.root_dir = root_dir
self.transform = transform
self.images = []
if self.in_memory:
self.images = np.array(
list(map(self._load_img, self.metadata["image_path"].tolist()))
)
def _load_img(self, img_path):
path = os.path.join(self.root_dir, img_path)
img = io.imread(path)
return img
def __len__(self) -> int:
"""Get the length of the dataset.
Returns:
int: the length
"""
return len(self.metadata)
def __getitem__(self, idx):
"""Get one items from the dataset.
Args:
idx (int) The index of the sample that is to be retrieved
Returns:
Item/Items which is a dictionary containing "image" and "focus_height"
"""
if torch.is_tensor(idx):
idx = idx.tolist()
if self.in_memory:
image = self.images[idx]
else:
image = self._load_img(self.metadata.iloc[idx, self.col_index_path])
if self.transform:
image = self.transform(image)
focus_height = torch.from_numpy(
np.asarray(self.metadata.iloc[idx, self.col_index_focus])
).float()
sample = {"image": image, "focus_height": focus_height}
for attr, col_idx in self.additional_col_index.items():
sample[attr] = self.metadata.iloc[idx, col_idx]
return sample
class FocusDataModule(LightningDataModule):
"""
LightningDataModule for FocusStack dataset.
"""
def __init__(
self,
data_dir: str = "data/",
csv_train_file: str = "data/train_metadata.csv",
csv_val_file: str = "data/validation_metadata.csv",
csv_test_file: str = "data/test_metadata.csv",
batch_size: int = 64,
num_workers: int = 0,
pin_memory: bool = False,
in_memory: bool = True,
augmentation: bool = False,
additional_col_list: List[str] = [],
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
self.save_hyperparameters(logger=False)
transform_list = [
transforms.ToTensor(),
transforms.ConvertImageDtype(torch.float),
]
self.base_transforms = []
self.base_transforms.extend(transform_list)
self.base_transforms = transforms.Compose(self.base_transforms)
if augmentation:
transform_list.extend(
[
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomChoice(
[
transforms.RandomApply(
[transforms.RandomRotation((90, 90))], p=0.5
),
transforms.RandomApply(
[transforms.RandomRotation((180, 180))], p=0.5
),
transforms.RandomApply(
[transforms.RandomRotation((270, 270))], p=0.5
),
]
),
]
)
# data transformations
self.transforms = transforms.Compose(transform_list)
self.data_train: Optional[Dataset] = None
self.data_val: Optional[Dataset] = None
self.data_test: Optional[Dataset] = None
self.in_memory = in_memory
self.additional_col_list = additional_col_list
def prepare_data(self):
"""This method is not implemented as of yet.
Download data if needed. This method is called only from a single GPU.
Do not use it to assign state (self.x = y).
"""
pass
def setup(self, stage: Optional[str] = None):
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning twice for `trainer.fit()` and `trainer.test()`, so be careful if you do a random split!
The `stage` can be used to differentiate whether it's called before trainer.fit()` or `trainer.test()`."""
# load datasets only if they're not loaded already
if not self.data_train and not self.data_val and not self.data_test:
self.data_train = FocusDataSet(
self.hparams.csv_train_file,
self.hparams.data_dir,
transform=self.transforms,
in_memory=self.in_memory,
additional_col_list=self.additional_col_list,
)
self.data_val = FocusDataSet(
self.hparams.csv_val_file,
self.hparams.data_dir,
transform=self.base_transforms,
in_memory=self.in_memory,
additional_col_list=self.additional_col_list,
)
self.data_test = FocusDataSet(
self.hparams.csv_test_file,
self.hparams.data_dir,
transform=self.base_transforms,
in_memory=self.in_memory,
additional_col_list=self.additional_col_list,
)
def train_dataloader(self):
return DataLoader(
dataset=self.data_train,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
dataset=self.data_val,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)
def test_dataloader(self):
return DataLoader(
dataset=self.data_test,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)