File size: 2,621 Bytes
8019be0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | import torch
import torch.nn.functional as F
from . import rotary
from .transformer import EmbeddingLayer, TimestepEmbedder, DDiTBlock, DDitFinalLayer
from omegaconf import OmegaConf
from torch.nn.attention.flex_attention import create_block_mask
def _dense_mask(b, h, q_idx, kv_idx):
return torch.full_like(q_idx, True, dtype=torch.bool)
class DDiTNoLengthModel(torch.nn.Module):
"""
A DDiT‐style model that predicts only per‐token posteriors,
without any sequence‐length head, opt for the vanilla MDM
"""
def __init__(self, config):
super().__init__()
# allowing dict configs too
if isinstance(config, dict):
config = OmegaConf.create(config)
self.config = config
self.vocab_size = config.interpolant.tokens
self.pad_token = config.interpolant.pad_token
self.mask_token = config.interpolant.mask_token
self.vocab_embed = EmbeddingLayer(config.model.hidden_size, self.vocab_size)
self.sigma_map = TimestepEmbedder(config.model.cond_dim)
self.rotary_emb = rotary.Rotary(
config.model.hidden_size // config.model.n_heads
)
self.blocks = torch.nn.ModuleList(
[
DDiTBlock(
config.model.hidden_size,
config.model.n_heads,
config.model.cond_dim,
dropout=config.model.dropout,
)
for _ in range(config.model.n_blocks)
]
)
# final per‐token head only / no length head
self.output_layer = DDitFinalLayer(
config.model.hidden_size, self.vocab_size, config.model.cond_dim
)
def forward(self, indices: torch.Tensor, t: torch.Tensor):
"""
indices: (B, L) token indices
t: (B,) timestep scalars
returns: ReparametrizedRate with only per_token_posterior set
"""
B, L = indices.shape
block_mask = create_block_mask(
_dense_mask, B=B, H=None, Q_LEN=indices.shape[1], KV_LEN=indices.shape[1]
)
print(block_mask)
x = self.vocab_embed(indices) # (B, L, hidden)
c = F.silu(self.sigma_map(t)) # (B, cond_dim)
rotary_cos_sin = self.rotary_emb(x) # precompute rotary embeddings
# run the stack
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
for i in range(len(self.blocks)):
x = self.blocks[i](x, rotary_cos_sin, c, block_mask)
token_logits = self.output_layer(x, c)
return token_logits
|