pixel_gen / src /data /dataset /txt_dir_eval.py
linxin02's picture
Upload lx_gan project
cef8b68 verified
Raw
History Blame Contribute Delete
1.43 kB
import os
from pathlib import Path
import torch
from PIL import Image
from torch.utils.data import Dataset
def txt_dir_save_fn(image, metadata, root_path):
stem = metadata["filename"]
Image.fromarray(image).save(os.path.join(root_path, f"{stem}.png"))
with open(os.path.join(root_path, f"{stem}.txt"), "w") as f:
f.write(metadata["prompt"])
class TxtDirEvalDataset(Dataset):
def __init__(
self,
txt_dir: str,
latent_shape,
limit: int = 50,
seed_offset: int = 0,
):
self.latent_shape = latent_shape
self.seed_offset = seed_offset
self.txt_paths = sorted(Path(txt_dir).glob("*.txt"))
if limit is not None and limit > 0:
self.txt_paths = self.txt_paths[:limit]
if not self.txt_paths:
raise FileNotFoundError(f"No .txt files found in {txt_dir}")
def __len__(self):
return len(self.txt_paths)
def __getitem__(self, idx):
path = self.txt_paths[idx]
prompt = path.read_text().strip()
seed = self.seed_offset + idx
generator = torch.Generator().manual_seed(seed)
latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32)
metadata = dict(
prompt=prompt,
filename=path.stem,
seed=seed,
save_fn=txt_dir_save_fn,
)
return latent, prompt, metadata