EarthEmbeddingExplorer / MajorTOM /MajorTOMDataset.py
ML4RS-Anonymous's picture
Upload all files
eb1aec4 verified
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from pathlib import Path
import rasterio as rio
from PIL import Image
import torchvision.transforms as transforms
class MajorTOM(Dataset):
"""MajorTOM Dataset (https://huggingface.co/Major-TOM)
Args:
df ((geo)pandas.DataFrame): Metadata dataframe
local_dir (string): Root directory of the local dataset version
tif_bands (list): A list of tif file names to be read
png_bands (list): A list of png file names to be read
"""
def __init__(self,
df,
local_dir = None,
tif_bands=['B04','B03','B02'],
png_bands=['thumbnail'],
tif_transforms=[transforms.ToTensor()],
png_transforms=[transforms.ToTensor()]
):
super().__init__()
self.df = df
self.local_dir = Path(local_dir) if isinstance(local_dir,str) else local_dir
self.tif_bands = tif_bands if not isinstance(tif_bands,str) else [tif_bands]
self.png_bands = png_bands if not isinstance(png_bands,str) else [png_bands]
self.tif_transforms = transforms.Compose(tif_transforms) if tif_transforms is not None else None
self.png_transforms = transforms.Compose(png_transforms) if png_transforms is not None else None
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
meta = self.df.iloc[idx]
product_id = meta.product_id
grid_cell = meta.grid_cell
row = grid_cell.split('_')[0]
path = self.local_dir / Path("{}/{}/{}".format(row, grid_cell, product_id))
out_dict = {'meta' : meta}
for band in self.tif_bands:
with rio.open(path / '{}.tif'.format(band)) as f:
out = f.read()
if self.tif_transforms is not None:
out = self.tif_transforms(out)
out_dict[band] = out
for band in self.png_bands:
out = Image.open(path / '{}.png'.format(band))
if self.png_transforms is not None:
out = self.png_transforms(out)
out_dict[band] = out
return out_dict