| | import numpy as np |
| | import torch |
| | from torch.utils.data.dataset import Dataset |
| | import os |
| | from configs.paths import dataset_root |
| | import copy |
| | from .base import BASE |
| |
|
| | |
| | class COMMON(BASE): |
| | def __init__(self, img_folder, **kwargs): |
| | super(COMMON, self).__init__(**kwargs) |
| | self.dataset_path = img_folder |
| | self.img_names = sorted([img_name\ |
| | for img_name\ |
| | in os.listdir(self.dataset_path)\ |
| | if img_name.endswith('.png') or img_name.endswith('.jpg') or img_name.endswith('.jpeg')]) |
| | assert self.mode == 'infer' |
| | |
| | def __len__(self): |
| | return len(self.img_names) |
| | |
| | def get_raw_data(self, idx): |
| | img_id=idx%len(self.img_names) |
| | img_name=self.img_names[img_id] |
| | img_path=os.path.join(self.dataset_path,img_name) |
| | raw_data={'img_path': img_path, |
| | 'img_name': img_name, |
| | 'ds': 'common' |
| | } |
| | |
| | return raw_data |
| |
|
| |
|
| |
|