import sys, os sys.path.insert(0, os.path.dirname(__file__)) import numpy as np from config import * import torch from torch.utils.data import Dataset, DataLoader, random_split from PIL import Image import torchvision.transforms as T import random class ImageDataset(Dataset): def __init__(self, file_names): self.file_names = file_names self.transform = T.Compose([ T.ToTensor(),]) def __len__(self): return len(self.file_names) def __getitem__(self, index): file_path = self.file_names[index] image_data = self.transform(Image.open(file_path).convert("RGB")) assert image_data.shape == torch.Size([3, image_res, image_res]), f"Unexpected shape: {image_data.shape}" return image_data def image_dataloader(): image_paths = [os.path.join(resized_img_dir, path) for path in os.listdir(resized_img_dir)[:] if path.endswith(".jpg")] dataset = ImageDataset(image_paths) # g = torch.Generator() # g.manual_seed(42) train_size = int(0.8 * len(dataset)) test_size = len(dataset) - train_size train_set, test_set = random_split(dataset, [train_size, test_size]) train_loader = DataLoader(train_set, batch_size=vae_batch_size, shuffle=True, worker_init_fn=seed_worker, num_workers=min(8, os.cpu_count()), pin_memory=True, persistent_workers=True, prefetch_factor=4) test_loader = DataLoader(test_set, batch_size=vae_batch_size, shuffle=False, worker_init_fn=seed_worker, num_workers=min(8, os.cpu_count()), pin_memory=True, persistent_workers=True, prefetch_factor=4) return train_loader, test_loader class LatentEmbeddingsDataset(Dataset): def __init__(self, file_names): self.file_names = file_names self.num_variants = 5 def __len__(self): return len(self.file_names) * self.num_variants def __getitem__(self, index): latent_index = index // self.num_variants variant_index = index % self.num_variants # variant_index = random.randint(0, self.num_variants - 1) file_path = self.file_names[latent_index] latent_data = torch.load(f"{latent_scaled_dir}/{file_path}", map_location="cpu", weights_only=True) embedding_data = torch.load(f"{embedding_dir}/{file_path[:-3]}_{variant_index}.pt", map_location="cpu", weights_only=True) assert embedding_data.shape == torch.Size([77, 1024]), f"Unexpected embedding shape: {embedding_data.shape}" return (latent_data, embedding_data) def latent_embedding_dataloader(): file_names = sorted(path for path in os.listdir(latent_scaled_dir) if path.endswith(".pt"))[:] # file_names = [path for path in os.listdir(latent_scaled_dir) if path.endswith(".pt")][:16] dataset = LatentEmbeddingsDataset(file_names) train_loader = DataLoader(dataset, batch_size=unet_batch_size, shuffle=True, worker_init_fn=seed_worker, num_workers=min(8, os.cpu_count()), pin_memory=True, persistent_workers=True, prefetch_factor=32) return train_loader def seed_worker(worker_id): # worker_seed = 42 + worker_id worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed)