File size: 3,143 Bytes
4aabce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a625e96
 
4aabce3
 
 
a625e96
 
4aabce3
 
 
 
 
 
 
 
 
 
 
a625e96
4aabce3
 
 
 
 
 
a625e96
 
4aabce3
a625e96
4aabce3
 
 
a625e96
 
4aabce3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
import numpy as np
from config import *
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import torchvision.transforms as T
import random

class ImageDataset(Dataset):
    def __init__(self, file_names):
        self.file_names = file_names
        self.transform = T.Compose([ T.ToTensor(),])
    def __len__(self): return len(self.file_names)
    def __getitem__(self, index):
        file_path = self.file_names[index]
        image_data = self.transform(Image.open(file_path).convert("RGB"))
        assert image_data.shape == torch.Size([3, image_res, image_res]), f"Unexpected shape: {image_data.shape}"
        return image_data
def image_dataloader():
    image_paths = [os.path.join(resized_img_dir, path) for path in os.listdir(resized_img_dir)[:] if path.endswith(".jpg")]
    dataset = ImageDataset(image_paths)
    # g = torch.Generator()
    # g.manual_seed(42)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_set, test_set = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_set, batch_size=vae_batch_size, shuffle=True, worker_init_fn=seed_worker, num_workers=min(8, os.cpu_count()), pin_memory=True, persistent_workers=True, prefetch_factor=4)
    test_loader = DataLoader(test_set, batch_size=vae_batch_size, shuffle=False, worker_init_fn=seed_worker, num_workers=min(8, os.cpu_count()), pin_memory=True, persistent_workers=True, prefetch_factor=4)
    return train_loader, test_loader

class LatentEmbeddingsDataset(Dataset):
    def __init__(self, file_names):
        self.file_names = file_names
        self.num_variants = 5

    def __len__(self): return len(self.file_names) * self.num_variants
    def __getitem__(self, index):
        latent_index = index // self.num_variants
        variant_index = index % self.num_variants
        # variant_index = random.randint(0, self.num_variants - 1)
        file_path = self.file_names[latent_index]
        latent_data = torch.load(f"{latent_scaled_dir}/{file_path}", map_location="cpu", weights_only=True)
        embedding_data = torch.load(f"{embedding_dir}/{file_path[:-3]}_{variant_index}.pt", map_location="cpu", weights_only=True)
        assert embedding_data.shape == torch.Size([77, 1024]), f"Unexpected embedding shape: {embedding_data.shape}"
        return (latent_data, embedding_data)
def latent_embedding_dataloader():
    file_names = sorted(path for path in os.listdir(latent_scaled_dir) if path.endswith(".pt"))[:]
    # file_names = [path for path in os.listdir(latent_scaled_dir) if path.endswith(".pt")][:16]
    dataset = LatentEmbeddingsDataset(file_names)
    train_loader = DataLoader(dataset, batch_size=unet_batch_size, shuffle=True, worker_init_fn=seed_worker, num_workers=min(8, os.cpu_count()), pin_memory=True, persistent_workers=True, prefetch_factor=32)
    return train_loader

def seed_worker(worker_id):
    # worker_seed = 42 + worker_id
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)