Spaces:
Running
Running
| # Adapted from https://github.com/meta-llama/llama3/blob/main/llama/model.py | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from ..util import get_logger | |
| from .adaln_zero import AdaLNZero | |
| logger = get_logger() | |
| try: | |
| from flash_attn import flash_attn_func, flash_attn_with_kvcache | |
| FLASH_ATTN_AVAILABLE = True | |
| except ImportError: | |
| FLASH_ATTN_AVAILABLE = False | |
| logger.warning( | |
| "FlashAttention is not installed. Falling back to PyTorch SDPA implementation. There is no warranty that the model will work correctly." | |
| ) | |
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | |
| t = torch.arange(end, device=freqs.device, dtype=torch.float32) | |
| freqs = torch.outer(t, freqs) | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
| return freqs_cis | |
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | |
| ndim = x.ndim | |
| assert 0 <= 1 < ndim | |
| assert freqs_cis.shape == (x.shape[1], x.shape[-1]) | |
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | |
| return freqs_cis.view(*shape) | |
| def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: | |
| x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) | |
| freqs_cis = reshape_for_broadcast(freqs_cis, x_) | |
| x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) | |
| return x_out.type_as(x) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| dropout: float, | |
| window_size: int | None, | |
| qkv_bias: bool = False, | |
| proj_bias: bool = False, | |
| use_flash_attention: bool = False, | |
| causal: bool = False, | |
| ): | |
| super().__init__() | |
| self.n_heads = n_heads | |
| self.head_dim = dim // n_heads | |
| self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=qkv_bias) | |
| self.wk = nn.Linear(dim, n_heads * self.head_dim, bias=qkv_bias) | |
| self.wv = nn.Linear(dim, n_heads * self.head_dim, bias=qkv_bias) | |
| self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=proj_bias) | |
| self.scale = self.head_dim**-0.5 | |
| self.dropout = dropout | |
| # Enable local attention if window_size is specified | |
| self.use_local_attention = window_size is not None | |
| if self.use_local_attention: | |
| assert window_size % 2 == 1, "Window size must be odd for local attention." | |
| self.window_per_side = window_size // 2 | |
| self.use_flash_attention = use_flash_attention | |
| self.causal = causal | |
| def create_mask( | |
| self, bsz: int, seqlen: int, mask: torch.Tensor | None, device: torch.device | |
| ) -> torch.Tensor | None: | |
| """Create attention mask combining provided mask and local attention constraints""" | |
| if not self.use_local_attention and mask is None: | |
| return None | |
| # Start with all positions allowed | |
| attn_mask = torch.ones((seqlen, seqlen), dtype=torch.bool, device=device) | |
| if self.causal: | |
| # Causal mask: no future positions allowed | |
| attn_mask = torch.tril(attn_mask) | |
| # Apply local attention constraints | |
| if self.use_local_attention: | |
| attn_mask = torch.triu(attn_mask, diagonal=-self.window_per_side) | |
| attn_mask = torch.tril(attn_mask, diagonal=self.window_per_side) | |
| # Expand mask to batch size | |
| attn_mask = attn_mask.unsqueeze(0).expand(bsz, -1, -1) | |
| # Apply global mask if provided | |
| if mask is not None: | |
| assert mask.shape[-1] == seqlen and mask.shape[-2] == seqlen, ( | |
| "Mask must be square and match sequence length." | |
| ) | |
| # Ensure mask has correct batch dimensions | |
| if mask.dim() == 2: | |
| mask = mask.unsqueeze(0).expand(bsz, -1, -1) | |
| attn_mask = attn_mask & mask | |
| # Expand to head dimension | |
| attn_mask = attn_mask.unsqueeze(1).expand(-1, self.n_heads, -1, -1) | |
| return attn_mask | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| freqs_cis: torch.Tensor | None, | |
| mask: torch.Tensor | None, | |
| return_kv: bool = False, | |
| ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: | |
| """Forward pass for multi-head attention. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape (bsz, seqlen, dim). | |
| freqs_cis (torch.Tensor, optional): Precomputed rotary frequencies. | |
| mask (torch.Tensor, optional): Attention mask. | |
| return_kv (bool): Whether to return KV pairs for caching. | |
| Returns: | |
| output (torch.Tensor): Output tensor of shape (bsz, seqlen, dim). | |
| new_kv (tuple, optional): KV pairs if return_kv is True. | |
| """ | |
| bsz, seqlen, _ = x.shape | |
| xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | |
| xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| # Apply rotary embeddings if provided | |
| if freqs_cis is not None: | |
| xq = apply_rotary_emb(xq, freqs_cis=freqs_cis[:seqlen]) | |
| xk = apply_rotary_emb(xk, freqs_cis=freqs_cis[:seqlen]) | |
| if self.use_flash_attention and FLASH_ATTN_AVAILABLE: | |
| assert mask is None, "Flash attention does not support arbitrary masking." | |
| # Flash Attention | |
| window_size = (self.window_per_side, self.window_per_side) if self.use_local_attention else (-1, -1) | |
| output = flash_attn_func( | |
| xq, # (bsz, seqlen, n_heads, head_dim) | |
| xk, # (bsz, seqlen, n_heads, head_dim) | |
| xv, # (bsz, seqlen, n_heads, head_dim) | |
| dropout_p=(self.dropout if self.training else 0.0), | |
| softmax_scale=self.scale, | |
| window_size=window_size, | |
| causal=self.causal, | |
| ) # (bsz, seqlen, n_heads, head_dim) | |
| else: | |
| attn_mask = self.create_mask(bsz, seqlen, mask, x.device) | |
| # SDPA Attention | |
| output = F.scaled_dot_product_attention( | |
| xq.transpose(1, 2), # (bsz, n_heads, seqlen, head_dim) | |
| xk.transpose(1, 2), # (bsz, n_heads, seqlen, head_dim) | |
| xv.transpose(1, 2), # (bsz, n_heads, seqlen, head_dim) | |
| attn_mask=attn_mask, # (bsz, n_heads, seqlen, seqlen) boolean mask | |
| dropout_p=self.dropout, | |
| scale=self.scale, | |
| ).transpose(1, 2) # (bsz, seqlen, n_heads, head_dim) | |
| output = output.contiguous().view(bsz, seqlen, -1) | |
| output = self.wo(output) | |
| if return_kv: | |
| return output, (xk, xv) | |
| return output | |
| def forward_with_cache( | |
| self, | |
| x: torch.Tensor, | |
| kv_cache: tuple[torch.Tensor, torch.Tensor], | |
| freqs_cis: torch.Tensor, | |
| start_pos: int, | |
| ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: | |
| """ | |
| Forward pass with KV cache for efficient inference. Only used for inference. | |
| Args: | |
| x (torch.Tensor): Input tensor for the current step. Shape: (bsz, 1, dim) | |
| kv_cache: A tuple of (key_cache, value_cache) from previous steps. | |
| start_pos (int): The starting position of the new token in the sequence. | |
| freqs_cis (torch.Tensor): Precomputed rotary frequencies. | |
| Returns: | |
| output (torch.Tensor): Output tensor after attention. Shape: (bsz, 1, dim) | |
| new_kv (tuple): Updated KV cache including the new key and value. | |
| """ | |
| bsz, seqlen, _ = x.shape | |
| assert seqlen == 1, "KV cache method is designed for single-token generation." | |
| xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | |
| xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| # Apply rotary embeddings using the correct positional slice | |
| xq = apply_rotary_emb(xq, freqs_cis=freqs_cis[start_pos : start_pos + seqlen]) | |
| xk = apply_rotary_emb(xk, freqs_cis=freqs_cis[start_pos : start_pos + seqlen]) | |
| # Update the KV cache | |
| k_cache, v_cache = kv_cache | |
| new_kv = (xk, xv) | |
| xk = torch.cat([k_cache, xk], dim=1) | |
| xv = torch.cat([v_cache, xv], dim=1) | |
| # For single token generation, causal mask is implicitly handled. | |
| # We attend to all keys (prefix + previous tokens). | |
| if self.use_flash_attention and FLASH_ATTN_AVAILABLE: | |
| # Flash Attention | |
| output = flash_attn_with_kvcache( | |
| xq, # (bsz, 1, n_heads, head_dim) | |
| xk, # (bsz, 1+kv_len, n_heads, head_dim) | |
| xv, # (bsz, 1+kv_len, n_heads, head_dim) | |
| softmax_scale=self.scale, | |
| ) # (bsz, 1, n_heads, head_dim) | |
| else: | |
| # SDPA Attention | |
| output = F.scaled_dot_product_attention( | |
| xq.transpose(1, 2), # (bsz, n_heads, 1, head_dim) | |
| xk.transpose(1, 2), # (bsz, n_heads, 1+kv_len, head_dim) | |
| xv.transpose(1, 2), # (bsz, n_heads, 1+kv_len, head_dim) | |
| scale=self.scale, | |
| ).transpose(1, 2) # (bsz, 1, n_heads, head_dim) | |
| output = output.contiguous().view(bsz, seqlen, -1) | |
| return self.wo(output), new_kv | |
| class FeedForward(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| hidden_dim: int, | |
| multiple_of: int, | |
| ffn_dim_multiplier: float | None, | |
| ): | |
| super().__init__() | |
| hidden_dim = int(2 * hidden_dim / 3) | |
| # custom dim factor multiplier | |
| if ffn_dim_multiplier is not None: | |
| hidden_dim = int(ffn_dim_multiplier * hidden_dim) | |
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) | |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) | |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) | |
| def forward(self, x): | |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| class TransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| qkv_bias: bool, | |
| proj_bias: bool, | |
| window_size: int | None, | |
| multiple_of: int, | |
| ffn_dim_multiplier: float | None, | |
| dropout: float, | |
| norm_eps: float, | |
| adanorm_condition_dim: int | None = None, | |
| use_flash_attention: bool = False, | |
| use_adaln_zero: bool = False, | |
| causal: bool = False, | |
| ): | |
| super().__init__() | |
| self.attention = Attention( | |
| dim=dim, | |
| n_heads=n_heads, | |
| dropout=dropout, | |
| window_size=window_size, | |
| use_flash_attention=use_flash_attention, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| causal=causal, | |
| ) | |
| self.feed_forward = FeedForward( | |
| dim=dim, | |
| hidden_dim=4 * dim, | |
| multiple_of=multiple_of, | |
| ffn_dim_multiplier=ffn_dim_multiplier, | |
| ) | |
| # Choose between AdaLNZero and regular LayerNorm | |
| self.use_adaln_zero = use_adaln_zero | |
| if self.use_adaln_zero: | |
| assert adanorm_condition_dim is not None, "condition_dim must be provided when using AdaLNZero" | |
| self.attention_norm = AdaLNZero(dim, adanorm_condition_dim, eps=norm_eps, return_gate=True) | |
| self.ffn_norm = AdaLNZero(dim, adanorm_condition_dim, eps=norm_eps, return_gate=True) | |
| else: | |
| self.attention_norm = nn.LayerNorm(dim, eps=norm_eps) | |
| self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| freqs_cis: torch.Tensor | None, | |
| mask: torch.Tensor | None, | |
| condition: torch.Tensor | None = None, | |
| return_kv: bool = False, | |
| kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None, | |
| start_pos: int | None = None, | |
| ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: | |
| """ | |
| Forward pass for a single Transformer block. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape (bsz, seqlen, dim). | |
| freqs_cis (torch.Tensor, optional): Precomputed rotary frequencies. | |
| mask (torch.Tensor, optional): Attention mask. | |
| condition (torch.Tensor, optional): Conditioning tensor for AdaLNZero. | |
| return_kv (bool): Whether to return KV pairs for caching. | |
| kv_cache (tuple, optional): KV cache for efficient inference. | |
| start_pos (int, optional): Starting position for KV cache. | |
| Returns: | |
| out (torch.Tensor): Output tensor of shape (bsz, seqlen, dim). | |
| new_kv (tuple, optional): New KV pairs if return_kv is True or kv_cache is provided. | |
| """ | |
| # Apply normalization | |
| if self.use_adaln_zero: | |
| assert condition is not None, "condition must be provided when using AdaLNZero" | |
| attn_normed, attn_gate = self.attention_norm(x, condition=condition) | |
| else: | |
| attn_normed = self.attention_norm(x) | |
| # Forward attention with KV cache if provided | |
| new_kv = None | |
| if kv_cache is not None and start_pos is not None: | |
| # Use KV cache for efficient inference | |
| attn_out, new_kv = self.attention.forward_with_cache(attn_normed, kv_cache, freqs_cis, start_pos) | |
| elif return_kv: | |
| # Return KV pairs for caching | |
| attn_out, new_kv = self.attention(attn_normed, freqs_cis, mask, return_kv=True) | |
| else: | |
| attn_out = self.attention(attn_normed, freqs_cis, mask) | |
| # Apply gating for attention if using AdaLNZero | |
| if self.use_adaln_zero: | |
| h = x + attn_gate * attn_out # residual + gate * x | |
| else: | |
| h = x + attn_out | |
| # Apply normalization for feedforward | |
| if self.use_adaln_zero: | |
| ffn_normed, ffn_gate = self.ffn_norm(h, condition=condition) | |
| else: | |
| ffn_normed = self.ffn_norm(h) | |
| ffn_out = self.feed_forward(ffn_normed) | |
| # Apply gating for feedforward if using AdaLNZero | |
| if self.use_adaln_zero: | |
| out = h + ffn_gate * ffn_out # residual + gate * x | |
| else: | |
| out = h + ffn_out | |
| # If using KV cache, return the new KV pairs | |
| if new_kv is not None: | |
| return out, new_kv | |
| return out | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int = 4096, | |
| n_layers: int = 32, | |
| n_heads: int = 32, | |
| qkv_bias: bool = False, | |
| proj_bias: bool = False, | |
| window_size: int | None = None, | |
| multiple_of: int = 256, | |
| ffn_dim_multiplier: float | None = None, | |
| dropout: float = 0.1, | |
| norm_eps: float = 1e-5, | |
| use_rope: bool = True, | |
| rope_theta: float = 500000.0, | |
| max_seq_len: int = 2048, | |
| input_dim: int | None = None, | |
| output_dim: int | None = None, | |
| adanorm_condition_dim: int | None = None, | |
| use_flash_attention: bool = False, | |
| use_adaln_zero: bool = False, | |
| use_xavier_init: bool = True, | |
| causal: bool = False, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.n_heads = n_heads | |
| self.rope_theta = rope_theta | |
| self.use_adaln_zero = use_adaln_zero | |
| self.layers = nn.ModuleList() | |
| for layer_id in range(n_layers): | |
| self.layers.append( | |
| TransformerBlock( | |
| dim=dim, | |
| n_heads=n_heads, | |
| window_size=window_size, | |
| multiple_of=multiple_of, | |
| ffn_dim_multiplier=ffn_dim_multiplier, | |
| dropout=dropout, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| norm_eps=norm_eps, | |
| adanorm_condition_dim=adanorm_condition_dim, | |
| use_flash_attention=use_flash_attention, | |
| use_adaln_zero=use_adaln_zero, | |
| causal=causal, | |
| ) | |
| ) | |
| # Choose between AdaLNZero (without gate) and regular LayerNorm for final norm | |
| if self.use_adaln_zero: | |
| assert adanorm_condition_dim is not None, "condition_dim must be provided when using AdaLNZero" | |
| self.norm = AdaLNZero(dim, adanorm_condition_dim, eps=norm_eps, return_gate=False) | |
| else: | |
| self.norm = nn.LayerNorm(dim, eps=norm_eps) | |
| self.input_proj = nn.Linear(input_dim, dim) if input_dim is not None else nn.Identity() | |
| self.output_proj = nn.Linear(dim, output_dim) if output_dim is not None else nn.Identity() | |
| self.output_dim_ = output_dim if output_dim is not None else dim | |
| if use_rope: | |
| self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_seq_len * 2, rope_theta) | |
| logger.debug( | |
| f"Using RoPE with theta={rope_theta}, max_seq_len={max_seq_len}, " | |
| f"dim={dim}, n_heads={n_heads}, freqs_cis shape={self.freqs_cis.shape}" | |
| ) | |
| else: | |
| self.freqs_cis = None | |
| if window_size is not None: | |
| logger.debug(f"Using local attention with window size {window_size}") | |
| if self.use_adaln_zero: | |
| logger.debug(f"Using AdaLNZero conditioning with condition_dim={adanorm_condition_dim}") | |
| if use_flash_attention: | |
| logger.debug("Using Flash Attention for memory-efficient attention computation") | |
| if use_xavier_init: | |
| logger.debug("Using Xavier initialization for linear layers") | |
| self.apply(self._init_weights) | |
| self.apply(self._init_adaln_zero) | |
| def output_dim(self) -> int: | |
| return self.output_dim_ | |
| def _init_weights(self, module: nn.Module): | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_normal_(module.weight) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def _init_adaln_zero(self, module: nn.Module): | |
| if isinstance(module, AdaLNZero): | |
| # Initialize condition projection weights to zero | |
| nn.init.zeros_(module.condition_proj[1].weight) | |
| nn.init.zeros_(module.condition_proj[1].bias) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| mask: torch.Tensor | None = None, | |
| condition: torch.Tensor | None = None, | |
| return_kv: bool = False, | |
| kv_cache: list[tuple[torch.Tensor, torch.Tensor]] | None = None, | |
| start_pos: int | None = None, | |
| ) -> torch.Tensor | tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: | |
| """ | |
| Forward pass for the Transformer model. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape (bsz, seqlen, input_dim). | |
| mask (torch.Tensor, optional): Attention mask. | |
| condition (torch.Tensor, optional): Conditioning tensor for AdaLNZero. | |
| return_kv (bool): Whether to return KV pairs for caching. | |
| kv_cache (list, optional): List of KV caches for each layer for efficient inference. | |
| start_pos (int, optional): Starting position for KV cache. | |
| Returns: | |
| output (torch.Tensor): Output tensor of shape (bsz, seqlen, output_dim). | |
| new_kv_list (list, optional): List of new KV pairs for each layer if return_kv is True or kv_cache is provided. | |
| """ | |
| bsz, seqlen, _dim = x.shape | |
| if self.use_adaln_zero: | |
| assert condition is not None, "condition must be provided when using AdaLNZero" | |
| # Rotary embeddings | |
| if self.freqs_cis is not None: | |
| # Recompute freqs_cis if the sequence length or starting position exceeds the precomputed length | |
| expected_len = (start_pos + 1) if start_pos is not None else seqlen | |
| if expected_len > self.freqs_cis.shape[0]: | |
| logger.warning( | |
| f"Input sequence length {expected_len} exceeds precomputed RoPE length {self.freqs_cis.shape[0]}. Recomputing freqs_cis." | |
| ) | |
| self.freqs_cis = precompute_freqs_cis(self.dim // self.n_heads, expected_len * 4, self.rope_theta) | |
| self.freqs_cis = self.freqs_cis.to(x.device) | |
| freqs_cis = self.freqs_cis | |
| else: | |
| freqs_cis = None | |
| x = self.input_proj(x) | |
| new_kv_list = [] | |
| for i, layer in enumerate(self.layers): | |
| # Collect KV cache if provided | |
| if kv_cache is not None and start_pos is not None: | |
| x, new_kv = layer(x, freqs_cis, mask, condition, kv_cache=kv_cache[i], start_pos=start_pos) | |
| new_kv_list.append(new_kv) | |
| elif return_kv: | |
| x, new_kv = layer(x, freqs_cis, mask, condition, return_kv=True) | |
| new_kv_list.append(new_kv) | |
| else: | |
| x = layer(x, freqs_cis, mask, condition) | |
| # Apply final normalization | |
| if self.use_adaln_zero: | |
| x, _ = self.norm(x, condition=condition) # Final norm doesn't use gate | |
| else: | |
| x = self.norm(x) | |
| output = self.output_proj(x) | |
| # If using KV cache, return the new KV pairs | |
| if new_kv_list: | |
| return output, new_kv_list | |
| return output | |