Raja-65's picture
Add ELF demo
6ab8280
Raw
History Blame Contribute Delete
10.2 kB
"""ELF transformer model."""
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from modules.layers import (
Attention, BottleneckTextProj, FinalLayer, RMSNorm, SwiGLUFFN,
TextRotaryEmbeddingFast, TimestepEmbedder,
DEFAULT_KERNEL_INIT, DEFAULT_BIAS_INIT, NORMAL_INIT_002,
_make_linear,
)
class ELFBlock(nn.Module):
"""ELF Transformer block."""
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0,
attn_drop: float = 0.0, proj_drop: float = 0.0):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.attn_drop = attn_drop
self.proj_drop = proj_drop
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = Attention(
hidden_size, num_heads, qkv_bias=True, qk_norm=True,
attn_drop=attn_drop, proj_drop=proj_drop,
)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
self.mlp = SwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
def forward(self, x: torch.Tensor, rope_fn: Optional[nn.Module] = None,
attention_mask: Optional[torch.Tensor] = None,
deterministic: bool = True) -> torch.Tensor:
x_normed = self.norm1(x)
attn_out = self.attn(x_normed, rope_fn, attention_mask=attention_mask,
deterministic=deterministic)
x = x + attn_out
x_normed = self.norm2(x)
mlp_out = self.mlp(x_normed, deterministic=deterministic)
x = x + mlp_out
return x
class ELF(nn.Module):
"""Text ELF Transformer."""
def __init__(
self,
text_encoder_dim: int,
max_length: int,
hidden_size: int = 1024,
depth: int = 24,
num_heads: int = 16,
mlp_ratio: float = 4.0,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
bottleneck_dim: int = 128,
num_time_tokens: int = 4,
num_self_cond_cfg_tokens: int = 4,
num_model_mode_tokens: int = 0,
vocab_size: int = 0,
gradient_checkpointing: bool = False,
):
super().__init__()
self.text_encoder_dim = text_encoder_dim
self.max_length = max_length
self.hidden_size = hidden_size
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.attn_drop = attn_drop
self.proj_drop = proj_drop
self.bottleneck_dim = bottleneck_dim
self.num_time_tokens = num_time_tokens
self.num_self_cond_cfg_tokens = num_self_cond_cfg_tokens
self.num_model_mode_tokens = num_model_mode_tokens
self.vocab_size = vocab_size
self.gradient_checkpointing = gradient_checkpointing
# Self-conditioning input projection (only used when input is [z, x_pred]).
self.self_cond_proj = _make_linear(2 * text_encoder_dim, text_encoder_dim, bias=True)
# Text bottleneck projection.
self.text_proj = BottleneckTextProj(text_encoder_dim, hidden_size, bottleneck_dim)
# Time / SC-CFG embedders + learned prefix tokens.
if num_time_tokens <= 0:
raise ValueError("num_time_tokens must be positive for prefix time conditioning")
self.t_embedder = TimestepEmbedder(hidden_size)
self.t_emb_tokens = nn.Parameter(torch.empty(1, num_time_tokens, hidden_size))
NORMAL_INIT_002(self.t_emb_tokens)
if num_self_cond_cfg_tokens > 0:
self.self_cond_cfg_embedder = TimestepEmbedder(hidden_size)
self.self_cond_cfg_tokens = nn.Parameter(torch.empty(1, num_self_cond_cfg_tokens, hidden_size))
NORMAL_INIT_002(self.self_cond_cfg_tokens)
if num_model_mode_tokens > 0:
self.mode_tokens = nn.Parameter(torch.empty(1, num_model_mode_tokens, hidden_size))
NORMAL_INIT_002(self.mode_tokens)
head_dim = hidden_size // num_heads
prefix_total = num_model_mode_tokens + num_time_tokens
if num_self_cond_cfg_tokens > 0:
prefix_total += num_self_cond_cfg_tokens
self.feat_rope = TextRotaryEmbeddingFast(
dim=head_dim, pt_seq_len=max_length, num_empty_token=prefix_total,
)
self.blocks = nn.ModuleList()
q1, q3 = depth // 4, depth // 4 * 3
for i in range(depth):
in_drop_range = q3 > i >= q1
self.blocks.append(ELFBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio,
attn_drop=attn_drop if in_drop_range else 0.0,
proj_drop=proj_drop if in_drop_range else 0.0,
))
# Final flow-matching output head.
self.final_layer = FinalLayer(hidden_size, patch_size=1, out_channels=text_encoder_dim)
# Factored decoder unembedding: hidden -> text_encoder_dim -> vocab.
bn = text_encoder_dim
self.proj_kernel = nn.Parameter(torch.empty(hidden_size, bn))
self.proj_bias = nn.Parameter(torch.empty(bn))
self.unembed_kernel = nn.Parameter(torch.empty(bn, vocab_size))
self.unembed_bias = nn.Parameter(torch.empty(vocab_size))
DEFAULT_KERNEL_INIT(self.proj_kernel)
DEFAULT_BIAS_INIT(self.proj_bias)
DEFAULT_KERNEL_INIT(self.unembed_kernel)
DEFAULT_BIAS_INIT(self.unembed_bias)
def build_context(self, t: torch.Tensor,
self_cond_cfg_scale: Optional[torch.Tensor] = None) -> list:
B = t.shape[0]
prefix_tokens = []
time_emb = self.t_embedder(t) # (B, hidden)
prefix_tokens.append(
self.t_emb_tokens.expand(B, -1, -1) + time_emb.unsqueeze(1)
)
if self_cond_cfg_scale is not None and self.num_self_cond_cfg_tokens > 0:
sc_emb = self.self_cond_cfg_embedder(self_cond_cfg_scale)
prefix_tokens.append(
self.self_cond_cfg_tokens.expand(B, -1, -1) + sc_emb.unsqueeze(1)
)
return prefix_tokens
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
deterministic: bool = True,
self_cond_cfg_scale: Optional[torch.Tensor] = None,
decoder_step_active: Optional[bool] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""x: (N, S, C) or (N, S, 2C) with self-cond. t: (N,). attention_mask: (N, S), 1=valid."""
B = x.shape[0]
# Self-conditioning: input is [z, x_pred] when 2x encoder dim
with torch.amp.autocast('cuda', enabled=False):
if x.shape[-1] == 2 * self.text_encoder_dim:
x = self.self_cond_proj(x.float())
x = self.text_proj(x.float())
context_prefix_tokens = self.build_context(t, self_cond_cfg_scale)
# Prepend learnable model-mode tokens (gated by decoder_step_active).
# decoder_step_active may be None / Python bool / (B,) tensor — the last
# form supports per-example branching at training time.
model_mode_offset = 0
if self.num_model_mode_tokens > 0:
mode_tokens = self.mode_tokens.expand(B, -1, -1)
if decoder_step_active is None:
active_gate = 0.0
elif isinstance(decoder_step_active, torch.Tensor) and decoder_step_active.dim() > 0:
active_gate = decoder_step_active.to(mode_tokens.dtype).view(-1, 1, 1)
else:
active_gate = float(decoder_step_active)
mode_tokens = mode_tokens * active_gate
x = torch.cat([mode_tokens, x], dim=1)
model_mode_offset = self.num_model_mode_tokens
if attention_mask is not None:
mode_mask = torch.ones((B, self.num_model_mode_tokens),
dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([mode_mask, attention_mask], dim=1)
prefix_len = 0
if context_prefix_tokens:
prefix_tokens = torch.cat(context_prefix_tokens, dim=1)
prefix_len = prefix_tokens.shape[1]
x = torch.cat([prefix_tokens, x], dim=1)
if attention_mask is not None:
prefix_mask = torch.ones((B, prefix_len),
dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
use_checkpoint = self.gradient_checkpointing and self.training and torch.is_grad_enabled()
for block in self.blocks:
if use_checkpoint:
def _block_forward(hidden: torch.Tensor, block: ELFBlock = block) -> torch.Tensor:
return block(hidden, rope_fn=self.feat_rope, attention_mask=attention_mask,
deterministic=deterministic)
x = checkpoint(_block_forward, x, use_reentrant=False)
else:
x = block(x, rope_fn=self.feat_rope, attention_mask=attention_mask,
deterministic=deterministic)
x = x[:, prefix_len + model_mode_offset:]
# Factored decoder unembedding: hidden -> text_encoder_dim -> vocab
with torch.amp.autocast('cuda', enabled=False):
decoder_logits = None
if decoder_step_active is not None:
x_f32 = x.float()
hidden = F.gelu(x_f32 @ self.proj_kernel + self.proj_bias, approximate="tanh")
decoder_logits = hidden @ self.unembed_kernel + self.unembed_bias
output = self.final_layer(x.float())
return output, decoder_logits
# Model factory functions
def ELF_B(**kwargs): return ELF(depth=12, hidden_size=768, num_heads=12, **kwargs)
def ELF_M(**kwargs): return ELF(depth=24, hidden_size=1056, num_heads=16, **kwargs)
def ELF_L(**kwargs): return ELF(depth=32, hidden_size=1280, num_heads=16, **kwargs)
ELF_models = {
'ELF-B': ELF_B, 'ELF-M': ELF_M, 'ELF-L': ELF_L,
}