diffusion_demo / handler.py
Slawek Biel
grid_generation
fbc7131
import itertools
from typing import Dict, List, Any
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from PIL import Image
from skimage.exposure import match_histograms, adjust_gamma
import numpy as np
import diffusers
class EndpointHandler():
def __init__(self, path=""):
self.txt2img_pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
self.txt2img_pipe.register_to_config(requires_safety_checker=False)
self.txt2img_pipe.register_to_config(safety_checker = None)
self.img2img_pipe = StableDiffusionImg2ImgPipeline(
vae=self.txt2img_pipe.vae,
text_encoder=self.txt2img_pipe.text_encoder,
tokenizer=self.txt2img_pipe.tokenizer,
unet=self.txt2img_pipe.unet,
scheduler=self.txt2img_pipe.scheduler,
safety_checker=self.txt2img_pipe.safety_checker,
feature_extractor=self.txt2img_pipe.feature_extractor,
).to("cuda")
def generate_story(self, prompts, do_img2img, num_inference_steps, guidance_scale):
first_img = self.txt2img_pipe(prompts[0], num_inference_steps = num_inference_steps, guidance_scale = guidance_scale).images[0]
ret = [first_img]
for prompt in prompts[1:]:
if do_img2img:
out = self.img2img_pipe(
prompt=prompt,
strength=0.9,
image=first_img,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale)
else:
out = self.txt2img_pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale)
ret.append(out.images[0])
return ret
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
prompts = data["prompts"]
num_images_per_prompt = data.get("num_images_per_prompt", 1)
do_img2img = data.get("do_img2img", True)
match_histogram = data.get("match_histogram", True)
adjust_contrast = data.get("adjust_contrast", True)
num_inference_steps = data.get("num_inference_steps", 25)
guidance_scale = data.pop("guidance_scale", 7.5)
all_stories = []
for _ in range(num_images_per_prompt):
story = self.generate_story(prompts, do_img2img, num_inference_steps, guidance_scale)
if match_histogram:
story = [match_colors(story[0], img) for img in story]
if adjust_contrast:
story = [add_contrast(img) for img in story]
all_stories.append(story)
all_stories = [self.generate_story(prompts, do_img2img, num_inference_steps, guidance_scale) for _ in range(num_images_per_prompt)]
return stories_to_grid(all_stories)
def match_colors(src,tar):
ret = match_histograms(np.array(tar), np.array(src), multichannel=True)
return Image.fromarray(ret.astype('uint8'))
def add_contrast(image, gamma_cor=1.2):
ar = np.array(image)
adjusted = adjust_gamma(ar, gamma_cor).astype(np.uint8)
return Image.fromarray(adjusted)
def stories_to_grid(stories):
rows = len(stories)
cols = len(stories[0])
imgs = list(itertools.chain(*stories))
return image_grid(imgs, rows, cols)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid