Spaces:
Build error
Build error
| import torch | |
| from torch import nn | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from typing import * | |
| from math import ceil | |
| class AttentionBackend(Enum): | |
| Naive = 0 | |
| FlashAttentionCuda = 1 | |
| FlashAttentionTriton = 2 | |
| global_config = { | |
| 'attn_backend': AttentionBackend.Naive | |
| } | |
| class TransformerConfig: | |
| vocab_size: int = -1, | |
| num_layers: int = -1, | |
| num_heads: int = -1, | |
| hidden_size: int = -1, | |
| max_seq_len: int = -1, | |
| root_model: 'ToyTransformer' = None | |
| device: torch.device = torch.device('cpu') | |
| dtype: torch.dtype = torch.float32 | |
| def expand_attn_mask(custom_attn_mask: torch.Tensor): | |
| B, T = custom_attn_mask.shape | |
| mask = custom_attn_mask.unsqueeze(1).repeat((1, T, 1)) | |
| seq_index_mask = (mask == custom_attn_mask[:, torch.arange(T)].view(B, T, 1)) | |
| return seq_index_mask & (torch.tril(mask) > 0) | |
| # expand attn mask to cu_seqlens for flash attn | |
| def expand_attn_mask_to_seq_lengths(attn_mask: torch.Tensor): | |
| attn_mask = attn_mask.to('cpu') | |
| seq_len = attn_mask.shape[0] * attn_mask.shape[1] | |
| disjoint_point = torch.cat([torch.tensor([[True]] * attn_mask.shape[0]), attn_mask[:, 1:] != attn_mask[:, :-1]], dim=1) | |
| return torch.cat([torch.nonzero(disjoint_point.view((-1,))), torch.tensor([[seq_len]])]).to(dtype=torch.int32) | |
| # naive RoPE implementation following https://arxiv.org/pdf/2104.09864.pdf | |
| def get_rope_cache_slow(seq_len: int, dim: int, theta: int, device: torch.device, dtype: torch.dtype): | |
| assert dim % 2 == 0 | |
| freqs = theta ** (-2 * torch.arange(0, dim // 2, 1.) / dim) | |
| freqs = torch.repeat_interleave(freqs, 2) | |
| v1 = torch.cos(torch.arange(seq_len, dtype=torch.float).view((seq_len, 1)) * freqs) | |
| v2 = torch.sin(torch.arange(seq_len, dtype=torch.float).view((seq_len, 1)) * freqs) | |
| v2 = v2 * torch.tensor([1, -1] * (dim // 2)) | |
| indices = torch.tensor([j for i in range(0, dim, 2) for j in (i + 1, i)]) | |
| return v1.to(device, dtype=dtype), v2.to(device, dtype=dtype), indices.to(device) | |
| def apply_rope_slow(x, rope_cache, positions: Optional[torch.Tensor] = None): | |
| v1, v2, indices = rope_cache | |
| seq_len, dim = x.shape[1:] | |
| if positions is None: | |
| v1 = v1[:seq_len, :] | |
| v2 = v2[:seq_len, :] | |
| else: | |
| v1 = v1[positions, torch.arange(dim)].view((-1, dim)) | |
| v2 = v2[positions, torch.arange(dim)].view((-1, dim)) | |
| applied_x = x * v1 + (x * v2)[:, :, indices] | |
| return applied_x | |
| # Optimized RoPE implementation adapted from https://github.com/facebookresearch/llama/blob/main/llama/model.py | |
| def get_rope_cache_fast(seq_len: int, dim: int, theta: int, device: torch.device, dtype: torch.dtype): | |
| freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))) | |
| t = torch.arange(seq_len, device=freqs.device) | |
| freqs = torch.outer(t, freqs).float() | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | |
| return freqs_cis.to(device) | |
| def apply_rope_fast(x, rope_cache, positions: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) | |
| if positions is None and x.shape[1] < rope_cache.shape[0]: | |
| freqs_cis = rope_cache[:x.shape[1], :] | |
| elif positions is not None: | |
| freqs_cis = rope_cache[positions, :] | |
| else: | |
| freqs_cis = rope_cache | |
| freqs_cis = freqs_cis.view([d if i == 1 or i == x_.ndim - 1 else 1 for i, d in enumerate(x_.shape)]) | |
| applied_x = torch.view_as_real(x_ * freqs_cis).flatten(2) | |
| return applied_x.type_as(x) | |
| # RMSNorm implementation following https://arxiv.org/pdf/1910.07467.pdf | |
| class RMSNorm(nn.Module): | |
| def __init__(self, hidden_size, dtype, eps=1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype)) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor): | |
| x_ = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| return self.weight * x_ | |
| class AttentionHead(nn.Module): | |
| def __init__(self, config: TransformerConfig): | |
| super().__init__() | |
| self.config = config | |
| self.head_size = config.hidden_size // config.num_heads | |
| self.dtype = config.dtype | |
| self.q_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype) | |
| self.k_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype) | |
| self.v_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype) | |
| def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor], | |
| kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
| q = self.q_proj(x) | |
| k = self.k_proj(x) | |
| v = self.v_proj(x) | |
| # if global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton: | |
| # padding the position indices for alignment | |
| # positions = torch.tensor([kv_cache[0].shape[1]] * q.shape[1]).to(q.device) if kv_cache is not None else torch.arange(0, x.shape[1], 1).to(q.device) | |
| positions = torch.tensor([kv_cache[0].shape[1]]).to(q.device) if kv_cache is not None else None | |
| q = apply_rope_fast(q, self.config.root_model.rope_cache, positions) | |
| k = apply_rope_fast(k, self.config.root_model.rope_cache, positions) | |
| if kv_cache is not None: | |
| k = torch.concat([kv_cache[0], k], dim=1) | |
| v = torch.concat([kv_cache[1], v], dim=1) | |
| if global_config['attn_backend'] == AttentionBackend.FlashAttentionCuda: | |
| q, k, v, = q.unsqueeze(2), k.unsqueeze(2), v.unsqueeze(2) | |
| attn_result = flash_attn_func(q, k, v, causal=True) | |
| q, k, v, attn_result = q.squeeze(2), k.squeeze(2), v.squeeze(2), attn_result.squeeze(2) | |
| elif global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton: | |
| q, k, v, = q.unsqueeze(2), k.unsqueeze(2), v.unsqueeze(2) | |
| attn_result = flash_attn_func_triton(q, k, v, attn_masked_bias.unsqueeze(1) if attn_masked_bias is not None else None, | |
| True if kv_cache is None else False) | |
| q, k, v, attn_result = q.squeeze(2), k.squeeze(2), v.squeeze(2), attn_result.squeeze(2) | |
| else: | |
| attn_score = (q @ k.permute(0, 2, 1) / (self.head_size ** 0.5)) + attn_masked_bias | |
| attn_result = torch.softmax(attn_score, dim=2) @ v | |
| return attn_result, [k, v] | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, config: TransformerConfig): | |
| super().__init__() | |
| self.config = config | |
| self.attn_heads = nn.ModuleList([AttentionHead(config) for _ in range(config.num_heads)]) | |
| self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.dtype) | |
| def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor], | |
| kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]: | |
| head_outputs = [head(x, attn_masked_bias, kv_cache[idx] if kv_cache is not None else None) for idx, head in | |
| enumerate(self.attn_heads)] | |
| return self.o_proj(torch.concat([o[0] for o in head_outputs], dim=2)), [o[1] for o in head_outputs] | |
| class DecoderLayer(nn.Module): | |
| def __init__(self, config: TransformerConfig): | |
| super().__init__() | |
| self.config = config | |
| self.mha = MultiHeadAttention(config) | |
| self.up_proj = nn.Linear(config.hidden_size, config.hidden_size * 4, dtype=config.dtype) | |
| self.down_proj = nn.Linear(config.hidden_size * 4, config.hidden_size, dtype=config.dtype) | |
| self.ln_mha = nn.LayerNorm(config.hidden_size, dtype=config.dtype) | |
| self.ln_ffn = nn.LayerNorm(config.hidden_size, dtype=config.dtype) | |
| self.act = nn.GELU() | |
| def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor], | |
| kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]: | |
| mha_output, new_kv_cache = self.mha(self.ln_mha(x), attn_masked_bias, kv_cache) | |
| mha_output = x + mha_output | |
| ffn_output = self.down_proj(self.act(self.up_proj(self.ln_ffn(mha_output)))) | |
| return mha_output + ffn_output, new_kv_cache | |
| class ToyTransformer(nn.Module): | |
| def __init__(self, vocab_size: int, num_layers: int, num_heads: int, hidden_size: int, max_seq_len: int, | |
| device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32): | |
| super().__init__() | |
| self.config = TransformerConfig(vocab_size, num_layers, num_heads, hidden_size, max_seq_len, self, device, | |
| dtype) | |
| self.sem_embed = nn.Embedding(vocab_size, hidden_size, dtype=dtype) | |
| self.rope_cache = get_rope_cache_fast(max_seq_len, hidden_size // num_heads, 10000, device, dtype) | |
| self.decoder_layers = nn.ModuleList([DecoderLayer(self.config) for _ in range(num_layers)]) | |
| self.lm_head = nn.Linear(hidden_size, vocab_size, dtype=dtype) | |
| self.to(device) | |
| def forward(self, seq: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| kv_cache: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[List[List[torch.Tensor]]]]: | |
| # sanity checks | |
| assert attn_mask is None or kv_cache is None # No support for attn_mask and kv_cache both enabled | |
| if kv_cache is not None: | |
| assert seq.shape[0] == 1, 'kv_cache is not supported for batch inference' | |
| # handle flash-attn triton alignment requirement (actually only needed for backward) | |
| seq_length = seq.shape[1] | |
| if kv_cache is None and global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and seq_length % 128 != 0: | |
| if attn_mask is None: # forcibly enable attn_mask due to padding | |
| attn_mask = torch.ones(seq.shape, device=self.device) | |
| pad_length = (ceil(seq_length / 128) * 128) - seq_length | |
| seq = nn.functional.pad(seq, (0, pad_length)) | |
| attn_mask = nn.functional.pad(attn_mask, (0, pad_length)) | |
| # handle attn_bias | |
| if global_config['attn_backend'] == AttentionBackend.FlashAttentionCuda: | |
| assert attn_mask is None, 'FlashAttn-Cuda does not support custom attn_mask' | |
| attn_masked_bias = None | |
| elif global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and attn_mask is None: | |
| attn_masked_bias = None | |
| elif attn_mask is not None: | |
| attn_masked_bias = expand_attn_mask(attn_mask) | |
| elif attn_mask is None and kv_cache is None: | |
| attn_masked_bias = expand_attn_mask(torch.ones(seq.shape, device=self.device)) | |
| elif kv_cache is not None: | |
| attn_masked_bias = torch.ones((1, seq.shape[1], seq.shape[1]), dtype=torch.bool, device=self.device) | |
| else: | |
| attn_masked_bias = None | |
| if attn_masked_bias is not None: | |
| mask_zero = torch.tensor(0, dtype=self.config.dtype) | |
| mask_val = torch.tensor(torch.finfo(self.config.dtype).min / 2, dtype=self.config.dtype) | |
| attn_masked_bias = torch.where(attn_masked_bias, mask_zero, mask_val).to(self.device) | |
| hidden = self.sem_embed(seq) | |
| new_kv_cache = [] | |
| for idx, decoder in enumerate(self.decoder_layers): | |
| hidden, layer_kv_cache = decoder(hidden, attn_masked_bias, kv_cache[idx] if kv_cache is not None else None) | |
| new_kv_cache.append(layer_kv_cache) | |
| logits = self.lm_head(hidden) | |
| # remove padding for flash-attn triton | |
| if kv_cache is None and global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and seq_length % 128 != 0: | |
| logits = logits[:, :seq_length, :] | |
| new_kv_cache = [[[cache[:, :seq_length, :] for cache in head] for head in layer] for layer in new_kv_cache] | |
| return logits, new_kv_cache | |
| def device(self): | |
| return next(self.parameters()).device | |