| import os |
| from pathlib import Path |
|
|
| import torch |
| from PIL import Image |
| from torch.utils.data import Dataset |
|
|
|
|
| def txt_dir_save_fn(image, metadata, root_path): |
| stem = metadata["filename"] |
| Image.fromarray(image).save(os.path.join(root_path, f"{stem}.png")) |
| with open(os.path.join(root_path, f"{stem}.txt"), "w") as f: |
| f.write(metadata["prompt"]) |
|
|
|
|
| class TxtDirEvalDataset(Dataset): |
| def __init__( |
| self, |
| txt_dir: str, |
| latent_shape, |
| limit: int = 50, |
| seed_offset: int = 0, |
| ): |
| self.latent_shape = latent_shape |
| self.seed_offset = seed_offset |
| self.txt_paths = sorted(Path(txt_dir).glob("*.txt")) |
| if limit is not None and limit > 0: |
| self.txt_paths = self.txt_paths[:limit] |
| if not self.txt_paths: |
| raise FileNotFoundError(f"No .txt files found in {txt_dir}") |
|
|
| def __len__(self): |
| return len(self.txt_paths) |
|
|
| def __getitem__(self, idx): |
| path = self.txt_paths[idx] |
| prompt = path.read_text().strip() |
| seed = self.seed_offset + idx |
| generator = torch.Generator().manual_seed(seed) |
| latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) |
| metadata = dict( |
| prompt=prompt, |
| filename=path.stem, |
| seed=seed, |
| save_fn=txt_dir_save_fn, |
| ) |
| return latent, prompt, metadata |
|
|