| import torch |
| import torch.nn.functional as F |
| from . import rotary |
| from .transformer import EmbeddingLayer, TimestepEmbedder, DDiTBlock, DDitFinalLayer |
| from omegaconf import OmegaConf |
| from torch.nn.attention.flex_attention import create_block_mask |
|
|
|
|
| def _dense_mask(b, h, q_idx, kv_idx): |
| return torch.full_like(q_idx, True, dtype=torch.bool) |
|
|
|
|
| class DDiTNoLengthModel(torch.nn.Module): |
| """ |
| A DDiT‐style model that predicts only per‐token posteriors, |
| without any sequence‐length head, opt for the vanilla MDM |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| |
| if isinstance(config, dict): |
| config = OmegaConf.create(config) |
|
|
| self.config = config |
| self.vocab_size = config.interpolant.tokens |
| self.pad_token = config.interpolant.pad_token |
| self.mask_token = config.interpolant.mask_token |
|
|
| self.vocab_embed = EmbeddingLayer(config.model.hidden_size, self.vocab_size) |
| self.sigma_map = TimestepEmbedder(config.model.cond_dim) |
| self.rotary_emb = rotary.Rotary( |
| config.model.hidden_size // config.model.n_heads |
| ) |
|
|
| self.blocks = torch.nn.ModuleList( |
| [ |
| DDiTBlock( |
| config.model.hidden_size, |
| config.model.n_heads, |
| config.model.cond_dim, |
| dropout=config.model.dropout, |
| ) |
| for _ in range(config.model.n_blocks) |
| ] |
| ) |
| |
| self.output_layer = DDitFinalLayer( |
| config.model.hidden_size, self.vocab_size, config.model.cond_dim |
| ) |
|
|
| def forward(self, indices: torch.Tensor, t: torch.Tensor): |
| """ |
| indices: (B, L) token indices |
| t: (B,) timestep scalars |
| returns: ReparametrizedRate with only per_token_posterior set |
| """ |
| B, L = indices.shape |
|
|
| block_mask = create_block_mask( |
| _dense_mask, B=B, H=None, Q_LEN=indices.shape[1], KV_LEN=indices.shape[1] |
| ) |
| print(block_mask) |
|
|
| x = self.vocab_embed(indices) |
| c = F.silu(self.sigma_map(t)) |
| rotary_cos_sin = self.rotary_emb(x) |
|
|
| |
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| for i in range(len(self.blocks)): |
| x = self.blocks[i](x, rotary_cos_sin, c, block_mask) |
|
|
| token_logits = self.output_layer(x, c) |
| return token_logits |
|
|