File size: 5,110 Bytes
7344bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from einops import rearrange
from torch import Tensor
from shared.attention import pay_attention


def attention(qkv_list, pe: Tensor, *, txt_len: int | None = None, NAG: dict | None = None) -> Tensor:
    q, k, v = qkv_list
    qkv_list.clear()
    q_list = [q]
    q = None
    q = apply_rope_(q_list, pe)
    k_list = [k]
    k = None
    k = apply_rope_(k_list, pe)

    if NAG is not None and txt_len is not None:
        cap_len = int(NAG.get("cap_embed_len", 0) or 0)
        prefix_len = int(NAG.get("prefix_len", 0) or 0)
        total_len = q.shape[2]
        img_start = txt_len
        packed_len = txt_len - prefix_len
        if cap_len > 0 and packed_len == (cap_len * 2) and img_start <= total_len:
            pos_start = prefix_len
            pos_end = pos_start + cap_len
            neg_start = pos_end
            neg_end = neg_start + cap_len
            if neg_end <= txt_len:
                # Build pos/neg sequences that share prefix + image tokens.
                q_neg = torch.cat( (q[:, :, :prefix_len], q[:, :, neg_start:neg_end], q[:, :, img_start:]), dim=2, )
                k_neg = torch.cat( (k[:, :, :prefix_len], k[:, :, neg_start:neg_end], k[:, :, img_start:]), dim=2, )
                v_neg = torch.cat( (v[:, :, :prefix_len], v[:, :, neg_start:neg_end], v[:, :, img_start:]), dim=2, )

                q_pos = torch.cat((q[:, :, :pos_end], q[:, :, img_start:]), dim=2)
                k_pos = torch.cat((k[:, :, :pos_end], k[:, :, img_start:]), dim=2)
                v_pos = torch.cat((v[:, :, :pos_end], v[:, :, img_start:]), dim=2)
                del q, k, v

                qkv_pos = [q_pos.transpose(1, 2), k_pos.transpose(1, 2), v_pos.transpose(1, 2)]
                q_pos = k_pos = v_pos = None
                x_pos = pay_attention(qkv_pos)
                x_pos = x_pos.flatten(2, 3)

                qkv_neg = [q_neg.transpose(1, 2), k_neg.transpose(1, 2), v_neg.transpose(1, 2)]
                q_neg = k_neg = v_neg = None
                x_neg = pay_attention(qkv_neg)
                x_neg = x_neg.flatten(2, 3)

                neg_slice_end = prefix_len + cap_len
                neg_out = x_neg[:, prefix_len:neg_slice_end].clone()
                nag_scale = NAG["scale"]
                nag_alpha = NAG["alpha"]
                nag_tau = NAG["tau"]
                dtype = x_pos.dtype

                x_guidance = x_neg
                x_guidance.mul_(1 - nag_scale)
                x_guidance.add_(x_pos, alpha=nag_scale)
                norm_positive = torch.norm(x_pos, p=1, dim=-1, keepdim=True)
                norm_guidance = torch.norm(x_guidance, p=1, dim=-1, keepdim=True)
                scale = norm_guidance / norm_positive
                torch.nan_to_num(scale, nan=10.0, posinf=10.0, neginf=10.0, out=scale)
                factor = (1 / (norm_guidance + 1e-7) * norm_positive * nag_tau).to(x_guidance.dtype)
                x_guidance = torch.where(scale > nag_tau, x_guidance * factor, x_guidance).to(dtype)
                del norm_positive, norm_guidance, scale, factor

                x_guidance.mul_(nag_alpha)
                x_guidance.add_(x_pos, alpha=(1 - nag_alpha))
                x_pos = None

                prefix_pos_guidance = x_guidance[:, :pos_end]
                img_guidance = x_guidance[:, pos_end:]
                x_guidance = None

                out = torch.cat([prefix_pos_guidance, neg_out, img_guidance], dim=1)
                prefix_pos_guidance = neg_out = img_guidance = None
                return out

    qkv_list = [q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)]
    del q, k, v
    x = pay_attention(qkv_list).transpose(1, 2)
    # x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
    x = rearrange(x, "B H L D -> B L (H D)")

    return x


def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
    assert dim % 2 == 0
    scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
    omega = 1.0 / (theta**scale)
    out = torch.einsum("...n,d->...nd", pos, omega)
    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
    return out.float()


def apply_rope_(q_list, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
    xq= q_list[0]
    xqshape = xq.shape
    xqdtype= xq.dtype
    q_list.clear()
    xq = xq.float().reshape(*xqshape[:-1], -1, 1, 2)
    xq_out = freqs_cis[..., 0] * xq[..., 0]
    xq = freqs_cis[..., 1] * xq[..., 1]

    xq_out.add_(xq)
    # xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]

    return xq_out.reshape(*xqshape).to(xqdtype)

def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)