flickr8k-backend / core /dataloader.py
Rohan3's picture
Updated: VAE, UNet, config, text embeddings, model and main
a625e96
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
import numpy as np
from config import *
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import torchvision.transforms as T
import random
class ImageDataset(Dataset):
def __init__(self, file_names):
self.file_names = file_names
self.transform = T.Compose([ T.ToTensor(),])
def __len__(self): return len(self.file_names)
def __getitem__(self, index):
file_path = self.file_names[index]
image_data = self.transform(Image.open(file_path).convert("RGB"))
assert image_data.shape == torch.Size([3, image_res, image_res]), f"Unexpected shape: {image_data.shape}"
return image_data
def image_dataloader():
image_paths = [os.path.join(resized_img_dir, path) for path in os.listdir(resized_img_dir)[:] if path.endswith(".jpg")]
dataset = ImageDataset(image_paths)
# g = torch.Generator()
# g.manual_seed(42)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_set, test_set = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_set, batch_size=vae_batch_size, shuffle=True, worker_init_fn=seed_worker, num_workers=min(8, os.cpu_count()), pin_memory=True, persistent_workers=True, prefetch_factor=4)
test_loader = DataLoader(test_set, batch_size=vae_batch_size, shuffle=False, worker_init_fn=seed_worker, num_workers=min(8, os.cpu_count()), pin_memory=True, persistent_workers=True, prefetch_factor=4)
return train_loader, test_loader
class LatentEmbeddingsDataset(Dataset):
def __init__(self, file_names):
self.file_names = file_names
self.num_variants = 5
def __len__(self): return len(self.file_names) * self.num_variants
def __getitem__(self, index):
latent_index = index // self.num_variants
variant_index = index % self.num_variants
# variant_index = random.randint(0, self.num_variants - 1)
file_path = self.file_names[latent_index]
latent_data = torch.load(f"{latent_scaled_dir}/{file_path}", map_location="cpu", weights_only=True)
embedding_data = torch.load(f"{embedding_dir}/{file_path[:-3]}_{variant_index}.pt", map_location="cpu", weights_only=True)
assert embedding_data.shape == torch.Size([77, 1024]), f"Unexpected embedding shape: {embedding_data.shape}"
return (latent_data, embedding_data)
def latent_embedding_dataloader():
file_names = sorted(path for path in os.listdir(latent_scaled_dir) if path.endswith(".pt"))[:]
# file_names = [path for path in os.listdir(latent_scaled_dir) if path.endswith(".pt")][:16]
dataset = LatentEmbeddingsDataset(file_names)
train_loader = DataLoader(dataset, batch_size=unet_batch_size, shuffle=True, worker_init_fn=seed_worker, num_workers=min(8, os.cpu_count()), pin_memory=True, persistent_workers=True, prefetch_factor=32)
return train_loader
def seed_worker(worker_id):
# worker_seed = 42 + worker_id
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)