import pytorch_lightning as pl from torchvision.datasets import VOCSegmentation from torchvision.transforms import transforms from torch.utils.data import DataLoader, random_split class SegmentationDataModule(pl.LightningDataModule): def __init__(self, data_dir: str, config): super().__init__() self.config = config # Transformación para la imagen y la máscara image_transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor() ]) mask_transform = transforms.Compose([ transforms.Resize((128, 128)), # Puedes añadir más transformaciones si lo deseas transforms.ToTensor(), lambda x: x.long() ]) self.transform = DualTransform(image_transform, mask_transform) self.data_dir = data_dir def prepare_data(self): # Descargar el dataset (si es necesario) VOCSegmentation(root=self.data_dir, year='2012', image_set='trainval', download=False) def setup(self, stage=None): # Inicializa el dataset self.dataset = VOCSegmentation(root=self.data_dir, year='2012', image_set='trainval', transforms=self.transform) # Dividir el dataset y asignar a sets de entrenamiento/validación train_len = int(0.8 * len(self.dataset)) val_len = len(self.dataset) - train_len self.train_dataset, self.val_dataset = random_split(self.dataset, [train_len, val_len]) def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.config.batch_size, shuffle=self.config.shuffle, num_workers=self.config.num_workers) def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.config.batch_size, shuffle=False, num_workers=self.config.num_workers) # Puedes añadir un test_dataloader si lo deseas class DualTransform: def __init__(self, image_transform, mask_transform): self.image_transform = image_transform self.mask_transform = mask_transform def __call__(self, image, mask): return self.image_transform(image), self.mask_transform(mask)