import os import rasterio import torch from torchgeo.datasets import NonGeoDataset import torch.nn.functional as F import numpy as np class MethaneSimulatedDataset(NonGeoDataset): def __init__(self, root_dir, excel_file, paths, transform=None): super().__init__() self.root_dir = root_dir self.transform = transform self.data_paths = [] # Collect paths for folder_name in paths: subdir_path = os.path.join(root_dir, folder_name) if os.path.isdir(subdir_path): # Note: Filenames here seem to match the folder name based on your script label_path = os.path.join(subdir_path, folder_name + '_mask.tif') scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat') if os.path.exists(label_path) and os.path.exists(scube_path): self.data_paths.append((label_path, scube_path)) def __len__(self): return len(self.data_paths) def __getitem__(self, idx): label_path, scube_path = self.data_paths[idx] # Load label with rasterio.open(label_path) as label_src: label_image = label_src.read() # Load sCube (I1/TOA data) with rasterio.open(scube_path) as scube_src: scube_image = scube_src.read() # Read only first 12 bands scube_image = scube_image[:12, :, :] # Convert to Tensors scube_tensor = torch.from_numpy(scube_image).float() label_tensor = torch.from_numpy(label_image).float() # Resize scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0) label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0) label_tensor = label_tensor.clip(0, 1) scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Convert labels to binary index (0 or 1) contains_methane = (label_tensor > 0).any().long() # Apply transformations if self.transform: # Albumentations expects [H, W, C] transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0))) scube_tensor = transformed['image'].transpose(2, 0, 1) return { 'image': scube_tensor, # <--- Named 'image' for TorchGeo 'label': contains_methane, # <--- Index for CrossEntropy 'sample': scube_path }