File size: 3,430 Bytes
97a17c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import os
import rasterio
import torch
from torchgeo.datasets import NonGeoDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import pandas as pd
def min_max_normalize(data, new_min=0, new_max=1):
data = np.array(data, dtype=np.float32) # Convert to NumPy array
# Handle NaN, Inf values
data = np.nan_to_num(data, nan=np.nanmin(data), posinf=np.max(data), neginf=np.min(data))
old_min, old_max = np.min(data), np.max(data)
if old_max == old_min: # Prevent division by zero
return np.full_like(data, new_min, dtype=np.float32) # Uniform array
return (data - old_min) / (old_max - old_min + 1e-10) * (new_max - new_min) + new_min
class MethaneSimulatedDataset(NonGeoDataset):
def __init__(self, root_dir, excel_file, paths, transform=None):
super().__init__()
self.root_dir = root_dir
self.transform = transform
self.data_paths = []
# Collect paths for labelbinary.tif and sCube.tif in selected folders
for folder_name in paths:
subdir_path = os.path.join(root_dir, folder_name)
if os.path.isdir(subdir_path):
label_path = os.path.join(subdir_path, folder_name + '_mask.tif')
scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat')
if os.path.exists(label_path) and os.path.exists(scube_path):
self.data_paths.append((label_path, scube_path))
def __len__(self):
return len(self.data_paths)
def __getitem__(self, idx):
label_path, scube_path = self.data_paths[idx]
# Load the label image (single band)
with rasterio.open(label_path) as label_src:
label_image = label_src.read() # Shape: [512, 512]
# Load the sCube image (multi-band), drop the first band
with rasterio.open(scube_path) as scube_src:
scube_image = scube_src.read() # Shape: [13, 512, 512]
# Read only the first 12 bands for testing purposes
# Map the bands later on
scube_image = scube_image[:12, :, :]
# Convert to PyTorch tensors
scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512]
# Resize to [12, 224, 224] and [224, 224] respectively
scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1]
scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
# normalized_tensor = min_max_normalize(scube_tensor)
# Convert labels to binary
contains_methane = (label_tensor > 0).any().long()
# Convert to one-hot encoding
one_hot_label = F.one_hot(contains_methane, num_classes=2).float()
# Apply transformations (if any)
if self.transform:
transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
return {'S2L2A': scube_tensor, 'label': one_hot_label, 'sample': scube_path}
|