import os from pathlib import Path import torch from PIL import Image from torch.utils.data import Dataset def txt_dir_save_fn(image, metadata, root_path): stem = metadata["filename"] Image.fromarray(image).save(os.path.join(root_path, f"{stem}.png")) with open(os.path.join(root_path, f"{stem}.txt"), "w") as f: f.write(metadata["prompt"]) class TxtDirEvalDataset(Dataset): def __init__( self, txt_dir: str, latent_shape, limit: int = 50, seed_offset: int = 0, ): self.latent_shape = latent_shape self.seed_offset = seed_offset self.txt_paths = sorted(Path(txt_dir).glob("*.txt")) if limit is not None and limit > 0: self.txt_paths = self.txt_paths[:limit] if not self.txt_paths: raise FileNotFoundError(f"No .txt files found in {txt_dir}") def __len__(self): return len(self.txt_paths) def __getitem__(self, idx): path = self.txt_paths[idx] prompt = path.read_text().strip() seed = self.seed_offset + idx generator = torch.Generator().manual_seed(seed) latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) metadata = dict( prompt=prompt, filename=path.stem, seed=seed, save_fn=txt_dir_save_fn, ) return latent, prompt, metadata