| import torch
|
| from einops import rearrange
|
| from torch import Tensor
|
| from shared.attention import pay_attention
|
|
|
|
|
| def attention(qkv_list, pe: Tensor, *, txt_len: int | None = None, NAG: dict | None = None) -> Tensor:
|
| q, k, v = qkv_list
|
| qkv_list.clear()
|
| q_list = [q]
|
| q = None
|
| q = apply_rope_(q_list, pe)
|
| k_list = [k]
|
| k = None
|
| k = apply_rope_(k_list, pe)
|
|
|
| if NAG is not None and txt_len is not None:
|
| cap_len = int(NAG.get("cap_embed_len", 0) or 0)
|
| prefix_len = int(NAG.get("prefix_len", 0) or 0)
|
| total_len = q.shape[2]
|
| img_start = txt_len
|
| packed_len = txt_len - prefix_len
|
| if cap_len > 0 and packed_len == (cap_len * 2) and img_start <= total_len:
|
| pos_start = prefix_len
|
| pos_end = pos_start + cap_len
|
| neg_start = pos_end
|
| neg_end = neg_start + cap_len
|
| if neg_end <= txt_len:
|
|
|
| q_neg = torch.cat( (q[:, :, :prefix_len], q[:, :, neg_start:neg_end], q[:, :, img_start:]), dim=2, )
|
| k_neg = torch.cat( (k[:, :, :prefix_len], k[:, :, neg_start:neg_end], k[:, :, img_start:]), dim=2, )
|
| v_neg = torch.cat( (v[:, :, :prefix_len], v[:, :, neg_start:neg_end], v[:, :, img_start:]), dim=2, )
|
|
|
| q_pos = torch.cat((q[:, :, :pos_end], q[:, :, img_start:]), dim=2)
|
| k_pos = torch.cat((k[:, :, :pos_end], k[:, :, img_start:]), dim=2)
|
| v_pos = torch.cat((v[:, :, :pos_end], v[:, :, img_start:]), dim=2)
|
| del q, k, v
|
|
|
| qkv_pos = [q_pos.transpose(1, 2), k_pos.transpose(1, 2), v_pos.transpose(1, 2)]
|
| q_pos = k_pos = v_pos = None
|
| x_pos = pay_attention(qkv_pos)
|
| x_pos = x_pos.flatten(2, 3)
|
|
|
| qkv_neg = [q_neg.transpose(1, 2), k_neg.transpose(1, 2), v_neg.transpose(1, 2)]
|
| q_neg = k_neg = v_neg = None
|
| x_neg = pay_attention(qkv_neg)
|
| x_neg = x_neg.flatten(2, 3)
|
|
|
| neg_slice_end = prefix_len + cap_len
|
| neg_out = x_neg[:, prefix_len:neg_slice_end].clone()
|
| nag_scale = NAG["scale"]
|
| nag_alpha = NAG["alpha"]
|
| nag_tau = NAG["tau"]
|
| dtype = x_pos.dtype
|
|
|
| x_guidance = x_neg
|
| x_guidance.mul_(1 - nag_scale)
|
| x_guidance.add_(x_pos, alpha=nag_scale)
|
| norm_positive = torch.norm(x_pos, p=1, dim=-1, keepdim=True)
|
| norm_guidance = torch.norm(x_guidance, p=1, dim=-1, keepdim=True)
|
| scale = norm_guidance / norm_positive
|
| torch.nan_to_num(scale, nan=10.0, posinf=10.0, neginf=10.0, out=scale)
|
| factor = (1 / (norm_guidance + 1e-7) * norm_positive * nag_tau).to(x_guidance.dtype)
|
| x_guidance = torch.where(scale > nag_tau, x_guidance * factor, x_guidance).to(dtype)
|
| del norm_positive, norm_guidance, scale, factor
|
|
|
| x_guidance.mul_(nag_alpha)
|
| x_guidance.add_(x_pos, alpha=(1 - nag_alpha))
|
| x_pos = None
|
|
|
| prefix_pos_guidance = x_guidance[:, :pos_end]
|
| img_guidance = x_guidance[:, pos_end:]
|
| x_guidance = None
|
|
|
| out = torch.cat([prefix_pos_guidance, neg_out, img_guidance], dim=1)
|
| prefix_pos_guidance = neg_out = img_guidance = None
|
| return out
|
|
|
| qkv_list = [q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)]
|
| del q, k, v
|
| x = pay_attention(qkv_list).transpose(1, 2)
|
|
|
| x = rearrange(x, "B H L D -> B L (H D)")
|
|
|
| return x
|
|
|
|
|
| def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
| assert dim % 2 == 0
|
| scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
|
| omega = 1.0 / (theta**scale)
|
| out = torch.einsum("...n,d->...nd", pos, omega)
|
| out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
| out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
| return out.float()
|
|
|
|
|
| def apply_rope_(q_list, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
| xq= q_list[0]
|
| xqshape = xq.shape
|
| xqdtype= xq.dtype
|
| q_list.clear()
|
| xq = xq.float().reshape(*xqshape[:-1], -1, 1, 2)
|
| xq_out = freqs_cis[..., 0] * xq[..., 0]
|
| xq = freqs_cis[..., 1] * xq[..., 1]
|
|
|
| xq_out.add_(xq)
|
|
|
|
|
| return xq_out.reshape(*xqshape).to(xqdtype)
|
|
|
| def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
| xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
|
|