| from diffusers import AutoPipelineForText2Image |
| from diffusers import TCDScheduler, LCMScheduler |
| from diffusers.utils import make_image_grid |
| import torch |
| from PIL import Image |
| import time |
| from pathlib import Path |
| import argparse |
|
|
| device = torch.device("mps") |
|
|
| folder = Path("../models/") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Generate images from a textual prompt using stable diffusion" |
| ) |
| parser.add_argument("prompt") |
| parser.add_argument("--model", choices=["sd1.5", "sd2", "realistic", "sdxl", "sdxl-turbo"], default="sdxl") |
| parser.add_argument("--n_images", type=int, default=1) |
| parser.add_argument("--steps", type=int) |
| parser.add_argument("--guidance_scale", type=float) |
| parser.add_argument("--negative_prompt", default="") |
| parser.add_argument("--output", default="out.png") |
| parser.add_argument("--img-size", type=int, default=512) |
| parser.add_argument("--lora", type=str, default=None) |
| parser.add_argument("--lora-scale", type=float, default=None) |
| args = parser.parse_args() |
|
|
| if args.model == "sdxl-turbo": |
| model_path = folder / "stabilityai/sdxl-turbo" |
| args.guidance_scale = args.guidance_scale or 0.0 |
| args.steps = args.steps or 2 |
| elif args.model == "sdxl": |
| model_path = folder / "stabilityai/stable-diffusion-xl-base-1.0" |
| args.guidance_scale = args.guidance_scale or 7.5 |
| args.steps = args.steps or 20 |
| elif args.model == "sd2": |
| model_path = folder / "stabilityai/stable-diffusion-2-1-base" |
| args.guidance_scale = args.guidance_scale or 7.5 |
| args.steps = args.steps or 20 |
| elif args.model == "sd1.5": |
| model_path = folder / "runwayml/stable-diffusion-v1-5" |
| args.guidance_scale = args.guidance_scale or 7.5 |
| args.steps = args.steps or 20 |
| elif args.model == "realistic": |
| model_path = folder / "SG161222/Realistic_Vision_V3.0_VAE" |
| args.guidance_scale = args.guidance_scale or 5 |
| args.steps = args.steps or 10 |
| else: |
| raise ValueError(f"Unknown model: {args.model}") |
|
|
| print("*" * 10, "configurations") |
| print(f"model: {args.model}\nimage number: {args.n_images}\nsteps: {args.steps}\n" |
| f"guidance_scale: {args.guidance_scale}\nnegative_prompt: {args.negative_prompt}\noutput: {args.output}\n" |
| f"img-size: {args.img_size}\nlora: {args.lora}\nlora-scale: {args.lora_scale}\nprompt:{args.prompt}\n") |
| print("*" * 10) |
|
|
| t0 = time.time() |
| pipe = AutoPipelineForText2Image.from_pretrained(model_path, |
| torch_dtype=torch.float16).to(device) |
| t1 = time.time() |
| print(f"load model time: {(t1 - t0):.3f}") |
|
|
| if args.lora: |
| t_load_lora = time.time() |
| |
| pipe.load_lora_weights(args.lora) |
| |
| t2 = time.time() |
| print(f"load lora time: {(t2 - t1):.3f}") |
|
|
| if args.lora_scale: |
| output = pipe(prompt=args.prompt, |
| height=args.img_size, |
| width=args.img_size, |
| num_inference_steps=args.steps, |
| num_images_per_prompt=args.n_images, |
| guidance_scale=args.guidance_scale, |
| cross_attention_kwargs={'scale': args.lora_scale}, |
| |
| ) |
| else: |
| output = pipe(prompt=args.prompt, |
| height=args.img_size, |
| width=args.img_size, |
| num_inference_steps=args.steps, |
| num_images_per_prompt=args.n_images, |
| guidance_scale=args.guidance_scale, |
| |
| ) |
| t3 = time.time() |
| print(f"generate image time: {(t3 - t2):.3f}") |
| img = make_image_grid(output.images, rows=1, cols=args.n_images) |
| img.save(args.output) |
| print(f"save image to: {args.output}") |
| print(f"output image size: {img.size}") |
| print(f"total time: {(time.time() - t0):.3f}") |
|
|