File size: 2,727 Bytes
97a17c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | import os
import rasterio
import torch
from torchgeo.datasets import NonGeoDataset
import torch.nn.functional as F
import numpy as np
class MethaneSimulatedDataset(NonGeoDataset):
def __init__(self, root_dir, excel_file, paths, transform=None):
super().__init__()
self.root_dir = root_dir
self.transform = transform
self.data_paths = []
# Collect paths
for folder_name in paths:
subdir_path = os.path.join(root_dir, folder_name)
if os.path.isdir(subdir_path):
# Construct paths based on folder name
label_path = os.path.join(subdir_path, folder_name + '_mask.tif')
scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat')
if os.path.exists(label_path) and os.path.exists(scube_path):
self.data_paths.append((label_path, scube_path))
def __len__(self):
return len(self.data_paths)
def __getitem__(self, idx):
label_path, scube_path = self.data_paths[idx]
# Load label
with rasterio.open(label_path) as label_src:
label_image = label_src.read()
# Load sCube (Try explicit ENVI driver first for .dat files)
try:
with rasterio.open(scube_path, driver='ENVI') as scube_src:
scube_image = scube_src.read()
scube_image = scube_image[:12, :, :] # Read first 12 bands
except Exception:
# Fallback if driver auto-detection is needed
with rasterio.open(scube_path) as scube_src:
scube_image = scube_src.read()
scube_image = scube_image[:12, :, :]
# Convert to Tensors
scube_tensor = torch.from_numpy(scube_image).float()
label_tensor = torch.from_numpy(label_image).float()
# Resize
scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
label_tensor = label_tensor.clip(0, 1)
scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0)
# Convert labels to binary index (0 or 1)
contains_methane = (label_tensor > 0).any().long()
# Apply transformations
if self.transform:
transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
scube_tensor = transformed['image'].transpose(2, 0, 1)
return {
'image': scube_tensor, # <--- 'image' for TorchGeo
'label': contains_methane, # <--- Index for CE Loss
'sample': scube_path
} |