import itertools from typing import Dict, List, Any import torch from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline from PIL import Image from skimage.exposure import match_histograms, adjust_gamma import numpy as np import diffusers class EndpointHandler(): def __init__(self, path=""): self.txt2img_pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16).to("cuda") self.txt2img_pipe.register_to_config(requires_safety_checker=False) self.txt2img_pipe.register_to_config(safety_checker = None) self.img2img_pipe = StableDiffusionImg2ImgPipeline( vae=self.txt2img_pipe.vae, text_encoder=self.txt2img_pipe.text_encoder, tokenizer=self.txt2img_pipe.tokenizer, unet=self.txt2img_pipe.unet, scheduler=self.txt2img_pipe.scheduler, safety_checker=self.txt2img_pipe.safety_checker, feature_extractor=self.txt2img_pipe.feature_extractor, ).to("cuda") def generate_story(self, prompts, do_img2img, num_inference_steps, guidance_scale): first_img = self.txt2img_pipe(prompts[0], num_inference_steps = num_inference_steps, guidance_scale = guidance_scale).images[0] ret = [first_img] for prompt in prompts[1:]: if do_img2img: out = self.img2img_pipe( prompt=prompt, strength=0.9, image=first_img, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) else: out = self.txt2img_pipe( prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) ret.append(out.images[0]) return ret def __call__(self, data: Any) -> List[List[Dict[str, float]]]: prompts = data["prompts"] num_images_per_prompt = data.get("num_images_per_prompt", 1) do_img2img = data.get("do_img2img", True) match_histogram = data.get("match_histogram", True) adjust_contrast = data.get("adjust_contrast", True) num_inference_steps = data.get("num_inference_steps", 25) guidance_scale = data.pop("guidance_scale", 7.5) all_stories = [] for _ in range(num_images_per_prompt): story = self.generate_story(prompts, do_img2img, num_inference_steps, guidance_scale) if match_histogram: story = [match_colors(story[0], img) for img in story] if adjust_contrast: story = [add_contrast(img) for img in story] all_stories.append(story) all_stories = [self.generate_story(prompts, do_img2img, num_inference_steps, guidance_scale) for _ in range(num_images_per_prompt)] return stories_to_grid(all_stories) def match_colors(src,tar): ret = match_histograms(np.array(tar), np.array(src), multichannel=True) return Image.fromarray(ret.astype('uint8')) def add_contrast(image, gamma_cor=1.2): ar = np.array(image) adjusted = adjust_gamma(ar, gamma_cor).astype(np.uint8) return Image.fromarray(adjusted) def stories_to_grid(stories): rows = len(stories) cols = len(stories[0]) imgs = list(itertools.chain(*stories)) return image_grid(imgs, rows, cols) def image_grid(imgs, rows, cols): assert len(imgs) == rows*cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) return grid