File size: 3,529 Bytes
97a17c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import os
import rasterio
import torch
from torchgeo.datasets import NonGeoDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import pandas as pd
import json
class MethaneTextDataset(NonGeoDataset):
def __init__(self, root_dir, paths, captions, transform=None):
super().__init__()
self.root_dir = root_dir
self.transform = transform
self.data_paths = []
self.captions_dict = captions
# Collect paths for labelbinary.tif and sCube.tif in selected folders
for folder_name in paths:
subdir_path = os.path.join(root_dir, folder_name)
if os.path.isdir(subdir_path):
filename_tokens = subdir_path.split("/")
label_path = os.path.join(subdir_path, 'labelbinary.tif')
scube_path = os.path.join(subdir_path, 'sCube.tif')
if os.path.exists(label_path) and os.path.exists(scube_path):
self.data_paths.append((label_path, scube_path))
else:
print(f"Warning: Missing files in {subdir_path}. Expected labelbinary.tif and sCube.tif.")
def __len__(self):
return len(self.data_paths)
def __getitem__(self, idx):
label_path, scube_path = self.data_paths[idx]
filename_tokens = label_path.split("/")
folder_name = filename_tokens[-2]
# Load the label image (single band)
with rasterio.open(label_path) as label_src:
label_image = label_src.read(1) # Shape: [512, 512]
# Load the sCube image (multi-band), drop the first band
with rasterio.open(scube_path) as scube_src:
scube_image = scube_src.read() # Shape: [13, 512, 512]
scube_image = scube_image[1:, :, :] # Drop first band → Shape: [12, 512, 512]
# print(label_image.shape)
# print(scube_image.shape)
# Convert to PyTorch tensors
scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512]
# Resize to [12, 224, 224] and [224, 224] respectively
scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
label_tensor = F.interpolate(label_tensor.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1]
scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
# normalized_tensor = min_max_normalize(scube_tensor)
# Convert labels to binary
contains_methane = (label_tensor > 0).any().long()
# Convert to one-hot encoding
one_hot_label = F.one_hot(contains_methane, num_classes=2).float()
# Apply transformations (if any)
if self.transform:
transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
if folder_name in self.captions_dict:
caption = self.captions_dict[folder_name]
else:
# If the folder name is not in the captions_dict, set a default caption or None
caption = 'No caption'
# caption = self.captions_dict[folder_name]
return {'S2L2A': scube_tensor, 'label': one_hot_label, 'caption': caption} |