Spaces:
Running
Running
| import os | |
| import torch | |
| from torch.utils.data import Dataset | |
| from src.data.transforms import augment_triplet | |
| class SatelliteTripletDataset(Dataset): | |
| """PyTorch Dataset for loading pre-processed Satellite TIR triplets. | |
| Expects data to be stored as `.pt` files containing tensors of shape [3, 1, H, W], | |
| representing Past, Present, and Future Brightness Temperature frames. | |
| """ | |
| def __init__(self, data_dir: str, augment: bool = True): | |
| self.data_dir = data_dir | |
| self.augment = augment | |
| self.triplet_files = sorted( | |
| [f for f in os.listdir(data_dir) if f.endswith(".pt")] | |
| ) | |
| def __len__(self) -> int: | |
| return len(self.triplet_files) | |
| def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| file_path = os.path.join(self.data_dir, self.triplet_files[idx]) | |
| triplet = torch.load(file_path) # [3, 1, H, W] | |
| img0 = triplet[0] # Past | |
| gt = triplet[1] # Ground Truth / Present | |
| img1 = triplet[2] # Future | |
| if self.augment: | |
| img0, img1, gt = augment_triplet(img0, img1, gt) | |
| return img0, img1, gt | |