| import torch | |
| import gc | |
| from pathlib import Path | |
| from tqdm.auto import tqdm | |
| from dataset import PromptDataset | |
| def generate_class_images(pipeline, class_prompt, num_class_images, class_images_dir, sample_batch_size=2): | |
| cur_class_images = len(list(class_images_dir.iterdir())) | |
| num_new_images = num_class_images - cur_class_images | |
| if num_new_images > 0: | |
| sample_dataset = PromptDataset(class_prompt, num_new_images) | |
| sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=sample_batch_size) | |
| for example in tqdm(sample_dataloader, desc="Generating class images"): | |
| images = pipeline(example["prompt"]).images | |
| for i, image in enumerate(images): | |
| image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") | |
| del pipeline | |
| gc.collect() | |
| torch.cuda.empty_cache() | |