Spaces:
Runtime error
Runtime error
| import os | |
| import albumentations | |
| import numpy as np | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| class HRCWHU(Dataset): | |
| METAINFO = dict( | |
| classes=('clear sky', 'cloud'), | |
| palette=((128, 192, 128), (255, 255, 255)), | |
| img_size=(3, 256, 256), # C, H, W | |
| ann_size=(256, 256), # C, H, W | |
| train_size=120, | |
| test_size=30, | |
| ) | |
| def __init__(self, root, phase, all_transform: albumentations.Compose = None, | |
| img_transform: albumentations.Compose = None, | |
| ann_transform: albumentations.Compose = None, seed: int = 42): | |
| self.root = root | |
| self.phase = phase | |
| self.all_transform = all_transform | |
| self.img_transform = img_transform | |
| self.ann_transform = ann_transform | |
| self.seed = seed | |
| self.data = self.load_data() | |
| def load_data(self): | |
| data_list = [] | |
| split = 'train' if self.phase == 'train' else 'test' | |
| split_file = os.path.join(self.root, f'{split}.txt') | |
| with open(split_file, 'r') as f: | |
| for line in f: | |
| image_file = line.strip() | |
| img_path = os.path.join(self.root, 'img_dir', split, image_file) | |
| ann_path = os.path.join(self.root, 'ann_dir', split, image_file) | |
| lac_type = image_file.split('_')[0] | |
| data_list.append((img_path, ann_path, lac_type)) | |
| return data_list | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| img_path, ann_path, lac_type = self.data[idx] | |
| img = Image.open(img_path) | |
| ann = Image.open(ann_path) | |
| img = np.array(img) | |
| ann = np.array(ann) | |
| if self.all_transform: | |
| albumention = self.all_transform(image=img, mask=ann) | |
| img = albumention['image'] | |
| ann = albumention['mask'] | |
| if self.img_transform: | |
| img = self.img_transform(image=img)['image'] | |
| if self.ann_transform: | |
| ann = self.ann_transform(image=img)['image'] | |
| # if self.img_transform is not None: | |
| # img = self.img_transform(img) | |
| # if self.ann_transform is not None: | |
| # ann = self.ann_transform(ann) | |
| # if self.all_transform is not None: | |
| # # 对img和ann实现相同的随机变换操作 | |
| # # seed_everything(self.seed, workers=True) | |
| # # random.seed(self.seed) | |
| # # img= self.all_transform(img) | |
| # # seed_everything(self.seed, workers=True) | |
| # # random.seed(self.seed) | |
| # # ann= self.all_transform(ann) | |
| # merge = torch.cat((img, ann), dim=0) | |
| # merge = self.all_transform(merge) | |
| # img = merge[:-1] | |
| # ann = merge[-1] | |
| return { | |
| 'img': img, | |
| 'ann': np.int64(ann), | |
| 'img_path': img_path, | |
| 'ann_path': ann_path, | |
| 'lac_type': lac_type, | |
| } | |
| if __name__ == '__main__': | |
| import torchvision.transforms as transforms | |
| import torch | |
| # all_transform = transforms.Compose([ | |
| # transforms.RandomCrop((256, 256)), | |
| # ]) | |
| all_transform = transforms.RandomCrop((256, 256)) | |
| # img_transform = transforms.Compose([ | |
| # transforms.ToTensor(), | |
| # ]) | |
| img_transform = transforms.ToTensor() | |
| # ann_transform = transforms.Compose([ | |
| # transforms.PILToTensor(), | |
| # ]) | |
| ann_transform = transforms.PILToTensor() | |
| train_dataset = HRCWHU(root='data/hrcwhu', phase='train', all_transform=all_transform, img_transform=img_transform, | |
| ann_transform=ann_transform) | |
| test_dataset = HRCWHU(root='data/hrcwhu', phase='test', all_transform=all_transform, img_transform=img_transform, | |
| ann_transform=ann_transform) | |
| assert len(train_dataset) == train_dataset.METAINFO['train_size'] | |
| assert len(test_dataset) == test_dataset.METAINFO['test_size'] | |
| train_sample = train_dataset[0] | |
| test_sample = test_dataset[0] | |
| assert train_sample['img'].shape == test_sample['img'].shape == train_dataset.METAINFO['img_size'] | |
| assert train_sample['ann'].shape == test_sample['ann'].shape == train_dataset.METAINFO['ann_size'] | |
| import matplotlib.pyplot as plt | |
| fig, axs = plt.subplots(1, 2, figsize=(10, 5)) | |
| for train_sample in train_dataset: | |
| axs[0].imshow(train_sample['img'].permute(1, 2, 0)) | |
| axs[0].set_title('Image') | |
| axs[1].imshow(torch.tensor(train_dataset.METAINFO['palette'])[train_sample['ann']]) | |
| axs[1].set_title('Annotation') | |
| plt.suptitle(f'Land Cover Type: {train_sample["lac_type"].capitalize()}', y=0.8) | |
| plt.tight_layout() | |
| plt.savefig('HRCWHU_sample.png', bbox_inches="tight") | |
| # break | |