File size: 940 Bytes
8381e8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import torch
from torch.utils.data import Dataset

class FlatTileDataset(Dataset):
    def __init__(self, data_dir):
        super().__init__()
        self.data_dir = data_dir
        # List all files in the data_dir that are files (not directories)
        self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))]

    def __len__(self):
        # Return the total number of files
        return len(self.files)

    def __getitem__(self, idx):
        # Get the file path for the given index
        file_path = self.files[idx]
        # Load the data from the file
        data = torch.load(file_path)
        # Assuming the data file is a dictionary with 'tile_data' and 'file_data' keys
        tile_data = torch.from_numpy(data['tile_data'][0])
        file_data = data['file_data']
        # Return the tile data and file data
        return tile_data, file_data