File size: 3,851 Bytes
c5373a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os

import random
import torch
import imageio.v3 as imageio
import numpy as np
import skimage.morphology as morph
import torchvision.transforms.v2.functional as T_F

from skimage.filters import sato
from pathlib import Path
from scipy.ndimage import zoom
from torchvision.datasets.folder import has_file_allowed_extension


def make_dataset_t(image_dir, extensions=(".tif", ".tiff")):
    image_dir = Path(image_dir)
    images = [
        (path, image_dir / f'Ridge_{path.name}', image_dir / f'Basins_{path.name}')
        for path in sorted(image_dir.iterdir())
        if (has_file_allowed_extension(path.name, extensions)
            and (not path.name.startswith('Ridge_')) and (not path.name.startswith('Basins_')))
    ]
    return images

def make_dataset_t_v(image_dir, extensions=(".tif", ".tiff")):
    image_dir = Path(image_dir)
    # Use list comprehension for faster filtering
    images = [
        (path, image_dir / f'Ridge_{path.name}', image_dir / f'Basins_{path.name}')
        for path in sorted(image_dir.iterdir())
        if (has_file_allowed_extension(path.name, extensions)
            and (not path.name.startswith('Ridge_')) and (not path.name.startswith('Basins_')))
    ]

    # Shuffle in place
    random.shuffle(images)

    # Calculate split index once
    split_idx = int(0.95 * len(images))
    return images[:split_idx], images[split_idx:]

def augmentations(image, label1, label2):
    if random.random() < 0.5:
        image, label1, label2 = T_F.vflip(image), T_F.vflip(label1), T_F.vflip(label2)
    if random.random() < 0.5:
        image, label1, label2 = T_F.hflip(image), T_F.hflip(label1), T_F.vflip(label2)
    angles = [90, 180, 270]
    angle = random.choice(angles)
    if random.random() < 0.75:
        image, label1, label2 = T_F.rotate(image, angle), T_F.rotate(label1, angle), T_F.rotate(label2, angle)
    return image, label1, label2

mean, std = (149.95293407563648, 330.8314960521203)
target_water_level_range = [-100, 300]

class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, train_split):
        self.train_split = train_split

    def __len__(self):
        return len(self.train_split)

    def __getitem__(self, index):
        pair = self.train_split[index]
        img = torch.from_numpy(imageio.imread(str(pair[0])))[None, :]
        img = (img - mean) / std
        ridge = torch.from_numpy(imageio.imread(str(pair[1])))[None, :].to(torch.float16)
        basins = torch.from_numpy(imageio.imread(str(pair[2])))[None, :]
        water_level = random.randint(*target_water_level_range)
        basins = (basins >= water_level).to(torch.float16)
        img, ridge, basins = augmentations(img, ridge, basins)
        return img, ridge, basins, torch.tensor(water_level, dtype=torch.float16)

class ValDataset(torch.utils.data.Dataset):
    def __init__(self, val_split):
        self.val_split = val_split

    def __len__(self):
        return len(self.val_split)

    def __getitem__(self, index):
        pair = self.val_split[index]
        img = torch.from_numpy(imageio.imread(str(pair[0])))[None, :]
        img = (img - mean) / std
        ridge = torch.from_numpy(imageio.imread(str(pair[1])))[None, :].to(torch.float16)
        basins = torch.from_numpy(imageio.imread(str(pair[2])))[None, :]
        target_level = random.randint(*target_water_level_range)
        basins = (basins >= target_level).to(torch.float16)
        return img, ridge, basins, torch.tensor(target_level, dtype=torch.float16)

if __name__ == '__main__':
    train_split, val_split = make_dataset_t_v('dataset')

    train_dataset = TrainDataset(train_split)
    val_dataset = ValDataset(val_split)

    print(train_dataset.__getitem__(0))
    print(val_dataset.__getitem__(0))