KPLabs's picture
Upload folder using huggingface_hub
97a17c2 verified
import pandas as pd
import albumentations as A
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torchgeo.datamodules import NonGeoDataModule
from methane_simulated_dataset import MethaneSimulatedDataset
class MethaneSimulatedDataModule(NonGeoDataModule):
def __init__(
self,
data_root: str,
excel_file: str,
batch_size: int = 8,
num_workers: int = 0,
val_split: float = 0.2,
seed: int = 42,
test_fold: int = 4,
num_folds: int = 5,
sim_tag: str = "toarefl", # <--- New arg for 'toarefl'/'boarefl'
**kwargs
):
super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs)
self.data_root = data_root
self.excel_file = excel_file
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split = val_split
self.seed = seed
self.test_fold = test_fold
self.num_folds = num_folds
self.sim_tag = sim_tag
self.train_paths = []
self.val_paths = []
def _get_training_transforms(self):
return A.Compose([
A.ElasticTransform(p=0.25),
A.RandomRotate90(p=0.5),
A.Flip(p=0.5),
A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
])
def _get_simulated_paths(self, paths):
"""Internal logic to rename files based on sim_tag"""
simulated_paths = []
for path in paths:
try:
tokens = path.split('_')
if len(tokens) >= 5:
# Logic: {ID}_{tag}_{Coord1}_{Coord2}
simulated_path = f"{tokens[0]}_{self.sim_tag}_{tokens[3]}_{tokens[4]}"
simulated_paths.append(simulated_path)
else:
simulated_paths.append(path)
except Exception:
simulated_paths.append(path)
return simulated_paths
def setup(self, stage: str = None):
# 1. Read Excel
try:
df = pd.read_excel(self.excel_file)
except Exception as e:
raise RuntimeError(f"Failed to load excel: {e}")
# 2. Filter Folds (Exclude test_fold)
all_folds = list(range(1, self.num_folds + 1))
train_pool_folds = [f for f in all_folds if f != self.test_fold]
df_filtered = df[df['Fold'].isin(train_pool_folds)]
raw_paths = df_filtered['Filename'].tolist()
# 3. Apply Path Renaming Logic
paths = self._get_simulated_paths(raw_paths)
# 4. Train/Val Split
self.train_paths, self.val_paths = train_test_split(
paths,
test_size=self.val_split,
random_state=self.seed
)
# 5. Instantiate Datasets
if stage in ("fit", "train"):
self.train_dataset = MethaneSimulatedDataset(
root_dir=self.data_root,
excel_file=self.excel_file,
paths=self.train_paths,
transform=self._get_training_transforms(),
)
if stage in ("fit", "validate", "val"):
self.val_dataset = MethaneSimulatedDataset(
root_dir=self.data_root,
excel_file=self.excel_file,
paths=self.val_paths,
transform=None,
)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_workers, drop_last=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False,
num_workers=self.num_workers, drop_last=True)
def on_after_batch_transfer(self, batch, dataloader_idx):
# 1. Run TorchGeo default (expects 'image')
batch = super().on_after_batch_transfer(batch, dataloader_idx)
# 2. Wrap into TerraMind format {'S2L2A': ...}
if 'image' in batch:
s2_data = batch['image']
batch['image'] = {'S2L2A': s2_data}
return batch