hf_segm_test / data /datamodule.py
ryali93's picture
first model version
d218927
import pytorch_lightning as pl
from torchvision.datasets import VOCSegmentation
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, random_split
class SegmentationDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str, config):
super().__init__()
self.config = config
# Transformación para la imagen y la máscara
image_transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
mask_transform = transforms.Compose([
transforms.Resize((128, 128)), # Puedes añadir más transformaciones si lo deseas
transforms.ToTensor(),
lambda x: x.long()
])
self.transform = DualTransform(image_transform, mask_transform)
self.data_dir = data_dir
def prepare_data(self):
# Descargar el dataset (si es necesario)
VOCSegmentation(root=self.data_dir, year='2012', image_set='trainval', download=False)
def setup(self, stage=None):
# Inicializa el dataset
self.dataset = VOCSegmentation(root=self.data_dir, year='2012', image_set='trainval', transforms=self.transform)
# Dividir el dataset y asignar a sets de entrenamiento/validación
train_len = int(0.8 * len(self.dataset))
val_len = len(self.dataset) - train_len
self.train_dataset, self.val_dataset = random_split(self.dataset, [train_len, val_len])
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.config.batch_size, shuffle=self.config.shuffle, num_workers=self.config.num_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.config.batch_size, shuffle=False, num_workers=self.config.num_workers)
# Puedes añadir un test_dataloader si lo deseas
class DualTransform:
def __init__(self, image_transform, mask_transform):
self.image_transform = image_transform
self.mask_transform = mask_transform
def __call__(self, image, mask):
return self.image_transform(image), self.mask_transform(mask)