Spaces:
Sleeping
Sleeping
| """ELF transformer model.""" | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.checkpoint import checkpoint | |
| from modules.layers import ( | |
| Attention, BottleneckTextProj, FinalLayer, RMSNorm, SwiGLUFFN, | |
| TextRotaryEmbeddingFast, TimestepEmbedder, | |
| DEFAULT_KERNEL_INIT, DEFAULT_BIAS_INIT, NORMAL_INIT_002, | |
| _make_linear, | |
| ) | |
| class ELFBlock(nn.Module): | |
| """ELF Transformer block.""" | |
| def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, | |
| attn_drop: float = 0.0, proj_drop: float = 0.0): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.mlp_ratio = mlp_ratio | |
| self.attn_drop = attn_drop | |
| self.proj_drop = proj_drop | |
| mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
| self.norm1 = RMSNorm(hidden_size, eps=1e-6) | |
| self.attn = Attention( | |
| hidden_size, num_heads, qkv_bias=True, qk_norm=True, | |
| attn_drop=attn_drop, proj_drop=proj_drop, | |
| ) | |
| self.norm2 = RMSNorm(hidden_size, eps=1e-6) | |
| self.mlp = SwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop) | |
| def forward(self, x: torch.Tensor, rope_fn: Optional[nn.Module] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| deterministic: bool = True) -> torch.Tensor: | |
| x_normed = self.norm1(x) | |
| attn_out = self.attn(x_normed, rope_fn, attention_mask=attention_mask, | |
| deterministic=deterministic) | |
| x = x + attn_out | |
| x_normed = self.norm2(x) | |
| mlp_out = self.mlp(x_normed, deterministic=deterministic) | |
| x = x + mlp_out | |
| return x | |
| class ELF(nn.Module): | |
| """Text ELF Transformer.""" | |
| def __init__( | |
| self, | |
| text_encoder_dim: int, | |
| max_length: int, | |
| hidden_size: int = 1024, | |
| depth: int = 24, | |
| num_heads: int = 16, | |
| mlp_ratio: float = 4.0, | |
| attn_drop: float = 0.0, | |
| proj_drop: float = 0.0, | |
| bottleneck_dim: int = 128, | |
| num_time_tokens: int = 4, | |
| num_self_cond_cfg_tokens: int = 4, | |
| num_model_mode_tokens: int = 0, | |
| vocab_size: int = 0, | |
| gradient_checkpointing: bool = False, | |
| ): | |
| super().__init__() | |
| self.text_encoder_dim = text_encoder_dim | |
| self.max_length = max_length | |
| self.hidden_size = hidden_size | |
| self.depth = depth | |
| self.num_heads = num_heads | |
| self.mlp_ratio = mlp_ratio | |
| self.attn_drop = attn_drop | |
| self.proj_drop = proj_drop | |
| self.bottleneck_dim = bottleneck_dim | |
| self.num_time_tokens = num_time_tokens | |
| self.num_self_cond_cfg_tokens = num_self_cond_cfg_tokens | |
| self.num_model_mode_tokens = num_model_mode_tokens | |
| self.vocab_size = vocab_size | |
| self.gradient_checkpointing = gradient_checkpointing | |
| # Self-conditioning input projection (only used when input is [z, x_pred]). | |
| self.self_cond_proj = _make_linear(2 * text_encoder_dim, text_encoder_dim, bias=True) | |
| # Text bottleneck projection. | |
| self.text_proj = BottleneckTextProj(text_encoder_dim, hidden_size, bottleneck_dim) | |
| # Time / SC-CFG embedders + learned prefix tokens. | |
| if num_time_tokens <= 0: | |
| raise ValueError("num_time_tokens must be positive for prefix time conditioning") | |
| self.t_embedder = TimestepEmbedder(hidden_size) | |
| self.t_emb_tokens = nn.Parameter(torch.empty(1, num_time_tokens, hidden_size)) | |
| NORMAL_INIT_002(self.t_emb_tokens) | |
| if num_self_cond_cfg_tokens > 0: | |
| self.self_cond_cfg_embedder = TimestepEmbedder(hidden_size) | |
| self.self_cond_cfg_tokens = nn.Parameter(torch.empty(1, num_self_cond_cfg_tokens, hidden_size)) | |
| NORMAL_INIT_002(self.self_cond_cfg_tokens) | |
| if num_model_mode_tokens > 0: | |
| self.mode_tokens = nn.Parameter(torch.empty(1, num_model_mode_tokens, hidden_size)) | |
| NORMAL_INIT_002(self.mode_tokens) | |
| head_dim = hidden_size // num_heads | |
| prefix_total = num_model_mode_tokens + num_time_tokens | |
| if num_self_cond_cfg_tokens > 0: | |
| prefix_total += num_self_cond_cfg_tokens | |
| self.feat_rope = TextRotaryEmbeddingFast( | |
| dim=head_dim, pt_seq_len=max_length, num_empty_token=prefix_total, | |
| ) | |
| self.blocks = nn.ModuleList() | |
| q1, q3 = depth // 4, depth // 4 * 3 | |
| for i in range(depth): | |
| in_drop_range = q3 > i >= q1 | |
| self.blocks.append(ELFBlock( | |
| hidden_size, num_heads, mlp_ratio=mlp_ratio, | |
| attn_drop=attn_drop if in_drop_range else 0.0, | |
| proj_drop=proj_drop if in_drop_range else 0.0, | |
| )) | |
| # Final flow-matching output head. | |
| self.final_layer = FinalLayer(hidden_size, patch_size=1, out_channels=text_encoder_dim) | |
| # Factored decoder unembedding: hidden -> text_encoder_dim -> vocab. | |
| bn = text_encoder_dim | |
| self.proj_kernel = nn.Parameter(torch.empty(hidden_size, bn)) | |
| self.proj_bias = nn.Parameter(torch.empty(bn)) | |
| self.unembed_kernel = nn.Parameter(torch.empty(bn, vocab_size)) | |
| self.unembed_bias = nn.Parameter(torch.empty(vocab_size)) | |
| DEFAULT_KERNEL_INIT(self.proj_kernel) | |
| DEFAULT_BIAS_INIT(self.proj_bias) | |
| DEFAULT_KERNEL_INIT(self.unembed_kernel) | |
| DEFAULT_BIAS_INIT(self.unembed_bias) | |
| def build_context(self, t: torch.Tensor, | |
| self_cond_cfg_scale: Optional[torch.Tensor] = None) -> list: | |
| B = t.shape[0] | |
| prefix_tokens = [] | |
| time_emb = self.t_embedder(t) # (B, hidden) | |
| prefix_tokens.append( | |
| self.t_emb_tokens.expand(B, -1, -1) + time_emb.unsqueeze(1) | |
| ) | |
| if self_cond_cfg_scale is not None and self.num_self_cond_cfg_tokens > 0: | |
| sc_emb = self.self_cond_cfg_embedder(self_cond_cfg_scale) | |
| prefix_tokens.append( | |
| self.self_cond_cfg_tokens.expand(B, -1, -1) + sc_emb.unsqueeze(1) | |
| ) | |
| return prefix_tokens | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| deterministic: bool = True, | |
| self_cond_cfg_scale: Optional[torch.Tensor] = None, | |
| decoder_step_active: Optional[bool] = None, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """x: (N, S, C) or (N, S, 2C) with self-cond. t: (N,). attention_mask: (N, S), 1=valid.""" | |
| B = x.shape[0] | |
| # Self-conditioning: input is [z, x_pred] when 2x encoder dim | |
| with torch.amp.autocast('cuda', enabled=False): | |
| if x.shape[-1] == 2 * self.text_encoder_dim: | |
| x = self.self_cond_proj(x.float()) | |
| x = self.text_proj(x.float()) | |
| context_prefix_tokens = self.build_context(t, self_cond_cfg_scale) | |
| # Prepend learnable model-mode tokens (gated by decoder_step_active). | |
| # decoder_step_active may be None / Python bool / (B,) tensor — the last | |
| # form supports per-example branching at training time. | |
| model_mode_offset = 0 | |
| if self.num_model_mode_tokens > 0: | |
| mode_tokens = self.mode_tokens.expand(B, -1, -1) | |
| if decoder_step_active is None: | |
| active_gate = 0.0 | |
| elif isinstance(decoder_step_active, torch.Tensor) and decoder_step_active.dim() > 0: | |
| active_gate = decoder_step_active.to(mode_tokens.dtype).view(-1, 1, 1) | |
| else: | |
| active_gate = float(decoder_step_active) | |
| mode_tokens = mode_tokens * active_gate | |
| x = torch.cat([mode_tokens, x], dim=1) | |
| model_mode_offset = self.num_model_mode_tokens | |
| if attention_mask is not None: | |
| mode_mask = torch.ones((B, self.num_model_mode_tokens), | |
| dtype=attention_mask.dtype, device=attention_mask.device) | |
| attention_mask = torch.cat([mode_mask, attention_mask], dim=1) | |
| prefix_len = 0 | |
| if context_prefix_tokens: | |
| prefix_tokens = torch.cat(context_prefix_tokens, dim=1) | |
| prefix_len = prefix_tokens.shape[1] | |
| x = torch.cat([prefix_tokens, x], dim=1) | |
| if attention_mask is not None: | |
| prefix_mask = torch.ones((B, prefix_len), | |
| dtype=attention_mask.dtype, device=attention_mask.device) | |
| attention_mask = torch.cat([prefix_mask, attention_mask], dim=1) | |
| use_checkpoint = self.gradient_checkpointing and self.training and torch.is_grad_enabled() | |
| for block in self.blocks: | |
| if use_checkpoint: | |
| def _block_forward(hidden: torch.Tensor, block: ELFBlock = block) -> torch.Tensor: | |
| return block(hidden, rope_fn=self.feat_rope, attention_mask=attention_mask, | |
| deterministic=deterministic) | |
| x = checkpoint(_block_forward, x, use_reentrant=False) | |
| else: | |
| x = block(x, rope_fn=self.feat_rope, attention_mask=attention_mask, | |
| deterministic=deterministic) | |
| x = x[:, prefix_len + model_mode_offset:] | |
| # Factored decoder unembedding: hidden -> text_encoder_dim -> vocab | |
| with torch.amp.autocast('cuda', enabled=False): | |
| decoder_logits = None | |
| if decoder_step_active is not None: | |
| x_f32 = x.float() | |
| hidden = F.gelu(x_f32 @ self.proj_kernel + self.proj_bias, approximate="tanh") | |
| decoder_logits = hidden @ self.unembed_kernel + self.unembed_bias | |
| output = self.final_layer(x.float()) | |
| return output, decoder_logits | |
| # Model factory functions | |
| def ELF_B(**kwargs): return ELF(depth=12, hidden_size=768, num_heads=12, **kwargs) | |
| def ELF_M(**kwargs): return ELF(depth=24, hidden_size=1056, num_heads=16, **kwargs) | |
| def ELF_L(**kwargs): return ELF(depth=32, hidden_size=1280, num_heads=16, **kwargs) | |
| ELF_models = { | |
| 'ELF-B': ELF_B, 'ELF-M': ELF_M, 'ELF-L': ELF_L, | |
| } | |