| | import os |
| | import rasterio |
| | import torch |
| | from torchgeo.datasets import NonGeoDataset |
| | from torch.utils.data import DataLoader |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import pandas as pd |
| | import json |
| |
|
| | class MethaneTextDataset(NonGeoDataset): |
| | def __init__(self, root_dir, paths, captions, transform=None): |
| | super().__init__() |
| | self.root_dir = root_dir |
| | self.transform = transform |
| | self.data_paths = [] |
| | self.captions_dict = captions |
| |
|
| | |
| | for folder_name in paths: |
| | subdir_path = os.path.join(root_dir, folder_name) |
| | if os.path.isdir(subdir_path): |
| | filename_tokens = subdir_path.split("/") |
| | label_path = os.path.join(subdir_path, 'labelbinary.tif') |
| | scube_path = os.path.join(subdir_path, 'sCube.tif') |
| | if os.path.exists(label_path) and os.path.exists(scube_path): |
| | self.data_paths.append((label_path, scube_path)) |
| | else: |
| | print(f"Warning: Missing files in {subdir_path}. Expected labelbinary.tif and sCube.tif.") |
| |
|
| |
|
| | def __len__(self): |
| | return len(self.data_paths) |
| |
|
| | def __getitem__(self, idx): |
| | label_path, scube_path = self.data_paths[idx] |
| | filename_tokens = label_path.split("/") |
| | folder_name = filename_tokens[-2] |
| |
|
| | |
| | with rasterio.open(label_path) as label_src: |
| | label_image = label_src.read(1) |
| |
|
| | |
| | with rasterio.open(scube_path) as scube_src: |
| | scube_image = scube_src.read() |
| | scube_image = scube_image[1:, :, :] |
| |
|
| | |
| | |
| | |
| | scube_tensor = torch.from_numpy(scube_image).float() |
| | label_tensor = torch.from_numpy(label_image).float() |
| |
|
| | |
| | scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0) |
| | label_tensor = F.interpolate(label_tensor.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0) |
| |
|
| | label_tensor = label_tensor.clip(0, 1) |
| | scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) |
| | |
| | |
| | contains_methane = (label_tensor > 0).any().long() |
| |
|
| | |
| | one_hot_label = F.one_hot(contains_methane, num_classes=2).float() |
| | |
| | |
| | if self.transform: |
| | transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0))) |
| | scube_tensor = transformed['image'].transpose(2, 0, 1) |
| |
|
| | if folder_name in self.captions_dict: |
| | caption = self.captions_dict[folder_name] |
| | else: |
| | |
| | caption = 'No caption' |
| | |
| | |
| |
|
| | return {'S2L2A': scube_tensor, 'label': one_hot_label, 'caption': caption} |