test_sd / run_sd_with_lora.py
yujuanqin's picture
add scripts
5bc80be verified
from diffusers import AutoPipelineForText2Image
from diffusers import TCDScheduler, LCMScheduler
from diffusers.utils import make_image_grid
import torch
from PIL import Image
import time
from pathlib import Path
import argparse
device = torch.device("mps")
folder = Path("../models/")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
)
parser.add_argument("prompt")
parser.add_argument("--model", choices=["sd1.5", "sd2", "realistic", "sdxl", "sdxl-turbo"], default="sdxl")
parser.add_argument("--n_images", type=int, default=1)
parser.add_argument("--steps", type=int)
parser.add_argument("--guidance_scale", type=float)
parser.add_argument("--negative_prompt", default="")
parser.add_argument("--output", default="out.png")
parser.add_argument("--img-size", type=int, default=512)
parser.add_argument("--lora", type=str, default=None)
parser.add_argument("--lora-scale", type=float, default=None)
args = parser.parse_args()
if args.model == "sdxl-turbo":
model_path = folder / "stabilityai/sdxl-turbo"
args.guidance_scale = args.guidance_scale or 0.0
args.steps = args.steps or 2
elif args.model == "sdxl":
model_path = folder / "stabilityai/stable-diffusion-xl-base-1.0"
args.guidance_scale = args.guidance_scale or 7.5
args.steps = args.steps or 20
elif args.model == "sd2":
model_path = folder / "stabilityai/stable-diffusion-2-1-base"
args.guidance_scale = args.guidance_scale or 7.5
args.steps = args.steps or 20
elif args.model == "sd1.5":
model_path = folder / "runwayml/stable-diffusion-v1-5"
args.guidance_scale = args.guidance_scale or 7.5
args.steps = args.steps or 20
elif args.model == "realistic":
model_path = folder / "SG161222/Realistic_Vision_V3.0_VAE"
args.guidance_scale = args.guidance_scale or 5 # 3,5-7
args.steps = args.steps or 10
else:
raise ValueError(f"Unknown model: {args.model}")
print("*" * 10, "configurations")
print(f"model: {args.model}\nimage number: {args.n_images}\nsteps: {args.steps}\n"
f"guidance_scale: {args.guidance_scale}\nnegative_prompt: {args.negative_prompt}\noutput: {args.output}\n"
f"img-size: {args.img_size}\nlora: {args.lora}\nlora-scale: {args.lora_scale}\nprompt:{args.prompt}\n")
print("*" * 10)
t0 = time.time()
pipe = AutoPipelineForText2Image.from_pretrained(model_path,
torch_dtype=torch.float16).to(device)
t1 = time.time()
print(f"load model time: {(t1 - t0):.3f}")
if args.lora:
t_load_lora = time.time()
# pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(args.lora)
# pipe.fuse_lora()
t2 = time.time()
print(f"load lora time: {(t2 - t1):.3f}")
if args.lora_scale:
output = pipe(prompt=args.prompt,
height=args.img_size,
width=args.img_size,
num_inference_steps=args.steps,
num_images_per_prompt=args.n_images,
guidance_scale=args.guidance_scale,
cross_attention_kwargs={'scale': args.lora_scale},
# generator=torch.Generator(device="mps").manual_seed(0)
)
else:
output = pipe(prompt=args.prompt,
height=args.img_size,
width=args.img_size,
num_inference_steps=args.steps,
num_images_per_prompt=args.n_images,
guidance_scale=args.guidance_scale,
# generator=torch.Generator(device="mps").manual_seed(0)
)
t3 = time.time()
print(f"generate image time: {(t3 - t2):.3f}")
img = make_image_grid(output.images, rows=1, cols=args.n_images)
img.save(args.output)
print(f"save image to: {args.output}")
print(f"output image size: {img.size}")
print(f"total time: {(time.time() - t0):.3f}")