Spaces:
Build error
Build error
File size: 1,243 Bytes
bd88c34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | import numpy as np
import config
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
class MapDataset(Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.list_files = os.listdir(self.root_dir)
def __len__(self):
return len(self.list_files)
def __getitem__(self, index):
img_file = self.list_files[index]
img_path = os.path.join(self.root_dir, img_file)
image = np.array(Image.open(img_path))
input_image = image
target_image = image
augmentations = config.both_transform(image=input_image, image0=target_image)
input_image = augmentations["image0"]
target_image = augmentations["image"]
input_image = config.transform_only_input(image=input_image)["image"]
target_image = config.transform_only_mask(image=target_image)["image"]
return input_image, target_image
if __name__ == "__main__":
dataset = MapDataset("data/train/")
loader = DataLoader(dataset, batch_size=5)
for x, y in loader:
print(x.shape)
save_image(x, "x.png")
save_image(y, "y.png")
import sys
sys.exit() |