Flashscape-V0 / data_utils.py
Fgdfgfthgr's picture
Upload 4 files
c5373a4 verified
raw
history blame
3.85 kB
import os
import random
import torch
import imageio.v3 as imageio
import numpy as np
import skimage.morphology as morph
import torchvision.transforms.v2.functional as T_F
from skimage.filters import sato
from pathlib import Path
from scipy.ndimage import zoom
from torchvision.datasets.folder import has_file_allowed_extension
def make_dataset_t(image_dir, extensions=(".tif", ".tiff")):
image_dir = Path(image_dir)
images = [
(path, image_dir / f'Ridge_{path.name}', image_dir / f'Basins_{path.name}')
for path in sorted(image_dir.iterdir())
if (has_file_allowed_extension(path.name, extensions)
and (not path.name.startswith('Ridge_')) and (not path.name.startswith('Basins_')))
]
return images
def make_dataset_t_v(image_dir, extensions=(".tif", ".tiff")):
image_dir = Path(image_dir)
# Use list comprehension for faster filtering
images = [
(path, image_dir / f'Ridge_{path.name}', image_dir / f'Basins_{path.name}')
for path in sorted(image_dir.iterdir())
if (has_file_allowed_extension(path.name, extensions)
and (not path.name.startswith('Ridge_')) and (not path.name.startswith('Basins_')))
]
# Shuffle in place
random.shuffle(images)
# Calculate split index once
split_idx = int(0.95 * len(images))
return images[:split_idx], images[split_idx:]
def augmentations(image, label1, label2):
if random.random() < 0.5:
image, label1, label2 = T_F.vflip(image), T_F.vflip(label1), T_F.vflip(label2)
if random.random() < 0.5:
image, label1, label2 = T_F.hflip(image), T_F.hflip(label1), T_F.vflip(label2)
angles = [90, 180, 270]
angle = random.choice(angles)
if random.random() < 0.75:
image, label1, label2 = T_F.rotate(image, angle), T_F.rotate(label1, angle), T_F.rotate(label2, angle)
return image, label1, label2
mean, std = (149.95293407563648, 330.8314960521203)
target_water_level_range = [-100, 300]
class TrainDataset(torch.utils.data.Dataset):
def __init__(self, train_split):
self.train_split = train_split
def __len__(self):
return len(self.train_split)
def __getitem__(self, index):
pair = self.train_split[index]
img = torch.from_numpy(imageio.imread(str(pair[0])))[None, :]
img = (img - mean) / std
ridge = torch.from_numpy(imageio.imread(str(pair[1])))[None, :].to(torch.float16)
basins = torch.from_numpy(imageio.imread(str(pair[2])))[None, :]
water_level = random.randint(*target_water_level_range)
basins = (basins >= water_level).to(torch.float16)
img, ridge, basins = augmentations(img, ridge, basins)
return img, ridge, basins, torch.tensor(water_level, dtype=torch.float16)
class ValDataset(torch.utils.data.Dataset):
def __init__(self, val_split):
self.val_split = val_split
def __len__(self):
return len(self.val_split)
def __getitem__(self, index):
pair = self.val_split[index]
img = torch.from_numpy(imageio.imread(str(pair[0])))[None, :]
img = (img - mean) / std
ridge = torch.from_numpy(imageio.imread(str(pair[1])))[None, :].to(torch.float16)
basins = torch.from_numpy(imageio.imread(str(pair[2])))[None, :]
target_level = random.randint(*target_water_level_range)
basins = (basins >= target_level).to(torch.float16)
return img, ridge, basins, torch.tensor(target_level, dtype=torch.float16)
if __name__ == '__main__':
train_split, val_split = make_dataset_t_v('dataset')
train_dataset = TrainDataset(train_split)
val_dataset = ValDataset(val_split)
print(train_dataset.__getitem__(0))
print(val_dataset.__getitem__(0))