A2D2 / model /MDM_transformer.py
Sophia
initial commit
8019be0
Raw
History Blame Contribute Delete
2.62 kB
import torch
import torch.nn.functional as F
from . import rotary
from .transformer import EmbeddingLayer, TimestepEmbedder, DDiTBlock, DDitFinalLayer
from omegaconf import OmegaConf
from torch.nn.attention.flex_attention import create_block_mask
def _dense_mask(b, h, q_idx, kv_idx):
return torch.full_like(q_idx, True, dtype=torch.bool)
class DDiTNoLengthModel(torch.nn.Module):
"""
A DDiT‐style model that predicts only per‐token posteriors,
without any sequence‐length head, opt for the vanilla MDM
"""
def __init__(self, config):
super().__init__()
# allowing dict configs too
if isinstance(config, dict):
config = OmegaConf.create(config)
self.config = config
self.vocab_size = config.interpolant.tokens
self.pad_token = config.interpolant.pad_token
self.mask_token = config.interpolant.mask_token
self.vocab_embed = EmbeddingLayer(config.model.hidden_size, self.vocab_size)
self.sigma_map = TimestepEmbedder(config.model.cond_dim)
self.rotary_emb = rotary.Rotary(
config.model.hidden_size // config.model.n_heads
)
self.blocks = torch.nn.ModuleList(
[
DDiTBlock(
config.model.hidden_size,
config.model.n_heads,
config.model.cond_dim,
dropout=config.model.dropout,
)
for _ in range(config.model.n_blocks)
]
)
# final per‐token head only / no length head
self.output_layer = DDitFinalLayer(
config.model.hidden_size, self.vocab_size, config.model.cond_dim
)
def forward(self, indices: torch.Tensor, t: torch.Tensor):
"""
indices: (B, L) token indices
t: (B,) timestep scalars
returns: ReparametrizedRate with only per_token_posterior set
"""
B, L = indices.shape
block_mask = create_block_mask(
_dense_mask, B=B, H=None, Q_LEN=indices.shape[1], KV_LEN=indices.shape[1]
)
print(block_mask)
x = self.vocab_embed(indices) # (B, L, hidden)
c = F.silu(self.sigma_map(t)) # (B, cond_dim)
rotary_cos_sin = self.rotary_emb(x) # precompute rotary embeddings
# run the stack
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
for i in range(len(self.blocks)):
x = self.blocks[i](x, rotary_cos_sin, c, block_mask)
token_logits = self.output_layer(x, c)
return token_logits