|
|
""" |
|
|
Credit: laksjdjf |
|
|
https://github.com/laksjdjf/cgem156-ComfyUI/blob/main/scripts/attention_couple/node.py |
|
|
|
|
|
Modified by. Haoming02 to work with Forge |
|
|
""" |
|
|
|
|
|
from scripts.attention_masks import get_mask, lcm_for_list |
|
|
from modules.devices import get_optimal_device |
|
|
import torch |
|
|
|
|
|
|
|
|
class AttentionCouple: |
|
|
|
|
|
def patch_unet(self, model, base_mask, kwargs: dict): |
|
|
|
|
|
new_model = model.clone() |
|
|
dtype = new_model.model.diffusion_model.dtype |
|
|
device = get_optimal_device() |
|
|
|
|
|
num_conds = len(kwargs) // 2 + 1 |
|
|
|
|
|
mask = [base_mask] + [kwargs[f"mask_{i}"] for i in range(1, num_conds)] |
|
|
mask = torch.stack(mask, dim=0).to(device, dtype=dtype) |
|
|
assert mask.sum(dim=0).min() > 0, "Masks must not contain zeroes..." |
|
|
mask = mask / mask.sum(dim=0, keepdim=True) |
|
|
|
|
|
conds = [ |
|
|
kwargs[f"cond_{i}"][0][0].to(device, dtype=dtype) |
|
|
for i in range(1, num_conds) |
|
|
] |
|
|
num_tokens = [cond.shape[1] for cond in conds] |
|
|
|
|
|
def attn2_patch(q, k, v, extra_options): |
|
|
assert k.mean() == v.mean(), "k and v must be the same." |
|
|
cond_or_unconds = extra_options["cond_or_uncond"] |
|
|
num_chunks = len(cond_or_unconds) |
|
|
self.batch_size = q.shape[0] // num_chunks |
|
|
q_chunks = q.chunk(num_chunks, dim=0) |
|
|
k_chunks = k.chunk(num_chunks, dim=0) |
|
|
lcm_tokens = lcm_for_list(num_tokens + [k.shape[1]]) |
|
|
conds_tensor = torch.cat( |
|
|
[ |
|
|
cond.repeat(self.batch_size, lcm_tokens // num_tokens[i], 1) |
|
|
for i, cond in enumerate(conds) |
|
|
], |
|
|
dim=0, |
|
|
) |
|
|
|
|
|
qs, ks = [], [] |
|
|
for i, cond_or_uncond in enumerate(cond_or_unconds): |
|
|
k_target = k_chunks[i].repeat(1, lcm_tokens // k.shape[1], 1) |
|
|
if cond_or_uncond == 1: |
|
|
qs.append(q_chunks[i]) |
|
|
ks.append(k_target) |
|
|
else: |
|
|
qs.append(q_chunks[i].repeat(num_conds, 1, 1)) |
|
|
ks.append(torch.cat([k_target, conds_tensor], dim=0)) |
|
|
|
|
|
qs = torch.cat(qs, dim=0) |
|
|
ks = torch.cat(ks, dim=0) |
|
|
|
|
|
return qs, ks, ks |
|
|
|
|
|
def attn2_output_patch(out, extra_options): |
|
|
cond_or_unconds = extra_options["cond_or_uncond"] |
|
|
mask_downsample = get_mask( |
|
|
mask, self.batch_size, out.shape[1], extra_options["original_shape"] |
|
|
) |
|
|
outputs = [] |
|
|
pos = 0 |
|
|
for cond_or_uncond in cond_or_unconds: |
|
|
if cond_or_uncond == 1: |
|
|
outputs.append(out[pos : pos + self.batch_size]) |
|
|
pos += self.batch_size |
|
|
else: |
|
|
masked_output = ( |
|
|
out[pos : pos + num_conds * self.batch_size] * mask_downsample |
|
|
).view(num_conds, self.batch_size, out.shape[1], out.shape[2]) |
|
|
masked_output = masked_output.sum(dim=0) |
|
|
outputs.append(masked_output) |
|
|
pos += num_conds * self.batch_size |
|
|
return torch.cat(outputs, dim=0) |
|
|
|
|
|
new_model.set_model_attn2_patch(attn2_patch) |
|
|
new_model.set_model_attn2_output_patch(attn2_output_patch) |
|
|
|
|
|
return new_model |
|
|
|