| import os | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| def generate_images(prompt, model_path="output/checkpoint-500/", num_images=1): | |
| required_files = ['pytorch_model.bin', 'model.safetensors', 'tf_model.h5', 'model.ckpt.index', 'flax_model.msgpack'] | |
| if not any(os.path.exists(os.path.join(model_path, file)) for file in required_files): | |
| raise EnvironmentError( | |
| f"Error no file named {', '.join(required_files)} found in directory {model_path}. " | |
| "Ensure your model is correctly saved." | |
| ) | |
| pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) | |
| pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu") | |
| images = pipe(prompt, num_images=num_images).images | |
| return images | |