File size: 3,703 Bytes
fbc7131
fb2991c
 
 
 
 
 
a738611
fb2991c
 
 
 
 
fbc7131
 
fb2991c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbc7131
 
fb2991c
fbc7131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb2991c
 
 
 
 
 
 
 
fbc7131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import itertools
from typing import  Dict, List, Any
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from PIL import Image
from skimage.exposure import match_histograms, adjust_gamma
import numpy as np
import diffusers


class EndpointHandler():
    def __init__(self, path=""):
        self.txt2img_pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
        self.txt2img_pipe.register_to_config(requires_safety_checker=False)
        self.txt2img_pipe.register_to_config(safety_checker = None)
        self.img2img_pipe = StableDiffusionImg2ImgPipeline(
            vae=self.txt2img_pipe.vae,
            text_encoder=self.txt2img_pipe.text_encoder,
            tokenizer=self.txt2img_pipe.tokenizer,
            unet=self.txt2img_pipe.unet,
            scheduler=self.txt2img_pipe.scheduler,
            safety_checker=self.txt2img_pipe.safety_checker,
            feature_extractor=self.txt2img_pipe.feature_extractor,
        ).to("cuda")
        
    def generate_story(self, prompts, do_img2img, num_inference_steps, guidance_scale):
        first_img = self.txt2img_pipe(prompts[0], num_inference_steps = num_inference_steps, guidance_scale = guidance_scale).images[0]
        ret = [first_img]
        for prompt in prompts[1:]:
            if do_img2img:
                out = self.img2img_pipe(
                    prompt=prompt,
                    strength=0.9,
                    image=first_img,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale)
            else:
                out = self.txt2img_pipe(
                    prompt=prompt,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale)
            ret.append(out.images[0])
        return ret
    


    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        prompts = data["prompts"]
        num_images_per_prompt = data.get("num_images_per_prompt", 1)
        do_img2img = data.get("do_img2img", True)
        match_histogram = data.get("match_histogram", True)
        adjust_contrast = data.get("adjust_contrast", True)
        num_inference_steps = data.get("num_inference_steps", 25)
        guidance_scale = data.pop("guidance_scale", 7.5)    
        all_stories = []
        for _ in range(num_images_per_prompt):
            story = self.generate_story(prompts, do_img2img, num_inference_steps, guidance_scale)
            if match_histogram:
                story = [match_colors(story[0], img) for img in story]
            if adjust_contrast:
                story = [add_contrast(img) for img in story]
            all_stories.append(story)
        all_stories = [self.generate_story(prompts, do_img2img, num_inference_steps, guidance_scale) for _ in range(num_images_per_prompt)]
        return stories_to_grid(all_stories)
                                      
def match_colors(src,tar):
    ret = match_histograms(np.array(tar), np.array(src), multichannel=True)
    return Image.fromarray(ret.astype('uint8'))   

def add_contrast(image, gamma_cor=1.2):
    ar = np.array(image)
    adjusted = adjust_gamma(ar, gamma_cor).astype(np.uint8)
    return Image.fromarray(adjusted)

def stories_to_grid(stories):
    rows = len(stories)
    cols = len(stories[0])
    imgs = list(itertools.chain(*stories))
    return image_grid(imgs, rows, cols)

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid