File size: 6,482 Bytes
74da6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
import torch.nn.functional as F
from torch import nn

from flash_attn import flash_attn_varlen_func

from transformers.models.t5.modeling_t5 import T5LayerNorm as RMSNorm


class FlashVarLenAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, causal=False, window_size=(-1,-1)):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)

        self.causal = causal
        self.window_size = window_size

    def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
        bsz, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        query_states = query_states.view(bsz, self.num_heads, self.head_dim).contiguous()
        key_states = self.k_proj(hidden_states)
        key_states = key_states.view(bsz, self.num_heads, self.head_dim).contiguous()
        value_states = self.v_proj(hidden_states)
        value_states = value_states.view(bsz, self.num_heads, self.head_dim).contiguous()

        cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
        max_seqlen = torch.max(seq_len).to(torch.int32).detach()

        attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen,
                                             max_seqlen, causal=self.causal, window_size=self.window_size)  # (bsz * qlen, nheads, headdim)
        attn_output = attn_output.reshape(bsz, self.embed_dim)
        attn_output = self.out_proj(attn_output)
        return attn_output



class CasualDepthTransformerLayer(nn.Module):
    def __init__(self, depth, transformer_dim, transformer_ffn_scale):
        super().__init__()
        self.depth = depth
        self.transformer_dim = transformer_dim
        self.transformer_ffn_scale = transformer_ffn_scale
        self.num_heads = self.transformer_dim // 128

        assert self.transformer_dim % 128 == 0
        assert self.transformer_dim % depth == 0

        self.self_attention = FlashVarLenAttention(embed_dim=self.transformer_dim, num_heads=self.num_heads, causal=True)

        self.layernorm1 = RMSNorm(self.transformer_dim)
        self.layernorm2 = RMSNorm(self.transformer_dim)
        
        self.linear1 = nn.Linear(self.transformer_dim, self.transformer_ffn_scale * self.transformer_dim)
        self.linear2 = nn.Linear(self.transformer_ffn_scale * self.transformer_dim, self.transformer_dim)

    def forward(self, x):
        bsz = x.shape[0]
        res = x
        x = self.layernorm1(x)
        seqlens = torch.tensor([self.depth] * bsz, dtype=torch.int32, device=x.device)
        _x = self.self_attention(x.view(-1, self.transformer_dim), seqlens)
        _x = _x.view(bsz, self.depth, self.transformer_dim).contiguous()

        _res = _x + res  # (bs, sl, d)
        res = self.layernorm2(_res)
        x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (self.transformer_ffn_scale * self.transformer_dim // self.depth, self.depth, self.transformer_dim)))
        x = torch.nn.functional.gelu(x)
        x = torch.einsum('blt,dlt->bld',x, torch.reshape(self.linear2.weight, (self.transformer_dim, self.depth, self.transformer_ffn_scale * self.transformer_dim // self.depth)))
        return _res + x
    

class CasualDepthTransformerHead(nn.Module):
    """
    Depth-wise causal transformer head shared by image/audio heads.
    """

    def __init__(
        self,
        hidden_size,
        codebook_sizes,
        transformer_layer_num,
        transformer_dim,
        transformer_ffn_scale,
        gradient_checkpointing=False,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.codebook_sizes = codebook_sizes
        self.transformer_ffn_scale = transformer_ffn_scale
        self.gradient_checkpointing = gradient_checkpointing

        if self.transformer_ffn_scale > 0:
            self.hidden_norm = RMSNorm(self.hidden_size)
            self.hidden_proj = nn.Linear(self.hidden_size, transformer_dim, bias=False)

        self.transformer_layers = nn.ModuleList(
            [
                CasualDepthTransformerLayer(len(codebook_sizes), transformer_dim, transformer_ffn_scale)
                for _ in range(transformer_layer_num)
            ]
        )
        self.headnorm = RMSNorm(transformer_dim)
        self.heads = nn.ModuleList(
            [nn.Linear(transformer_dim, vq_size + 1) for vq_size in codebook_sizes]
        )

        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x, visual_tokens, visual_emb_layers, level):
        main_device = "cuda:0"
        visual_tokens = visual_tokens.to(main_device)
        visual_emb_layers = visual_emb_layers.to(main_device)

        cumsum_visual_embed = torch.stack([
            visual_emb_layers(visual_tokens[..., i])
            for i, vq_size in enumerate(self.codebook_sizes[:-1])
            ], dim=1).to(x.device)

        cumsum_visual_embed = torch.cumsum(cumsum_visual_embed, dim=1)  # (bs, depth-1, d)

        hidden_states = torch.concat([x.reshape(-1, 1, self.hidden_size), cumsum_visual_embed], dim=1)  # (bs, depth, d)
        assert hidden_states.size(1) == len(self.codebook_sizes)

        if self.transformer_ffn_scale > 0:
            hidden_states = self.hidden_norm(hidden_states)
            hidden_states = self.hidden_proj(hidden_states)

        for i, tlayer in enumerate(self.transformer_layers):
            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs)

                    return custom_forward

                hidden_states  = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(tlayer), hidden_states,
                )
            else:
                hidden_states  = tlayer(
                    hidden_states,
                )
        hidden_states = self.headnorm(hidden_states)
        logits = self.heads[level](hidden_states[:, level])
        return logits