# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- import os import PIL import os, random, glob import numpy as np import torch import torch.utils.data as data import torchvision.transforms as transforms from os.path import isfile, join import segyio from itertools import permutations random.seed(42) from torchvision import datasets, transforms from timm.data import create_transform from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD def build_dataset(is_train, args): transform = build_transform(is_train, args) root = os.path.join(args.data_path, 'train' if is_train else 'val') dataset = datasets.ImageFolder(root, transform=transform) print(dataset) return dataset def build_transform(is_train, args): mean = IMAGENET_DEFAULT_MEAN std = IMAGENET_DEFAULT_STD # train transform if is_train: # this should always dispatch to transforms_imagenet_train transform = create_transform( input_size=args.input_size, is_training=True, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation='bicubic', re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, mean=mean, std=std, ) return transform # eval transform t = [] if args.input_size <= 224: crop_pct = 224 / 256 else: crop_pct = 1.0 size = int(args.input_size / crop_pct) t.append( transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(args.input_size)) t.append(transforms.ToTensor()) t.append(transforms.Normalize(mean, std)) return transforms.Compose(t) ## pretrain class SeismicSet(data.Dataset): def __init__(self, path, input_size) -> None: super().__init__() # self.file_list = os.listdir(path) # self.file_list = [os.path.join(path, f) for f in self.file_list] self.get_file_list(path) self.input_size = input_size print(len(self.file_list)) def __len__(self) -> int: return len(self.file_list) # return 100000 def __getitem__(self, index): d = np.fromfile(self.file_list[index], dtype=np.float32) d = d.reshape(1, self.input_size, self.input_size) d = (d - d.mean()) / (d.std()+1e-6) # return to_transforms(d, self.input_size) return d,torch.tensor([1]) def get_file_list(self, path): dirs = [os.path.join(path, f) for f in os.listdir(path)] self.file_list = dirs # for ds in dirs: # if os.path.isdir(ds): # self.file_list += [os.path.join(ds, f) for f in os.listdir(ds)] return random.shuffle(self.file_list) def to_transforms(d, input_size): t = transforms.Compose([ transforms.RandomResizedCrop(input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic transforms.RandomHorizontalFlip(), transforms.ToTensor() ]) return t(d) ### fintune class FacesSet(data.Dataset): # folder/train/data/**.dat, folder/train/label/**.dat # folder/test/data/**.dat, folder/test/label/**.dat def __init__(self, folder, shape=[768, 768], is_train=True) -> None: super().__init__() self.shape = shape # self.data_list = sorted(glob.glob(folder + 'seismic/*.dat')) self.data_list = [folder +'seismic/'+ str(f)+'.dat' for f in range(117)] n = len(self.data_list) if is_train: self.data_list = self.data_list[:100] elif not is_train: self.data_list = self.data_list[100:] self.label_list = [ f.replace('/seismic/', '/label/') for f in self.data_list ] def __getitem__(self, index): d = np.fromfile(self.data_list[index], np.float32) d = d.reshape([1] + self.shape) l = np.fromfile(self.label_list[index], np.float32).reshape(self.shape)-1 l = l.astype(int) return torch.tensor(d), torch.tensor(l) def __len__(self): return len(self.data_list) class SaltSet(data.Dataset): def __init__(self, folder, shape=[224, 224], is_train=True) -> None: super().__init__() self.shape = shape self.data_list = [folder +'seismic/'+ str(f)+'.dat' for f in range(4000)] n = len(self.data_list) if is_train: self.data_list = self.data_list[:3500] elif not is_train: self.data_list = self.data_list[3500:] self.label_list = [ f.replace('/seismic/', '/label/') for f in self.data_list ] def __getitem__(self, index): d = np.fromfile(self.data_list[index], np.float32) d = d.reshape([1] + self.shape) l = np.fromfile(self.label_list[index], np.float32).reshape(self.shape) l = l.astype(int) return torch.tensor(d), torch.tensor(l) def __len__(self): return len(self.data_list) class InterpolationSet(data.Dataset): # folder/train/data/**.dat, folder/train/label/**.dat # folder/test/data/**.dat, folder/test/label/**.dat def __init__(self, folder, shape=[224, 224], is_train=True) -> None: super().__init__() self.shape = shape self.data_list = [folder + str(f)+'.dat' for f in range(6000)] n = len(self.data_list) if is_train: self.data_list = self.data_list elif not is_train: self.data_list = [folder+'U'+ + str(f)+'.dat' for f in range(2000,4000)] self.label_list = self.data_list def __getitem__(self, index): d = np.fromfile(self.data_list[index], np.float32) d = d.reshape([1] + self.shape) return torch.tensor(d), torch.tensor(d) def __len__(self): return len(self.data_list) # return 10000 class DenoiseSet(data.Dataset): def __init__(self, folder, shape=[224, 224], is_train=True) -> None: super().__init__() self.shape = shape self.data_list = [folder+'seismic/'+ str(f)+'.dat' for f in range(2000)] n = len(self.data_list) if is_train: self.data_list = self.data_list self.label_list = [f.replace('/seismic/', '/label/') for f in self.data_list] elif not is_train: self.data_list = [folder+'field/'+ str(f)+'.dat' for f in range(4000)] self.label_list = self.data_list def __getitem__(self, index): d = np.fromfile(self.data_list[index], np.float32) d = d.reshape([1] + self.shape) # d = (d - d.mean())/d.std() l = np.fromfile(self.label_list[index], np.float32) l = l.reshape([1] + self.shape) # l = (l - d.mean())/l.std() return torch.tensor(d), torch.tensor(l) def __len__(self): return len(self.data_list) class ReflectSet(data.Dataset): # folder/train/data/**.dat, folder/train/label/**.dat # folder/test/data/**.dat, folder/test/label/**.dat def __init__(self, folder, shape=[224, 224], is_train=True) -> None: super().__init__() self.shape = shape self.data_list = [folder+'seismic/'+ str(f)+'.dat' for f in range(2200)] n = len(self.data_list) if is_train: self.data_list = self.data_list self.label_list = [ f.replace('/seismic/', '/label/') for f in self.data_list ] elif not is_train: self.data_list = [folder+'SEAMseismic/'+ str(f)+'.dat' for f in range(4000)] self.label_list = [ f.replace('/SEAMseismic/', '/SEAMreflect/') for f in self.data_list ] def __getitem__(self, index): d = np.fromfile(self.data_list[index], np.float32) d = d- d.mean() d = d/(d.std()+1e-6) d = d.reshape([1] + self.shape) l = np.fromfile(self.label_list[index], np.float32) l = l-l.mean() l = l/(l.std()+1e-6) l = l.reshape([1] + self.shape) return torch.tensor(d), torch.tensor(l) def __len__(self): return len(self.data_list) class ThebeSet(data.Dataset): def __init__(self, folder, shape=[224, 224], mode ='train') -> None: super().__init__() self.folder = folder if not os.path.exists(folder): raise FileNotFoundError(f"The folder {folder} does not exist.") self.num_files = len(os.listdir(join(folder, 'fault'))) self.shape = shape self.fault_list = [folder + '/fault/{i}.npy'.format(i=i) for i in range(1, self.num_files + 1)] self.seis_list = [folder + '/seis/{i}.npy'.format(i=i) for i in range(1, self.num_files + 1)] self.train_size = int(0.75 * self.num_files) self.val_size = int(0.15 * self.num_files) self.test_size = self.num_files - self.train_size - self.val_size self.train_index = self.train_size self.val_index = self.train_index + self.val_size if mode == 'train': self.fault_list = self.fault_list[:self.train_index] self.seis_list = self.seis_list[:self.train_index] elif mode == 'val': self.fault_list = self.fault_list[self.train_index:self.val_index] self.seis_list = self.seis_list[self.train_index:self.val_index] elif mode == 'test': self.fault_list = self.fault_list[self.val_index:] self.seis_list = self.seis_list[self.val_index:] else: raise ValueError("Mode must be 'train', 'val', or 'test'.") def __len__(self): return len(self.fault_list) def retrieve_patch(self, fault, seis): # image will (probably) be of size [3174, 1537] # return a patch of size [224, 224] patch_height = self.shape[0] patch_width = self.shape[1] h, w = fault.shape if h < patch_height or w < patch_width: raise ValueError(f"Image dimensions must be at least {patch_height}x{patch_width}.") top = random.randint(0, h - patch_height) left = random.randint(0, w - patch_width) return fault[top:top + patch_height, left:left + patch_width], seis[top:top + patch_height, left:left + patch_width] def random_transform(self, fault, seis): # Apply the same random transformations to the fault and seismic data # Mirror the patch horizontally if random.random() > 0.5: fault = np.fliplr(fault) seis = np.fliplr(seis) # Mirror the patch vertically if random.random() > 0.5: fault = np.flipud(fault) seis = np.flipud(seis) return fault, seis def __getitem__(self, index): # need to see if we do normalization here (i.e. what data pre-treatement we do) fault = np.load(self.fault_list[index]) seis = np.load(self.seis_list[index]) fault, seis = self.retrieve_patch(fault, seis) fault, seis = self.random_transform(fault, seis) seis = (seis - seis.mean()) / (seis.std() + 1e-6) fault = torch.tensor(fault.copy(), dtype=torch.float32).unsqueeze(0) seis = torch.tensor(seis.copy(), dtype=torch.float32).unsqueeze(0) return seis, fault class FSegSet(data.Dataset): def __init__(self, folder, shape=[128, 128], mode ='train') -> None: super().__init__() self.folder = folder if not os.path.exists(folder): raise FileNotFoundError(f"The folder {folder} does not exist.") self.shape = shape self.mode = mode if mode == 'train': self.fault_path = join(self.folder, 'train/fault') self.seis_path = join(self.folder, 'train/seis') elif mode == 'val': self.fault_path = join(self.folder, 'val/fault') self.seis_path = join(self.folder, 'val/seis') else: raise ValueError("Mode must be 'train' or 'val'.") self.fault_list = [join(self.fault_path, f) for f in os.listdir(self.fault_path) if f.endswith('.npy')] self.seis_list = [join(self.seis_path, f) for f in os.listdir(self.seis_path) if f.endswith('.npy')] def __len__(self): return len(self.fault_list) def __getitem__(self, index): fault_img, seis_img = np.load(self.fault_list[index]), np.load(self.seis_list[index]) # These will be 128x128 seis_img = (seis_img - seis_img.mean()) / (seis_img.std() + 1e-6) fault = torch.tensor(fault_img.copy(), dtype=torch.float32).unsqueeze(0) seis = torch.tensor(seis_img.copy(), dtype=torch.float32).unsqueeze(0) return seis, fault class F3DFaciesSet(data.Dataset): def __init__(self, folder, shape=[128, 128], mode='train', random_resize = False): super().__init__() self.folder = folder if not os.path.exists(folder): raise FileNotFoundError(f"The folder {folder} does not exist.") self.seises = np.load(join(folder, "{}/seismic.npy".format(mode))) self.labels = np.load(join(folder, "{}/labels.npy".format(mode))) self.image_shape = shape if mode == 'train': self.size_categories = [ (401, 701), (701, 255), (401, 255) ] elif mode == 'val': self.size_categories = [ (601, 200), (200, 255), (601, 255) ] elif mode == 'test': self.size_categories = [ (701, 255), (200, 701), (200, 255) ] else: raise ValueError("Mode must be 'train', 'val', or 'test'.") def __len__(self): # We will take cross sections along each dimension, so the length is the sum of all dimensions return sum(self.seises.shape) def random_transform(self, label, seis): # Apply the same random transformations to the fault and seismic data # Mirror the patch horizontally if random.random() > 0.5: label = np.fliplr(label) seis = np.fliplr(seis) # Mirror the patch vertically if random.random() > 0.5: label = np.flipud(label) seis = np.flipud(seis) return label, seis def __getitem__(self, index): m1, m2, m3 = self.seises.shape if index < m1: seis, label = self.seises[index, :, :], self.labels[index, :, :] elif index < m1 + m2: seis, label = self.seises[:, index - m1, :], self.labels[:, index - m1, :] elif index < m1 + m2 + m3: seis, label = self.seises[:, :, index - m1 - m2], self.labels[:, :, index - m1 - m2] else: raise IndexError("Index out of bounds") seis, label = self.random_transform(seis, label) seis = (seis - seis.mean()) / (seis.std() + 1e-6) seis, label = torch.tensor(seis.copy(), dtype=torch.float32).unsqueeze(0), torch.tensor(label.copy(), dtype=torch.float32).unsqueeze(0) # label is now shape [1, H, W] # we want shape [6, H, W] with each slice being a binary mask depending on the int value of label label = label.squeeze(0) label = (label == torch.arange(6, device=label.device).view(6, 1, 1)).float() return seis, label class P3DFaciesSet(data.Dataset): def __init__(self, folder, shape=[128, 128], mode='train', random_resize = False): super().__init__() self.folder = folder if not os.path.exists(folder): raise FileNotFoundError(f"The folder {folder} does not exist.") self.random_resize = random_resize # Validation set will be validation set from F3DSet if mode == 'val': mode = 'train' # TEMPORARY SINCE P3D does not have labelled val set self.mode = mode self.image_shape = shape self.s_path = join(folder, "{}/seismic.segy".format(mode)) self.l_path = join(folder, "{}/labels.segy".format(mode)) if mode != 'val': with segyio.open(self.s_path, ignore_geometry=True) as seis_file: self.seises = seis_file.trace.raw[:] if self.mode in ['val', 'train']: with segyio.open(self.l_path, ignore_geometry=True) as label_file: self.labels = label_file.trace.raw[:] else: # Since the test files are unlabeled self.labels = np.zeros_like(self.seises) else: f3d_file_path = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\F3D_facies_DATASET" self.seises = np.load(join(f3d_file_path, "val/seismic.npy")) self.labels = np.load(join(f3d_file_path, "val/labels.npy")) if mode == 'train': m1, m2, m3 = 590, 782, 1006 elif mode == 'val': m1, m2, m3 = 601, 200, 255 elif mode == 'test_1': m1, m2, m3 = 841, 334, 1006 elif mode == 'test_2': m1, m2, m3 = 251, 782, 1006 else: raise ValueError("Mode must be 'train', 'test_2', 'val', or 'test_1'.") self.size_categories = list(permutations([m1, m2, m3], 2)) self.seises = self.seises.reshape(m1, m2, m3) self.labels = self.labels.reshape(m1, m2, m3) def __len__(self): # We will take cross sections along the first 2 dimensions ONLY return self.seises.shape[0] + self.seises.shape[1] def _random_transform(self, label, seis): # Apply the same random transformations to the fault and seismic data # Mirror the patch horizontally if random.random() > 0.5: label = np.fliplr(label) seis = np.fliplr(seis) # Mirror the patch vertically if random.random() > 0.5: label = np.flipud(label) seis = np.flipud(seis) # random rotation to 2D image label,seis #r_int = random.randint(0, 3) #label = np.rot90(label, r_int) #seis = np.rot90(seis, r_int) return label, seis def _random_resize(self, label, seis, min_size = (256, 256)): # Randomly resize the label and seismic data r_height = random.randint(min_size[0], seis.shape[0]) r_width = random.randint(min_size[1], seis.shape[1]) r_pos_x = random.randint(0, seis.shape[0] - r_height) r_pos_y = random.randint(0, seis.shape[1] - r_width) label = label[r_pos_x:r_pos_x + r_height, r_pos_y:r_pos_y + r_width] seis = seis[r_pos_x:r_pos_x + r_height, r_pos_y:r_pos_y + r_width] return label, seis def __getitem__(self, index): m1, m2, m3 = self.seises.shape if index < m1: seis, label = self.seises[index, :, :], self.labels[index, :, :] elif index < m1 + m2: seis, label = self.seises[:, index - m1, :], self.labels[:, index - m1, :] elif index < m1 + m2 + m3: seis, label = self.seises[:, :, index - m1 - m2], self.labels[:, :, index - m1 - m2] else: raise IndexError("Index out of bounds") seis, label = self._random_transform(seis, label) if self.random_resize: seis, label = self._random_resize(seis, label) seis = (seis - seis.mean()) / (seis.std() + 1e-6) seis, label = torch.tensor(seis.copy(), dtype=torch.float32).unsqueeze(0), torch.tensor(label.copy(), dtype=torch.float32).unsqueeze(0) # label is now shape [1, H, W] # we want shape [6, H, W] with each slice being a binary mask depending on the int value of label label = label.squeeze(0) label = (label == torch.arange(1, 7, device=label.device).view(6, 1, 1)).float() return seis, label