| | import argparse |
| | import math |
| | import os |
| |
|
| | import torch |
| | from neural_compressor.utils.pytorch import load |
| | from PIL import Image |
| | from transformers import CLIPTextModel, CLIPTokenizer |
| |
|
| | from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "-m", |
| | "--pretrained_model_name_or_path", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="Path to pretrained model or model identifier from huggingface.co/models.", |
| | ) |
| | parser.add_argument( |
| | "-c", |
| | "--caption", |
| | type=str, |
| | default="robotic cat with wings", |
| | help="Text used to generate images.", |
| | ) |
| | parser.add_argument( |
| | "-n", |
| | "--images_num", |
| | type=int, |
| | default=4, |
| | help="How much images to generate.", |
| | ) |
| | parser.add_argument( |
| | "-s", |
| | "--seed", |
| | type=int, |
| | default=42, |
| | help="Seed for random process.", |
| | ) |
| | parser.add_argument( |
| | "-ci", |
| | "--cuda_id", |
| | type=int, |
| | default=0, |
| | help="cuda_id.", |
| | ) |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def image_grid(imgs, rows, cols): |
| | if not len(imgs) == rows * cols: |
| | raise ValueError("The specified number of rows and columns are not correct.") |
| |
|
| | w, h = imgs[0].size |
| | grid = Image.new("RGB", size=(cols * w, rows * h)) |
| | grid_w, grid_h = grid.size |
| |
|
| | for i, img in enumerate(imgs): |
| | grid.paste(img, box=(i % cols * w, i // cols * h)) |
| | return grid |
| |
|
| |
|
| | def generate_images( |
| | pipeline, |
| | prompt="robotic cat with wings", |
| | guidance_scale=7.5, |
| | num_inference_steps=50, |
| | num_images_per_prompt=1, |
| | seed=42, |
| | ): |
| | generator = torch.Generator(pipeline.device).manual_seed(seed) |
| | images = pipeline( |
| | prompt, |
| | guidance_scale=guidance_scale, |
| | num_inference_steps=num_inference_steps, |
| | generator=generator, |
| | num_images_per_prompt=num_images_per_prompt, |
| | ).images |
| | _rows = int(math.sqrt(num_images_per_prompt)) |
| | grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows) |
| | return grid, images |
| |
|
| |
|
| | args = parse_args() |
| | |
| | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") |
| | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") |
| | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") |
| | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") |
| |
|
| | pipeline = StableDiffusionPipeline.from_pretrained( |
| | args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer |
| | ) |
| | pipeline.safety_checker = lambda images, clip_input: (images, False) |
| | if os.path.exists(os.path.join(args.pretrained_model_name_or_path, "best_model.pt")): |
| | unet = load(args.pretrained_model_name_or_path, model=unet) |
| | unet.eval() |
| | setattr(pipeline, "unet", unet) |
| | else: |
| | unet = unet.to(torch.device("cuda", args.cuda_id)) |
| | pipeline = pipeline.to(unet.device) |
| | grid, images = generate_images(pipeline, prompt=args.caption, num_images_per_prompt=args.images_num, seed=args.seed) |
| | grid.save(os.path.join(args.pretrained_model_name_or_path, "{}.png".format("_".join(args.caption.split())))) |
| | dirname = os.path.join(args.pretrained_model_name_or_path, "_".join(args.caption.split())) |
| | os.makedirs(dirname, exist_ok=True) |
| | for idx, image in enumerate(images): |
| | image.save(os.path.join(dirname, "{}.png".format(idx + 1))) |
| |
|