|
|
import pytorch_lightning as pl |
|
|
from torch.utils.data import DataLoader |
|
|
from dataset import MyDataset, load_filenames |
|
|
|
|
|
class DataModule(pl.LightningDataModule): |
|
|
def __init__(self, img_dir, batch_size, img_size=112, num_workers=0): |
|
|
super().__init__() |
|
|
self.img_dir = img_dir |
|
|
self.batch_size = batch_size |
|
|
self.img_size = img_size |
|
|
self.num_workers = num_workers |
|
|
self.file_num = 1000 |
|
|
|
|
|
def setup(self, stage=None): |
|
|
filenames = load_filenames(self.img_dir) |
|
|
self.train_dataset = MyDataset(filenames[:self.file_num], img_dir=self.img_dir, img_size=self.img_size) |
|
|
|
|
|
def train_dataloader(self): |
|
|
return DataLoader( |
|
|
self.train_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=self.num_workers, |
|
|
persistent_workers=True |
|
|
) |
|
|
|