| import itertools | |
| from typing import Dict, List, Any | |
| import torch | |
| from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline | |
| from PIL import Image | |
| from skimage.exposure import match_histograms, adjust_gamma | |
| import numpy as np | |
| import diffusers | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| self.txt2img_pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16).to("cuda") | |
| self.txt2img_pipe.register_to_config(requires_safety_checker=False) | |
| self.txt2img_pipe.register_to_config(safety_checker = None) | |
| self.img2img_pipe = StableDiffusionImg2ImgPipeline( | |
| vae=self.txt2img_pipe.vae, | |
| text_encoder=self.txt2img_pipe.text_encoder, | |
| tokenizer=self.txt2img_pipe.tokenizer, | |
| unet=self.txt2img_pipe.unet, | |
| scheduler=self.txt2img_pipe.scheduler, | |
| safety_checker=self.txt2img_pipe.safety_checker, | |
| feature_extractor=self.txt2img_pipe.feature_extractor, | |
| ).to("cuda") | |
| def generate_story(self, prompts, do_img2img, num_inference_steps, guidance_scale): | |
| first_img = self.txt2img_pipe(prompts[0], num_inference_steps = num_inference_steps, guidance_scale = guidance_scale).images[0] | |
| ret = [first_img] | |
| for prompt in prompts[1:]: | |
| if do_img2img: | |
| out = self.img2img_pipe( | |
| prompt=prompt, | |
| strength=0.9, | |
| image=first_img, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale) | |
| else: | |
| out = self.txt2img_pipe( | |
| prompt=prompt, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale) | |
| ret.append(out.images[0]) | |
| return ret | |
| def __call__(self, data: Any) -> List[List[Dict[str, float]]]: | |
| prompts = data["prompts"] | |
| num_images_per_prompt = data.get("num_images_per_prompt", 1) | |
| do_img2img = data.get("do_img2img", True) | |
| match_histogram = data.get("match_histogram", True) | |
| adjust_contrast = data.get("adjust_contrast", True) | |
| num_inference_steps = data.get("num_inference_steps", 25) | |
| guidance_scale = data.pop("guidance_scale", 7.5) | |
| all_stories = [] | |
| for _ in range(num_images_per_prompt): | |
| story = self.generate_story(prompts, do_img2img, num_inference_steps, guidance_scale) | |
| if match_histogram: | |
| story = [match_colors(story[0], img) for img in story] | |
| if adjust_contrast: | |
| story = [add_contrast(img) for img in story] | |
| all_stories.append(story) | |
| all_stories = [self.generate_story(prompts, do_img2img, num_inference_steps, guidance_scale) for _ in range(num_images_per_prompt)] | |
| return stories_to_grid(all_stories) | |
| def match_colors(src,tar): | |
| ret = match_histograms(np.array(tar), np.array(src), multichannel=True) | |
| return Image.fromarray(ret.astype('uint8')) | |
| def add_contrast(image, gamma_cor=1.2): | |
| ar = np.array(image) | |
| adjusted = adjust_gamma(ar, gamma_cor).astype(np.uint8) | |
| return Image.fromarray(adjusted) | |
| def stories_to_grid(stories): | |
| rows = len(stories) | |
| cols = len(stories[0]) | |
| imgs = list(itertools.chain(*stories)) | |
| return image_grid(imgs, rows, cols) | |
| def image_grid(imgs, rows, cols): | |
| assert len(imgs) == rows*cols | |
| w, h = imgs[0].size | |
| grid = Image.new('RGB', size=(cols*w, rows*h)) | |
| for i, img in enumerate(imgs): | |
| grid.paste(img, box=(i%cols*w, i//cols*h)) | |
| return grid |