import os import pandas as pd import torch from torch.utils.data import Dataset from pathlib import Path import rasterio as rio from PIL import Image import torchvision.transforms as transforms class MajorTOM(Dataset): """MajorTOM Dataset (https://huggingface.co/Major-TOM) Args: df ((geo)pandas.DataFrame): Metadata dataframe local_dir (string): Root directory of the local dataset version tif_bands (list): A list of tif file names to be read png_bands (list): A list of png file names to be read """ def __init__(self, df, local_dir = None, tif_bands=['B04','B03','B02'], png_bands=['thumbnail'], tif_transforms=[transforms.ToTensor()], png_transforms=[transforms.ToTensor()] ): super().__init__() self.df = df self.local_dir = Path(local_dir) if isinstance(local_dir,str) else local_dir self.tif_bands = tif_bands if not isinstance(tif_bands,str) else [tif_bands] self.png_bands = png_bands if not isinstance(png_bands,str) else [png_bands] self.tif_transforms = transforms.Compose(tif_transforms) if tif_transforms is not None else None self.png_transforms = transforms.Compose(png_transforms) if png_transforms is not None else None def __len__(self): return len(self.df) def __getitem__(self, idx): meta = self.df.iloc[idx] product_id = meta.product_id grid_cell = meta.grid_cell row = grid_cell.split('_')[0] path = self.local_dir / Path("{}/{}/{}".format(row, grid_cell, product_id)) out_dict = {'meta' : meta} for band in self.tif_bands: with rio.open(path / '{}.tif'.format(band)) as f: out = f.read() if self.tif_transforms is not None: out = self.tif_transforms(out) out_dict[band] = out for band in self.png_bands: out = Image.open(path / '{}.png'.format(band)) if self.png_transforms is not None: out = self.png_transforms(out) out_dict[band] = out return out_dict