Spaces:
No application file
No application file
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func | |
| from .fused_add_dropout_scale import modulate_fused, bias_dropout_add_scale_fused_train, bias_dropout_add_scale_fused_inference | |
| from .transformer import LayerNorm, EmbeddingLayer | |
| from . import rotary | |
| class CausalDiTBlock(nn.Module): | |
| def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1): | |
| super().__init__() | |
| self.n_heads = n_heads | |
| self.norm1 = LayerNorm(dim) | |
| self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) | |
| self.attn_out = nn.Linear(dim, dim, bias=False) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.norm2 = LayerNorm(dim) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(dim, mlp_ratio * dim, bias=True), | |
| nn.GELU(approximate="tanh"), | |
| nn.Linear(mlp_ratio * dim, dim, bias=True) | |
| ) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.dropout = dropout | |
| # No time or label conditioning, so no adaLN_modulation | |
| def _get_bias_dropout_scale(self): | |
| return ( | |
| bias_dropout_add_scale_fused_train | |
| if self.training | |
| else bias_dropout_add_scale_fused_inference | |
| ) | |
| def forward(self, x, rotary_cos_sin, seqlens=None): | |
| batch_size, seq_len = x.shape[0], x.shape[1] | |
| bias_dropout_scale_fn = self._get_bias_dropout_scale() | |
| # attention operation | |
| x_skip = x | |
| x = self.norm1(x) | |
| # dtype0 = x.dtype | |
| qkv = self.attn_qkv(x) | |
| qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.n_heads) | |
| with torch.cuda.amp.autocast(enabled=False): | |
| cos, sin = rotary_cos_sin | |
| qkv = rotary.apply_rotary_pos_emb( | |
| qkv, cos.to(qkv.dtype), sin.to(qkv.dtype) | |
| ) | |
| qkv = rearrange(qkv, 'b s ... -> (b s) ...') | |
| if seqlens is None: | |
| cu_seqlens = torch.arange( | |
| 0, (batch_size + 1) * seq_len, step=seq_len, | |
| dtype=torch.int32, device=qkv.device | |
| ) | |
| else: | |
| cu_seqlens = seqlens.cumsum(-1) | |
| x = flash_attn_varlen_qkvpacked_func( | |
| qkv, cu_seqlens, seq_len, 0., causal=True) | |
| x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size) | |
| scale = torch.ones(1, device=x.device, dtype=x.dtype) | |
| x = bias_dropout_scale_fn(self.attn_out(x), None, scale, x_skip, self.dropout) | |
| # mlp operation | |
| x = bias_dropout_scale_fn( | |
| self.mlp(self.norm2(x)), None, scale, x, self.dropout | |
| ) | |
| return x | |
| class CausalDiT(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| if isinstance(config, dict): | |
| config = OmegaConf.create(config) | |
| self.config = config | |
| self.vocab_size = config.interpolant.tokens | |
| self.pad_token = config.interpolant.pad_token | |
| self.vocab_embed = EmbeddingLayer(config.model.hidden_size, self.vocab_size) | |
| self.rotary_emb = rotary.Rotary(config.model.hidden_size // config.model.n_heads) | |
| self.blocks = nn.ModuleList([ | |
| CausalDiTBlock(config.model.hidden_size, config.model.n_heads, config.model.cond_dim, dropout=config.model.dropout) | |
| for _ in range(config.model.n_blocks) | |
| ]) | |
| self.output_layer = nn.Linear(config.model.hidden_size, self.vocab_size) | |
| def forward(self, indices): | |
| x = self.vocab_embed(indices) | |
| rotary_cos_sin = self.rotary_emb(x) | |
| with torch.amp.autocast('cuda', dtype=torch.bfloat16): | |
| for block in self.blocks: | |
| x = block(x, rotary_cos_sin, seqlens=None) | |
| logits = self.output_layer(x) | |
| return logits |