| | import torch |
| | import math |
| | import torch.nn.functional as F |
| | from comfy.ldm.modules.attention import optimized_attention |
| | from .utils import tensor_to_size |
| |
|
| | class Attn2Replace: |
| | def __init__(self, callback=None, **kwargs): |
| | self.callback = [callback] |
| | self.kwargs = [kwargs] |
| |
|
| | def add(self, callback, **kwargs): |
| | self.callback.append(callback) |
| | self.kwargs.append(kwargs) |
| |
|
| | for key, value in kwargs.items(): |
| | setattr(self, key, value) |
| |
|
| | def __call__(self, q, k, v, extra_options): |
| | dtype = q.dtype |
| | out = optimized_attention(q, k, v, extra_options["n_heads"]) |
| | sigma = extra_options["sigmas"].detach().cpu()[0].item() if 'sigmas' in extra_options else 999999999.9 |
| |
|
| | for i, callback in enumerate(self.callback): |
| | if sigma <= self.kwargs[i]["sigma_start"] and sigma >= self.kwargs[i]["sigma_end"]: |
| | out = out + callback(out, q, k, v, extra_options, **self.kwargs[i]) |
| |
|
| | return out.to(dtype=dtype) |
| |
|
| | def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=None, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only', **kwargs): |
| | dtype = q.dtype |
| | cond_or_uncond = extra_options["cond_or_uncond"] |
| | block_type = extra_options["block"][0] |
| | |
| | t_idx = extra_options["transformer_index"] |
| | layers = 11 if '101_to_k_ip' in ipadapter.ip_layers.to_kvs else 16 |
| | k_key = module_key + "_to_k_ip" |
| | v_key = module_key + "_to_v_ip" |
| |
|
| | |
| | ad_params = extra_options['ad_params'] if "ad_params" in extra_options else None |
| |
|
| | b = q.shape[0] |
| | seq_len = q.shape[1] |
| | batch_prompt = b // len(cond_or_uncond) |
| | _, _, oh, ow = extra_options["original_shape"] |
| |
|
| | if weight_type == 'ease in': |
| | weight = weight * (0.05 + 0.95 * (1 - t_idx / layers)) |
| | elif weight_type == 'ease out': |
| | weight = weight * (0.05 + 0.95 * (t_idx / layers)) |
| | elif weight_type == 'ease in-out': |
| | weight = weight * (0.05 + 0.95 * (1 - abs(t_idx - (layers/2)) / (layers/2))) |
| | elif weight_type == 'reverse in-out': |
| | weight = weight * (0.05 + 0.95 * (abs(t_idx - (layers/2)) / (layers/2))) |
| | elif weight_type == 'weak input' and block_type == 'input': |
| | weight = weight * 0.2 |
| | elif weight_type == 'weak middle' and block_type == 'middle': |
| | weight = weight * 0.2 |
| | elif weight_type == 'weak output' and block_type == 'output': |
| | weight = weight * 0.2 |
| | elif weight_type == 'strong middle' and (block_type == 'input' or block_type == 'output'): |
| | weight = weight * 0.2 |
| | elif isinstance(weight, dict): |
| | if t_idx not in weight: |
| | return 0 |
| |
|
| | if weight_type == "style transfer precise": |
| | if layers == 11 and t_idx == 3: |
| | uncond = cond |
| | cond = cond * 0 |
| | elif layers == 16 and (t_idx == 4 or t_idx == 5): |
| | uncond = cond |
| | cond = cond * 0 |
| | elif weight_type == "composition precise": |
| | if layers == 11 and t_idx != 3: |
| | uncond = cond |
| | cond = cond * 0 |
| | elif layers == 16 and (t_idx != 4 and t_idx != 5): |
| | uncond = cond |
| | cond = cond * 0 |
| |
|
| | weight = weight[t_idx] |
| |
|
| | if cond_alt is not None and t_idx in cond_alt: |
| | cond = cond_alt[t_idx] |
| | del cond_alt |
| |
|
| | if unfold_batch: |
| | |
| | if ad_params is not None and ad_params["sub_idxs"] is not None: |
| | if isinstance(weight, torch.Tensor): |
| | weight = tensor_to_size(weight, ad_params["full_length"]) |
| | weight = torch.Tensor(weight[ad_params["sub_idxs"]]) |
| | if torch.all(weight == 0): |
| | return 0 |
| | weight = weight.repeat(len(cond_or_uncond), 1, 1) |
| | elif weight == 0: |
| | return 0 |
| |
|
| | |
| | if cond.shape[0] >= ad_params["full_length"]: |
| | cond = torch.Tensor(cond[ad_params["sub_idxs"]]) |
| | uncond = torch.Tensor(uncond[ad_params["sub_idxs"]]) |
| | |
| | else: |
| | cond = tensor_to_size(cond, ad_params["full_length"]) |
| | uncond = tensor_to_size(uncond, ad_params["full_length"]) |
| | cond = cond[ad_params["sub_idxs"]] |
| | uncond = uncond[ad_params["sub_idxs"]] |
| | else: |
| | if isinstance(weight, torch.Tensor): |
| | weight = tensor_to_size(weight, batch_prompt) |
| | if torch.all(weight == 0): |
| | return 0 |
| | weight = weight.repeat(len(cond_or_uncond), 1, 1) |
| | elif weight == 0: |
| | return 0 |
| |
|
| | cond = tensor_to_size(cond, batch_prompt) |
| | uncond = tensor_to_size(uncond, batch_prompt) |
| |
|
| | k_cond = ipadapter.ip_layers.to_kvs[k_key](cond) |
| | k_uncond = ipadapter.ip_layers.to_kvs[k_key](uncond) |
| | v_cond = ipadapter.ip_layers.to_kvs[v_key](cond) |
| | v_uncond = ipadapter.ip_layers.to_kvs[v_key](uncond) |
| | else: |
| | |
| | if isinstance(weight, torch.Tensor): |
| | weight = tensor_to_size(weight, batch_prompt) |
| | if torch.all(weight == 0): |
| | return 0 |
| | weight = weight.repeat(len(cond_or_uncond), 1, 1) |
| | elif weight == 0: |
| | return 0 |
| | |
| | k_cond = ipadapter.ip_layers.to_kvs[k_key](cond).repeat(batch_prompt, 1, 1) |
| | k_uncond = ipadapter.ip_layers.to_kvs[k_key](uncond).repeat(batch_prompt, 1, 1) |
| | v_cond = ipadapter.ip_layers.to_kvs[v_key](cond).repeat(batch_prompt, 1, 1) |
| | v_uncond = ipadapter.ip_layers.to_kvs[v_key](uncond).repeat(batch_prompt, 1, 1) |
| |
|
| | if len(cond_or_uncond) == 3: |
| | ip_k = torch.cat([(k_cond, k_uncond, k_cond)[i] for i in cond_or_uncond], dim=0) |
| | ip_v = torch.cat([(v_cond, v_uncond, v_cond)[i] for i in cond_or_uncond], dim=0) |
| | else: |
| | ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) |
| | ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) |
| |
|
| | if embeds_scaling == 'K+mean(V) w/ C penalty': |
| | scaling = float(ip_k.shape[2]) / 1280.0 |
| | weight = weight * scaling |
| | ip_k = ip_k * weight |
| | ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True) |
| | ip_v = (ip_v - ip_v_mean) + ip_v_mean * weight |
| | out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) |
| | del ip_v_mean |
| | elif embeds_scaling == 'K+V w/ C penalty': |
| | scaling = float(ip_k.shape[2]) / 1280.0 |
| | weight = weight * scaling |
| | ip_k = ip_k * weight |
| | ip_v = ip_v * weight |
| | out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) |
| | elif embeds_scaling == 'K+V': |
| | ip_k = ip_k * weight |
| | ip_v = ip_v * weight |
| | out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) |
| | else: |
| | |
| | out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) |
| | out_ip = out_ip * weight |
| |
|
| | if mask is not None: |
| | mask_h = oh / math.sqrt(oh * ow / seq_len) |
| | mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) |
| | mask_w = seq_len // mask_h |
| |
|
| | |
| | if (mask.shape[0] > 1 and ad_params is not None and ad_params["sub_idxs"] is not None): |
| | |
| | if mask.shape[0] >= ad_params["full_length"]: |
| | mask = torch.Tensor(mask[ad_params["sub_idxs"]]) |
| | mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1) |
| | else: |
| | mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1) |
| | mask = tensor_to_size(mask, ad_params["full_length"]) |
| | mask = mask[ad_params["sub_idxs"]] |
| | else: |
| | mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1) |
| | mask = tensor_to_size(mask, batch_prompt) |
| |
|
| | mask = mask.repeat(len(cond_or_uncond), 1, 1) |
| | mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2]) |
| |
|
| | |
| | mask_len = mask_h * mask_w |
| | if mask_len < seq_len: |
| | pad_len = seq_len - mask_len |
| | pad1 = pad_len // 2 |
| | pad2 = pad_len - pad1 |
| | mask = F.pad(mask, (0, 0, pad1, pad2), value=0.0) |
| | elif mask_len > seq_len: |
| | crop_start = (mask_len - seq_len) // 2 |
| | mask = mask[:, crop_start:crop_start+seq_len, :] |
| |
|
| | out_ip = out_ip * mask |
| |
|
| | |
| |
|
| | return out_ip.to(dtype=dtype) |
| |
|