|
|
import math |
|
|
import typing |
|
|
|
|
|
import einops |
|
|
from functools import partial |
|
|
import huggingface_hub |
|
|
import omegaconf |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.nn.attention.flex_attention import flex_attention, create_block_mask |
|
|
import transformers |
|
|
from functools import lru_cache |
|
|
from .config import EsoLMConfig |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.set_float32_matmul_precision("high") |
|
|
torch.backends.cudnn.benchmark = True |
|
|
import torch._inductor.config as inductor_cfg |
|
|
inductor_cfg.triton.cudagraphs = True |
|
|
inductor_cfg.coordinate_descent_tuning = True |
|
|
|
|
|
|
|
|
torch._C._jit_set_profiling_mode(False) |
|
|
torch._C._jit_set_profiling_executor(False) |
|
|
torch._C._jit_override_can_fuse_on_cpu(True) |
|
|
torch._C._jit_override_can_fuse_on_gpu(True) |
|
|
|
|
|
|
|
|
@lru_cache |
|
|
def _causal_mask(b, h, q_idx, kv_idx): |
|
|
causal = q_idx >= kv_idx |
|
|
return causal |
|
|
|
|
|
|
|
|
@lru_cache |
|
|
def _get_causal_mask(seq_len): |
|
|
return create_block_mask( |
|
|
_causal_mask, |
|
|
B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len) |
|
|
|
|
|
|
|
|
@lru_cache |
|
|
def _bidirectional_mask(b, h, q_idx, kv_idx): |
|
|
bidirectional = q_idx == q_idx |
|
|
return bidirectional |
|
|
|
|
|
|
|
|
@lru_cache |
|
|
def _get_bidirectional_mask(seq_len): |
|
|
return create_block_mask( |
|
|
_bidirectional_mask, |
|
|
B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len) |
|
|
|
|
|
|
|
|
@lru_cache |
|
|
def _mixed_mask(b, h, q_idx, kv_idx, cutoffs): |
|
|
causal = q_idx >= kv_idx |
|
|
block_identity = q_idx >= cutoffs[b] |
|
|
return causal | block_identity |
|
|
|
|
|
|
|
|
@lru_cache |
|
|
def _get_mixed_mask(seq_len, cutoffs): |
|
|
return create_block_mask( |
|
|
partial(_mixed_mask, cutoffs=cutoffs), |
|
|
B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len) |
|
|
|
|
|
|
|
|
@lru_cache |
|
|
def _mixed2_mask(b, h, q_idx, kv_idx, cutoffs): |
|
|
causal = q_idx >= kv_idx |
|
|
block_identity = (q_idx < cutoffs[b]) & (kv_idx < cutoffs[b]) |
|
|
return causal | block_identity |
|
|
|
|
|
|
|
|
@lru_cache |
|
|
def _get_mixed2_mask(seq_len, cutoffs): |
|
|
return create_block_mask( |
|
|
partial(_mixed2_mask, cutoffs=cutoffs), |
|
|
B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len) |
|
|
|
|
|
|
|
|
def _block_diff_mask(b, h, q_idx, kv_idx, block_size=1, n=None): |
|
|
""" |
|
|
Copied directly from BD3LM's codebase: https://github.com/kuleshov-group/bd3lms |
|
|
|
|
|
Constructs the specialized block diffusion attention mask for training |
|
|
composed of three masks: |
|
|
- **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks |
|
|
- **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context |
|
|
- **Block Causal Mask (M_BC)**: Attention to update x0 |
|
|
|
|
|
Args: |
|
|
b, h: Batch and head indices (ignored for mask logic). |
|
|
q_idx, kv_idx: Query and Key indices. |
|
|
seq_len: Total sequence length. |
|
|
block_size: Defines the block structure. |
|
|
|
|
|
Returns: |
|
|
A boolean attention mask. |
|
|
""" |
|
|
|
|
|
|
|
|
x0_flag_q = (q_idx >= n) |
|
|
x0_flag_kv = (kv_idx >= n) |
|
|
|
|
|
|
|
|
block_q = torch.where(x0_flag_q == 1, |
|
|
(q_idx - n) // block_size, |
|
|
q_idx // block_size) |
|
|
block_kv = torch.where(x0_flag_kv == 1, |
|
|
(kv_idx - n) // block_size, |
|
|
kv_idx // block_size) |
|
|
|
|
|
|
|
|
block_diagonal = ( |
|
|
block_q == block_kv) & (x0_flag_q == x0_flag_kv) |
|
|
|
|
|
|
|
|
offset_block_causal = ((block_q > block_kv) |
|
|
& (x0_flag_kv == 1) |
|
|
& (x0_flag_q == 0)) |
|
|
|
|
|
|
|
|
block_causal = (block_q >= block_kv) & ( |
|
|
x0_flag_kv == 1) & (x0_flag_q == 1) |
|
|
|
|
|
|
|
|
return block_diagonal | offset_block_causal | block_causal |
|
|
|
|
|
|
|
|
@lru_cache |
|
|
def _get_seq_mask(seq_len): |
|
|
|
|
|
return create_block_mask( |
|
|
partial(_block_diff_mask, block_size=1, n=seq_len), |
|
|
B=None, H=None, Q_LEN=seq_len*2, KV_LEN=seq_len*2) |
|
|
|
|
|
|
|
|
def _block_diff_mask_prefix_lm(b, h, q_idx, kv_idx, n, cutoffs): |
|
|
block_diff_mask_output = _block_diff_mask( |
|
|
b, h, q_idx, kv_idx, block_size=1, n=n) |
|
|
block_prefix_lm = ( |
|
|
(n <= q_idx) & (q_idx < n + cutoffs[b]) |
|
|
& (n <= kv_idx) & (kv_idx < n + cutoffs[b])) |
|
|
return block_diff_mask_output | block_prefix_lm |
|
|
|
|
|
|
|
|
@lru_cache |
|
|
def _get_seq_mask_prefix_lm(seq_len, cutoffs): |
|
|
|
|
|
return create_block_mask( |
|
|
partial(_block_diff_mask_prefix_lm, n=seq_len, cutoffs=cutoffs), |
|
|
B=None, H=None, Q_LEN=seq_len*2, KV_LEN=seq_len*2) |
|
|
|
|
|
|
|
|
flex_attention_compiled = torch.compile(flex_attention, dynamic=False, fullgraph=True, mode='reduce-overhead') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fused_flex_attention(q, k, v, mask=None): |
|
|
return flex_attention_compiled(q, k, v, block_mask=mask) |
|
|
|
|
|
|
|
|
def bias_dropout_add_scale( |
|
|
x: torch.Tensor, |
|
|
bias: typing.Optional[torch.Tensor], |
|
|
scale: torch.Tensor, |
|
|
residual: typing.Optional[torch.Tensor], |
|
|
prob: float, |
|
|
training: bool) -> torch.Tensor: |
|
|
if bias is not None: |
|
|
out = scale * F.dropout(x + bias, p=prob, training=training) |
|
|
else: |
|
|
out = scale * F.dropout(x, p=prob, training=training) |
|
|
|
|
|
if residual is not None: |
|
|
out = residual + out |
|
|
return out |
|
|
|
|
|
|
|
|
def get_bias_dropout_add_scale(training): |
|
|
def _bias_dropout_add(x, bias, scale, residual, prob): |
|
|
return bias_dropout_add_scale( |
|
|
x, bias, scale, residual, prob, training) |
|
|
|
|
|
return _bias_dropout_add |
|
|
|
|
|
|
|
|
|
|
|
def modulate(x: torch.Tensor, |
|
|
shift: torch.Tensor, |
|
|
scale: torch.Tensor) -> torch.Tensor: |
|
|
return x * (1 + scale) + shift |
|
|
|
|
|
|
|
|
@torch.jit.script |
|
|
def bias_dropout_add_scale_fused_train( |
|
|
x: torch.Tensor, |
|
|
bias: typing.Optional[torch.Tensor], |
|
|
scale: torch.Tensor, |
|
|
residual: typing.Optional[torch.Tensor], |
|
|
prob: float) -> torch.Tensor: |
|
|
return bias_dropout_add_scale( |
|
|
x, bias, scale, residual, prob, True) |
|
|
|
|
|
|
|
|
@torch.jit.script |
|
|
def bias_dropout_add_scale_fused_inference( |
|
|
x: torch.Tensor, |
|
|
bias: typing.Optional[torch.Tensor], |
|
|
scale: torch.Tensor, |
|
|
residual: typing.Optional[torch.Tensor], |
|
|
prob: float) -> torch.Tensor: |
|
|
return bias_dropout_add_scale( |
|
|
x, bias, scale, residual, prob, False) |
|
|
|
|
|
|
|
|
@torch.jit.script |
|
|
def modulate_fused(x: torch.Tensor, |
|
|
shift: torch.Tensor, |
|
|
scale: torch.Tensor) -> torch.Tensor: |
|
|
return modulate(x, shift, scale) |
|
|
|
|
|
|
|
|
class Rotary(torch.nn.Module): |
|
|
def __init__(self, dim, base=10_000): |
|
|
super().__init__() |
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.register_buffer('inv_freq', inv_freq) |
|
|
self.seq_len_cached = None |
|
|
self.cos_cached = None |
|
|
self.sin_cached = None |
|
|
|
|
|
def forward(self, x, seq_dim=1): |
|
|
seq_len = x.shape[seq_dim] |
|
|
if seq_len != self.seq_len_cached: |
|
|
self.seq_len_cached = seq_len |
|
|
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) |
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone()) |
|
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
|
|
|
|
|
self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1) |
|
|
self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1) |
|
|
|
|
|
self.cos_cached[:,:,2,:,:].fill_(1.) |
|
|
self.sin_cached[:,:,2,:,:].fill_(0.) |
|
|
|
|
|
return self.cos_cached, self.sin_cached |
|
|
|
|
|
|
|
|
def rotate_half(x, interleaved=False): |
|
|
"""Copied and refactored from FlashAttention""" |
|
|
if interleaved: |
|
|
x1, x2 = x[..., ::2], x[..., 1::2] |
|
|
return einops.rearrange( |
|
|
torch.stack((-x2, x1), dim=-1), |
|
|
"... d two -> ... (d two)", |
|
|
two=2) |
|
|
x1, x2 = x.chunk(2, dim=-1) |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_emb_torch(x, cos, sin, interleaved=False): |
|
|
""" |
|
|
Copied and refactored from FlashAttention |
|
|
x: (batch_size, seq_len, nheads, headdim) |
|
|
cos, sin: (seq_len, rotary_dim / 2) or (batch_size, seq_len, rotary_dim / 2) |
|
|
""" |
|
|
ro_dim = cos.shape[-1] * 2 |
|
|
assert ro_dim <= x.shape[-1] |
|
|
pattern = "... d -> ... 1 (2 d)" |
|
|
if interleaved: |
|
|
pattern = "... d -> ... 1 (d 2)" |
|
|
cos = einops.repeat(cos, pattern) |
|
|
sin = einops.repeat(sin, pattern) |
|
|
return torch.cat( |
|
|
[x[..., :ro_dim] * cos |
|
|
+ rotate_half(x[..., :ro_dim], |
|
|
interleaved) * sin, x[..., ro_dim:]], |
|
|
dim=-1) |
|
|
|
|
|
|
|
|
def _split_rotary(rotary_cos_sin, dtype): |
|
|
cos, sin = rotary_cos_sin |
|
|
cos = cos.to(dtype) |
|
|
sin = sin.to(dtype) |
|
|
cos = cos[0,:,0,0,:cos.shape[-1]//2] |
|
|
sin = sin[0,:,0,0,:sin.shape[-1]//2] |
|
|
return cos, sin |
|
|
|
|
|
|
|
|
def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin): |
|
|
with torch.amp.autocast('cuda', enabled=False): |
|
|
cos, sin = _split_rotary(rotary_cos_sin, dtype=qkv.dtype) |
|
|
q, k, v = qkv.chunk(3, dim=2) |
|
|
q = apply_rotary_emb_torch( |
|
|
q.squeeze(dim=2), cos, sin) |
|
|
k = apply_rotary_emb_torch( |
|
|
k.squeeze(dim=2), cos, sin) |
|
|
v = v.squeeze(dim=2) |
|
|
return q, k, v |
|
|
|
|
|
|
|
|
def split_and_apply_rotary_pos_emb_batch(qkv, rotary_cos_sin): |
|
|
with torch.amp.autocast('cuda', enabled=False): |
|
|
cos, sin = rotary_cos_sin |
|
|
cos = cos.to(qkv.dtype) |
|
|
sin = sin.to(qkv.dtype) |
|
|
cos = cos[:,:,0,0,:cos.shape[-1]//2] |
|
|
sin = sin[:,:,0,0,:sin.shape[-1]//2] |
|
|
q, k, v = qkv.chunk(3, dim=2) |
|
|
q = apply_rotary_emb_torch( |
|
|
q.squeeze(dim=2), cos, sin) |
|
|
k = apply_rotary_emb_torch( |
|
|
k.squeeze(dim=2), cos, sin) |
|
|
v = v.squeeze(dim=2) |
|
|
return q, k, v |
|
|
|
|
|
|
|
|
def flex_attention_multi_headed(q, k, v, mask): |
|
|
q = q.transpose(1, 2).contiguous() |
|
|
k = k.transpose(1, 2).contiguous() |
|
|
v = v.transpose(1, 2).contiguous() |
|
|
attention_output = fused_flex_attention(q, k, v, mask=mask) |
|
|
attention_output = attention_output.transpose(1, 2).contiguous() |
|
|
return einops.rearrange(attention_output, 'b s h d -> b s (h d)') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones([dim])) |
|
|
self.dim = dim |
|
|
def forward(self, x): |
|
|
with torch.amp.autocast('cuda', enabled=False): |
|
|
x = F.layer_norm(x.float(), [self.dim]) |
|
|
return x * self.weight[None, None, :] |
|
|
|
|
|
|
|
|
def residual_linear(x, W, x_skip, residual_scale): |
|
|
"""x_skip + residual_scale * W @ x""" |
|
|
dim_out, dim_in = W.shape[0], W.shape[1] |
|
|
return torch.addmm( |
|
|
x_skip.view(-1, dim_out), |
|
|
x.view(-1, dim_in), |
|
|
W.T, |
|
|
alpha=residual_scale).view(*x.shape[:-1], dim_out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TimestepEmbedder(nn.Module): |
|
|
""" |
|
|
Embeds scalar timesteps into vector representations. |
|
|
""" |
|
|
def __init__(self, hidden_size, frequency_embedding_size=256): |
|
|
super().__init__() |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(frequency_embedding_size, hidden_size, bias=True), |
|
|
nn.SiLU(), |
|
|
nn.Linear(hidden_size, hidden_size, bias=True)) |
|
|
self.frequency_embedding_size = frequency_embedding_size |
|
|
|
|
|
@staticmethod |
|
|
def timestep_embedding(t, dim, max_period=10000): |
|
|
""" |
|
|
Create sinusoidal timestep embeddings. |
|
|
:param t: a 1-D Tensor of N indices, one per batch element. |
|
|
These may be fractional. |
|
|
:param dim: the dimension of the output. |
|
|
:param max_period: controls the minimum frequency of the embeddings. |
|
|
:return: an (N, D) Tensor of positional embeddings. |
|
|
""" |
|
|
|
|
|
half = dim // 2 |
|
|
freqs = torch.exp( |
|
|
- math.log(max_period) |
|
|
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) |
|
|
/ half) |
|
|
args = t[:, None].float() * freqs[None] |
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
if dim % 2: |
|
|
embedding = torch.cat( |
|
|
[embedding, |
|
|
torch.zeros_like(embedding[:, :1])], dim=-1) |
|
|
return embedding |
|
|
|
|
|
def forward(self, t): |
|
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
|
|
t_emb = self.mlp(t_freq) |
|
|
return t_emb |
|
|
|
|
|
|
|
|
class LabelEmbedder(nn.Module): |
|
|
"""Embeds class labels into vector representations. |
|
|
|
|
|
Also handles label dropout for classifier-free guidance. |
|
|
""" |
|
|
def __init__(self, num_classes, cond_size): |
|
|
super().__init__() |
|
|
self.embedding_table = nn.Embedding(num_classes + 1, cond_size) |
|
|
self.num_classes = num_classes |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, labels): |
|
|
embeddings = self.embedding_table(labels) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DDiTBlockCausal(nn.Module): |
|
|
def __init__(self, dim, n_heads, mlp_ratio=4, dropout=0.1): |
|
|
super().__init__() |
|
|
self.n_heads = n_heads |
|
|
|
|
|
self.dim = dim |
|
|
self.norm1 = LayerNorm(dim) |
|
|
self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) |
|
|
self.attn_out = nn.Linear(dim, dim, bias=False) |
|
|
self.dropout1 = nn.Dropout(dropout) |
|
|
|
|
|
self.norm2 = LayerNorm(dim) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(dim, mlp_ratio * dim, bias=True), |
|
|
nn.GELU(approximate='tanh'), |
|
|
nn.Linear(mlp_ratio * dim, dim, bias=True)) |
|
|
self.dropout2 = nn.Dropout(dropout) |
|
|
self.dropout = dropout |
|
|
|
|
|
self.past_k = None |
|
|
self.past_v = None |
|
|
|
|
|
def _get_bias_dropout_scale(self): |
|
|
if self.training: |
|
|
return bias_dropout_add_scale_fused_train |
|
|
else: |
|
|
return bias_dropout_add_scale_fused_inference |
|
|
|
|
|
def reset_kv_cache(self): |
|
|
self.past_k = None |
|
|
self.past_v = None |
|
|
|
|
|
def _process_and_update_kv(self, k, v): |
|
|
if (self.past_k is not None |
|
|
and self.past_v is not None): |
|
|
k = torch.cat([self.past_k, k], dim=1) |
|
|
v = torch.cat([self.past_v, v], dim=1) |
|
|
self.past_k = k |
|
|
self.past_v = v |
|
|
return k, v |
|
|
|
|
|
@torch.no_grad() |
|
|
def _attention_with_kv_cache(self, qkv, rotary_cos_sin): |
|
|
assert qkv.shape[1] == 1 |
|
|
q, k, v = qkv.chunk(3, dim=2) |
|
|
k, v = self._process_and_update_kv(k=k, v=v) |
|
|
with torch.amp.autocast('cuda', enabled=False): |
|
|
cos, sin = _split_rotary(rotary_cos_sin, q.dtype) |
|
|
q = apply_rotary_emb_torch( |
|
|
q.squeeze(dim=2), cos[-1:, :], sin[-1:, :]) |
|
|
k = apply_rotary_emb_torch(k.squeeze(dim=2), cos, sin) |
|
|
v = v.squeeze(dim=2) |
|
|
scale = q.shape[-1] ** 0.5 |
|
|
|
|
|
q = q.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale |
|
|
attn_weights = F.softmax(attn_scores, dim=-1) |
|
|
x = torch.matmul(attn_weights, v).transpose(1, 2) |
|
|
return x.view(x.shape[0], 1, self.dim) |
|
|
|
|
|
def forward(self, x, rotary_cos_sin, kv_cache=False, **kwargs): |
|
|
del kwargs |
|
|
bias_dropout_scale_fn = self._get_bias_dropout_scale() |
|
|
x_skip = x |
|
|
x = self.norm1(x) |
|
|
qkv = einops.rearrange( |
|
|
self.attn_qkv(x), |
|
|
'b s (three h d) -> b s three h d', |
|
|
three=3, |
|
|
h=self.n_heads) |
|
|
|
|
|
if kv_cache: |
|
|
x = self._attention_with_kv_cache(qkv.detach()) |
|
|
else: |
|
|
q, k, v = split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin) |
|
|
|
|
|
|
|
|
attn_mask = _get_causal_mask(x.shape[1]) |
|
|
x = flex_attention_multi_headed(q, k, v, attn_mask) |
|
|
|
|
|
scale = torch.ones(1, device=x.device, dtype=x.dtype) |
|
|
x = bias_dropout_scale_fn( |
|
|
self.attn_out(x), None, scale, x_skip, self.dropout) |
|
|
|
|
|
|
|
|
x = bias_dropout_scale_fn( |
|
|
self.mlp(self.norm2(x)), None, scale, x, self.dropout) |
|
|
return x |
|
|
|
|
|
|
|
|
class DDiTBlock(nn.Module): |
|
|
def __init__(self, dim, n_heads, adaLN, |
|
|
cond_dim=None, mlp_ratio=4, |
|
|
dropout=0.1): |
|
|
super().__init__() |
|
|
self.n_heads = n_heads |
|
|
self.dim = dim |
|
|
self.adaLN = adaLN |
|
|
|
|
|
self.norm1 = LayerNorm(dim) |
|
|
self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) |
|
|
self.attn_out = nn.Linear(dim, dim, bias=False) |
|
|
self.dropout1 = nn.Dropout(dropout) |
|
|
|
|
|
self.norm2 = LayerNorm(dim) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(dim, mlp_ratio * dim, bias=True), |
|
|
nn.GELU(approximate='tanh'), |
|
|
nn.Linear(mlp_ratio * dim, dim, bias=True)) |
|
|
self.dropout2 = nn.Dropout(dropout) |
|
|
self.dropout = dropout |
|
|
|
|
|
if self.adaLN: |
|
|
self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim) |
|
|
self.adaLN_modulation.weight.data.zero_() |
|
|
self.adaLN_modulation.bias.data.zero_() |
|
|
|
|
|
self.past_k = None |
|
|
self.past_v = None |
|
|
self.neg_infinity = -1000000.0 |
|
|
|
|
|
def _get_bias_dropout_scale(self): |
|
|
if self.training: |
|
|
return bias_dropout_add_scale_fused_train |
|
|
else: |
|
|
return bias_dropout_add_scale_fused_inference |
|
|
|
|
|
def reset_kv_cache(self): |
|
|
self.past_k = None |
|
|
self.past_v = None |
|
|
|
|
|
def _process_and_update_kv(self, k, v, num_clean): |
|
|
if num_clean == 0: |
|
|
|
|
|
return k, v |
|
|
else: |
|
|
if (self.past_k is None |
|
|
and self.past_v is None): |
|
|
self.past_k = k[:, :num_clean, :, :] |
|
|
self.past_v = v[:, :num_clean, :, :] |
|
|
return k, v |
|
|
else: |
|
|
k_so_far = torch.cat([self.past_k, k], dim=1) |
|
|
v_so_far = torch.cat([self.past_v, v], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
self.past_k = torch.cat( |
|
|
[self.past_k, k[:, :num_clean, :, :]], dim=1) |
|
|
self.past_v = torch.cat( |
|
|
[self.past_v, v[:, :num_clean, :, :]], dim=1) |
|
|
return k_so_far, v_so_far |
|
|
|
|
|
@torch.no_grad() |
|
|
def _attention_with_kv_cache(self, qkv, rotary_cos_sin, |
|
|
num_clean, num_clean_and_mask): |
|
|
|
|
|
|
|
|
assert qkv.shape[1] == num_clean_and_mask |
|
|
|
|
|
|
|
|
q, k, v = qkv.chunk(3, dim=2) |
|
|
q = q.squeeze(dim=2) |
|
|
k = k.squeeze(dim=2) |
|
|
v = v.squeeze(dim=2) |
|
|
k, v = self._process_and_update_kv( |
|
|
k=k, v=v, num_clean=num_clean) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.amp.autocast('cuda', enabled=False): |
|
|
cos, sin = rotary_cos_sin |
|
|
cos = cos.to(qkv.dtype) |
|
|
sin = sin.to(qkv.dtype) |
|
|
cos = cos[:,:,0,0,:cos.shape[-1]//2] |
|
|
sin = sin[:,:,0,0,:sin.shape[-1]//2] |
|
|
cos_part = cos[:, -num_clean_and_mask:] |
|
|
sin_part = sin[:, -num_clean_and_mask:] |
|
|
q = apply_rotary_emb_torch(q, cos_part, sin_part) |
|
|
k = apply_rotary_emb_torch(k, cos, sin) |
|
|
scale = q.shape[-1] ** 0.5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q = q.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale |
|
|
ones = torch.ones( |
|
|
num_clean_and_mask, num_clean_and_mask).to(qkv.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
A = self.neg_infinity * torch.triu(ones, diagonal=1) |
|
|
A = A.view(1, 1, num_clean_and_mask, num_clean_and_mask) |
|
|
attn_scores[:, :, :, -num_clean_and_mask:] += A |
|
|
attn_weights = F.softmax(attn_scores, dim=-1) |
|
|
|
|
|
|
|
|
attn_output = torch.matmul(attn_weights, v).transpose(1, 2) |
|
|
return einops.rearrange(attn_output, 'b s h d -> b s (h d)') |
|
|
|
|
|
def forward(self, x, rotary_cos_sin, c=None, attn_mask=None, |
|
|
kv_cache=False, num_clean=None, num_clean_and_mask=None): |
|
|
bias_dropout_scale_fn = self._get_bias_dropout_scale() |
|
|
|
|
|
x_skip = x |
|
|
x = self.norm1(x) |
|
|
if self.adaLN: |
|
|
|
|
|
|
|
|
|
|
|
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, |
|
|
gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2) |
|
|
x = modulate_fused(x, shift_msa, scale_msa) |
|
|
|
|
|
qkv = einops.rearrange( |
|
|
self.attn_qkv(x), |
|
|
'b s (three h d) -> b s three h d', |
|
|
three=3, |
|
|
h=self.n_heads).contiguous() |
|
|
if kv_cache: |
|
|
x = self._attention_with_kv_cache( |
|
|
qkv.detach(), rotary_cos_sin, |
|
|
num_clean=num_clean, num_clean_and_mask=num_clean_and_mask) |
|
|
else: |
|
|
if rotary_cos_sin[0].shape[0] > 1: |
|
|
q, k, v = split_and_apply_rotary_pos_emb_batch(qkv, rotary_cos_sin) |
|
|
else: |
|
|
q, k, v = split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin) |
|
|
x = flex_attention_multi_headed(q, k, v, attn_mask) |
|
|
|
|
|
if self.adaLN: |
|
|
x = bias_dropout_scale_fn(self.attn_out(x), |
|
|
None, |
|
|
gate_msa, |
|
|
x_skip, |
|
|
self.dropout) |
|
|
x = bias_dropout_scale_fn( |
|
|
self.mlp(modulate_fused( |
|
|
self.norm2(x), shift_mlp, scale_mlp)), |
|
|
None, gate_mlp, x, self.dropout) |
|
|
else: |
|
|
scale = torch.ones(1, device=x.device, dtype=x.dtype) |
|
|
x = bias_dropout_scale_fn( |
|
|
self.attn_out(x), None, scale, x_skip, self.dropout) |
|
|
x = bias_dropout_scale_fn( |
|
|
self.mlp(self.norm2(x)), None, scale, x, self.dropout) |
|
|
return x |
|
|
|
|
|
|
|
|
class EmbeddingLayer(nn.Module): |
|
|
def __init__(self, dim, vocab_dim): |
|
|
super().__init__() |
|
|
self.embedding = nn.Parameter(torch.empty((vocab_dim, dim))) |
|
|
torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5)) |
|
|
|
|
|
def forward(self, x): |
|
|
if x.ndim == 2: |
|
|
return self.embedding[x] |
|
|
assert x.ndim == 3 |
|
|
return torch.einsum( |
|
|
"blv,ve->ble", |
|
|
torch.nn.functional.softmax(x, dim=-1).float(), |
|
|
self.embedding.float()).to(x.dtype) |
|
|
|
|
|
|
|
|
class DDiTFinalLayer(nn.Module): |
|
|
def __init__(self, hidden_size, out_channels, cond_dim, |
|
|
adaLN): |
|
|
super().__init__() |
|
|
self.norm_final = LayerNorm(hidden_size) |
|
|
self.linear = nn.Linear(hidden_size, out_channels) |
|
|
self.linear.weight.data.zero_() |
|
|
self.linear.bias.data.zero_() |
|
|
self.adaLN = adaLN |
|
|
if self.adaLN: |
|
|
self.adaLN_modulation = nn.Linear(cond_dim, |
|
|
2 * hidden_size, |
|
|
bias=True) |
|
|
self.adaLN_modulation.weight.data.zero_() |
|
|
self.adaLN_modulation.bias.data.zero_() |
|
|
|
|
|
|
|
|
def forward(self, x, c): |
|
|
x = self.norm_final(x) |
|
|
if self.adaLN: |
|
|
shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2) |
|
|
x = modulate_fused(x, shift, scale) |
|
|
x = self.linear(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class DiT(nn.Module, huggingface_hub.PyTorchModelHubMixin): |
|
|
def __init__(self, config, vocab_size: int): |
|
|
super().__init__() |
|
|
if type(config) == dict: |
|
|
config = omegaconf.OmegaConf.create(config) |
|
|
self.causal = config.algo.causal_attention |
|
|
self.adaLN = not self.causal |
|
|
self.config = config |
|
|
self.vocab_size = vocab_size |
|
|
dim = config.model.hidden_size |
|
|
cond_dim = config.model.cond_dim |
|
|
self.vocab_embed = EmbeddingLayer(dim, vocab_size) |
|
|
if not self.causal: |
|
|
self.sigma_map = TimestepEmbedder(cond_dim) |
|
|
self.rotary_dim = dim // config.model.n_heads |
|
|
self.rotary_emb = Rotary(self.rotary_dim) |
|
|
|
|
|
blocks = [] |
|
|
for _ in range(config.model.n_blocks): |
|
|
if self.causal: |
|
|
block = DDiTBlockCausal( |
|
|
dim=dim, |
|
|
n_heads=config.model.n_heads, |
|
|
dropout=config.model.dropout) |
|
|
else: |
|
|
block = DDiTBlock( |
|
|
dim=dim, |
|
|
n_heads=config.model.n_heads, |
|
|
cond_dim=cond_dim, |
|
|
adaLN=self.adaLN, |
|
|
dropout=config.model.dropout) |
|
|
blocks.append(block) |
|
|
self.blocks = nn.ModuleList(blocks) |
|
|
|
|
|
self.output_layer = DDiTFinalLayer( |
|
|
hidden_size=dim, |
|
|
out_channels=vocab_size, |
|
|
cond_dim=cond_dim, |
|
|
adaLN=self.adaLN) |
|
|
self.scale_by_sigma = config.model.scale_by_sigma |
|
|
|
|
|
def _get_bias_dropout_scale(self): |
|
|
if self.training: |
|
|
return bias_dropout_add_scale_fused_train |
|
|
else: |
|
|
return bias_dropout_add_scale_fused_inference |
|
|
|
|
|
def reset_kv_cache(self): |
|
|
for block in self.blocks: |
|
|
block.reset_kv_cache() |
|
|
|
|
|
def forward(self, x, sigma, x0=None, kv_cache=False): |
|
|
assert x0 is None |
|
|
x = self.vocab_embed(x) |
|
|
if self.causal: |
|
|
t_cond = None |
|
|
else: |
|
|
t_cond = F.silu(self.sigma_map(sigma)) |
|
|
|
|
|
rotary_cos_sin = self.rotary_emb(x) |
|
|
if kv_cache: |
|
|
x = x[:, -1:, :] |
|
|
with torch.amp.autocast('cuda', dtype=torch.bfloat16): |
|
|
for i in range(len(self.blocks)): |
|
|
x = self.blocks[i]( |
|
|
x, rotary_cos_sin, c=t_cond, kv_cache=kv_cache) |
|
|
x = self.output_layer(x, c=t_cond) |
|
|
return x |
|
|
|
|
|
|
|
|
def _get_reverse_indices(indices): |
|
|
""" |
|
|
indices: LongTensor of shape [B, N] representing permutations |
|
|
returns: LongTensor of shape [B, N] representing the inverse permutations |
|
|
""" |
|
|
B, N = indices.shape |
|
|
reverse_indices = torch.empty_like(indices) |
|
|
arange = torch.arange(N, device=indices.device).unsqueeze(0).expand(B, -1) |
|
|
reverse_indices.scatter_(1, indices, arange) |
|
|
return reverse_indices |
|
|
|
|
|
|
|
|
class EsoLMDiT(DiT): |
|
|
def __init__(self, config, vocab_size: int, mask_index: int): |
|
|
super().__init__(config, vocab_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert not self.causal and self.adaLN |
|
|
self.mask_index = mask_index |
|
|
|
|
|
self.diffusion_shuffle = config.algo.diffusion_shuffle |
|
|
self.diffusion_attn_mode = config.algo.diffusion_attn_mode |
|
|
self.sequential_shuffle = config.algo.sequential_shuffle |
|
|
self.sequential_attn_mode = config.algo.sequential_attn_mode |
|
|
|
|
|
self.mdlm_mask = None |
|
|
self.seq_mask = None |
|
|
|
|
|
def _sort_indices( |
|
|
self, indices, shuffle, keep_masks_unshuffled=False): |
|
|
masked = (indices == self.mask_index) |
|
|
if shuffle: |
|
|
offsets = torch.rand( |
|
|
indices.shape).to(indices.device) * 0.9 |
|
|
if keep_masks_unshuffled: |
|
|
|
|
|
|
|
|
offsets[masked] = torch.linspace( |
|
|
0, 1, torch.sum(masked)).to(indices.device) |
|
|
else: |
|
|
offsets = torch.linspace( |
|
|
0, 0.9, indices.shape[1]).to(indices.device) |
|
|
sort_idx = (masked + offsets).argsort(descending=False) |
|
|
indices = torch.gather(indices, dim=1, index=sort_idx) |
|
|
return indices, sort_idx |
|
|
|
|
|
def _sort_rotary_cos_sin(self, rotary_cos_sin, sort_idx): |
|
|
|
|
|
|
|
|
cos, sin = rotary_cos_sin |
|
|
bs = sort_idx.shape[0] |
|
|
cos = cos.expand(bs, -1, -1, -1, -1) |
|
|
sin = sin.expand(bs, -1, -1, -1, -1) |
|
|
cos = torch.gather( |
|
|
cos, dim=1, |
|
|
index=sort_idx[:, :, None, None, None].expand( |
|
|
-1, -1, 3, -1, self.rotary_dim)).contiguous() |
|
|
sin = torch.gather( |
|
|
sin, dim=1, |
|
|
index=sort_idx[:, :, None, None, None].expand( |
|
|
-1, -1, 3, -1, self.rotary_dim)).contiguous() |
|
|
return cos, sin |
|
|
|
|
|
def _get_attention_mask(self, seq_len, attn_mode=None, |
|
|
cutoffs=None): |
|
|
if attn_mode == 'causal': |
|
|
if self.mdlm_mask is None: |
|
|
self.mdlm_mask = _get_causal_mask(seq_len) |
|
|
return self.mdlm_mask |
|
|
elif attn_mode == 'bidirectional': |
|
|
if self.mdlm_mask is None: |
|
|
self.mdlm_mask = _get_bidirectional_mask(seq_len) |
|
|
return self.mdlm_mask |
|
|
elif attn_mode == 'mixed': |
|
|
|
|
|
|
|
|
return _get_mixed_mask(seq_len=seq_len, |
|
|
cutoffs=cutoffs) |
|
|
elif attn_mode == 'mixed2': |
|
|
|
|
|
|
|
|
return _get_mixed2_mask(seq_len=seq_len, |
|
|
cutoffs=cutoffs) |
|
|
|
|
|
def _diffusion_features(self, zt, sort_idx=None, |
|
|
attn_mode=None, cutoffs=None): |
|
|
|
|
|
|
|
|
|
|
|
if cutoffs is None: |
|
|
cutoffs = torch.sum(zt != self.mask_index, dim=1) |
|
|
if attn_mode is None: |
|
|
attn_mode = self.diffusion_attn_mode |
|
|
if sort_idx is None: |
|
|
zt, sort_idx = self._sort_indices( |
|
|
zt, self.diffusion_shuffle) |
|
|
x = self.vocab_embed(zt) |
|
|
rotary_cos_sin = self.rotary_emb(x) |
|
|
rotary_cos_sin = self._sort_rotary_cos_sin( |
|
|
rotary_cos_sin, sort_idx) |
|
|
attention_mask = self._get_attention_mask( |
|
|
seq_len=zt.shape[1], |
|
|
attn_mode=attn_mode, |
|
|
cutoffs=cutoffs) |
|
|
return {'x': x, |
|
|
'rotary': rotary_cos_sin, |
|
|
'attention': attention_mask, |
|
|
'sorted_indices': sort_idx} |
|
|
|
|
|
def _sequential_features(self, zt, x0): |
|
|
|
|
|
|
|
|
|
|
|
seq_len = zt.shape[1] |
|
|
zt, sort_idx = self._sort_indices( |
|
|
zt, self.sequential_shuffle, |
|
|
keep_masks_unshuffled=True) |
|
|
x0 = torch.gather(x0, dim=1, index=sort_idx) |
|
|
zt_and_x0 = torch.cat([zt, x0], dim=1) |
|
|
cutoffs = torch.sum(zt != self.mask_index, dim=1) |
|
|
x = self.vocab_embed(zt_and_x0) |
|
|
rotary_cos_sin = self.rotary_emb(x[:, :seq_len]) |
|
|
rotary_cos_sin = self._sort_rotary_cos_sin( |
|
|
rotary_cos_sin, sort_idx) |
|
|
cos, sin = rotary_cos_sin |
|
|
cos = torch.cat([cos, cos], dim=1) |
|
|
sin = torch.cat([sin, sin], dim=1) |
|
|
rotary_cos_sin = (cos, sin) |
|
|
|
|
|
if self.sequential_attn_mode == 'causal': |
|
|
if self.seq_mask is None: |
|
|
self.seq_mask = _get_seq_mask(seq_len) |
|
|
return {'x': x, |
|
|
'rotary': rotary_cos_sin, |
|
|
'attention': self.seq_mask, |
|
|
'sorted_indices': sort_idx} |
|
|
elif self.sequential_attn_mode == 'mixed': |
|
|
return {'x': x, |
|
|
'rotary': rotary_cos_sin, |
|
|
'attention': _get_seq_mask_prefix_lm( |
|
|
seq_len, cutoffs=cutoffs), |
|
|
'sorted_indices': sort_idx} |
|
|
|
|
|
def forward(self, zt, sigma, x0=None): |
|
|
diffusion_mode = x0 is None |
|
|
seq_len = zt.shape[1] |
|
|
|
|
|
if diffusion_mode: |
|
|
features = self._diffusion_features(zt) |
|
|
else: |
|
|
features = self._sequential_features(zt, x0) |
|
|
x = features['x'] |
|
|
t_cond = F.silu(self.sigma_map(sigma)) |
|
|
with torch.amp.autocast('cuda', enabled=False): |
|
|
for i in range(len(self.blocks)): |
|
|
x = self.blocks[i](x, features['rotary'], c=t_cond, |
|
|
attn_mask=features['attention']) |
|
|
x = self.output_layer(x, c=t_cond) |
|
|
|
|
|
if not diffusion_mode: |
|
|
x = x[:, :seq_len] |
|
|
sort_idx_reversed = _get_reverse_indices(features['sorted_indices']) |
|
|
x = torch.gather( |
|
|
x, dim=1, |
|
|
index=sort_idx_reversed[:, :, None].expand( |
|
|
-1, -1, self.vocab_size)) |
|
|
return x |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward_sample(self, zt, sort_idx, attn_mode=None, |
|
|
cutoffs=None, kv_cache=False, |
|
|
last_k_start=None, |
|
|
curr_k_start=None, |
|
|
curr_k_end=None): |
|
|
""" |
|
|
zt is expected to be sorted as per sort_idx. |
|
|
|
|
|
When kv_cache is true: |
|
|
- zt will have shape (num_samples, model.length); we need its shape to generate |
|
|
all the rotary embeddings because any of them can be selected by |
|
|
the random ordering |
|
|
- sort_idx will have shape |
|
|
(num_samples, model.length) for the same reason |
|
|
- last_k_start_idx (starting index) |
|
|
- curr_k_start_idx |
|
|
- curr_k_end_idx (ending index) |
|
|
- use these two to select features['x'] to pass into the blocks |
|
|
|
|
|
Within self._diffusion_features, zt will be used |
|
|
to generate the full rotary embeddings, and sort_idx |
|
|
will be index the embedded zt into shape |
|
|
(num_samples, num_tokens_generated_last_time (non-mask) + num_tokens_to_gen (mask), hidden) |
|
|
|
|
|
We want to append the kv values for num_tokens_generated_last_time to the old kv cache |
|
|
and not build up kv values for num_tokens_to_gen (because they are masks) |
|
|
""" |
|
|
assert attn_mode is not None |
|
|
ones = torch.ones(zt.shape[0], device=zt.device) |
|
|
if cutoffs is not None: |
|
|
cutoffs = cutoffs * ones |
|
|
assert cutoffs.ndim == 1 |
|
|
features = self._diffusion_features( |
|
|
zt=zt, |
|
|
sort_idx=sort_idx, |
|
|
attn_mode=attn_mode, |
|
|
cutoffs=cutoffs) |
|
|
zeros = torch.zeros(zt.shape[0], device=zt.device) |
|
|
t_cond = F.silu(self.sigma_map(zeros)) |
|
|
|
|
|
x = features['x'] |
|
|
rotary = features['rotary'] |
|
|
if kv_cache: |
|
|
|
|
|
x = x[:, last_k_start:curr_k_end, :] |
|
|
|
|
|
|
|
|
cos, sin = rotary |
|
|
rotary = (cos[:, :curr_k_end], sin[:, :curr_k_end]) |
|
|
num_clean = curr_k_start - last_k_start |
|
|
num_clean_and_mask = curr_k_end - last_k_start |
|
|
else: |
|
|
num_clean = None |
|
|
num_clean_and_mask = None |
|
|
|
|
|
with torch.amp.autocast('cuda', enabled=False): |
|
|
for i in range(len(self.blocks)): |
|
|
x = self.blocks[i]( |
|
|
x, rotary, c=t_cond, |
|
|
attn_mask=features['attention'], |
|
|
kv_cache=kv_cache, |
|
|
num_clean=num_clean, |
|
|
num_clean_and_mask=num_clean_and_mask) |
|
|
x = self.output_layer(x, c=t_cond) |
|
|
|
|
|
if kv_cache: |
|
|
x = x[:, num_clean:, :] |
|
|
return x |
|
|
|
|
|
|
|
|
class EsoLMHFDiT(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.vocab_embed = EmbeddingLayer( |
|
|
config.hidden_size, config.vocab_size) |
|
|
self.sigma_map = TimestepEmbedder(config.cond_dim) |
|
|
self.rotary_dim = config.hidden_size // config.n_heads |
|
|
self.rotary_emb = Rotary(self.rotary_dim) |
|
|
|
|
|
blocks = [] |
|
|
for _ in range(config.n_blocks): |
|
|
block = DDiTBlock( |
|
|
dim=config.hidden_size, |
|
|
n_heads=config.n_heads, |
|
|
cond_dim=config.cond_dim, |
|
|
adaLN=True, |
|
|
dropout=config.dropout) |
|
|
blocks.append(block) |
|
|
self.blocks = nn.ModuleList(blocks) |
|
|
|
|
|
self.output_layer = DDiTFinalLayer( |
|
|
hidden_size=config.hidden_size, |
|
|
out_channels=config.vocab_size, |
|
|
cond_dim=config.cond_dim, |
|
|
adaLN=True) |
|
|
|
|
|
def reset_kv_cache(self): |
|
|
for block in self.blocks: |
|
|
block.reset_kv_cache() |
|
|
|
|
|
|
|
|
class EsoLM(transformers.PreTrainedModel): |
|
|
"""HF-compatible model.""" |
|
|
config_class = EsoLMConfig |
|
|
base_model_prefix = 'esolm' |
|
|
|
|
|
def __init__(self, config: EsoLMConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.backbone = EsoLMHFDiT(config) |
|
|
|