| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
| from flash_attn import flash_attn_varlen_func |
|
|
| from transformers.models.t5.modeling_t5 import T5LayerNorm as RMSNorm |
|
|
|
|
| class FlashVarLenAttention(nn.Module): |
| def __init__(self, embed_dim, num_heads, causal=False, window_size=(-1,-1)): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.head_dim = embed_dim // num_heads |
|
|
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
| self.causal = causal |
| self.window_size = window_size |
|
|
| def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor): |
| bsz, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states) |
| query_states = query_states.view(bsz, self.num_heads, self.head_dim).contiguous() |
| key_states = self.k_proj(hidden_states) |
| key_states = key_states.view(bsz, self.num_heads, self.head_dim).contiguous() |
| value_states = self.v_proj(hidden_states) |
| value_states = value_states.view(bsz, self.num_heads, self.head_dim).contiguous() |
|
|
| cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32) |
| max_seqlen = torch.max(seq_len).to(torch.int32).detach() |
|
|
| attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen, |
| max_seqlen, causal=self.causal, window_size=self.window_size) |
| attn_output = attn_output.reshape(bsz, self.embed_dim) |
| attn_output = self.out_proj(attn_output) |
| return attn_output |
|
|
|
|
|
|
| class CasualDepthTransformerLayer(nn.Module): |
| def __init__(self, depth, transformer_dim, transformer_ffn_scale): |
| super().__init__() |
| self.depth = depth |
| self.transformer_dim = transformer_dim |
| self.transformer_ffn_scale = transformer_ffn_scale |
| self.num_heads = self.transformer_dim // 128 |
|
|
| assert self.transformer_dim % 128 == 0 |
| assert self.transformer_dim % depth == 0 |
|
|
| self.self_attention = FlashVarLenAttention(embed_dim=self.transformer_dim, num_heads=self.num_heads, causal=True) |
|
|
| self.layernorm1 = RMSNorm(self.transformer_dim) |
| self.layernorm2 = RMSNorm(self.transformer_dim) |
| |
| self.linear1 = nn.Linear(self.transformer_dim, self.transformer_ffn_scale * self.transformer_dim) |
| self.linear2 = nn.Linear(self.transformer_ffn_scale * self.transformer_dim, self.transformer_dim) |
|
|
| def forward(self, x): |
| bsz = x.shape[0] |
| res = x |
| x = self.layernorm1(x) |
| seqlens = torch.tensor([self.depth] * bsz, dtype=torch.int32, device=x.device) |
| _x = self.self_attention(x.view(-1, self.transformer_dim), seqlens) |
| _x = _x.view(bsz, self.depth, self.transformer_dim).contiguous() |
|
|
| _res = _x + res |
| res = self.layernorm2(_res) |
| x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (self.transformer_ffn_scale * self.transformer_dim // self.depth, self.depth, self.transformer_dim))) |
| x = torch.nn.functional.gelu(x) |
| x = torch.einsum('blt,dlt->bld',x, torch.reshape(self.linear2.weight, (self.transformer_dim, self.depth, self.transformer_ffn_scale * self.transformer_dim // self.depth))) |
| return _res + x |
| |
|
|
| class CasualDepthTransformerHead(nn.Module): |
| """ |
| Depth-wise causal transformer head shared by image/audio heads. |
| """ |
|
|
| def __init__( |
| self, |
| hidden_size, |
| codebook_sizes, |
| transformer_layer_num, |
| transformer_dim, |
| transformer_ffn_scale, |
| gradient_checkpointing=False, |
| ): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.codebook_sizes = codebook_sizes |
| self.transformer_ffn_scale = transformer_ffn_scale |
| self.gradient_checkpointing = gradient_checkpointing |
|
|
| if self.transformer_ffn_scale > 0: |
| self.hidden_norm = RMSNorm(self.hidden_size) |
| self.hidden_proj = nn.Linear(self.hidden_size, transformer_dim, bias=False) |
|
|
| self.transformer_layers = nn.ModuleList( |
| [ |
| CasualDepthTransformerLayer(len(codebook_sizes), transformer_dim, transformer_ffn_scale) |
| for _ in range(transformer_layer_num) |
| ] |
| ) |
| self.headnorm = RMSNorm(transformer_dim) |
| self.heads = nn.ModuleList( |
| [nn.Linear(transformer_dim, vq_size + 1) for vq_size in codebook_sizes] |
| ) |
|
|
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, x, visual_tokens, visual_emb_layers, level): |
| main_device = "cuda:0" |
| visual_tokens = visual_tokens.to(main_device) |
| visual_emb_layers = visual_emb_layers.to(main_device) |
|
|
| cumsum_visual_embed = torch.stack([ |
| visual_emb_layers(visual_tokens[..., i]) |
| for i, vq_size in enumerate(self.codebook_sizes[:-1]) |
| ], dim=1).to(x.device) |
|
|
| cumsum_visual_embed = torch.cumsum(cumsum_visual_embed, dim=1) |
|
|
| hidden_states = torch.concat([x.reshape(-1, 1, self.hidden_size), cumsum_visual_embed], dim=1) |
| assert hidden_states.size(1) == len(self.codebook_sizes) |
|
|
| if self.transformer_ffn_scale > 0: |
| hidden_states = self.hidden_norm(hidden_states) |
| hidden_states = self.hidden_proj(hidden_states) |
|
|
| for i, tlayer in enumerate(self.transformer_layers): |
| if self.gradient_checkpointing and self.training: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| hidden_states = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(tlayer), hidden_states, |
| ) |
| else: |
| hidden_states = tlayer( |
| hidden_states, |
| ) |
| hidden_states = self.headnorm(hidden_states) |
| logits = self.heads[level](hidden_states[:, level]) |
| return logits |
|
|