LongCat-Next / modular_longcat_next.py
Locke's picture
code
3e934f1
import torch
import torch.nn.functional as F
from torch import nn
from flash_attn import flash_attn_varlen_func
from transformers.models.t5.modeling_t5 import T5LayerNorm as RMSNorm
class FlashVarLenAttention(nn.Module):
def __init__(self, embed_dim, num_heads, causal=False, window_size=(-1,-1)):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.causal = causal
self.window_size = window_size
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
bsz, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, self.num_heads, self.head_dim).contiguous()
key_states = self.k_proj(hidden_states)
key_states = key_states.view(bsz, self.num_heads, self.head_dim).contiguous()
value_states = self.v_proj(hidden_states)
value_states = value_states.view(bsz, self.num_heads, self.head_dim).contiguous()
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen,
max_seqlen, causal=self.causal, window_size=self.window_size) # (bsz * qlen, nheads, headdim)
attn_output = attn_output.reshape(bsz, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class CasualDepthTransformerLayer(nn.Module):
def __init__(self, depth, transformer_dim, transformer_ffn_scale):
super().__init__()
self.depth = depth
self.transformer_dim = transformer_dim
self.transformer_ffn_scale = transformer_ffn_scale
self.num_heads = self.transformer_dim // 128
assert self.transformer_dim % 128 == 0
assert self.transformer_dim % depth == 0
self.self_attention = FlashVarLenAttention(embed_dim=self.transformer_dim, num_heads=self.num_heads, causal=True)
self.layernorm1 = RMSNorm(self.transformer_dim)
self.layernorm2 = RMSNorm(self.transformer_dim)
self.linear1 = nn.Linear(self.transformer_dim, self.transformer_ffn_scale * self.transformer_dim)
self.linear2 = nn.Linear(self.transformer_ffn_scale * self.transformer_dim, self.transformer_dim)
def forward(self, x):
bsz = x.shape[0]
res = x
x = self.layernorm1(x)
seqlens = torch.tensor([self.depth] * bsz, dtype=torch.int32, device=x.device)
_x = self.self_attention(x.view(-1, self.transformer_dim), seqlens)
_x = _x.view(bsz, self.depth, self.transformer_dim).contiguous()
_res = _x + res # (bs, sl, d)
res = self.layernorm2(_res)
x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (self.transformer_ffn_scale * self.transformer_dim // self.depth, self.depth, self.transformer_dim)))
x = torch.nn.functional.gelu(x)
x = torch.einsum('blt,dlt->bld',x, torch.reshape(self.linear2.weight, (self.transformer_dim, self.depth, self.transformer_ffn_scale * self.transformer_dim // self.depth)))
return _res + x
class CasualDepthTransformerHead(nn.Module):
"""
Depth-wise causal transformer head shared by image/audio heads.
"""
def __init__(
self,
hidden_size,
codebook_sizes,
transformer_layer_num,
transformer_dim,
transformer_ffn_scale,
gradient_checkpointing=False,
):
super().__init__()
self.hidden_size = hidden_size
self.codebook_sizes = codebook_sizes
self.transformer_ffn_scale = transformer_ffn_scale
self.gradient_checkpointing = gradient_checkpointing
if self.transformer_ffn_scale > 0:
self.hidden_norm = RMSNorm(self.hidden_size)
self.hidden_proj = nn.Linear(self.hidden_size, transformer_dim, bias=False)
self.transformer_layers = nn.ModuleList(
[
CasualDepthTransformerLayer(len(codebook_sizes), transformer_dim, transformer_ffn_scale)
for _ in range(transformer_layer_num)
]
)
self.headnorm = RMSNorm(transformer_dim)
self.heads = nn.ModuleList(
[nn.Linear(transformer_dim, vq_size + 1) for vq_size in codebook_sizes]
)
for param in self.parameters():
param.requires_grad = False
def forward(self, x, visual_tokens, visual_emb_layers, level):
main_device = "cuda:0"
visual_tokens = visual_tokens.to(main_device)
visual_emb_layers = visual_emb_layers.to(main_device)
cumsum_visual_embed = torch.stack([
visual_emb_layers(visual_tokens[..., i])
for i, vq_size in enumerate(self.codebook_sizes[:-1])
], dim=1).to(x.device)
cumsum_visual_embed = torch.cumsum(cumsum_visual_embed, dim=1) # (bs, depth-1, d)
hidden_states = torch.concat([x.reshape(-1, 1, self.hidden_size), cumsum_visual_embed], dim=1) # (bs, depth, d)
assert hidden_states.size(1) == len(self.codebook_sizes)
if self.transformer_ffn_scale > 0:
hidden_states = self.hidden_norm(hidden_states)
hidden_states = self.hidden_proj(hidden_states)
for i, tlayer in enumerate(self.transformer_layers):
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(tlayer), hidden_states,
)
else:
hidden_states = tlayer(
hidden_states,
)
hidden_states = self.headnorm(hidden_states)
logits = self.heads[level](hidden_states[:, level])
return logits