import os import rasterio import torch from torchgeo.datasets import NonGeoDataset from torch.utils.data import DataLoader import torch.nn.functional as F import numpy as np import pandas as pd import json class MethaneTextDataset(NonGeoDataset): def __init__(self, root_dir, paths, captions, transform=None): super().__init__() self.root_dir = root_dir self.transform = transform self.data_paths = [] self.captions_dict = captions # Collect paths for labelbinary.tif and sCube.tif in selected folders for folder_name in paths: subdir_path = os.path.join(root_dir, folder_name) if os.path.isdir(subdir_path): filename_tokens = subdir_path.split("/") label_path = os.path.join(subdir_path, 'labelbinary.tif') scube_path = os.path.join(subdir_path, 'sCube.tif') if os.path.exists(label_path) and os.path.exists(scube_path): self.data_paths.append((label_path, scube_path)) else: print(f"Warning: Missing files in {subdir_path}. Expected labelbinary.tif and sCube.tif.") def __len__(self): return len(self.data_paths) def __getitem__(self, idx): label_path, scube_path = self.data_paths[idx] filename_tokens = label_path.split("/") folder_name = filename_tokens[-2] # Load the label image (single band) with rasterio.open(label_path) as label_src: label_image = label_src.read(1) # Shape: [512, 512] # Load the sCube image (multi-band), drop the first band with rasterio.open(scube_path) as scube_src: scube_image = scube_src.read() # Shape: [13, 512, 512] scube_image = scube_image[1:, :, :] # Drop first band → Shape: [12, 512, 512] # print(label_image.shape) # print(scube_image.shape) # Convert to PyTorch tensors scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512] label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512] # Resize to [12, 224, 224] and [224, 224] respectively scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0) label_tensor = F.interpolate(label_tensor.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0) label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1] scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0 # normalized_tensor = min_max_normalize(scube_tensor) # Convert labels to binary contains_methane = (label_tensor > 0).any().long() # Convert to one-hot encoding one_hot_label = F.one_hot(contains_methane, num_classes=2).float() # Apply transformations (if any) if self.transform: transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0))) scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W] if folder_name in self.captions_dict: caption = self.captions_dict[folder_name] else: # If the folder name is not in the captions_dict, set a default caption or None caption = 'No caption' # caption = self.captions_dict[folder_name] return {'S2L2A': scube_tensor, 'label': one_hot_label, 'caption': caption}