File size: 2,246 Bytes
eb1aec4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from pathlib import Path
import rasterio as rio
from PIL import Image
import torchvision.transforms as transforms
class MajorTOM(Dataset):
"""MajorTOM Dataset (https://huggingface.co/Major-TOM)
Args:
df ((geo)pandas.DataFrame): Metadata dataframe
local_dir (string): Root directory of the local dataset version
tif_bands (list): A list of tif file names to be read
png_bands (list): A list of png file names to be read
"""
def __init__(self,
df,
local_dir = None,
tif_bands=['B04','B03','B02'],
png_bands=['thumbnail'],
tif_transforms=[transforms.ToTensor()],
png_transforms=[transforms.ToTensor()]
):
super().__init__()
self.df = df
self.local_dir = Path(local_dir) if isinstance(local_dir,str) else local_dir
self.tif_bands = tif_bands if not isinstance(tif_bands,str) else [tif_bands]
self.png_bands = png_bands if not isinstance(png_bands,str) else [png_bands]
self.tif_transforms = transforms.Compose(tif_transforms) if tif_transforms is not None else None
self.png_transforms = transforms.Compose(png_transforms) if png_transforms is not None else None
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
meta = self.df.iloc[idx]
product_id = meta.product_id
grid_cell = meta.grid_cell
row = grid_cell.split('_')[0]
path = self.local_dir / Path("{}/{}/{}".format(row, grid_cell, product_id))
out_dict = {'meta' : meta}
for band in self.tif_bands:
with rio.open(path / '{}.tif'.format(band)) as f:
out = f.read()
if self.tif_transforms is not None:
out = self.tif_transforms(out)
out_dict[band] = out
for band in self.png_bands:
out = Image.open(path / '{}.png'.format(band))
if self.png_transforms is not None:
out = self.png_transforms(out)
out_dict[band] = out
return out_dict
|