File size: 4,336 Bytes
c6535db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
from ..wanvideo.modules.attention import attention

def vector_to_list(tensor, lens, dim):
    return list(torch.split(tensor, lens, dim=dim))

def list_to_vector(tensor_list, dim):
    lens = [tensor.shape[dim] for tensor in tensor_list]
    tensor = torch.cat(tensor_list, dim)
    return tensor, lens

def merge_token_lists(list1, list2, dim):
    assert(len(list1) == len(list2))
    return [torch.cat((t1, t2), dim) for t1, t2 in zip(list1, list2)]

try:
    from sageattention import sageattn_varlen
except ImportError:
    sageattn_varlen = None

class WanLynxIPCrossAttention(nn.Module):
    def __init__(self, cross_attention_dim=5120, dim=5120, n_registers=16, bias=True):
        super().__init__()
        self.to_k_ip = nn.Linear(cross_attention_dim, dim, bias=bias)
        self.to_v_ip = nn.Linear(cross_attention_dim, dim, bias=bias)
        if n_registers > 0:
            self.registers = nn.Parameter(torch.randn(1, n_registers, cross_attention_dim) / dim**0.5)
        else:
            self.registers = None

    def forward(self, block, q, ip_x):
        b, s, n, d = q.shape
        ip_lens = [ip_x.shape[1]]

        if self.registers is not None and ip_x is not None and ip_x.shape[0] == 1:
            ip_x = torch.cat([ip_x, self.registers], dim=1)
            ip_lens[0] += self.registers.shape[1]
        elif self.registers is not None and ip_x.shape[0] > 1:
            ip_x_list = vector_to_list(ip_x, ip_lens, 1)
            ip_x_list = merge_token_lists(ip_x_list, [self.registers] * len(ip_x_list), 1)
            ip_x, ip_lens = list_to_vector(ip_x_list, 1)

        ip_key = self.to_k_ip(ip_x)
        ip_value = self.to_v_ip(ip_x)

        if self.registers is None: # lite model normalization
            ip_key = ip_key * torch.rsqrt(ip_key.pow(2).mean(dim=-1, keepdim=True) + 1e-5).to(ip_key.dtype)
        else: # full model
            ip_key = block.norm_k(ip_key)

        return attention(
            q,
            ip_key.view(b, -1, n, d),
            ip_value.view(b, -1, n, d)
        ).flatten(2)


#@torch.compiler.disable()
class WanLynxRefAttention(nn.Module):
    def __init__(self, dim=5120, bias=True, attention_mode="sdpa"):
        super().__init__()
        self.to_k_ref = nn.Linear(dim, dim, bias=bias)
        self.to_v_ref = nn.Linear(dim, dim, bias=bias)
        self.attention_mode = attention_mode
        # Pre-compute attention mode flags to avoid string operations in forward
        self.use_flash_attn = "flash_attn" in attention_mode
        self.use_sageattn = sageattn_varlen is not None

    def forward(self, block, q, ref_feature):
        b, s, n, d = q.shape

        ref_key = self.to_k_ref(ref_feature)
        ref_value = self.to_v_ref(ref_feature)
        ref_key = block.norm_k(ref_key)

        # Use pre-computed flags instead of runtime string checks
        if not self.use_flash_attn and not self.use_sageattn:
            # Pad ref_key and ref_value to match q's sequence length (s)
            seq_len = ref_key.shape[1]
            pad_len = s - seq_len
            if pad_len > 0:
                # Pad on the sequence dimension (dim=1)
                ref_key = torch.nn.functional.pad(ref_key, (0, 0, 0, pad_len))
                ref_value = torch.nn.functional.pad(ref_value, (0, 0, 0, pad_len))

            # Create attention mask: True for real tokens, False for padded
            attn_mask = torch.zeros((b, s), dtype=torch.bool, device=ref_key.device)
            attn_mask[:, :seq_len] = True

            ref_key = ref_key.view(b, s, n, d)
            ref_value = ref_value.view(b, s, n, d)

            ref_x = attention(
                q,
                ref_key,
                ref_value,
                attention_mode="sdpa",
                attn_mask=attn_mask,
            )
        else:
            q_lens = [s] * b
            k_lens = [ref_key.shape[1]] * b

            ref_x = attention(
                q.view(-1, n, d),
                ref_key.view(-1, n, d),
                ref_value.view(-1, n, d),
                q_lens=q_lens,
                k_lens=k_lens,
                max_seqlen_k=ref_key.shape[1],
                max_seqlen_q=s,
                attention_mode='sageattn_varlen' if self.use_sageattn else self.attention_mode,
                )

        return ref_x