Spaces:
Sleeping
Sleeping
| import sys, os | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| import numpy as np | |
| from config import * | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| import random | |
| class ImageDataset(Dataset): | |
| def __init__(self, file_names): | |
| self.file_names = file_names | |
| self.transform = T.Compose([ T.ToTensor(),]) | |
| def __len__(self): return len(self.file_names) | |
| def __getitem__(self, index): | |
| file_path = self.file_names[index] | |
| image_data = self.transform(Image.open(file_path).convert("RGB")) | |
| assert image_data.shape == torch.Size([3, image_res, image_res]), f"Unexpected shape: {image_data.shape}" | |
| return image_data | |
| def image_dataloader(): | |
| image_paths = [os.path.join(resized_img_dir, path) for path in os.listdir(resized_img_dir)[:] if path.endswith(".jpg")] | |
| dataset = ImageDataset(image_paths) | |
| # g = torch.Generator() | |
| # g.manual_seed(42) | |
| train_size = int(0.8 * len(dataset)) | |
| test_size = len(dataset) - train_size | |
| train_set, test_set = random_split(dataset, [train_size, test_size]) | |
| train_loader = DataLoader(train_set, batch_size=vae_batch_size, shuffle=True, worker_init_fn=seed_worker, num_workers=min(8, os.cpu_count()), pin_memory=True, persistent_workers=True, prefetch_factor=4) | |
| test_loader = DataLoader(test_set, batch_size=vae_batch_size, shuffle=False, worker_init_fn=seed_worker, num_workers=min(8, os.cpu_count()), pin_memory=True, persistent_workers=True, prefetch_factor=4) | |
| return train_loader, test_loader | |
| class LatentEmbeddingsDataset(Dataset): | |
| def __init__(self, file_names): | |
| self.file_names = file_names | |
| self.num_variants = 5 | |
| def __len__(self): return len(self.file_names) * self.num_variants | |
| def __getitem__(self, index): | |
| latent_index = index // self.num_variants | |
| variant_index = index % self.num_variants | |
| # variant_index = random.randint(0, self.num_variants - 1) | |
| file_path = self.file_names[latent_index] | |
| latent_data = torch.load(f"{latent_scaled_dir}/{file_path}", map_location="cpu", weights_only=True) | |
| embedding_data = torch.load(f"{embedding_dir}/{file_path[:-3]}_{variant_index}.pt", map_location="cpu", weights_only=True) | |
| assert embedding_data.shape == torch.Size([77, 1024]), f"Unexpected embedding shape: {embedding_data.shape}" | |
| return (latent_data, embedding_data) | |
| def latent_embedding_dataloader(): | |
| file_names = sorted(path for path in os.listdir(latent_scaled_dir) if path.endswith(".pt"))[:] | |
| # file_names = [path for path in os.listdir(latent_scaled_dir) if path.endswith(".pt")][:16] | |
| dataset = LatentEmbeddingsDataset(file_names) | |
| train_loader = DataLoader(dataset, batch_size=unet_batch_size, shuffle=True, worker_init_fn=seed_worker, num_workers=min(8, os.cpu_count()), pin_memory=True, persistent_workers=True, prefetch_factor=32) | |
| return train_loader | |
| def seed_worker(worker_id): | |
| # worker_seed = 42 + worker_id | |
| worker_seed = torch.initial_seed() % 2**32 | |
| np.random.seed(worker_seed) | |
| random.seed(worker_seed) |