| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func |
| from .fused_add_dropout_scale import modulate_fused, bias_dropout_add_scale_fused_train, bias_dropout_add_scale_fused_inference |
| from .transformer import LayerNorm, EmbeddingLayer |
| from . import rotary |
|
|
|
|
| class CausalDiTBlock(nn.Module): |
| def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1): |
| super().__init__() |
| self.n_heads = n_heads |
| self.norm1 = LayerNorm(dim) |
| self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) |
| self.attn_out = nn.Linear(dim, dim, bias=False) |
| self.dropout1 = nn.Dropout(dropout) |
| self.norm2 = LayerNorm(dim) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, mlp_ratio * dim, bias=True), |
| nn.GELU(approximate="tanh"), |
| nn.Linear(mlp_ratio * dim, dim, bias=True) |
| ) |
| self.dropout2 = nn.Dropout(dropout) |
| self.dropout = dropout |
| |
|
|
| def _get_bias_dropout_scale(self): |
| return ( |
| bias_dropout_add_scale_fused_train |
| if self.training |
| else bias_dropout_add_scale_fused_inference |
| ) |
|
|
| def forward(self, x, rotary_cos_sin, seqlens=None): |
| batch_size, seq_len = x.shape[0], x.shape[1] |
| bias_dropout_scale_fn = self._get_bias_dropout_scale() |
|
|
| |
| x_skip = x |
| x = self.norm1(x) |
| |
|
|
| qkv = self.attn_qkv(x) |
| qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.n_heads) |
| with torch.cuda.amp.autocast(enabled=False): |
| cos, sin = rotary_cos_sin |
| qkv = rotary.apply_rotary_pos_emb( |
| qkv, cos.to(qkv.dtype), sin.to(qkv.dtype) |
| ) |
| qkv = rearrange(qkv, 'b s ... -> (b s) ...') |
| if seqlens is None: |
| cu_seqlens = torch.arange( |
| 0, (batch_size + 1) * seq_len, step=seq_len, |
| dtype=torch.int32, device=qkv.device |
| ) |
| else: |
| cu_seqlens = seqlens.cumsum(-1) |
| x = flash_attn_varlen_qkvpacked_func( |
| qkv, cu_seqlens, seq_len, 0., causal=True) |
| x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size) |
| |
| scale = torch.ones(1, device=x.device, dtype=x.dtype) |
| x = bias_dropout_scale_fn(self.attn_out(x), None, scale, x_skip, self.dropout) |
|
|
| |
| x = bias_dropout_scale_fn( |
| self.mlp(self.norm2(x)), None, scale, x, self.dropout |
| ) |
| return x |
|
|
| class CausalDiT(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| if isinstance(config, dict): |
| config = OmegaConf.create(config) |
|
|
| self.config = config |
| self.vocab_size = config.interpolant.tokens |
| self.pad_token = config.interpolant.pad_token |
|
|
| self.vocab_embed = EmbeddingLayer(config.model.hidden_size, self.vocab_size) |
| self.rotary_emb = rotary.Rotary(config.model.hidden_size // config.model.n_heads) |
| self.blocks = nn.ModuleList([ |
| CausalDiTBlock(config.model.hidden_size, config.model.n_heads, config.model.cond_dim, dropout=config.model.dropout) |
| for _ in range(config.model.n_blocks) |
| ]) |
| self.output_layer = nn.Linear(config.model.hidden_size, self.vocab_size) |
|
|
| def forward(self, indices): |
| x = self.vocab_embed(indices) |
| rotary_cos_sin = self.rotary_emb(x) |
| with torch.amp.autocast('cuda', dtype=torch.bfloat16): |
| for block in self.blocks: |
| x = block(x, rotary_cos_sin, seqlens=None) |
| logits = self.output_layer(x) |
| return logits |