| import argparse | |
| import itertools | |
| import math | |
| import os | |
| from contextlib import nullcontext | |
| import random | |
| import torch | |
| import PIL | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import set_seed | |
| from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel | |
| from diffusers.optimization import get_scheduler | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
| from PIL import Image | |
| from torchvision import transforms | |
| from tqdm.auto import tqdm | |
| from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | |
| import bitsandbytes as bnb | |
| def image_grid(imgs, rows, cols): | |
| assert len(imgs) == rows*cols | |
| w, h = imgs[0].size | |
| grid = Image.new('RGB', size=(cols*w, rows*h)) | |
| grid_w, grid_h = grid.size | |
| for i, img in enumerate(imgs): | |
| grid.paste(img, box=(i%cols*w, i//cols*h)) | |
| return grid | |
| output_dir = './' | |
| from diffusers import DPMSolverMultistepScheduler | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| output_dir, | |
| scheduler = DPMSolverMultistepScheduler.from_pretrained(output_dir, subfolder="scheduler"), | |
| torch_dtype=torch.float16, | |
| ) | |
| import gradio as gr | |
| def inference(prompt, num_samples): | |
| all_images = [] | |
| images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=25).images | |
| all_images.extend(images) | |
| return all_images | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="prompt") | |
| samples = gr.Slider(label="Samples",value=1) | |
| run = gr.Button(value="Run") | |
| with gr.Column(): | |
| gallery = gr.Gallery(show_label=False) | |
| run.click(inference, inputs=[prompt,samples], outputs=gallery) | |
| gr.Examples([["Foods in tokyo", 1,1]], [prompt,samples], gallery, inference, cache_examples=False) | |
| demo.launch() | |