Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| from copy import deepcopy | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| # from logger import logger | |
| import numpy as np | |
| # import torch | |
| # import torch.utils.data as torchdata | |
| # import torchvision.transforms as tvf | |
| from omegaconf import DictConfig, OmegaConf | |
| import pytorch_lightning as pl | |
| from dataset.UAV.dataset import UavMapPair | |
| # from torch.utils.data import Dataset, DataLoader | |
| # from torchvision import transforms | |
| from torch.utils.data import Dataset, ConcatDataset | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| import torchvision.transforms as tvf | |
| # 自定义数据模块类,继承自pl.LightningDataModule | |
| class UavMapDatasetModule(pl.LightningDataModule): | |
| def __init__(self, cfg: Dict[str, Any]): | |
| super().__init__() | |
| # default_cfg = OmegaConf.create(self.default_cfg) | |
| # OmegaConf.set_struct(default_cfg, True) # cannot add new keys | |
| # self.cfg = OmegaConf.merge(default_cfg, cfg) | |
| self.cfg=cfg | |
| # self.transform = tvf.Compose([ | |
| # tvf.ToTensor(), | |
| # tvf.Resize(self.cfg.image_size), | |
| # tvf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
| # ]) | |
| tfs = [] | |
| tfs.append(tvf.ToTensor()) | |
| tfs.append(tvf.Resize(self.cfg.image_size)) | |
| self.val_tfs = tvf.Compose(tfs) | |
| # transforms.Resize(self.cfg.image_size), | |
| if cfg.augmentation.image.apply: | |
| args = OmegaConf.masked_copy( | |
| cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"] | |
| ) | |
| tfs.append(tvf.ColorJitter(**args)) | |
| self.train_tfs = tvf.Compose(tfs) | |
| # self.train_tfs=self.transform | |
| # self.val_tfs = self.transform | |
| self.init() | |
| def init(self): | |
| self.train_dataset = ConcatDataset([ | |
| UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.train_tfs) | |
| for city in self.cfg.train_citys | |
| ]) | |
| self.val_dataset = ConcatDataset([ | |
| UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs) | |
| for city in self.cfg.val_citys | |
| ]) | |
| self.test_dataset = ConcatDataset([ | |
| UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs) | |
| for city in self.cfg.test_citys | |
| ]) | |
| # self.val_datasets = { | |
| # city:UavMapPair(root=Path(self.cfg.root),city=city,transform=self.val_tfs) | |
| # for city in self.cfg.val_citys | |
| # } | |
| # logger.info("train data len:{},val data len:{}".format(len(self.train_dataset),len(self.val_dataset))) | |
| # # 定义分割比例 | |
| # train_ratio = 0.8 # 训练集比例 | |
| # # 计算分割的样本数量 | |
| # train_size = int(len(self.dataset) * train_ratio) | |
| # val_size = len(self.dataset) - train_size | |
| # self.train_dataset, self.val_dataset = random_split(self.dataset, [train_size, val_size]) | |
| def train_dataloader(self): | |
| train_loader = DataLoader(self.train_dataset, | |
| batch_size=self.cfg.train.batch_size, | |
| num_workers=self.cfg.train.num_workers, | |
| shuffle=True,pin_memory = True) | |
| return train_loader | |
| def val_dataloader(self): | |
| val_loader = DataLoader(self.val_dataset, | |
| batch_size=self.cfg.val.batch_size, | |
| num_workers=self.cfg.val.num_workers, | |
| shuffle=True,pin_memory = True) | |
| # | |
| # my_dict = {k: v for k, v in self.val_datasets} | |
| # val_loaders={city: DataLoader(dataset, | |
| # batch_size=self.cfg.val.batch_size, | |
| # num_workers=self.cfg.val.num_workers, | |
| # shuffle=False,pin_memory = True) for city, dataset in self.val_datasets.items()} | |
| return val_loader | |
| def test_dataloader(self): | |
| val_loader = DataLoader(self.test_dataset, | |
| batch_size=self.cfg.val.batch_size, | |
| num_workers=self.cfg.val.num_workers, | |
| shuffle=True,pin_memory = True) | |
| # | |
| # my_dict = {k: v for k, v in self.val_datasets} | |
| # val_loaders={city: DataLoader(dataset, | |
| # batch_size=self.cfg.val.batch_size, | |
| # num_workers=self.cfg.val.num_workers, | |
| # shuffle=False,pin_memory = True) for city, dataset in self.val_datasets.items()} | |
| return val_loader |