KPLabs's picture
Upload folder using huggingface_hub
97a17c2 verified
import os
import rasterio
import torch
from torchgeo.datasets import NonGeoDataset
import torch.nn.functional as F
import numpy as np
class MethaneSimulatedDataset(NonGeoDataset):
def __init__(self, root_dir, excel_file, paths, transform=None):
super().__init__()
self.root_dir = root_dir
self.transform = transform
self.data_paths = []
# Collect paths
for folder_name in paths:
subdir_path = os.path.join(root_dir, folder_name)
if os.path.isdir(subdir_path):
# Note: Filenames here seem to match the folder name based on your script
label_path = os.path.join(subdir_path, folder_name + '_mask.tif')
scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat')
if os.path.exists(label_path) and os.path.exists(scube_path):
self.data_paths.append((label_path, scube_path))
def __len__(self):
return len(self.data_paths)
def __getitem__(self, idx):
label_path, scube_path = self.data_paths[idx]
# Load label
with rasterio.open(label_path) as label_src:
label_image = label_src.read()
# Load sCube (I1/TOA data)
with rasterio.open(scube_path) as scube_src:
scube_image = scube_src.read()
# Read only first 12 bands
scube_image = scube_image[:12, :, :]
# Convert to Tensors
scube_tensor = torch.from_numpy(scube_image).float()
label_tensor = torch.from_numpy(label_image).float()
# Resize
scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
label_tensor = label_tensor.clip(0, 1)
scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0)
# Convert labels to binary index (0 or 1)
contains_methane = (label_tensor > 0).any().long()
# Apply transformations
if self.transform:
# Albumentations expects [H, W, C]
transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
scube_tensor = transformed['image'].transpose(2, 0, 1)
return {
'image': scube_tensor, # <--- Named 'image' for TorchGeo
'label': contains_methane, # <--- Index for CrossEntropy
'sample': scube_path
}