TerraMind-Methane-Classification / intuition1_classification_finetuning /config /methane_simulated_dataset.py
| import os | |
| import rasterio | |
| import torch | |
| from torchgeo.datasets import NonGeoDataset | |
| import torch.nn.functional as F | |
| import numpy as np | |
| class MethaneSimulatedDataset(NonGeoDataset): | |
| def __init__(self, root_dir, excel_file, paths, transform=None): | |
| super().__init__() | |
| self.root_dir = root_dir | |
| self.transform = transform | |
| self.data_paths = [] | |
| # Collect paths | |
| for folder_name in paths: | |
| subdir_path = os.path.join(root_dir, folder_name) | |
| if os.path.isdir(subdir_path): | |
| # Note: Filenames here seem to match the folder name based on your script | |
| label_path = os.path.join(subdir_path, folder_name + '_mask.tif') | |
| scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat') | |
| if os.path.exists(label_path) and os.path.exists(scube_path): | |
| self.data_paths.append((label_path, scube_path)) | |
| def __len__(self): | |
| return len(self.data_paths) | |
| def __getitem__(self, idx): | |
| label_path, scube_path = self.data_paths[idx] | |
| # Load label | |
| with rasterio.open(label_path) as label_src: | |
| label_image = label_src.read() | |
| # Load sCube (I1/TOA data) | |
| with rasterio.open(scube_path) as scube_src: | |
| scube_image = scube_src.read() | |
| # Read only first 12 bands | |
| scube_image = scube_image[:12, :, :] | |
| # Convert to Tensors | |
| scube_tensor = torch.from_numpy(scube_image).float() | |
| label_tensor = torch.from_numpy(label_image).float() | |
| # Resize | |
| scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0) | |
| label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0) | |
| label_tensor = label_tensor.clip(0, 1) | |
| scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) | |
| # Convert labels to binary index (0 or 1) | |
| contains_methane = (label_tensor > 0).any().long() | |
| # Apply transformations | |
| if self.transform: | |
| # Albumentations expects [H, W, C] | |
| transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0))) | |
| scube_tensor = transformed['image'].transpose(2, 0, 1) | |
| return { | |
| 'image': scube_tensor, # <--- Named 'image' for TorchGeo | |
| 'label': contains_methane, # <--- Index for CrossEntropy | |
| 'sample': scube_path | |
| } |