| | import os |
| | import torch |
| | from torch.utils.data import Dataset |
| |
|
| | class FlatTileDataset(Dataset): |
| | def __init__(self, data_dir): |
| | super().__init__() |
| | self.data_dir = data_dir |
| | |
| | self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))] |
| |
|
| | def __len__(self): |
| | |
| | return len(self.files) |
| |
|
| | def __getitem__(self, idx): |
| | |
| | file_path = self.files[idx] |
| | |
| | data = torch.load(file_path) |
| | |
| | tile_data = torch.from_numpy(data['tile_data'][0]) |
| | file_data = data['file_data'] |
| | |
| | return tile_data, file_data |
| |
|