File size: 5,363 Bytes
78a947a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Contains an implementation of a dataset for preprocessed images. It expects the images
to be stored as quantized uint16 NumPy arrays. It also loads masks and reference values.

Author: Ole-Christian Galbo Engstrøm
E-mail: ocge@foss.dk
"""

from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision import tv_tensors

from .utils import dequantize_16_bit


class PreprocessedImageDataset(Dataset):
    """
    This is the dataset class that we used to load the preprocessed images
    (already cropped or padded to the input size of 2360 x 1272 pixels)
    """
    def __init__(
        self,
        reference_values,
        images_path,
        quant_biases_path,
        quant_scales_path,
        masks_path,
        csv_path,
        split,
        split_column,
        transform,
    ):
        self.reference_values = reference_values
        self.csv_path = Path(csv_path)
        self.split = split
        self.split_column = split_column
        if images_path is not None:
            self.images_path = Path(images_path)
            try:
                self.quant_scales_path = Path(quant_scales_path)
                self.quant_biases_path = Path(quant_biases_path)
            except TypeError:
                self.quant_scales_path = None
                self.quant_biases_path = None
        else:
            self.images_path = None
        if masks_path is not None:
            self.masks_path = Path(masks_path)
        else:
            self.masks_path = None
        self.transform = transform

        self.df = self._load_csv()
        if self.images_path is not None:
            self.image_file_names = self._load_file_names(self.images_path, "npy")
            if self.quant_scales_path is not None:
                self.bias_file_names = self._load_file_names(
                    self.quant_biases_path, "npy"
                )
                self.scale_file_names = self._load_file_names(
                    self.quant_scales_path, "npy"
                )
            else:
                self.bias_file_names = None
                self.scale_file_names = None
        if self.masks_path is not None:
            self.mask_file_names = self._load_file_names(self.masks_path, "npy")
            self.mask_files = self._load_mask_files()

    def __len__(self):
        if self.images_path is not None:
            return len(self.image_file_names)
        elif self.masks_path is not None:
            return len(self.mask_file_names)
        else:
            raise ValueError("Either images_path or masks_path must be provided.")

    def _load_csv(self):
        df = pd.read_csv(self.csv_path)
        # Use the split column to filter the dataframe. self.split is a list of integers
        df = df[df[self.split_column].isin(self.split)]
        df.set_index("meat_id", inplace=True)
        return df

    def _load_file_names(self, path: Path, file_extension: str):
        file_names = []
        for f in sorted(list(path.glob(f"*.{file_extension}"))):
            meat_id = int(f.stem.split("-")[0])
            if meat_id not in self.df.index.values:
                continue
            file_names.append(f)
        return file_names

    def _load_mask_files(self):
        # Masks are around 2 MB each, so we can load them all at once
        mask_files = {}
        for f in self.mask_file_names:
            key = f
            mask = np.load(f).astype(np.float32)
            mask_files[key] = mask
        return mask_files

    def __getitem__(self, idx):
        if self.images_path is not None:
            file_id = self.image_file_names[idx].stem
            meat_id = int(file_id.split("-")[0])
            img = np.load(self.image_file_names[idx])
            if self.quant_scales_path is not None:
                bias = np.load(self.bias_file_names[idx])
                scale = np.load(self.scale_file_names[idx])
                img = dequantize_16_bit(img, bias, scale)
            img = tv_tensors.Image(img)
        else:
            img = None

        if self.masks_path is not None:
            file_id = self.mask_file_names[idx].stem
            meat_id = int(file_id.split("-")[0])
            mask = self.mask_files[self.mask_file_names[idx]]
            mask = tv_tensors.Mask(mask)
        else:
            mask = None
        return_tuple = ()
        if img is not None:
            return_tuple += (img,)
        if mask is not None:
            return_tuple += (mask,)
        if self.transform is not None:
            if len(return_tuple) == 1:
                return_tuple = (self.transform(*return_tuple),)
            else:
                return_tuple = self.transform(*return_tuple)

        if self.reference_values is not None:
            refs = self.df.loc[meat_id, self.reference_values].values
            refs = torch.tensor(refs, dtype=torch.float32)
            if self.masks_path is not None:
                # Extract mask from the return tuple
                mask = return_tuple[-1]
                # Couple the reference values with the mask
                mask_refs_tuple = ((mask, refs),)
                return_tuple = return_tuple[:-1] + mask_refs_tuple
            else:
                return_tuple += (refs,)
        return_tuple += (file_id,)
        return return_tuple