Spaces:
Running
Running
| from functools import partial | |
| import os | |
| import numpy as np | |
| from omegaconf import DictConfig | |
| import pytorch_lightning as pl | |
| import torch | |
| import torchvision.transforms as T | |
| from torch.utils.data import DataLoader | |
| from datasets import load_dataset, concatenate_datasets | |
| from models.loupe import LoupeImageProcessor, LoupeConfig | |
| class DataModule(pl.LightningDataModule): | |
| def __init__(self, cfg: DictConfig, model_config: LoupeConfig) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| self.model_config = model_config | |
| self.processor = LoupeImageProcessor(self.model_config) | |
| def setup(self, stage: str) -> None: | |
| dataset = load_dataset("parquet", data_dir=self.cfg.dataset.data_dir) | |
| if stage in [None, "validate", "fit"]: | |
| validset = dataset["validation"] | |
| if isinstance(self.cfg.dataset.valid_size, int): | |
| assert 0 < self.cfg.dataset.valid_size < len(validset) | |
| valid_size = self.cfg.dataset.valid_size | |
| elif isinstance(self.cfg.dataset.valid_size, float): | |
| assert 0 < self.cfg.dataset.valid_size <= 1 | |
| valid_size = int(self.cfg.dataset.valid_size * len(validset)) | |
| else: | |
| raise ValueError( | |
| f"Invalid valid_size: {self.cfg.dataset.valid_size}. It should be either int or float." | |
| ) | |
| # use a small subset to prevent too long validation time | |
| additional_trainset, validset = validset.train_test_split( | |
| test_size=valid_size, seed=self.cfg.seed, shuffle=True | |
| ).values() | |
| self.validset = validset | |
| if self.cfg.stage.name in ["cls_seg", "test"] and not getattr( | |
| self.cfg.stage, "train_on_trainset", False | |
| ): | |
| self.trainset = additional_trainset | |
| else: | |
| self.trainset = dataset["train"] | |
| elif stage == "test": | |
| self.testset = dataset["validation"] | |
| elif stage == "predict": | |
| self.testset = dataset["test"] | |
| def train_collate_fn(self, batch): | |
| images = [x["image"] for x in batch] | |
| masks = [x["mask"] for x in batch] | |
| labels = [x is not None for x in masks] # mask is None means it is real | |
| return { | |
| **self.processor( | |
| images, | |
| masks if not getattr(self.cfg.stage, "enable_tta", False) else None, | |
| self.model_config.enable_patch_cls, | |
| return_tensors="pt", | |
| ), | |
| "labels": torch.tensor(labels, dtype=torch.long), # (N,) | |
| } | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.trainset, | |
| batch_size=self.cfg.hparams.batch_size, | |
| num_workers=self.cfg.dataset.num_workers, | |
| collate_fn=self.train_collate_fn, | |
| shuffle=True, | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.validset, | |
| batch_size=self.cfg.hparams.batch_size, | |
| num_workers=self.cfg.dataset.num_workers, | |
| collate_fn=self.test_collate_fn, | |
| shuffle=False, | |
| ) | |
| def test_collate_fn(self, batch): | |
| """ | |
| Collate function for valid and test dataloaders. | |
| Args: | |
| batch: List of dictionaries containing "image" and "mask" keys. | |
| """ | |
| images = [x["image"] for x in batch] | |
| masks = [x["mask"] for x in batch] | |
| labels = [x is not None for x in masks] # mask is None means it is real | |
| outputs = self.processor(images, masks, return_tensors="pt") | |
| for i, mask in enumerate(masks): | |
| if mask is None: | |
| # note that in PIL image, the size is (W, H) | |
| masks[i] = torch.zeros( | |
| (images[i].size[1], images[i].size[0]), | |
| dtype=torch.uint8, | |
| ) | |
| else: | |
| # convert to binary mask with 0 and 1 | |
| masks[i] = self.processor.convert_to_binary_masks(mask) | |
| return { | |
| **outputs, | |
| "masks": masks, # a list of (N, H_i, W_i) | |
| "labels": (torch.tensor(labels, dtype=torch.long)), # (N,) | |
| } | |
| def test_dataloader(self): | |
| return DataLoader( | |
| self.testset, | |
| batch_size=self.cfg.hparams.batch_size, | |
| num_workers=self.cfg.dataset.num_workers, | |
| collate_fn=self.test_collate_fn, | |
| ) | |
| def predict_collate_fn(self, batch): | |
| """ | |
| Collate function for predict dataloader. | |
| Args: | |
| batch: List of dictionaries containing "image" and "mask" keys. | |
| """ | |
| images = [x["image"] for x in batch] | |
| outputs = self.processor(images, return_tensors="pt") | |
| return { | |
| **outputs, | |
| "target_sizes": [image.size[::-1] for image in images], | |
| "name": [os.path.basename(x["path"]) for x in batch], | |
| } | |
| def predict_dataloader(self): | |
| return DataLoader( | |
| self.testset, | |
| batch_size=self.cfg.hparams.batch_size, | |
| num_workers=self.cfg.dataset.num_workers, | |
| collate_fn=self.predict_collate_fn, | |
| shuffle=False, | |
| ) | |