| | import math |
| | from types import SimpleNamespace |
| | from typing import Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from einops import rearrange, repeat |
| | from torch import Tensor, nn |
| | from transformers import PreTrainedModel |
| |
|
| | try: |
| | import flash_attn |
| | except ImportError: |
| | flash_attn = None |
| |
|
| | try: |
| | import flash_attn_interface |
| | except ImportError: |
| | flash_attn_interface = None |
| | from configuration_dfm import DFMConfig |
| |
|
| |
|
| | class Rotary(torch.nn.Module): |
| | """ |
| | From: https://github.com/louaaron/Score-Entropy-Discrete-Diffusion |
| | """ |
| |
|
| | def __init__(self, dim: int, base: int = 10_000): |
| | super().__init__() |
| | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| | self.register_buffer("inv_freq", inv_freq) |
| | self.seq_len_cached = None |
| | self.cos_cached = None |
| | self.sin_cached = None |
| |
|
| | def forward(self, x: Tensor, seq_dim: int = 1) -> Tuple[Tensor, Tensor]: |
| | seq_len = x.shape[seq_dim] |
| | if seq_len != self.seq_len_cached: |
| | self.seq_len_cached = seq_len |
| | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) |
| | freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone()) |
| | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
| |
|
| | |
| | self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1) |
| | self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1) |
| |
|
| | |
| | self.cos_cached[:, :, 2, :, :].fill_(1.0) |
| | self.sin_cached[:, :, 2, :, :].fill_(0.0) |
| |
|
| | return self.cos_cached, self.sin_cached |
| |
|
| |
|
| | def rotate_half(x: Tensor) -> Tensor: |
| | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
| |
|
| | return torch.cat((-x2, x1), dim=-1) |
| |
|
| |
|
| | def apply_rotary_emb_torch(x, cos, sin, interleaved=False): |
| | """ |
| | From: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L20 |
| | """ |
| | cos = cos[0, :, 0, 0, : cos.shape[-1] // 2] |
| | sin = sin[0, :, 0, 0, : sin.shape[-1] // 2] |
| |
|
| | ro_dim = cos.shape[-1] * 2 |
| | assert ro_dim <= x.shape[-1] |
| | cos = repeat( |
| | cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" |
| | ) |
| | sin = repeat( |
| | sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" |
| | ) |
| |
|
| | return x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim]) * sin |
| |
|
| |
|
| | def bias_dropout_add_scale( |
| | x: Tensor, scale: Tensor, residual: Optional[Tensor], prob: float, training: bool |
| | ) -> Tensor: |
| | return residual + scale * F.dropout(x, p=prob, training=training) |
| |
|
| |
|
| | def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: |
| | return x * (1 + scale) + shift |
| |
|
| |
|
| | class LayerNorm(nn.Module): |
| | def __init__(self, dim: int): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones([dim])) |
| | self.dim = dim |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | with torch.amp.autocast("cuda", enabled=False): |
| | x = F.layer_norm(x.float(), [self.dim]) |
| |
|
| | return x * self.weight[None, None, :] |
| |
|
| |
|
| | class TimestepEmbedder(nn.Module): |
| | """ |
| | Embeds scalar timesteps into vector representations. |
| | """ |
| |
|
| | def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): |
| | super().__init__() |
| | self.mlp = nn.Sequential( |
| | nn.Linear(frequency_embedding_size, hidden_size, bias=True), |
| | nn.SiLU(), |
| | nn.Linear(hidden_size, hidden_size, bias=True), |
| | ) |
| | self.frequency_embedding_size = frequency_embedding_size |
| |
|
| | @staticmethod |
| | def timestep_embedding(time: Tensor, dim: int, max_period: int = 10000) -> Tensor: |
| | """ |
| | Create sinusoidal timestep embeddings. |
| | :param t: a 1-D Tensor of N indices, one per batch element. |
| | These may be fractional. |
| | :param dim: the dimension of the output. |
| | :param max_period: controls the minimum frequency of the embeddings. |
| | :return: an (N, D) Tensor of positional embeddings. |
| | """ |
| | half = dim // 2 |
| | freqs = torch.exp( |
| | -math.log(max_period) |
| | * torch.arange(start=0, end=half, dtype=torch.float32) |
| | / half |
| | ).to(device=time.device) |
| | args = time[:, None].float() * freqs[None] |
| | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| | if dim % 2: |
| | embedding = torch.cat( |
| | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 |
| | ) |
| | return embedding |
| |
|
| | def forward(self, time: Tensor) -> Tensor: |
| | t_freq = self.timestep_embedding(time=time, dim=self.frequency_embedding_size) |
| | t_emb = self.mlp(t_freq) |
| | return t_emb |
| |
|
| |
|
| | class DDiTBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | n_heads: int, |
| | cond_dim: int, |
| | mlp_ratio: int = 4, |
| | dropout: float = 0.1, |
| | ): |
| | super().__init__() |
| | assert dim % n_heads == 0, "dim must be devisable by n_heads" |
| |
|
| | self.n_heads = n_heads |
| | self.dim = dim |
| | self.dropout = dropout |
| |
|
| | self.head_dim = self.dim // self.n_heads |
| |
|
| | self.norm1 = LayerNorm(dim=dim) |
| |
|
| | self.qw = nn.Linear(dim, dim, bias=False) |
| | self.kw = nn.Linear(dim, dim, bias=False) |
| | self.vw = nn.Linear(dim, dim, bias=False) |
| |
|
| | self.attn_out = nn.Linear(dim, dim, bias=False) |
| | self.dropout1 = nn.Dropout(dropout) |
| |
|
| | self.norm2 = LayerNorm(dim=dim) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(dim, mlp_ratio * dim, bias=True), |
| | nn.GELU(approximate="tanh"), |
| | nn.Linear(mlp_ratio * dim, dim, bias=True), |
| | ) |
| |
|
| | self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True) |
| | self.adaLN_modulation.weight.data.zero_() |
| | self.adaLN_modulation.bias.data.zero_() |
| |
|
| | def forward(self, x: Tensor, rotary_cos_sin: Tensor, c: Tensor) -> Tensor: |
| | batch_size, seq_len = x.shape[0], x.shape[1] |
| |
|
| | ( |
| | shift_msa, |
| | scale_msa, |
| | gate_msa, |
| | shift_mlp, |
| | scale_mlp, |
| | gate_mlp, |
| | ) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2) |
| |
|
| | x_skip = x |
| | x = modulate(x=self.norm1(x), shift=shift_msa, scale=scale_msa) |
| |
|
| | q = self.qw(x) |
| | k = self.kw(x) |
| | v = self.vw(x) |
| |
|
| | q, k, v = ( |
| | item.view(batch_size, seq_len, self.n_heads, self.head_dim) |
| | for item in (q, k, v) |
| | ) |
| |
|
| | with torch.amp.autocast("cuda", enabled=False): |
| | cos, sin = rotary_cos_sin |
| | original_dtype = q.dtype |
| |
|
| | q = apply_rotary_emb_torch( |
| | x=q.float(), cos=cos.float(), sin=sin.float() |
| | ).to(original_dtype) |
| | k = apply_rotary_emb_torch( |
| | x=k.float(), cos=cos.float(), sin=sin.float() |
| | ).to(original_dtype) |
| |
|
| | use_flash_attn = ( |
| | flash_attn_interface is not None or flash_attn is not None |
| | ) and q.is_cuda |
| | if use_flash_attn: |
| | qkv = torch.stack((q, k, v), dim=2) |
| | if flash_attn_interface is not None: |
| | x = flash_attn_interface.flash_attn_qkvpacked_func(qkv, causal=False) |
| | else: |
| | x = flash_attn.flash_attn_qkvpacked_func(qkv, 0.0, causal=False) |
| | x = rearrange(x, "b s h d -> b s (h d)", b=batch_size) |
| | else: |
| | q, k, v = (item.transpose(1, 2) for item in (q, k, v)) |
| | x = F.scaled_dot_product_attention(query=q, key=k, value=v) |
| | x = rearrange(x, "b h s d -> b s (h d)", b=batch_size) |
| | x = bias_dropout_add_scale( |
| | x=self.attn_out(x), |
| | scale=gate_msa, |
| | residual=x_skip, |
| | prob=self.dropout, |
| | training=self.training, |
| | ) |
| | x = bias_dropout_add_scale( |
| | x=self.mlp(modulate(x=self.norm2(x), shift=shift_mlp, scale=scale_mlp)), |
| | scale=gate_mlp, |
| | residual=x, |
| | prob=self.dropout, |
| | training=self.training, |
| | ) |
| |
|
| | return x |
| |
|
| |
|
| | class DDitFinalLayer(nn.Module): |
| | def __init__(self, hidden_size: int, out_channels: int, cond_dim: int): |
| | super().__init__() |
| | self.norm_final = LayerNorm(hidden_size) |
| | self.linear = nn.Linear(hidden_size, out_channels) |
| | self.linear.weight.data.zero_() |
| | self.linear.bias.data.zero_() |
| |
|
| | self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True) |
| | self.adaLN_modulation.weight.data.zero_() |
| | self.adaLN_modulation.bias.data.zero_() |
| |
|
| | def forward(self, x: Tensor, c: Tensor) -> Tensor: |
| | shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2) |
| | x = modulate(x=self.norm_final(x), shift=shift, scale=scale) |
| | x = self.linear(x) |
| |
|
| | return x |
| |
|
| |
|
| | class Transformer(nn.Module): |
| | def __init__(self, vocab_size: int, masked: bool, config): |
| | super().__init__() |
| |
|
| | if isinstance(config, dict): |
| | config = SimpleNamespace(**config) |
| |
|
| | self.config = config |
| | self.vocab_size = vocab_size |
| |
|
| | add_token = 1 if masked else 0 |
| |
|
| | self.vocab_embed = nn.Embedding(self.vocab_size + add_token, config.hidden_size) |
| |
|
| | self.time_embedding = TimestepEmbedder(hidden_size=config.cond_dim) |
| | self.rotary_emb = Rotary(dim=config.hidden_size // config.n_heads) |
| |
|
| | self.blocks = nn.ModuleList( |
| | [ |
| | DDiTBlock( |
| | dim=config.hidden_size, |
| | n_heads=config.n_heads, |
| | cond_dim=config.cond_dim, |
| | dropout=config.dropout, |
| | ) |
| | for _ in range(config.n_blocks) |
| | ] |
| | ) |
| |
|
| | self.output_layer = DDitFinalLayer( |
| | hidden_size=config.hidden_size, |
| | out_channels=vocab_size + add_token, |
| | cond_dim=config.cond_dim, |
| | ) |
| |
|
| | def forward(self, x_t: Tensor, time: Tensor) -> Tensor: |
| | x = self.vocab_embed(x_t) |
| | c = F.silu(self.time_embedding(time=time)) |
| |
|
| | rotary_cos_sin = self.rotary_emb(x=x) |
| |
|
| | with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| | for i in range(len(self.blocks)): |
| | x = self.blocks[i](x=x, rotary_cos_sin=rotary_cos_sin, c=c) |
| |
|
| | x = self.output_layer(x=x, c=c) |
| |
|
| | return x |
| |
|
| |
|
| | class DFMModel(PreTrainedModel): |
| | config_class = DFMConfig |
| | base_model_prefix = "model" |
| |
|
| | def __init__(self, config: DFMConfig): |
| | super().__init__(config) |
| | masked = config.source_distribution == "mask" |
| | self.model = Transformer( |
| | vocab_size=config.vocab_size, |
| | masked=masked, |
| | config={ |
| | "hidden_size": config.hidden_size, |
| | "cond_dim": config.cond_dim, |
| | "length": config.sequence_length, |
| | "n_blocks": config.n_blocks, |
| | "n_heads": config.n_heads, |
| | "dropout": config.dropout, |
| | "compile": False, |
| | }, |
| | ) |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | x_t: torch.Tensor, |
| | time: torch.Tensor, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | return self.model(x_t=x_t, time=time) |
| |
|
| | @classmethod |
| | def _load_pretrained_model( |
| | cls, |
| | model, |
| | state_dict, |
| | *args, |
| | **kwargs, |
| | ): |
| | if state_dict is not None: |
| | if "model" in state_dict and isinstance(state_dict["model"], dict): |
| | state_dict = state_dict["model"] |
| | if state_dict and not any( |
| | k.startswith("model.") for k in state_dict.keys() |
| | ): |
| | state_dict = {f"model.{k}": v for k, v in state_dict.items()} |
| | return super()._load_pretrained_model( |
| | model, |
| | state_dict, |
| | *args, |
| | **kwargs, |
| | ) |
| |
|