| import pytorch_lightning as pl | |
| from torch.utils.data import DataLoader | |
| from torchvision.transforms import v2 | |
| from .data import ImageDataset, TransformDino | |
| class InferenceDataModel(pl.LightningDataModule): | |
| def __init__( | |
| self, | |
| metadata_path, | |
| images_root_path, | |
| batch_size=32, | |
| ): | |
| super().__init__() | |
| self.metadata_path = metadata_path | |
| self.images_root_path = images_root_path | |
| self.batch_size = batch_size | |
| def setup(self, stage=None): | |
| self.dataloader = DataLoader( | |
| ImageDataset(self.metadata_path, self.images_root_path), | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=4, | |
| ) | |
| def predict_dataloader(self): | |
| transform = v2.Compose([TransformDino("facebook/dinov2-base")]) | |
| for batch in self.dataloader: | |
| batch = transform(batch) | |
| yield batch | |