Loupe / src /data_module.py
xxwyyds's picture
Upload 86 files
891e05c verified
from functools import partial
import os
import numpy as np
from omegaconf import DictConfig
import pytorch_lightning as pl
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets
from models.loupe import LoupeImageProcessor, LoupeConfig
class DataModule(pl.LightningDataModule):
def __init__(self, cfg: DictConfig, model_config: LoupeConfig) -> None:
super().__init__()
self.cfg = cfg
self.model_config = model_config
self.processor = LoupeImageProcessor(self.model_config)
def setup(self, stage: str) -> None:
dataset = load_dataset("parquet", data_dir=self.cfg.dataset.data_dir)
if stage in [None, "validate", "fit"]:
validset = dataset["validation"]
if isinstance(self.cfg.dataset.valid_size, int):
assert 0 < self.cfg.dataset.valid_size < len(validset)
valid_size = self.cfg.dataset.valid_size
elif isinstance(self.cfg.dataset.valid_size, float):
assert 0 < self.cfg.dataset.valid_size <= 1
valid_size = int(self.cfg.dataset.valid_size * len(validset))
else:
raise ValueError(
f"Invalid valid_size: {self.cfg.dataset.valid_size}. It should be either int or float."
)
# use a small subset to prevent too long validation time
additional_trainset, validset = validset.train_test_split(
test_size=valid_size, seed=self.cfg.seed, shuffle=True
).values()
self.validset = validset
if self.cfg.stage.name in ["cls_seg", "test"] and not getattr(
self.cfg.stage, "train_on_trainset", False
):
self.trainset = additional_trainset
else:
self.trainset = dataset["train"]
elif stage == "test":
self.testset = dataset["validation"]
elif stage == "predict":
self.testset = dataset["test"]
def train_collate_fn(self, batch):
images = [x["image"] for x in batch]
masks = [x["mask"] for x in batch]
labels = [x is not None for x in masks] # mask is None means it is real
return {
**self.processor(
images,
masks if not getattr(self.cfg.stage, "enable_tta", False) else None,
self.model_config.enable_patch_cls,
return_tensors="pt",
),
"labels": torch.tensor(labels, dtype=torch.long), # (N,)
}
def train_dataloader(self):
return DataLoader(
self.trainset,
batch_size=self.cfg.hparams.batch_size,
num_workers=self.cfg.dataset.num_workers,
collate_fn=self.train_collate_fn,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
self.validset,
batch_size=self.cfg.hparams.batch_size,
num_workers=self.cfg.dataset.num_workers,
collate_fn=self.test_collate_fn,
shuffle=False,
)
def test_collate_fn(self, batch):
"""
Collate function for valid and test dataloaders.
Args:
batch: List of dictionaries containing "image" and "mask" keys.
"""
images = [x["image"] for x in batch]
masks = [x["mask"] for x in batch]
labels = [x is not None for x in masks] # mask is None means it is real
outputs = self.processor(images, masks, return_tensors="pt")
for i, mask in enumerate(masks):
if mask is None:
# note that in PIL image, the size is (W, H)
masks[i] = torch.zeros(
(images[i].size[1], images[i].size[0]),
dtype=torch.uint8,
)
else:
# convert to binary mask with 0 and 1
masks[i] = self.processor.convert_to_binary_masks(mask)
return {
**outputs,
"masks": masks, # a list of (N, H_i, W_i)
"labels": (torch.tensor(labels, dtype=torch.long)), # (N,)
}
def test_dataloader(self):
return DataLoader(
self.testset,
batch_size=self.cfg.hparams.batch_size,
num_workers=self.cfg.dataset.num_workers,
collate_fn=self.test_collate_fn,
)
def predict_collate_fn(self, batch):
"""
Collate function for predict dataloader.
Args:
batch: List of dictionaries containing "image" and "mask" keys.
"""
images = [x["image"] for x in batch]
outputs = self.processor(images, return_tensors="pt")
return {
**outputs,
"target_sizes": [image.size[::-1] for image in images],
"name": [os.path.basename(x["path"]) for x in batch],
}
def predict_dataloader(self):
return DataLoader(
self.testset,
batch_size=self.cfg.hparams.batch_size,
num_workers=self.cfg.dataset.num_workers,
collate_fn=self.predict_collate_fn,
shuffle=False,
)