dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
"""
Credit: laksjdjf
https://github.com/laksjdjf/cgem156-ComfyUI/blob/main/scripts/attention_couple/node.py
Modified by. Haoming02 to work with Forge
"""
from scripts.attention_masks import get_mask, lcm_for_list
from modules.devices import get_optimal_device
import torch
class AttentionCouple:
def patch_unet(self, model, base_mask, kwargs: dict):
new_model = model.clone()
dtype = new_model.model.diffusion_model.dtype
device = get_optimal_device()
num_conds = len(kwargs) // 2 + 1
mask = [base_mask] + [kwargs[f"mask_{i}"] for i in range(1, num_conds)]
mask = torch.stack(mask, dim=0).to(device, dtype=dtype)
assert mask.sum(dim=0).min() > 0, "Masks must not contain zeroes..."
mask = mask / mask.sum(dim=0, keepdim=True)
conds = [
kwargs[f"cond_{i}"][0][0].to(device, dtype=dtype)
for i in range(1, num_conds)
]
num_tokens = [cond.shape[1] for cond in conds]
def attn2_patch(q, k, v, extra_options):
assert k.mean() == v.mean(), "k and v must be the same."
cond_or_unconds = extra_options["cond_or_uncond"]
num_chunks = len(cond_or_unconds)
self.batch_size = q.shape[0] // num_chunks
q_chunks = q.chunk(num_chunks, dim=0)
k_chunks = k.chunk(num_chunks, dim=0)
lcm_tokens = lcm_for_list(num_tokens + [k.shape[1]])
conds_tensor = torch.cat(
[
cond.repeat(self.batch_size, lcm_tokens // num_tokens[i], 1)
for i, cond in enumerate(conds)
],
dim=0,
)
qs, ks = [], []
for i, cond_or_uncond in enumerate(cond_or_unconds):
k_target = k_chunks[i].repeat(1, lcm_tokens // k.shape[1], 1)
if cond_or_uncond == 1: # uncond
qs.append(q_chunks[i])
ks.append(k_target)
else:
qs.append(q_chunks[i].repeat(num_conds, 1, 1))
ks.append(torch.cat([k_target, conds_tensor], dim=0))
qs = torch.cat(qs, dim=0)
ks = torch.cat(ks, dim=0)
return qs, ks, ks
def attn2_output_patch(out, extra_options):
cond_or_unconds = extra_options["cond_or_uncond"]
mask_downsample = get_mask(
mask, self.batch_size, out.shape[1], extra_options["original_shape"]
)
outputs = []
pos = 0
for cond_or_uncond in cond_or_unconds:
if cond_or_uncond == 1: # uncond
outputs.append(out[pos : pos + self.batch_size])
pos += self.batch_size
else:
masked_output = (
out[pos : pos + num_conds * self.batch_size] * mask_downsample
).view(num_conds, self.batch_size, out.shape[1], out.shape[2])
masked_output = masked_output.sum(dim=0)
outputs.append(masked_output)
pos += num_conds * self.batch_size
return torch.cat(outputs, dim=0)
new_model.set_model_attn2_patch(attn2_patch)
new_model.set_model_attn2_output_patch(attn2_output_patch)
return new_model