sd / app.py
Amarsaish's picture
Update app.py
8f35b93 verified
import argparse
import itertools
import math
import os
from contextlib import nullcontext
import random
import torch
import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import bitsandbytes as bnb
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
output_dir = './'
from diffusers import DPMSolverMultistepScheduler
pipe = StableDiffusionPipeline.from_pretrained(
output_dir,
scheduler = DPMSolverMultistepScheduler.from_pretrained(output_dir, subfolder="scheduler"),
torch_dtype=torch.float16,
)
import gradio as gr
def inference(prompt, num_samples):
all_images = []
images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=25).images
all_images.extend(images)
return all_images
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="prompt")
samples = gr.Slider(label="Samples",value=1)
run = gr.Button(value="Run")
with gr.Column():
gallery = gr.Gallery(show_label=False)
run.click(inference, inputs=[prompt,samples], outputs=gallery)
gr.Examples([["Foods in tokyo", 1,1]], [prompt,samples], gallery, inference, cache_examples=False)
demo.launch()