Wan2GP / models /flux /math.py
Egnalkram's picture
Upload folder using huggingface_hub
4689c2b verified
import torch
from einops import rearrange
from torch import Tensor
from shared.attention import pay_attention
def attention(qkv_list, pe: Tensor, *, txt_len: int | None = None, NAG: dict | None = None) -> Tensor:
q, k, v = qkv_list
qkv_list.clear()
q_list = [q]
q = None
q = apply_rope_(q_list, pe)
k_list = [k]
k = None
k = apply_rope_(k_list, pe)
if NAG is not None and txt_len is not None:
cap_len = int(NAG.get("cap_embed_len", 0) or 0)
prefix_len = int(NAG.get("prefix_len", 0) or 0)
total_len = q.shape[2]
img_start = txt_len
packed_len = txt_len - prefix_len
if cap_len > 0 and packed_len == (cap_len * 2) and img_start <= total_len:
pos_start = prefix_len
pos_end = pos_start + cap_len
neg_start = pos_end
neg_end = neg_start + cap_len
if neg_end <= txt_len:
# Build pos/neg sequences that share prefix + image tokens.
q_neg = torch.cat( (q[:, :, :prefix_len], q[:, :, neg_start:neg_end], q[:, :, img_start:]), dim=2, )
k_neg = torch.cat( (k[:, :, :prefix_len], k[:, :, neg_start:neg_end], k[:, :, img_start:]), dim=2, )
v_neg = torch.cat( (v[:, :, :prefix_len], v[:, :, neg_start:neg_end], v[:, :, img_start:]), dim=2, )
q_pos = torch.cat((q[:, :, :pos_end], q[:, :, img_start:]), dim=2)
k_pos = torch.cat((k[:, :, :pos_end], k[:, :, img_start:]), dim=2)
v_pos = torch.cat((v[:, :, :pos_end], v[:, :, img_start:]), dim=2)
del q, k, v
qkv_pos = [q_pos.transpose(1, 2), k_pos.transpose(1, 2), v_pos.transpose(1, 2)]
q_pos = k_pos = v_pos = None
x_pos = pay_attention(qkv_pos)
x_pos = x_pos.flatten(2, 3)
qkv_neg = [q_neg.transpose(1, 2), k_neg.transpose(1, 2), v_neg.transpose(1, 2)]
q_neg = k_neg = v_neg = None
x_neg = pay_attention(qkv_neg)
x_neg = x_neg.flatten(2, 3)
neg_slice_end = prefix_len + cap_len
neg_out = x_neg[:, prefix_len:neg_slice_end].clone()
nag_scale = NAG["scale"]
nag_alpha = NAG["alpha"]
nag_tau = NAG["tau"]
dtype = x_pos.dtype
x_guidance = x_neg
x_guidance.mul_(1 - nag_scale)
x_guidance.add_(x_pos, alpha=nag_scale)
norm_positive = torch.norm(x_pos, p=1, dim=-1, keepdim=True)
norm_guidance = torch.norm(x_guidance, p=1, dim=-1, keepdim=True)
scale = norm_guidance / norm_positive
torch.nan_to_num(scale, nan=10.0, posinf=10.0, neginf=10.0, out=scale)
factor = (1 / (norm_guidance + 1e-7) * norm_positive * nag_tau).to(x_guidance.dtype)
x_guidance = torch.where(scale > nag_tau, x_guidance * factor, x_guidance).to(dtype)
del norm_positive, norm_guidance, scale, factor
x_guidance.mul_(nag_alpha)
x_guidance.add_(x_pos, alpha=(1 - nag_alpha))
x_pos = None
prefix_pos_guidance = x_guidance[:, :pos_end]
img_guidance = x_guidance[:, pos_end:]
x_guidance = None
out = torch.cat([prefix_pos_guidance, neg_out, img_guidance], dim=1)
prefix_pos_guidance = neg_out = img_guidance = None
return out
qkv_list = [q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)]
del q, k, v
x = pay_attention(qkv_list).transpose(1, 2)
# x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope_(q_list, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq= q_list[0]
xqshape = xq.shape
xqdtype= xq.dtype
q_list.clear()
xq = xq.float().reshape(*xqshape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq[..., 0]
xq = freqs_cis[..., 1] * xq[..., 1]
xq_out.add_(xq)
# xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
return xq_out.reshape(*xqshape).to(xqdtype)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)