File size: 7,869 Bytes
47ab351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

import torch
import torch.nn.functional as F

from ..modules.reference_net import hack_inference_forward
from ..models.basemodel import CustomizedColorizer, CustomizedWrapper
from ..modules.lora import LoraModules
from ..util import exists, expand_to_batch_size, instantiate_from_config, get_crop_scale, resize_and_crop



class InferenceWrapper(CustomizedWrapper, CustomizedColorizer):
    def __init__(
            self,
            scalar_embedder_config,
            img_embedder_config,
            lora_config = None,
            logits_embed = False,
            *args,
            **kwargs
    ):
        CustomizedColorizer.__init__(self, version="sdxl", *args, **kwargs)
        CustomizedWrapper.__init__(self)

        self.scalar_embedder = instantiate_from_config(scalar_embedder_config)
        self.img_embedder = instantiate_from_config(img_embedder_config)
        self.loras = LoraModules(self, **lora_config) if exists(lora_config) else None
        self.logits_embed = logits_embed

        new_model_list = {
            "scalar_embedder": self.scalar_embedder,
            "img_embedder": self.img_embedder,
            # "style_encoder": self.style_encoder,
        }
        self.switch_cond_modules += list(new_model_list.keys())
        self.model_list.update(new_model_list)

    def retrieve_attn_modules(self):
        scale_factor_levels = {"high": 0.5, "low": 0.25, "bottom": 0.25}

        from refnet.modules.transformer import BasicTransformerBlock
        from refnet.sampling import torch_dfs

        attn_modules = []
        for module in torch_dfs(self.model.diffusion_model):
            if isinstance(module, BasicTransformerBlock):
                attn_modules.append(module)

        self.attn_modules = {
            "high": [0, 1, 2, 3] + [64, 65, 66, 67, 68, 69],
            "low": [i for i in range(4, 24)] + [i for i in range(34, 64)],
            "bottom": [i for i in range(24, 34)],
            "encoder": [i for i in range(24)],
            "decoder": [i for i in range(34, len(attn_modules))]
        }
        self.attn_modules["modules"] = attn_modules

        for k in ["high", "low", "bottom"]:
            scale_factor = scale_factor_levels[k]
            for attn in self.attn_modules[k]:
                attn_modules[attn].scale_factor = scale_factor

    def adjust_reference_scale(self, scale_kwargs):
        for module in self.attn_modules["modules"]:
            module.reference_scale = scale_kwargs["scales"]["encoder"]

    def adjust_masked_attn(self, scale, mask_threshold, merge_scale):
        for layer in self.attn_layers:
            layer.mask_scale = scale
            layer.mask_threshold = mask_threshold
            layer.merge_scale = merge_scale

    def rescale_size(self, x: torch.Tensor, height, width):
        oh, ow = x.shape[2:]
        if oh < height or ow < width:
            dh, dw = height - oh, width - ow
            if dh > dw:
                iw = ow + int(dh * ow/oh)
                ih = height
            else:
                ih = oh + int(dw * oh/ow)
                iw = width
        else:
            ih, iw = oh, ow
        return torch.Tensor([ih]), torch.Tensor([iw])

    def get_learned_embedding(self, c, bg=False, mapping=False, sketch=None, *args, **kwargs):
        clip_emb = self.cond_stage_model.encode(c, "full").detach()
        wd_emb, logits = self.img_embedder.encode(c, pooled=False, return_logits=True)
        cls_emb, local_emb = clip_emb[:, :1], clip_emb[:, 1:]

        if mapping:
            _, sketch_logits = self.img_embedder.encode(-sketch, pooled=False, return_logits=True)
            sketch_logits.mean(dim=1, keepdim=True)
            logits = self.img_embedder.geometry_update(logits, sketch_logits)
        emb = self.proj(clip_emb, logits if self.logits_embed else wd_emb, bg)
        return emb, cls_emb

    def prepare_conditions(
            self,
            bs,
            sketch,
            reference,
            height,
            width,
            control_scale = (1., 1., 1., 1.),
            merge_scale = 0,
            mask_scale = 1.,
            fg_scale = 1.,
            bg_scale = 1.,
            smask = None,
            rmask = None,
            mask_threshold_ref = 0.,
            mask_threshold_sketch = 0.,
            style_enhance = False,
            fg_enhance = False,
            bg_enhance = False,
            background = None,
            targets = None,
            anchors = None,
            controls = None,
            target_scales = None,
            enhances = None,
            thresholds_list = None,
            geometry_map = False,
            latent_inpaint = False,
            low_vram = False,
            *args,
            **kwargs
    ):
        # prepare reference embedding
        # manipulate = self.check_manipulate(target_scales)
        c = {}
        uc = [{}, {}]

        if exists(reference):
            emb, cls_emb = self.get_learned_embedding(reference, sketch=sketch, mapping=geometry_map)
        else:
            emb, cls_emb = map(lambda t: torch.zeros_like(t), self.get_learned_embedding(sketch))

        h, w, score = torch.Tensor([height]), torch.Tensor([width]), torch.Tensor([7.])
        y = torch.cat(self.scalar_embedder(torch.cat([(h*w)**0.5, score])).cuda().chunk(2), 1)

        if bg_enhance:
            assert exists(rmask) and exists(smask)

            if low_vram:
                self.low_vram_shift(["first", "cond", "img_embedder", "proj"])
            
            if latent_inpaint and exists(background):
                bgh, bgw = background.shape[2:]
                ch, cw = get_crop_scale(torch.tensor([height]), torch.tensor([width]), bgh, bgw)
                hs_bg = self.get_first_stage_encoding(resize_and_crop(background, ch, cw, height, width).to(self.first_stage_model.dtype))
                bg_emb, _ = self.get_learned_embedding(background, bg=True)
                hs_bg = expand_to_batch_size(hs_bg, bs)
                c.update({"inpaint_bg": hs_bg})
            else:
                if exists(background):
                    bg_emb, _ = self.get_learned_embedding(background, bg=True)
                else:
                    bg_emb, _ = self.get_learned_embedding(
                        torch.where(rmask < mask_threshold_ref, reference, torch.ones_like(reference)),
                        True
                    )
            emb = torch.cat([emb, bg_emb], 1)

        if fg_enhance and exists(self.loras):
            self.loras.switch_lora(True, "foreground")
            if not bg_enhance:
                emb = emb.repeat(1, 2, 1)

        if fg_enhance or bg_enhance:
            # sketch mask for cross-attention
            smask = expand_to_batch_size(smask.to(self.dtype), bs)
            for d in [c] + uc:
                d.update({"mask": F.interpolate(smask, scale_factor=0.125)})
        elif exists(self.loras):
            self.loras.switch_lora(False)

        sketch = sketch.to(self.dtype)
        context = expand_to_batch_size(emb, bs).to(self.dtype)
        y = expand_to_batch_size(y, bs)
        uc_context = torch.zeros_like(context)

        control = []
        uc_control = []
        if low_vram:
            self.low_vram_shift(["control_encoder"])
        encoded_sketch = self.control_encoder(
            torch.cat([sketch, -torch.ones_like(sketch)], 0)
        )
        for idx, es in enumerate(encoded_sketch):
            es = es * control_scale[idx]
            ec, uec = es.chunk(2)
            control.append(expand_to_batch_size(ec, bs))
            uc_control.append(expand_to_batch_size(uec, bs))

        c.update({"control": control, "context": [context], "y": [y]})
        uc[0].update({"control": control, "context": [uc_context], "y": [y]})
        uc[1].update({"control": uc_control, "context": [context], "y": [y]})
        return c, uc