Fgdfgfthgr commited on
Commit
c8c6825
·
verified ·
1 Parent(s): 3b47fa1

Delete data_utils.py

Browse files
Files changed (1) hide show
  1. data_utils.py +0 -99
data_utils.py DELETED
@@ -1,99 +0,0 @@
1
- import os
2
-
3
- import random
4
- import torch
5
- import imageio.v3 as imageio
6
- import numpy as np
7
- import skimage.morphology as morph
8
- import torchvision.transforms.v2.functional as T_F
9
-
10
- from skimage.filters import sato
11
- from pathlib import Path
12
- from scipy.ndimage import zoom
13
- from torchvision.datasets.folder import has_file_allowed_extension
14
-
15
-
16
- def make_dataset_t(image_dir, extensions=(".tif", ".tiff")):
17
- image_dir = Path(image_dir)
18
- images = [
19
- (path, image_dir / f'Ridge_{path.name}', image_dir / f'Basins_{path.name}')
20
- for path in sorted(image_dir.iterdir())
21
- if (has_file_allowed_extension(path.name, extensions)
22
- and (not path.name.startswith('Ridge_')) and (not path.name.startswith('Basins_')))
23
- ]
24
- return images
25
-
26
- def make_dataset_t_v(image_dir, extensions=(".tif", ".tiff")):
27
- image_dir = Path(image_dir)
28
- # Use list comprehension for faster filtering
29
- images = [
30
- (path, image_dir / f'Ridge_{path.name}', image_dir / f'Basins_{path.name}')
31
- for path in sorted(image_dir.iterdir())
32
- if (has_file_allowed_extension(path.name, extensions)
33
- and (not path.name.startswith('Ridge_')) and (not path.name.startswith('Basins_')))
34
- ]
35
-
36
- # Shuffle in place
37
- random.shuffle(images)
38
-
39
- # Calculate split index once
40
- split_idx = int(0.95 * len(images))
41
- return images[:split_idx], images[split_idx:]
42
-
43
- def augmentations(image, label1, label2):
44
- if random.random() < 0.5:
45
- image, label1, label2 = T_F.vflip(image), T_F.vflip(label1), T_F.vflip(label2)
46
- if random.random() < 0.5:
47
- image, label1, label2 = T_F.hflip(image), T_F.hflip(label1), T_F.vflip(label2)
48
- angles = [90, 180, 270]
49
- angle = random.choice(angles)
50
- if random.random() < 0.75:
51
- image, label1, label2 = T_F.rotate(image, angle), T_F.rotate(label1, angle), T_F.rotate(label2, angle)
52
- return image, label1, label2
53
-
54
- mean, std = (149.95293407563648, 330.8314960521203)
55
- target_water_level_range = [-100, 300]
56
-
57
- class TrainDataset(torch.utils.data.Dataset):
58
- def __init__(self, train_split):
59
- self.train_split = train_split
60
-
61
- def __len__(self):
62
- return len(self.train_split)
63
-
64
- def __getitem__(self, index):
65
- pair = self.train_split[index]
66
- img = torch.from_numpy(imageio.imread(str(pair[0])))[None, :]
67
- img = (img - mean) / std
68
- ridge = torch.from_numpy(imageio.imread(str(pair[1])))[None, :].to(torch.float16)
69
- basins = torch.from_numpy(imageio.imread(str(pair[2])))[None, :]
70
- water_level = random.randint(*target_water_level_range)
71
- basins = (basins >= water_level).to(torch.float16)
72
- img, ridge, basins = augmentations(img, ridge, basins)
73
- return img, ridge, basins, torch.tensor(water_level, dtype=torch.float16)
74
-
75
- class ValDataset(torch.utils.data.Dataset):
76
- def __init__(self, val_split):
77
- self.val_split = val_split
78
-
79
- def __len__(self):
80
- return len(self.val_split)
81
-
82
- def __getitem__(self, index):
83
- pair = self.val_split[index]
84
- img = torch.from_numpy(imageio.imread(str(pair[0])))[None, :]
85
- img = (img - mean) / std
86
- ridge = torch.from_numpy(imageio.imread(str(pair[1])))[None, :].to(torch.float16)
87
- basins = torch.from_numpy(imageio.imread(str(pair[2])))[None, :]
88
- target_level = random.randint(*target_water_level_range)
89
- basins = (basins >= target_level).to(torch.float16)
90
- return img, ridge, basins, torch.tensor(target_level, dtype=torch.float16)
91
-
92
- if __name__ == '__main__':
93
- train_split, val_split = make_dataset_t_v('dataset')
94
-
95
- train_dataset = TrainDataset(train_split)
96
- val_dataset = ValDataset(val_split)
97
-
98
- print(train_dataset.__getitem__(0))
99
- print(val_dataset.__getitem__(0))