Commit ·
43539ed
1
Parent(s): 8182ebe
Copy Python verbatim from vortex
Browse files- attention.py +999 -0
- cache.py +62 -0
- engine.py +597 -0
- generation.py +373 -0
- layers.py +272 -0
- model.py +937 -0
- positional_embeddings.py +114 -0
- sample.py +60 -0
- special_tokens_map.json +1 -0
- utils.py +251 -0
attention.py
ADDED
|
@@ -0,0 +1,999 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from einops import rearrange, repeat
|
| 7 |
+
|
| 8 |
+
from .utils import get_dim_for_local_rank
|
| 9 |
+
|
| 10 |
+
# Not bothering with ops right now
|
| 11 |
+
# try:
|
| 12 |
+
# from vortex.ops import (
|
| 13 |
+
# local_flash_attn_kvpacked_func,
|
| 14 |
+
# local_flash_attn_qkvpacked_func,
|
| 15 |
+
# local_flash_attn_varlen_kvpacked_func,
|
| 16 |
+
# local_flash_attn_varlen_qkvpacked_func,
|
| 17 |
+
# local_flash_attn_with_kvcache,
|
| 18 |
+
# )
|
| 19 |
+
# except ImportError:
|
| 20 |
+
# local_flash_attn_varlen_qkvpacked_func, local_flash_attn_varlen_kvpacked_func = (
|
| 21 |
+
# None,
|
| 22 |
+
# None,
|
| 23 |
+
# )
|
| 24 |
+
# local_flash_attn_qkvpacked_func, local_flash_attn_kvpacked_func = None, None
|
| 25 |
+
# local_flash_attn_with_kvcache = None
|
| 26 |
+
|
| 27 |
+
local_flash_attn_varlen_qkvpacked_func, local_flash_attn_varlen_kvpacked_func = (
|
| 28 |
+
None,
|
| 29 |
+
None,
|
| 30 |
+
)
|
| 31 |
+
local_flash_attn_qkvpacked_func, local_flash_attn_kvpacked_func = None, None
|
| 32 |
+
local_flash_attn_with_kvcache = None
|
| 33 |
+
|
| 34 |
+
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
| 35 |
+
|
| 36 |
+
from .rotary import RotaryEmbedding
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
|
| 40 |
+
def get_alibi_slopes(nheads):
|
| 41 |
+
def get_slopes_power_of_2(nheads):
|
| 42 |
+
start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
|
| 43 |
+
ratio = start
|
| 44 |
+
return [start * ratio**i for i in range(nheads)]
|
| 45 |
+
|
| 46 |
+
if math.log2(nheads).is_integer():
|
| 47 |
+
return get_slopes_power_of_2(nheads)
|
| 48 |
+
else:
|
| 49 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
|
| 50 |
+
return (
|
| 51 |
+
get_slopes_power_of_2(closest_power_of_2)
|
| 52 |
+
+ get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class FlashSelfAttention(nn.Module):
|
| 57 |
+
"""Implement the scaled dot product attention with softmax.
|
| 58 |
+
Arguments
|
| 59 |
+
---------
|
| 60 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 61 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 62 |
+
runtime)
|
| 63 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 64 |
+
(default: 0.0)
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
layer_number,
|
| 70 |
+
causal=False,
|
| 71 |
+
softmax_scale=None,
|
| 72 |
+
attention_dropout=0.0,
|
| 73 |
+
window_size=(-1, -1),
|
| 74 |
+
alibi_slopes=None,
|
| 75 |
+
deterministic=False,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
assert local_flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
| 79 |
+
assert local_flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
| 80 |
+
self.layer_number = layer_number
|
| 81 |
+
self.causal = causal
|
| 82 |
+
self.softmax_scale = softmax_scale
|
| 83 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 84 |
+
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
| 85 |
+
self.window_size = window_size
|
| 86 |
+
self.deterministic = deterministic
|
| 87 |
+
|
| 88 |
+
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
| 89 |
+
"""Implements the multihead softmax attention.
|
| 90 |
+
Arguments
|
| 91 |
+
---------
|
| 92 |
+
qkv: The tensor containing the query, key, and value.
|
| 93 |
+
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
|
| 94 |
+
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
|
| 95 |
+
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
|
| 96 |
+
causal: if passed, will override self.causal
|
| 97 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 98 |
+
of the sequences in the batch, used to index into qkv.
|
| 99 |
+
max_seqlen: int. Maximum sequence length in the batch.
|
| 100 |
+
Returns:
|
| 101 |
+
--------
|
| 102 |
+
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
|
| 103 |
+
else (B, S, H, D).
|
| 104 |
+
"""
|
| 105 |
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
| 106 |
+
assert qkv.is_cuda
|
| 107 |
+
|
| 108 |
+
causal = self.causal if causal is None else causal
|
| 109 |
+
unpadded = cu_seqlens is not None
|
| 110 |
+
if self.alibi_slopes is not None:
|
| 111 |
+
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
|
| 112 |
+
if unpadded:
|
| 113 |
+
assert cu_seqlens.dtype == torch.int32
|
| 114 |
+
assert max_seqlen is not None
|
| 115 |
+
assert isinstance(max_seqlen, int)
|
| 116 |
+
return local_flash_attn_varlen_qkvpacked_func(
|
| 117 |
+
qkv,
|
| 118 |
+
cu_seqlens,
|
| 119 |
+
max_seqlen,
|
| 120 |
+
self.drop.p if self.training else 0.0,
|
| 121 |
+
softmax_scale=self.softmax_scale,
|
| 122 |
+
causal=causal,
|
| 123 |
+
alibi_slopes=self.alibi_slopes,
|
| 124 |
+
window_size=self.window_size,
|
| 125 |
+
deterministic=self.deterministic,
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
y = local_flash_attn_qkvpacked_func(
|
| 129 |
+
qkv,
|
| 130 |
+
self.drop.p if self.training else 0.0,
|
| 131 |
+
softmax_scale=self.softmax_scale,
|
| 132 |
+
causal=causal,
|
| 133 |
+
alibi_slopes=self.alibi_slopes,
|
| 134 |
+
window_size=self.window_size,
|
| 135 |
+
deterministic=self.deterministic,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
return y
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class FlashCrossAttention(nn.Module):
|
| 142 |
+
"""Implement the scaled dot product attention with softmax.
|
| 143 |
+
Arguments
|
| 144 |
+
---------
|
| 145 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 146 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 147 |
+
runtime)
|
| 148 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 149 |
+
(default: 0.0)
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
causal=False,
|
| 155 |
+
softmax_scale=None,
|
| 156 |
+
attention_dropout=0.0,
|
| 157 |
+
alibi_slopes=None,
|
| 158 |
+
window_size=(-1, -1),
|
| 159 |
+
deterministic=False,
|
| 160 |
+
):
|
| 161 |
+
super().__init__()
|
| 162 |
+
assert local_flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
| 163 |
+
assert local_flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
| 164 |
+
self.causal = causal
|
| 165 |
+
self.softmax_scale = softmax_scale
|
| 166 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 167 |
+
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
| 168 |
+
self.window_size = window_size
|
| 169 |
+
self.deterministic = deterministic
|
| 170 |
+
|
| 171 |
+
def forward(
|
| 172 |
+
self,
|
| 173 |
+
q,
|
| 174 |
+
kv,
|
| 175 |
+
causal=None,
|
| 176 |
+
cu_seqlens=None,
|
| 177 |
+
max_seqlen=None,
|
| 178 |
+
cu_seqlens_k=None,
|
| 179 |
+
max_seqlen_k=None,
|
| 180 |
+
):
|
| 181 |
+
"""Implements the multihead softmax attention.
|
| 182 |
+
Arguments
|
| 183 |
+
---------
|
| 184 |
+
q: The tensor containing the query. (B, Sq, H, D)
|
| 185 |
+
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
|
| 186 |
+
causal: if passed, will override self.causal
|
| 187 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 188 |
+
of the sequences in the batch, used to index into q.
|
| 189 |
+
max_seqlen: int. Maximum sequence length in the batch of q.
|
| 190 |
+
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 191 |
+
of the sequences in the batch, used to index into kv.
|
| 192 |
+
max_seqlen_k: int. Maximum sequence length in the batch of k and v.
|
| 193 |
+
"""
|
| 194 |
+
assert q.dtype in [torch.float16, torch.bfloat16]
|
| 195 |
+
assert q.is_cuda and kv.is_cuda
|
| 196 |
+
causal = self.causal if causal is None else causal
|
| 197 |
+
unpadded = cu_seqlens is not None
|
| 198 |
+
if self.alibi_slopes is not None:
|
| 199 |
+
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
|
| 200 |
+
if unpadded:
|
| 201 |
+
assert cu_seqlens.dtype == torch.int32
|
| 202 |
+
assert max_seqlen is not None
|
| 203 |
+
assert isinstance(max_seqlen, int)
|
| 204 |
+
assert cu_seqlens_k is not None
|
| 205 |
+
assert cu_seqlens_k.dtype == torch.int32
|
| 206 |
+
assert max_seqlen_k is not None
|
| 207 |
+
assert isinstance(max_seqlen_k, int)
|
| 208 |
+
return local_flash_attn_varlen_kvpacked_func(
|
| 209 |
+
q,
|
| 210 |
+
kv,
|
| 211 |
+
cu_seqlens,
|
| 212 |
+
cu_seqlens_k,
|
| 213 |
+
max_seqlen,
|
| 214 |
+
max_seqlen_k,
|
| 215 |
+
self.drop.p if self.training else 0.0,
|
| 216 |
+
softmax_scale=self.softmax_scale,
|
| 217 |
+
causal=causal,
|
| 218 |
+
alibi_slopes=self.alibi_slopes,
|
| 219 |
+
window_size=self.window_size,
|
| 220 |
+
deterministic=self.deterministic,
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 224 |
+
seqlen_k = kv.shape[1]
|
| 225 |
+
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
| 226 |
+
return local_flash_attn_kvpacked_func(
|
| 227 |
+
q,
|
| 228 |
+
kv,
|
| 229 |
+
self.drop.p if self.training else 0.0,
|
| 230 |
+
causal=causal,
|
| 231 |
+
softmax_scale=self.softmax_scale,
|
| 232 |
+
alibi_slopes=self.alibi_slopes,
|
| 233 |
+
window_size=self.window_size,
|
| 234 |
+
deterministic=self.deterministic,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class SelfAttention(nn.Module):
|
| 239 |
+
"""Implement the scaled dot product attention with softmax.
|
| 240 |
+
Arguments
|
| 241 |
+
---------
|
| 242 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 243 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 244 |
+
runtime)
|
| 245 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 246 |
+
(default: 0.0)
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
| 250 |
+
super().__init__()
|
| 251 |
+
self.causal = causal
|
| 252 |
+
self.softmax_scale = softmax_scale
|
| 253 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 254 |
+
|
| 255 |
+
def forward(self, qkv, causal=None, key_padding_mask=None):
|
| 256 |
+
"""Implements the multihead softmax attention.
|
| 257 |
+
Arguments
|
| 258 |
+
---------
|
| 259 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
|
| 260 |
+
causal: if passed, will override self.causal
|
| 261 |
+
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
| 262 |
+
False means to mask out. (B, S)
|
| 263 |
+
"""
|
| 264 |
+
q, k, v = qkv.unbind(dim=2) # each: (B, T, H, D)
|
| 265 |
+
q = q.permute(0, 2, 1, 3) # (B, H, T, D)
|
| 266 |
+
k = k.permute(0, 2, 1, 3)
|
| 267 |
+
v = v.permute(0, 2, 1, 3)
|
| 268 |
+
batch_size, num_heads, seqlen, d = q.shape
|
| 269 |
+
|
| 270 |
+
scale = self.softmax_scale if self.softmax_scale is not None else 1.0 / math.sqrt(d)
|
| 271 |
+
q = q * (scale * math.sqrt(d))
|
| 272 |
+
|
| 273 |
+
attn_mask = None
|
| 274 |
+
if key_padding_mask is not None:
|
| 275 |
+
attn_mask = torch.where(
|
| 276 |
+
repeat(key_padding_mask, "b s -> b t s", t=seqlen),
|
| 277 |
+
0.0,
|
| 278 |
+
-10000.0,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
output = torch.nn.functional.scaled_dot_product_attention(
|
| 282 |
+
q,
|
| 283 |
+
k,
|
| 284 |
+
v,
|
| 285 |
+
attn_mask=attn_mask,
|
| 286 |
+
dropout_p=self.drop.p if self.training else 0.0,
|
| 287 |
+
is_causal=(self.causal if causal is None else causal),
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
output = output.permute(0, 2, 1, 3)
|
| 291 |
+
return output
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class CrossAttention(nn.Module):
|
| 295 |
+
"""Implement the scaled dot product attention with softmax.
|
| 296 |
+
Arguments
|
| 297 |
+
---------
|
| 298 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 299 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 300 |
+
runtime)
|
| 301 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 302 |
+
(default: 0.0)
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
| 306 |
+
super().__init__()
|
| 307 |
+
self.causal = causal
|
| 308 |
+
self.softmax_scale = softmax_scale
|
| 309 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 310 |
+
|
| 311 |
+
def forward(self, q, kv, causal=None, key_padding_mask=None):
|
| 312 |
+
"""Implements the multihead softmax attention.
|
| 313 |
+
Arguments
|
| 314 |
+
---------
|
| 315 |
+
q: The tensor containing the query. (B, Sq, H, D)
|
| 316 |
+
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
|
| 317 |
+
causal: if passed, will override self.causal
|
| 318 |
+
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
| 319 |
+
False means to mask out. (B, Sk)
|
| 320 |
+
"""
|
| 321 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 322 |
+
causal = self.causal if causal is None else causal
|
| 323 |
+
seqlen_k = kv.shape[1]
|
| 324 |
+
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
| 325 |
+
if kv.shape[3] != q.shape[2]: # MQA/GQA
|
| 326 |
+
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
| 327 |
+
k, v = kv.unbind(dim=2)
|
| 328 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
| 329 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 330 |
+
if key_padding_mask is not None:
|
| 331 |
+
padding_mask = torch.full(
|
| 332 |
+
(batch_size, seqlen_k),
|
| 333 |
+
-10000.0,
|
| 334 |
+
dtype=scores.dtype,
|
| 335 |
+
device=scores.device,
|
| 336 |
+
)
|
| 337 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 338 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
| 339 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
| 340 |
+
if causal:
|
| 341 |
+
# causal mask needs to take into account the difference between seqlen_q and seqlen_k
|
| 342 |
+
row_idx = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
|
| 343 |
+
col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
|
| 344 |
+
sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 345 |
+
causal_mask = col_idx > row_idx + sk - seqlen_q
|
| 346 |
+
scores = scores.masked_fill(causal_mask, -10000.0)
|
| 347 |
+
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
| 348 |
+
attention_drop = self.drop(attention)
|
| 349 |
+
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
|
| 350 |
+
return output
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class LinearResidual(nn.Linear):
|
| 354 |
+
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
|
| 355 |
+
|
| 356 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 357 |
+
return super().forward(input), input
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def _update_kv_cache(kv, inference_params, layer_idx):
|
| 361 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 362 |
+
# Pre-allocate memory for key-values for inference.
|
| 363 |
+
num_heads, head_dim = kv.shape[-2:]
|
| 364 |
+
if layer_idx not in inference_params.key_value_memory_dict:
|
| 365 |
+
kv_cache = torch.empty(
|
| 366 |
+
inference_params.max_batch_size,
|
| 367 |
+
inference_params.max_seqlen,
|
| 368 |
+
2,
|
| 369 |
+
num_heads,
|
| 370 |
+
head_dim,
|
| 371 |
+
dtype=kv.dtype,
|
| 372 |
+
device=kv.device,
|
| 373 |
+
)
|
| 374 |
+
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
| 375 |
+
else:
|
| 376 |
+
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
| 377 |
+
# Adjust key and value for inference
|
| 378 |
+
batch_start = inference_params.batch_size_offset
|
| 379 |
+
batch_end = batch_start + kv.shape[0]
|
| 380 |
+
sequence_start = inference_params.seqlen_offset
|
| 381 |
+
sequence_end = sequence_start + kv.shape[1]
|
| 382 |
+
assert batch_end <= kv_cache.shape[0]
|
| 383 |
+
assert sequence_end <= kv_cache.shape[1]
|
| 384 |
+
assert kv_cache is not None
|
| 385 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
| 386 |
+
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class MHA(nn.Module):
|
| 390 |
+
"""Multi-head self-attention and cross-attention"""
|
| 391 |
+
|
| 392 |
+
def __init__(
|
| 393 |
+
self,
|
| 394 |
+
embed_dim,
|
| 395 |
+
num_heads,
|
| 396 |
+
num_heads_kv=None,
|
| 397 |
+
cross_attn=False,
|
| 398 |
+
qkv_proj_bias=True,
|
| 399 |
+
out_proj_bias=True,
|
| 400 |
+
dropout=0.0,
|
| 401 |
+
softmax_scale=None,
|
| 402 |
+
causal=False,
|
| 403 |
+
layer_idx=None,
|
| 404 |
+
dwconv=False,
|
| 405 |
+
rotary_emb_dim=0,
|
| 406 |
+
rotary_emb_base=10000.0,
|
| 407 |
+
rotary_emb_scale_base=None,
|
| 408 |
+
rotary_emb_interleaved=False,
|
| 409 |
+
use_alibi=False,
|
| 410 |
+
window_size=(-1, -1),
|
| 411 |
+
fused_bias_fc=False,
|
| 412 |
+
use_flash_attn=False,
|
| 413 |
+
return_residual=False,
|
| 414 |
+
checkpointing=False,
|
| 415 |
+
device=None,
|
| 416 |
+
dtype=None,
|
| 417 |
+
) -> None:
|
| 418 |
+
"""
|
| 419 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
| 420 |
+
return_residual: whether to return the input x along with the output. This is for
|
| 421 |
+
performance reason: for post-norm architecture, returning the input allows us
|
| 422 |
+
to fuse the backward of nn.Linear with the residual connection.
|
| 423 |
+
"""
|
| 424 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 425 |
+
super().__init__()
|
| 426 |
+
self.embed_dim = embed_dim
|
| 427 |
+
self.cross_attn = cross_attn
|
| 428 |
+
self.causal = causal
|
| 429 |
+
self.layer_idx = layer_idx
|
| 430 |
+
self.dwconv = dwconv
|
| 431 |
+
self.rotary_emb_dim = rotary_emb_dim
|
| 432 |
+
self.use_flash_attn = use_flash_attn
|
| 433 |
+
self.return_residual = return_residual
|
| 434 |
+
self.checkpointing = checkpointing
|
| 435 |
+
if use_alibi:
|
| 436 |
+
assert use_flash_attn, "ALiBi code path requires flash_attn"
|
| 437 |
+
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
| 438 |
+
else:
|
| 439 |
+
alibi_slopes = None
|
| 440 |
+
if window_size != (-1, -1):
|
| 441 |
+
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
| 442 |
+
|
| 443 |
+
self.num_heads = num_heads
|
| 444 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
| 445 |
+
assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
|
| 446 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 447 |
+
self.head_dim = self.embed_dim // num_heads
|
| 448 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
| 449 |
+
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
| 450 |
+
|
| 451 |
+
if self.rotary_emb_dim > 0:
|
| 452 |
+
assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
|
| 453 |
+
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
| 454 |
+
self.rotary_emb = RotaryEmbedding(
|
| 455 |
+
self.rotary_emb_dim,
|
| 456 |
+
base=rotary_emb_base,
|
| 457 |
+
scale_base=rotary_emb_scale_base,
|
| 458 |
+
interleaved=rotary_emb_interleaved,
|
| 459 |
+
device=device,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
if fused_bias_fc and FusedDense is None:
|
| 463 |
+
raise ImportError("fused_dense is not installed")
|
| 464 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 465 |
+
linear_resid_cls = LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
| 466 |
+
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
| 467 |
+
inner_attn_cls = (
|
| 468 |
+
partial(
|
| 469 |
+
FlashSelfAttention,
|
| 470 |
+
layer_number=self.layer_idx,
|
| 471 |
+
alibi_slopes=alibi_slopes,
|
| 472 |
+
window_size=window_size,
|
| 473 |
+
)
|
| 474 |
+
if use_flash_attn
|
| 475 |
+
else SelfAttention
|
| 476 |
+
)
|
| 477 |
+
inner_cross_attn_cls = (
|
| 478 |
+
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
| 479 |
+
if use_flash_attn
|
| 480 |
+
else CrossAttention
|
| 481 |
+
)
|
| 482 |
+
if not self.cross_attn:
|
| 483 |
+
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 484 |
+
else:
|
| 485 |
+
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 486 |
+
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 487 |
+
if self.dwconv:
|
| 488 |
+
if self.num_heads_kv == self.num_heads:
|
| 489 |
+
self.dwconv_qkv = nn.Conv1d(qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim)
|
| 490 |
+
else:
|
| 491 |
+
self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim)
|
| 492 |
+
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
|
| 493 |
+
self.inner_attn = inner_attn_cls(
|
| 494 |
+
causal=causal,
|
| 495 |
+
softmax_scale=softmax_scale,
|
| 496 |
+
attention_dropout=dropout,
|
| 497 |
+
)
|
| 498 |
+
self.inner_cross_attn = inner_cross_attn_cls(
|
| 499 |
+
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
| 500 |
+
)
|
| 501 |
+
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
| 502 |
+
|
| 503 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
| 504 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
| 505 |
+
device = self.out_proj.weight.device
|
| 506 |
+
return torch.empty(
|
| 507 |
+
batch_size,
|
| 508 |
+
max_seqlen,
|
| 509 |
+
2,
|
| 510 |
+
self.num_heads_kv,
|
| 511 |
+
self.head_dim,
|
| 512 |
+
dtype=dtype,
|
| 513 |
+
device=device,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
def _update_kv_cache(self, kv, inference_params):
|
| 517 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 518 |
+
assert not self.dwconv, "Generation does not support dwconv yet"
|
| 519 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
| 520 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
| 521 |
+
|
| 522 |
+
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
| 523 |
+
"""
|
| 524 |
+
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
| 525 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
| 526 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
| 527 |
+
"""
|
| 528 |
+
assert inference_params is not None and inference_params.seqlen_offset > 0
|
| 529 |
+
assert self.use_flash_attn
|
| 530 |
+
if self.rotary_emb_dim > 0:
|
| 531 |
+
assert self.rotary_emb.scale is None, "This code path does not support xPos"
|
| 532 |
+
self.rotary_emb._update_cos_sin_cache(inference_params.max_seqlen, device=q.device, dtype=q.dtype)
|
| 533 |
+
rotary_cos, rotary_sin = (
|
| 534 |
+
self.rotary_emb._cos_cached,
|
| 535 |
+
self.rotary_emb._sin_cached,
|
| 536 |
+
)
|
| 537 |
+
else:
|
| 538 |
+
rotary_cos, rotary_sin = None, None
|
| 539 |
+
batch = q.shape[0]
|
| 540 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 541 |
+
cache_seqlens = (
|
| 542 |
+
inference_params.lengths_per_sample[:batch]
|
| 543 |
+
if inference_params.lengths_per_sample is not None
|
| 544 |
+
else inference_params.seqlen_offset
|
| 545 |
+
)
|
| 546 |
+
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
| 547 |
+
context = local_flash_attn_with_kvcache(
|
| 548 |
+
q,
|
| 549 |
+
kv_cache[:, :, 0],
|
| 550 |
+
kv_cache[:, :, 1],
|
| 551 |
+
kv[:, :, 0],
|
| 552 |
+
kv[:, :, 1],
|
| 553 |
+
rotary_cos=rotary_cos,
|
| 554 |
+
rotary_sin=rotary_sin,
|
| 555 |
+
cache_seqlens=cache_seqlens,
|
| 556 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 557 |
+
causal=self.inner_cross_attn.causal,
|
| 558 |
+
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
| 559 |
+
alibi_slopes=alibi_slopes,
|
| 560 |
+
)
|
| 561 |
+
return context
|
| 562 |
+
|
| 563 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
| 564 |
+
"""Write kv to inference_params, then do attention"""
|
| 565 |
+
if inference_params.seqlen_offset == 0 or local_flash_attn_with_kvcache is None or not self.use_flash_attn:
|
| 566 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
| 567 |
+
kv = self._update_kv_cache(kv, inference_params)
|
| 568 |
+
return self.inner_cross_attn(q, kv)
|
| 569 |
+
else:
|
| 570 |
+
batch = q.shape[0]
|
| 571 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 572 |
+
cache_seqlens = (
|
| 573 |
+
inference_params.lengths_per_sample[:batch]
|
| 574 |
+
if inference_params.lengths_per_sample is not None
|
| 575 |
+
else inference_params.seqlen_offset
|
| 576 |
+
)
|
| 577 |
+
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
| 578 |
+
return local_flash_attn_with_kvcache(
|
| 579 |
+
q,
|
| 580 |
+
kv_cache[:, :, 0],
|
| 581 |
+
kv_cache[:, :, 1],
|
| 582 |
+
kv[:, :, 0],
|
| 583 |
+
kv[:, :, 1],
|
| 584 |
+
cache_seqlens=cache_seqlens,
|
| 585 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 586 |
+
causal=self.inner_cross_attn.causal,
|
| 587 |
+
alibi_slopes=alibi_slopes,
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
def forward(
|
| 591 |
+
self,
|
| 592 |
+
x,
|
| 593 |
+
x_kv=None,
|
| 594 |
+
key_padding_mask=None,
|
| 595 |
+
cu_seqlens=None,
|
| 596 |
+
max_seqlen=None,
|
| 597 |
+
mixer_subset=None,
|
| 598 |
+
inference_params=None,
|
| 599 |
+
**kwargs,
|
| 600 |
+
):
|
| 601 |
+
"""
|
| 602 |
+
Arguments:
|
| 603 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
| 604 |
+
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
| 605 |
+
is the is the sum of the sequence lengths in the batch.
|
| 606 |
+
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
|
| 607 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 608 |
+
of the sequences in the batch, used to index into x. Only applicable when using
|
| 609 |
+
FlashAttention.
|
| 610 |
+
max_seqlen: int. Maximum sequence length in the batch.
|
| 611 |
+
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
| 612 |
+
(batch, seqlen). Only applicable when not using FlashAttention.
|
| 613 |
+
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
| 614 |
+
before applying the query projection. Useful for e.g., ViT where we only care
|
| 615 |
+
about the CLS token in the last layer.
|
| 616 |
+
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
| 617 |
+
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
| 618 |
+
"""
|
| 619 |
+
if cu_seqlens is not None:
|
| 620 |
+
assert max_seqlen is not None
|
| 621 |
+
assert key_padding_mask is None
|
| 622 |
+
assert self.use_flash_attn
|
| 623 |
+
assert not self.dwconv
|
| 624 |
+
assert self.rotary_emb_dim == 0
|
| 625 |
+
if key_padding_mask is not None:
|
| 626 |
+
assert cu_seqlens is None
|
| 627 |
+
assert max_seqlen is None
|
| 628 |
+
assert not self.use_flash_attn
|
| 629 |
+
if inference_params is not None:
|
| 630 |
+
assert key_padding_mask is None
|
| 631 |
+
assert cu_seqlens is None and max_seqlen is None
|
| 632 |
+
assert not self.dwconv
|
| 633 |
+
|
| 634 |
+
kwargs = (
|
| 635 |
+
{"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
|
| 636 |
+
if self.use_flash_attn
|
| 637 |
+
else {"key_padding_mask": key_padding_mask, **kwargs}
|
| 638 |
+
)
|
| 639 |
+
seqlen_offset = (
|
| 640 |
+
0
|
| 641 |
+
if inference_params is None
|
| 642 |
+
else (
|
| 643 |
+
inference_params.lengths_per_sample
|
| 644 |
+
if inference_params.lengths_per_sample is not None
|
| 645 |
+
else inference_params.seqlen_offset
|
| 646 |
+
)
|
| 647 |
+
)
|
| 648 |
+
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
| 649 |
+
batch, seqlen = x.shape[:2]
|
| 650 |
+
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
| 651 |
+
assert x_kv is None and mixer_subset is None
|
| 652 |
+
if not self.return_residual:
|
| 653 |
+
qkv = self.Wqkv(x)
|
| 654 |
+
else:
|
| 655 |
+
qkv, x = self.Wqkv(x)
|
| 656 |
+
if self.dwconv:
|
| 657 |
+
qkv = rearrange(
|
| 658 |
+
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
|
| 659 |
+
"b d s -> b s d",
|
| 660 |
+
).contiguous()
|
| 661 |
+
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
| 662 |
+
if (
|
| 663 |
+
inference_params is None
|
| 664 |
+
or inference_params.seqlen_offset == 0
|
| 665 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 666 |
+
or not self.use_flash_attn
|
| 667 |
+
):
|
| 668 |
+
if self.rotary_emb_dim > 0:
|
| 669 |
+
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen)
|
| 670 |
+
if inference_params is None:
|
| 671 |
+
if not self.checkpointing:
|
| 672 |
+
context = self.inner_attn(qkv, **kwargs)
|
| 673 |
+
else:
|
| 674 |
+
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
| 675 |
+
else:
|
| 676 |
+
context = self._update_kvcache_attention(qkv[:, :, 0], qkv[:, :, 1:], inference_params)
|
| 677 |
+
else:
|
| 678 |
+
context = self._apply_rotary_update_kvcache_attention(qkv[:, :, 0], qkv[:, :, 1:], inference_params)
|
| 679 |
+
else:
|
| 680 |
+
if self.cross_attn:
|
| 681 |
+
if not self.return_residual:
|
| 682 |
+
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
| 683 |
+
kv = self.Wkv(x_kv if x_kv is not None else x)
|
| 684 |
+
else:
|
| 685 |
+
if x_kv is not None:
|
| 686 |
+
kv, x_kv = self.Wkv(x_kv)
|
| 687 |
+
else:
|
| 688 |
+
kv, x = self.Wkv(x)
|
| 689 |
+
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
| 690 |
+
else:
|
| 691 |
+
assert self.num_heads_kv != self.num_heads
|
| 692 |
+
if not self.return_residual:
|
| 693 |
+
qkv = self.Wqkv(x)
|
| 694 |
+
else:
|
| 695 |
+
qkv, x = self.Wqkv(x)
|
| 696 |
+
q = qkv[..., : self.num_heads * self.head_dim]
|
| 697 |
+
kv = qkv[..., self.num_heads * self.head_dim :]
|
| 698 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
| 699 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
| 700 |
+
if self.dwconv:
|
| 701 |
+
q = rearrange(
|
| 702 |
+
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
|
| 703 |
+
"b d s -> b s d",
|
| 704 |
+
).contiguous()
|
| 705 |
+
kv = rearrange(
|
| 706 |
+
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
|
| 707 |
+
"b d s -> b s d",
|
| 708 |
+
).contiguous()
|
| 709 |
+
if (
|
| 710 |
+
inference_params is None
|
| 711 |
+
or inference_params.seqlen_offset == 0
|
| 712 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 713 |
+
or not self.use_flash_attn
|
| 714 |
+
):
|
| 715 |
+
if self.rotary_emb_dim > 0:
|
| 716 |
+
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen)
|
| 717 |
+
if inference_params is None:
|
| 718 |
+
if not self.checkpointing:
|
| 719 |
+
context = self.inner_cross_attn(q, kv, **kwargs)
|
| 720 |
+
else:
|
| 721 |
+
context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv, **kwargs)
|
| 722 |
+
else:
|
| 723 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
| 724 |
+
else:
|
| 725 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 726 |
+
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
| 727 |
+
return out if not self.return_residual else (out, x)
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
class ParallelMHA(nn.Module):
|
| 731 |
+
"""Multi-head self-attention and cross-attention"""
|
| 732 |
+
|
| 733 |
+
def __init__(
|
| 734 |
+
self,
|
| 735 |
+
embed_dim,
|
| 736 |
+
num_heads,
|
| 737 |
+
process_group,
|
| 738 |
+
num_heads_kv=None,
|
| 739 |
+
qkv_proj_bias=True,
|
| 740 |
+
out_proj_bias=True,
|
| 741 |
+
dropout=0.0,
|
| 742 |
+
softmax_scale=None,
|
| 743 |
+
causal=False,
|
| 744 |
+
layer_idx=None,
|
| 745 |
+
rotary_emb_dim=0,
|
| 746 |
+
rotary_emb_base=10000.0,
|
| 747 |
+
rotary_emb_scale_base=None,
|
| 748 |
+
rotary_emb_interleaved=False,
|
| 749 |
+
use_alibi=False,
|
| 750 |
+
window_size=(-1, -1),
|
| 751 |
+
use_flash_attn=False,
|
| 752 |
+
checkpointing=False,
|
| 753 |
+
sequence_parallel=True,
|
| 754 |
+
device=None,
|
| 755 |
+
dtype=None,
|
| 756 |
+
) -> None:
|
| 757 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 758 |
+
super().__init__()
|
| 759 |
+
self.embed_dim = embed_dim
|
| 760 |
+
self.causal = causal
|
| 761 |
+
self.layer_idx = layer_idx
|
| 762 |
+
self.rotary_emb_dim = rotary_emb_dim
|
| 763 |
+
self.use_flash_attn = use_flash_attn
|
| 764 |
+
self.checkpointing = checkpointing
|
| 765 |
+
self.process_group = process_group
|
| 766 |
+
self.world_size = process_group.size()
|
| 767 |
+
self.local_rank = torch.distributed.get_rank(process_group)
|
| 768 |
+
|
| 769 |
+
self.num_heads = num_heads
|
| 770 |
+
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 771 |
+
|
| 772 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
| 773 |
+
assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
|
| 774 |
+
|
| 775 |
+
self.num_heads_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
|
| 776 |
+
self.num_heads_kv_per_rank = get_dim_for_local_rank(self.num_heads_kv, self.world_size, self.local_rank)
|
| 777 |
+
self.head_dim = self.embed_dim // num_heads
|
| 778 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
| 779 |
+
|
| 780 |
+
if use_alibi:
|
| 781 |
+
assert use_flash_attn, "ALiBi code path requires flash_attn"
|
| 782 |
+
num_heads_local = math.ceil(self.num_heads / self.world_size)
|
| 783 |
+
alibi_slopes = torch.tensor(
|
| 784 |
+
get_alibi_slopes(num_heads)[
|
| 785 |
+
self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
|
| 786 |
+
],
|
| 787 |
+
device=device,
|
| 788 |
+
)
|
| 789 |
+
else:
|
| 790 |
+
alibi_slopes = None
|
| 791 |
+
if window_size != (-1, -1):
|
| 792 |
+
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
| 793 |
+
|
| 794 |
+
if self.rotary_emb_dim > 0:
|
| 795 |
+
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
| 796 |
+
self.rotary_emb = RotaryEmbedding(
|
| 797 |
+
self.rotary_emb_dim,
|
| 798 |
+
base=rotary_emb_base,
|
| 799 |
+
scale_base=rotary_emb_scale_base,
|
| 800 |
+
interleaved=rotary_emb_interleaved,
|
| 801 |
+
device=device,
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
if ColumnParallelLinear is None or RowParallelLinear is None:
|
| 805 |
+
raise ImportError("fused_dense is not installed")
|
| 806 |
+
self.Wqkv = ColumnParallelLinear(
|
| 807 |
+
embed_dim,
|
| 808 |
+
qkv_dim,
|
| 809 |
+
process_group,
|
| 810 |
+
bias=qkv_proj_bias,
|
| 811 |
+
sequence_parallel=sequence_parallel,
|
| 812 |
+
multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
|
| 813 |
+
**factory_kwargs,
|
| 814 |
+
)
|
| 815 |
+
inner_attn_cls = (
|
| 816 |
+
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
| 817 |
+
if use_flash_attn
|
| 818 |
+
else SelfAttention
|
| 819 |
+
)
|
| 820 |
+
inner_cross_attn_cls = (
|
| 821 |
+
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
| 822 |
+
if use_flash_attn
|
| 823 |
+
else CrossAttention
|
| 824 |
+
)
|
| 825 |
+
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
| 826 |
+
self.inner_cross_attn = inner_cross_attn_cls(
|
| 827 |
+
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
| 828 |
+
)
|
| 829 |
+
self.out_proj = RowParallelLinear(
|
| 830 |
+
embed_dim,
|
| 831 |
+
embed_dim,
|
| 832 |
+
process_group,
|
| 833 |
+
bias=out_proj_bias,
|
| 834 |
+
sequence_parallel=sequence_parallel,
|
| 835 |
+
multiple_of=self.head_dim,
|
| 836 |
+
**factory_kwargs,
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
| 840 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
| 841 |
+
device = self.out_proj.weight.device
|
| 842 |
+
return torch.empty(
|
| 843 |
+
batch_size,
|
| 844 |
+
max_seqlen,
|
| 845 |
+
2,
|
| 846 |
+
self.num_heads_kv_per_rank,
|
| 847 |
+
self.head_dim,
|
| 848 |
+
dtype=dtype,
|
| 849 |
+
device=device,
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
def _update_kv_cache(self, kv, inference_params):
|
| 853 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 854 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
| 855 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
| 856 |
+
|
| 857 |
+
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
| 858 |
+
"""
|
| 859 |
+
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
| 860 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
| 861 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
| 862 |
+
"""
|
| 863 |
+
assert inference_params is not None and inference_params.seqlen_offset > 0
|
| 864 |
+
assert self.use_flash_attn
|
| 865 |
+
if self.rotary_emb_dim > 0:
|
| 866 |
+
assert self.rotary_emb.scale is None, "This code path does not support xPos"
|
| 867 |
+
self.rotary_emb._update_cos_sin_cache(inference_params.max_seqlen, device=q.device, dtype=q.dtype)
|
| 868 |
+
rotary_cos, rotary_sin = (
|
| 869 |
+
self.rotary_emb._cos_cached,
|
| 870 |
+
self.rotary_emb._sin_cached,
|
| 871 |
+
)
|
| 872 |
+
else:
|
| 873 |
+
rotary_cos, rotary_sin = None, None
|
| 874 |
+
batch = q.shape[0]
|
| 875 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 876 |
+
cache_seqlens = (
|
| 877 |
+
inference_params.lengths_per_sample[:batch]
|
| 878 |
+
if inference_params.lengths_per_sample is not None
|
| 879 |
+
else inference_params.seqlen_offset
|
| 880 |
+
)
|
| 881 |
+
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
| 882 |
+
context = local_flash_attn_with_kvcache(
|
| 883 |
+
q,
|
| 884 |
+
kv_cache[:, :, 0],
|
| 885 |
+
kv_cache[:, :, 1],
|
| 886 |
+
kv[:, :, 0],
|
| 887 |
+
kv[:, :, 1],
|
| 888 |
+
rotary_cos=rotary_cos,
|
| 889 |
+
rotary_sin=rotary_sin,
|
| 890 |
+
cache_seqlens=cache_seqlens,
|
| 891 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 892 |
+
causal=self.inner_cross_attn.causal,
|
| 893 |
+
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
| 894 |
+
alibi_slopes=alibi_slopes,
|
| 895 |
+
)
|
| 896 |
+
return context
|
| 897 |
+
|
| 898 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
| 899 |
+
"""Write kv to inference_params, then do attention"""
|
| 900 |
+
if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
|
| 901 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
| 902 |
+
kv = self._update_kv_cache(kv, inference_params)
|
| 903 |
+
return self.inner_cross_attn(q, kv)
|
| 904 |
+
else:
|
| 905 |
+
batch = q.shape[0]
|
| 906 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 907 |
+
cache_seqlens = (
|
| 908 |
+
inference_params.lengths_per_sample[:batch]
|
| 909 |
+
if inference_params.lengths_per_sample is not None
|
| 910 |
+
else inference_params.seqlen_offset
|
| 911 |
+
)
|
| 912 |
+
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
| 913 |
+
context = local_flash_attn_with_kvcache(
|
| 914 |
+
q,
|
| 915 |
+
kv_cache[:, :, 0],
|
| 916 |
+
kv_cache[:, :, 1],
|
| 917 |
+
kv[:, :, 0],
|
| 918 |
+
kv[:, :, 1],
|
| 919 |
+
cache_seqlens=cache_seqlens,
|
| 920 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 921 |
+
causal=self.inner_cross_attn.causal,
|
| 922 |
+
alibi_slopes=alibi_slopes,
|
| 923 |
+
)
|
| 924 |
+
return context
|
| 925 |
+
|
| 926 |
+
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
| 927 |
+
"""
|
| 928 |
+
Arguments:
|
| 929 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
| 930 |
+
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
|
| 931 |
+
split x during sequence parallel, we split the batch * seqlen dimension
|
| 932 |
+
(in case batch is small).
|
| 933 |
+
"""
|
| 934 |
+
qkv = self.Wqkv(x)
|
| 935 |
+
if seqlen is not None:
|
| 936 |
+
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
|
| 937 |
+
seqlen_offset = (
|
| 938 |
+
0
|
| 939 |
+
if inference_params is None
|
| 940 |
+
else (
|
| 941 |
+
inference_params.lengths_per_sample
|
| 942 |
+
if inference_params.lengths_per_sample is not None
|
| 943 |
+
else inference_params.seqlen_offset
|
| 944 |
+
)
|
| 945 |
+
)
|
| 946 |
+
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
| 947 |
+
if self.num_heads_kv == self.num_heads:
|
| 948 |
+
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
| 949 |
+
if (
|
| 950 |
+
inference_params is None
|
| 951 |
+
or inference_params.seqlen_offset == 0
|
| 952 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 953 |
+
or not self.use_flash_attn
|
| 954 |
+
):
|
| 955 |
+
if self.rotary_emb_dim > 0:
|
| 956 |
+
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen)
|
| 957 |
+
if inference_params is None:
|
| 958 |
+
if not self.checkpointing:
|
| 959 |
+
context = self.inner_attn(qkv, **kwargs)
|
| 960 |
+
else:
|
| 961 |
+
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
| 962 |
+
else:
|
| 963 |
+
context = self._update_kvcache_attention(qkv[:, :, 0], qkv[:, :, 1:], inference_params)
|
| 964 |
+
else:
|
| 965 |
+
context = self._apply_rotary_update_kvcache_attention(qkv[:, :, 0], qkv[:, :, 1:], inference_params)
|
| 966 |
+
else:
|
| 967 |
+
q = rearrange(
|
| 968 |
+
qkv[..., : self.num_heads_per_rank * self.head_dim],
|
| 969 |
+
"... (h d) -> ... h d",
|
| 970 |
+
d=self.head_dim,
|
| 971 |
+
)
|
| 972 |
+
kv = rearrange(
|
| 973 |
+
qkv[..., self.num_heads_per_rank * self.head_dim :],
|
| 974 |
+
"... (two hkv d) -> ... two hkv d",
|
| 975 |
+
two=2,
|
| 976 |
+
d=self.head_dim,
|
| 977 |
+
)
|
| 978 |
+
if (
|
| 979 |
+
inference_params is None
|
| 980 |
+
or inference_params.seqlen_offset == 0
|
| 981 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 982 |
+
or not self.use_flash_attn
|
| 983 |
+
):
|
| 984 |
+
if self.rotary_emb_dim > 0:
|
| 985 |
+
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen)
|
| 986 |
+
if inference_params is None:
|
| 987 |
+
if not self.checkpointing:
|
| 988 |
+
context = self.inner_cross_attn(q, kv, **kwargs)
|
| 989 |
+
else:
|
| 990 |
+
context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv, **kwargs)
|
| 991 |
+
else:
|
| 992 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
| 993 |
+
else:
|
| 994 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 995 |
+
context = rearrange(context, "b s h d -> b s (h d)")
|
| 996 |
+
if seqlen is not None:
|
| 997 |
+
context = rearrange(context, "b s d -> (b s) d")
|
| 998 |
+
out = self.out_proj(context)
|
| 999 |
+
return out
|
cache.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied verbatim from vortex
|
| 2 |
+
# Copyright (c) 2024, Michael Poli.
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py
|
| 12 |
+
@dataclass
|
| 13 |
+
class InferenceParams:
|
| 14 |
+
"""Inference parameters that are passed to the main model in order
|
| 15 |
+
to efficienly calculate and store the context during inference."""
|
| 16 |
+
|
| 17 |
+
max_seqlen: int
|
| 18 |
+
max_batch_size: int
|
| 19 |
+
seqlen_offset: int = 0
|
| 20 |
+
batch_size_offset: int = 0
|
| 21 |
+
key_value_memory_dict: dict = field(default_factory=dict)
|
| 22 |
+
lengths_per_sample: Optional[Tensor] = None
|
| 23 |
+
|
| 24 |
+
def reset(self, max_seqlen, max_batch_size):
|
| 25 |
+
self.max_seqlen = max_seqlen
|
| 26 |
+
self.max_batch_size = max_batch_size
|
| 27 |
+
self.seqlen_offset = 0
|
| 28 |
+
if self.lengths_per_sample is not None:
|
| 29 |
+
self.lengths_per_sample.zero_()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class HyenaCascadeIIRInferenceParams:
|
| 34 |
+
"""Inference parameters passed to long Hyena blocks with recurrent mode."""
|
| 35 |
+
|
| 36 |
+
fir_filter_length: int = 3
|
| 37 |
+
state_dim: int = 16
|
| 38 |
+
seqlen_offset: int = 0
|
| 39 |
+
fir_state_dict: dict = field(default_factory=dict)
|
| 40 |
+
state_dict: dict = field(default_factory=dict)
|
| 41 |
+
|
| 42 |
+
def reset(self):
|
| 43 |
+
self.fir_filter_length = 3
|
| 44 |
+
self.state_dim = 16
|
| 45 |
+
self.seqlen_offset = 0
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class HyenaCascadeFIRInferenceParams:
|
| 50 |
+
"""Inference parameters passed to short and medium Hyena blocks."""
|
| 51 |
+
|
| 52 |
+
fir_filter_length: int = 3
|
| 53 |
+
fir_inner_filter_length: int = 4
|
| 54 |
+
seqlen_offset: int = 0
|
| 55 |
+
fir_inner_state_dict: dict = field(default_factory=dict)
|
| 56 |
+
fir_state_dict: dict = field(default_factory=dict)
|
| 57 |
+
state_dict: dict = field(default_factory=dict)
|
| 58 |
+
|
| 59 |
+
def reset(self):
|
| 60 |
+
self.fir_filter_length = 3
|
| 61 |
+
self.fir_inner_filter_length = 4
|
| 62 |
+
self.seqlen_offset = 0
|
engine.py
ADDED
|
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied verbatim from vortex
|
| 2 |
+
# Copyright (c) 2024, Michael Poli.
|
| 3 |
+
|
| 4 |
+
import gc
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
pass
|
| 11 |
+
except:
|
| 12 |
+
pass
|
| 13 |
+
from .utils import column_split
|
| 14 |
+
from .rich_logging import activations_logger
|
| 15 |
+
|
| 16 |
+
IIR_PREFILL_MODES = [
|
| 17 |
+
"recurrence",
|
| 18 |
+
"modal-fft",
|
| 19 |
+
"hybrid-modal-recurrence",
|
| 20 |
+
"modal-scan",
|
| 21 |
+
"canonical-fft",
|
| 22 |
+
"iir-fir-caching",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def adjust_filter_shape_for_broadcast(u, h):
|
| 27 |
+
h = h.squeeze() # Standardize to [D, L] from [1, D, L] and [D, 1, L]
|
| 28 |
+
|
| 29 |
+
# Case: u: [B, D, L], k_f: [D, L]
|
| 30 |
+
if len(u.shape) > len(h.shape):
|
| 31 |
+
h = h.unsqueeze(0)
|
| 32 |
+
|
| 33 |
+
# Case: u: [B, D1, D2, L], k_f: [B, D, L]
|
| 34 |
+
if len(u.shape) > 3:
|
| 35 |
+
h = h.unsqueeze(1)
|
| 36 |
+
return h
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def fftconv_func(
|
| 40 |
+
u,
|
| 41 |
+
k,
|
| 42 |
+
D,
|
| 43 |
+
dropout_mask,
|
| 44 |
+
gelu=True,
|
| 45 |
+
k_rev=None,
|
| 46 |
+
bidirectional=False,
|
| 47 |
+
print_activations=False,
|
| 48 |
+
layer_idx=None,
|
| 49 |
+
**kwargs,
|
| 50 |
+
):
|
| 51 |
+
seqlen = u.shape[-1]
|
| 52 |
+
fft_size = 2 * seqlen
|
| 53 |
+
|
| 54 |
+
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
|
| 55 |
+
k_f = adjust_filter_shape_for_broadcast(u, k_f)
|
| 56 |
+
k = k.squeeze()
|
| 57 |
+
|
| 58 |
+
if bidirectional:
|
| 59 |
+
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
|
| 60 |
+
k, k2 = k.split(k.shape[1] // 2, dim=1)
|
| 61 |
+
k2_f = torch.fft.rfft(k2, n=fft_size) / fft_size
|
| 62 |
+
y1 = u_f * k_f
|
| 63 |
+
y2 = u_f.conj() * k2_f.conj()
|
| 64 |
+
|
| 65 |
+
y = torch.fft.irfft(y1 + y2, n=fft_size, norm="forward")[..., :seqlen]
|
| 66 |
+
|
| 67 |
+
else:
|
| 68 |
+
if k_rev is not None:
|
| 69 |
+
k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
|
| 70 |
+
k_f = k_f + k_rev_f.conj()
|
| 71 |
+
|
| 72 |
+
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
|
| 73 |
+
|
| 74 |
+
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
|
| 75 |
+
|
| 76 |
+
if print_activations:
|
| 77 |
+
activations_logger.info(f"post fftconv pre bias {y} {y.min()} {y.max()}")
|
| 78 |
+
|
| 79 |
+
out = y + u * D.unsqueeze(-1)
|
| 80 |
+
|
| 81 |
+
if print_activations:
|
| 82 |
+
activations_logger.info(f"post fftconv post bias {out} {out.min()} {out.max()}")
|
| 83 |
+
|
| 84 |
+
return out.to(dtype=u.dtype)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def canonicalize_modal_system(poles, residues):
|
| 88 |
+
"""Canonicalize a modal system.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
poles (Tensor): The poles of the system.
|
| 92 |
+
residues (Tensor): The residues of the system.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Tuple[Tensor, Tensor]: The canonicalized poles and residues.
|
| 96 |
+
"""
|
| 97 |
+
raise NotImplementedError
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def list_tensors(idx):
|
| 101 |
+
for obj in gc.get_objects():
|
| 102 |
+
try:
|
| 103 |
+
if torch.is_tensor(obj) and isinstance(obj, torch.Tensor):
|
| 104 |
+
# dump to log
|
| 105 |
+
print(type(obj), obj.size())
|
| 106 |
+
el = obj[0]
|
| 107 |
+
with open(f"tensors_{idx}.txt", "a") as f:
|
| 108 |
+
f.write(f"{type(obj)} {obj.size()} {el}\n")
|
| 109 |
+
except Exception:
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class HyenaInferenceEngine:
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
fir_fn=None,
|
| 117 |
+
iir_prefill_style="modal-fft",
|
| 118 |
+
layer_idx=None,
|
| 119 |
+
ground_truth_activations_path=None,
|
| 120 |
+
print_activations=False,
|
| 121 |
+
hyena_flip_x1x2=False,
|
| 122 |
+
) -> None:
|
| 123 |
+
self.fir_fn = fir_fn
|
| 124 |
+
assert iir_prefill_style in IIR_PREFILL_MODES, f"iir_prefill_style must be one of {IIR_PREFILL_MODES}"
|
| 125 |
+
self.iir_prefill_style = iir_prefill_style
|
| 126 |
+
self.layer_idx = layer_idx
|
| 127 |
+
self.low_mem_mode = False
|
| 128 |
+
self.ground_truth_activations_path = ground_truth_activations_path
|
| 129 |
+
self.print_activations = print_activations
|
| 130 |
+
self.hyena_flip_x1x2 = hyena_flip_x1x2
|
| 131 |
+
|
| 132 |
+
def parallel_fir(
|
| 133 |
+
self,
|
| 134 |
+
fir_fn,
|
| 135 |
+
u,
|
| 136 |
+
weight,
|
| 137 |
+
bias,
|
| 138 |
+
L,
|
| 139 |
+
dims,
|
| 140 |
+
groups=None,
|
| 141 |
+
gated_bias=False,
|
| 142 |
+
column_split_hyena=False,
|
| 143 |
+
dim_last=True,
|
| 144 |
+
fir_length=3,
|
| 145 |
+
gate=False,
|
| 146 |
+
inference_params=None,
|
| 147 |
+
prefill_mode=None,
|
| 148 |
+
padding_mask=None,
|
| 149 |
+
):
|
| 150 |
+
L = u.shape[1] if dim_last else u.shape[2]
|
| 151 |
+
if gate:
|
| 152 |
+
hidden_size, num_attention_heads, hidden_size_per_attention_head, _, _ = dims
|
| 153 |
+
# Compatibility with training infra that column splits the projections
|
| 154 |
+
if column_split_hyena:
|
| 155 |
+
x2, x1, v = column_split(u, num_attention_heads, hidden_size_per_attention_head)
|
| 156 |
+
else:
|
| 157 |
+
x2, x1, v = u.split([hidden_size, hidden_size, hidden_size], dim=1)
|
| 158 |
+
if self.hyena_flip_x1x2:
|
| 159 |
+
x1, x2 = x2, x1
|
| 160 |
+
u = x1 * v
|
| 161 |
+
|
| 162 |
+
if self.print_activations:
|
| 163 |
+
activations_logger.info(f"q: {x2}, {x2.min()}, {x2.max()}")
|
| 164 |
+
activations_logger.info(f"k: {x1}, {x1.min()}, {x1.max()}")
|
| 165 |
+
activations_logger.info(f"v: {v}, {v.min()}, {v.max()}")
|
| 166 |
+
activations_logger.info(f"post pregate: {u}, {u.min()}, {u.max()}")
|
| 167 |
+
|
| 168 |
+
# prepare input layout, dimensions and dispatch to fir kernel
|
| 169 |
+
# Deprecated
|
| 170 |
+
if fir_fn != torch.nn.functional.conv1d:
|
| 171 |
+
if dim_last:
|
| 172 |
+
u = u.permute(0, 2, 1) # B, D, L
|
| 173 |
+
z = fir_fn(u)[:, :L] # B, L, D
|
| 174 |
+
|
| 175 |
+
elif fir_length >= 128:
|
| 176 |
+
with torch.autocast("cuda"):
|
| 177 |
+
z = fftconv_func(
|
| 178 |
+
u.to(torch.float32),
|
| 179 |
+
weight[:, :, :L].to(torch.float32),
|
| 180 |
+
bias,
|
| 181 |
+
None,
|
| 182 |
+
gelu=False,
|
| 183 |
+
bidirectional=False,
|
| 184 |
+
print_activations=self.print_activations,
|
| 185 |
+
groups=groups,
|
| 186 |
+
layer_idx=self.layer_idx,
|
| 187 |
+
)
|
| 188 |
+
z = z.to(u.dtype)
|
| 189 |
+
else:
|
| 190 |
+
if dim_last:
|
| 191 |
+
u = u.permute(0, 2, 1) # B, D, L
|
| 192 |
+
|
| 193 |
+
if groups is None:
|
| 194 |
+
g = u.shape[1]
|
| 195 |
+
else:
|
| 196 |
+
g = groups
|
| 197 |
+
|
| 198 |
+
z = fir_fn(
|
| 199 |
+
u.to(torch.float32),
|
| 200 |
+
weight.to(torch.float32),
|
| 201 |
+
bias=None,
|
| 202 |
+
stride=1,
|
| 203 |
+
padding=fir_length - 1,
|
| 204 |
+
groups=u.shape[1], # always set to D, regardless of filter grouping
|
| 205 |
+
)[..., :L]
|
| 206 |
+
if self.print_activations:
|
| 207 |
+
activations_logger.info(f"post filter: {z}, {z.min()}, {z.max()}")
|
| 208 |
+
|
| 209 |
+
z = z.to(u.dtype)
|
| 210 |
+
|
| 211 |
+
if gated_bias is False:
|
| 212 |
+
if self.print_activations:
|
| 213 |
+
activations_logger.info(f"post dw conv {z} {z.min()} {z.max()}")
|
| 214 |
+
# if self.ground_truth_activations_path:
|
| 215 |
+
# z_savanna = torch.load(f"{self.ground_truth_activations_path}/post_dw_conv_{self.layer_idx}.pt")
|
| 216 |
+
# z_savanna = z_savanna.permute(1, 2, 0)
|
| 217 |
+
# z_diff = (z.squeeze() - z_savanna.squeeze()).abs().max()
|
| 218 |
+
# activations_logger.info(f"dw_conv_diff: {z_diff}")
|
| 219 |
+
|
| 220 |
+
if bias is not None:
|
| 221 |
+
if gated_bias:
|
| 222 |
+
z = z + bias[None, :, None] * u
|
| 223 |
+
else:
|
| 224 |
+
z = z + bias[None, :, None]
|
| 225 |
+
|
| 226 |
+
# handle padding post fir, the only place with biases
|
| 227 |
+
if type(padding_mask) == torch.Tensor:
|
| 228 |
+
z = z * padding_mask[:, None]
|
| 229 |
+
|
| 230 |
+
if gate:
|
| 231 |
+
# if self.layer_idx == 1:
|
| 232 |
+
# breakpoint()
|
| 233 |
+
z = x2 * z
|
| 234 |
+
|
| 235 |
+
if self.print_activations:
|
| 236 |
+
activations_logger.info(f"hyena filter: {weight}, {weight.min()}, {weight.max()}")
|
| 237 |
+
activations_logger.info(f"post postgate: {z}, {z.min()}, {z.max()}")
|
| 238 |
+
# if self.ground_truth_activations_path:
|
| 239 |
+
# q_savanna = torch.load(f"{self.ground_truth_activations_path}/q_{self.layer_idx}.pt")
|
| 240 |
+
# k_savanna = torch.load(f"{self.ground_truth_activations_path}/k_{self.layer_idx}.pt")
|
| 241 |
+
# v_savanna = torch.load(f"{self.ground_truth_activations_path}/v_{self.layer_idx}.pt")
|
| 242 |
+
|
| 243 |
+
# q_diff = (x2 - q_savanna).abs()
|
| 244 |
+
# k_diff = (x1 - k_savanna).abs()
|
| 245 |
+
# v_diff = (v - v_savanna).abs()
|
| 246 |
+
|
| 247 |
+
# activations_logger.info(f"q_diff: {q_diff.max()}, {q_diff.mean()}")
|
| 248 |
+
# activations_logger.info(f"k_diff: {k_diff.max()}, {k_diff.mean()}")
|
| 249 |
+
# activations_logger.info(f"v_diff: {v_diff.max()}, {v_diff.mean()}")
|
| 250 |
+
|
| 251 |
+
# h_savanna = torch.load(f"/home/zymrael/checkpoints/evo2/activations/savanna/hyena_filter_{self.layer_idx}.pt")
|
| 252 |
+
# h_diff = (weight[..., :h_savanna.shape[-1]].squeeze() - h_savanna.squeeze()).abs()
|
| 253 |
+
|
| 254 |
+
# activations_logger.info(f"h_diff: {h_diff.max()}, {h_diff.mean()}")
|
| 255 |
+
|
| 256 |
+
if inference_params is not None:
|
| 257 |
+
fir_state = u[..., -fir_length + 1 :]
|
| 258 |
+
else:
|
| 259 |
+
fir_state = None
|
| 260 |
+
|
| 261 |
+
return z, fir_state
|
| 262 |
+
|
| 263 |
+
def parallel_iir(
|
| 264 |
+
self,
|
| 265 |
+
z_pre,
|
| 266 |
+
h,
|
| 267 |
+
D,
|
| 268 |
+
L,
|
| 269 |
+
poles,
|
| 270 |
+
residues,
|
| 271 |
+
t,
|
| 272 |
+
dims,
|
| 273 |
+
layer_idx,
|
| 274 |
+
inference_params=None,
|
| 275 |
+
prefill_style="fft",
|
| 276 |
+
fftconv_fn=None,
|
| 277 |
+
padding_mask=None,
|
| 278 |
+
use_flashfft=False,
|
| 279 |
+
column_split_hyena=False,
|
| 280 |
+
long_fir_threshold=None,
|
| 281 |
+
):
|
| 282 |
+
"""Compute the output state of the short convolutional filter."""
|
| 283 |
+
fft_size = 2 * L
|
| 284 |
+
hidden_size, num_attention_heads, hidden_size_per_attention_head, _, _ = dims
|
| 285 |
+
# Compatibility with training infra that column splits the projections
|
| 286 |
+
if column_split_hyena:
|
| 287 |
+
z = z_pre.reshape(
|
| 288 |
+
z_pre.shape[0],
|
| 289 |
+
num_attention_heads,
|
| 290 |
+
3 * hidden_size_per_attention_head,
|
| 291 |
+
z_pre.shape[2],
|
| 292 |
+
)
|
| 293 |
+
x2, x1, v = (
|
| 294 |
+
z[:, :, :hidden_size_per_attention_head],
|
| 295 |
+
z[
|
| 296 |
+
:,
|
| 297 |
+
:,
|
| 298 |
+
hidden_size_per_attention_head : 2 * hidden_size_per_attention_head,
|
| 299 |
+
],
|
| 300 |
+
z[:, :, 2 * hidden_size_per_attention_head :],
|
| 301 |
+
)
|
| 302 |
+
x2, x1, v = (
|
| 303 |
+
x2.reshape(x2.shape[0], -1, x2.shape[-1]),
|
| 304 |
+
x1.reshape(x1.shape[0], -1, x1.shape[-1]),
|
| 305 |
+
v.reshape(v.shape[0], -1, v.shape[-1]),
|
| 306 |
+
)
|
| 307 |
+
else:
|
| 308 |
+
x2, x1, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)
|
| 309 |
+
|
| 310 |
+
if self.hyena_flip_x1x2:
|
| 311 |
+
x1, x2 = x2, x1
|
| 312 |
+
|
| 313 |
+
x1v = x1 * v
|
| 314 |
+
|
| 315 |
+
if inference_params is not None and prefill_style == "recurrence":
|
| 316 |
+
y = self.prefill_via_direct_recurrence(
|
| 317 |
+
inference_params=inference_params,
|
| 318 |
+
x1v=x1v,
|
| 319 |
+
L=L,
|
| 320 |
+
poles=poles,
|
| 321 |
+
residues=residues,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
else:
|
| 325 |
+
if use_flashfft and (L % 2) == 0: # only works with even L
|
| 326 |
+
y = fftconv_fn(
|
| 327 |
+
x1v.to(dtype=torch.bfloat16).contiguous(),
|
| 328 |
+
h.to(dtype=torch.float32),
|
| 329 |
+
)
|
| 330 |
+
X_s = None
|
| 331 |
+
|
| 332 |
+
elif long_fir_threshold is None:
|
| 333 |
+
H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) / fft_size
|
| 334 |
+
X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size)
|
| 335 |
+
X = X_s[..., : H.shape[-1]]
|
| 336 |
+
if len(z_pre.shape) > 3:
|
| 337 |
+
H = H.unsqueeze(1)
|
| 338 |
+
y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L]
|
| 339 |
+
|
| 340 |
+
else:
|
| 341 |
+
assert h.shape[0] == 1, "batch size must be 1 for long_fir_threshold"
|
| 342 |
+
h = h[0][:, None] # rearrange to d, 1, l for depthwise conv1d
|
| 343 |
+
h = h[..., :long_fir_threshold]
|
| 344 |
+
y = F.conv1d(
|
| 345 |
+
x1v,
|
| 346 |
+
h.to(dtype=x1v.dtype),
|
| 347 |
+
stride=1,
|
| 348 |
+
groups=x1v.shape[1],
|
| 349 |
+
padding=h.shape[-1] - 1,
|
| 350 |
+
)[..., :L]
|
| 351 |
+
# if self.layer_idx == 2:
|
| 352 |
+
# breakpoint()
|
| 353 |
+
y = y.to(dtype=x1v.dtype)
|
| 354 |
+
y = (y + x1v * D.unsqueeze(-1)) * x2
|
| 355 |
+
|
| 356 |
+
if self.print_activations:
|
| 357 |
+
activations_logger.info(f"hyena filter: {h}, {h.min()}, {h.max()}")
|
| 358 |
+
activations_logger.info(f"post hyena iir gate: {y}, {y.min()}, {y.max()}")
|
| 359 |
+
activations_logger.info(f"q: {x2}, {x2.min()}, {x2.max()}")
|
| 360 |
+
activations_logger.info(f"k: {x1}, {x1.min()}, {x1.max()}")
|
| 361 |
+
activations_logger.info(f"v: {v}, {v.min()}, {v.max()}")
|
| 362 |
+
# if self.ground_truth_activations_path:
|
| 363 |
+
# q_savanna = torch.load(f"{self.ground_truth_activations_path}/q_{self.layer_idx}.pt")
|
| 364 |
+
# k_savanna = torch.load(f"{self.ground_truth_activations_path}/k_{self.layer_idx}.pt")
|
| 365 |
+
# v_savanna = torch.load(f"{self.ground_truth_activations_path}/v_{self.layer_idx}.pt")
|
| 366 |
+
|
| 367 |
+
# q_diff = (x2 - q_savanna).abs()
|
| 368 |
+
# k_diff = (x1 - k_savanna).abs()
|
| 369 |
+
# v_diff = (v - v_savanna).abs()
|
| 370 |
+
|
| 371 |
+
# activations_logger.info(f"q_diff: {q_diff.max()}, {q_diff.mean()}")
|
| 372 |
+
# activations_logger.info(f"k_diff: {k_diff.max()}, {k_diff.mean()}")
|
| 373 |
+
# activations_logger.info(f"v_diff: {v_diff.max()}, {v_diff.mean()}")
|
| 374 |
+
|
| 375 |
+
# h_savanna = torch.load(f"/home/zymrael/checkpoints/evo2/activations/savanna/hyena_filter_{self.layer_idx}.pt")
|
| 376 |
+
|
| 377 |
+
# h_diff = (h[..., :h_savanna.shape[-1]].squeeze() - h_savanna.squeeze()).abs()
|
| 378 |
+
# activations_logger.info(f"h_diff: {h_diff.max()}, {h_diff.mean()}")
|
| 379 |
+
|
| 380 |
+
if inference_params is not None:
|
| 381 |
+
if prefill_style == "fft":
|
| 382 |
+
self.prefill_via_modal_fft(
|
| 383 |
+
inference_params=inference_params,
|
| 384 |
+
x1v=x1v,
|
| 385 |
+
X_s=X_s,
|
| 386 |
+
L=L,
|
| 387 |
+
t=t,
|
| 388 |
+
poles=poles,
|
| 389 |
+
dims=dims,
|
| 390 |
+
layer_idx=layer_idx,
|
| 391 |
+
use_flashfft=use_flashfft,
|
| 392 |
+
fftconv_fn=fftconv_fn,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
elif prefill_style == "recurrence":
|
| 396 |
+
# recurrent prefill is done before
|
| 397 |
+
pass
|
| 398 |
+
else:
|
| 399 |
+
raise NotImplementedError
|
| 400 |
+
if self.low_mem_mode:
|
| 401 |
+
# TODO: smarter gc
|
| 402 |
+
del z_pre, x2, x1, v, x1v, h, poles, residues
|
| 403 |
+
torch.cuda.empty_cache()
|
| 404 |
+
|
| 405 |
+
return y.permute(0, 2, 1)
|
| 406 |
+
|
| 407 |
+
def step_fir(self, u, fir_state, weight, bias=None, gated_bias=False, flip_filter=False):
|
| 408 |
+
"""Steps forward FIR filters in the architecture.
|
| 409 |
+
|
| 410 |
+
FIR filters generally include truncated convolutions in Hyena with an explicit or hybrid time-domain parametrization:
|
| 411 |
+
* Short FIR filters in Hyena featurizers
|
| 412 |
+
* Short and medium FIR filters in Hyena operators
|
| 413 |
+
|
| 414 |
+
Note:
|
| 415 |
+
`fir_state` contains the last FIR filter length - 1 elements of `u`: `u_(L-2), u_{L-1), ...`
|
| 416 |
+
We assume dimensions of `short_filter_weight` to be `[d, 1, short_filter_len]`.
|
| 417 |
+
"""
|
| 418 |
+
weight = weight.squeeze()
|
| 419 |
+
|
| 420 |
+
cache_size = fir_state.shape[-1]
|
| 421 |
+
filter_length = weight.shape[-1]
|
| 422 |
+
if flip_filter:
|
| 423 |
+
weight = weight.flip(-1)
|
| 424 |
+
weight = weight[..., -cache_size - 1 :].unsqueeze(0)
|
| 425 |
+
else:
|
| 426 |
+
weight = weight[..., : cache_size + 1].unsqueeze(0)
|
| 427 |
+
|
| 428 |
+
input_dtype = u.dtype
|
| 429 |
+
weight = weight.to(torch.float32)
|
| 430 |
+
u = u.to(torch.float32)
|
| 431 |
+
fir_state = fir_state.to(torch.float32)
|
| 432 |
+
bias = bias.to(torch.float32) if bias is not None else None
|
| 433 |
+
|
| 434 |
+
h0, h = weight[..., -1], weight[..., :-1]
|
| 435 |
+
y = h0 * u + torch.sum(fir_state * h, dim=-1)
|
| 436 |
+
|
| 437 |
+
if bias is not None:
|
| 438 |
+
if gated_bias:
|
| 439 |
+
y = y + bias * u
|
| 440 |
+
else:
|
| 441 |
+
y = y + bias
|
| 442 |
+
|
| 443 |
+
# Update the state
|
| 444 |
+
if cache_size < filter_length - 1:
|
| 445 |
+
fir_state = torch.cat([fir_state, u[..., None]], dim=-1)
|
| 446 |
+
else:
|
| 447 |
+
fir_state = torch.roll(fir_state, -1, dims=2)
|
| 448 |
+
fir_state[..., -1] = u
|
| 449 |
+
|
| 450 |
+
return y.to(input_dtype), fir_state
|
| 451 |
+
|
| 452 |
+
def step_iir(self, x2, x1, v, D, residues, poles, iir_state, iir_groups=1):
|
| 453 |
+
# TODO: kernelize
|
| 454 |
+
x1v = x1 * v
|
| 455 |
+
poles = torch.exp(poles) # poles arg contains log_poles
|
| 456 |
+
poles = poles[..., 0][None] # squeeze dummy seqlen dim and add dummy batch dim
|
| 457 |
+
residues = residues[None] # add dummy batch dim
|
| 458 |
+
iir_state = poles * iir_state + x1v[..., None]
|
| 459 |
+
|
| 460 |
+
res_state = torch.sum(residues * iir_state, dim=-1)
|
| 461 |
+
|
| 462 |
+
if iir_groups > 1:
|
| 463 |
+
raise NotImplementedError
|
| 464 |
+
# if self.layer_idx == 2:
|
| 465 |
+
# breakpoint()
|
| 466 |
+
y = x2 * (res_state + D * x1v)
|
| 467 |
+
|
| 468 |
+
return y, iir_state
|
| 469 |
+
|
| 470 |
+
def prefill_via_fir_caching(self, u, inference_params, L, *args, **kwargs):
|
| 471 |
+
"""Turns the IIR filter into a FIR and uses a cache for decoding."""
|
| 472 |
+
raise NotImplementedError(":)")
|
| 473 |
+
|
| 474 |
+
def prefill_via_direct_recurrence(self, inference_params, x1v, L, residues, poles, *args, **kwargs) -> torch.Tensor:
|
| 475 |
+
"""
|
| 476 |
+
Compute the IIR state via explicit recurrence (modal form)
|
| 477 |
+
|
| 478 |
+
This is the most memory efficient prefilling method for Hyena filters.
|
| 479 |
+
|
| 480 |
+
Note:
|
| 481 |
+
dtypes: [state: float32, poles: float32, x1v: bfloat16, output: bfloat16]
|
| 482 |
+
"""
|
| 483 |
+
state_dim = poles.shape[1]
|
| 484 |
+
x1v_ = x1v[..., None, None] # b, d, l, sdim, reim
|
| 485 |
+
x1v_ = x1v_.repeat(1, 1, 1, state_dim, 2) # b, d, l, sdim, reim
|
| 486 |
+
x1v_[..., 1] = 0
|
| 487 |
+
|
| 488 |
+
state = 0 * x1v_[:, :, 0]
|
| 489 |
+
output = 0 * x1v_[:, :, :, 0, 0] # b, d, l
|
| 490 |
+
|
| 491 |
+
# suppress dummy seqlen dimension
|
| 492 |
+
poles = poles[:, :, 0][None]
|
| 493 |
+
residues = residues[:, :, 0][None].repeat(x1v_.shape[0], 1, 1, 1) # b, d, sdim, reim
|
| 494 |
+
|
| 495 |
+
# state: b, d, sdim, reim
|
| 496 |
+
# poles: 1, d, sdim, reim
|
| 497 |
+
# x1v_: b, d, l, sdim, reim
|
| 498 |
+
for i in range(L):
|
| 499 |
+
state[..., 0] = poles[..., 0] * state[..., 0] - poles[..., 1] * state[..., 1] + x1v_[:, :, i, :, 0]
|
| 500 |
+
state[..., 1] = poles[..., 0] * state[..., 1] + poles[..., 1] * state[..., 0] + x1v_[:, :, i, :, 1]
|
| 501 |
+
output[:, :, i] = torch.sum(residues * state, dim=-2)[..., 0] # .real
|
| 502 |
+
|
| 503 |
+
inference_params.state_dict[self.layer_idx] = state.to(dtype=torch.float32)
|
| 504 |
+
|
| 505 |
+
return output
|
| 506 |
+
|
| 507 |
+
def prefill_via_hybrid_recurrence(self, inference_params, u, log_poles, x1v_f_a, L, *args, **kwargs):
|
| 508 |
+
"""
|
| 509 |
+
Compute the IIR state via hybrid recurrence-convolution over blocks
|
| 510 |
+
"""
|
| 511 |
+
raise NotImplementedError(":)")
|
| 512 |
+
|
| 513 |
+
def prefill_via_scan(self, u, inference_params=None, *args, **kwargs):
|
| 514 |
+
raise NotImplementedError
|
| 515 |
+
|
| 516 |
+
def prefill_via_canonical_fft(self, u, inference_params=None, *args, **kwargs):
|
| 517 |
+
"""
|
| 518 |
+
Compute the IIR state via a single FFT
|
| 519 |
+
|
| 520 |
+
This is the most memory efficient "parallelized" prefilling method for Hyena.
|
| 521 |
+
|
| 522 |
+
From: https://arxiv.org/abs/2310.18780
|
| 523 |
+
"""
|
| 524 |
+
raise NotImplementedError(":)")
|
| 525 |
+
|
| 526 |
+
def prefill_via_modal_fft(
|
| 527 |
+
self,
|
| 528 |
+
inference_params,
|
| 529 |
+
x1v,
|
| 530 |
+
L,
|
| 531 |
+
poles,
|
| 532 |
+
t,
|
| 533 |
+
dims,
|
| 534 |
+
layer_idx,
|
| 535 |
+
X_s=None,
|
| 536 |
+
use_flashfft=False,
|
| 537 |
+
fftconv_fn=None,
|
| 538 |
+
state_dtype=torch.float32,
|
| 539 |
+
*args,
|
| 540 |
+
**kwargs,
|
| 541 |
+
):
|
| 542 |
+
"""
|
| 543 |
+
Compute the IIR state via a single FFT
|
| 544 |
+
"""
|
| 545 |
+
# When the model has a long convolution derived from a recurrence in modal form and prefill_style is "fft",
|
| 546 |
+
# we split the filter into poles and residues and reuse FFT computation on the input.
|
| 547 |
+
hidden_size, _, _, state_size, hyena_filter_groups = dims
|
| 548 |
+
|
| 549 |
+
assert X_s is not None
|
| 550 |
+
bs = x1v.shape[0]
|
| 551 |
+
fft_size = 2 * L
|
| 552 |
+
# poles = torch.view_as_complex(poles.to(torch.float32))
|
| 553 |
+
state_s = (poles.to(torch.float32) * t).exp()
|
| 554 |
+
|
| 555 |
+
# state_s = poles**t
|
| 556 |
+
state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # B, D, state_dim, 2 * L
|
| 557 |
+
if hyena_filter_groups > 1:
|
| 558 |
+
state_S = state_S.repeat_interleave(hidden_size // hyena_filter_groups, 1)
|
| 559 |
+
state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
|
| 560 |
+
inference_params.state_dict[layer_idx] = state[..., L - 1].to(dtype=state_dtype)
|
| 561 |
+
|
| 562 |
+
def _compute_state(self, log_poles, u, t, L, *args, **kwargs):
|
| 563 |
+
"""
|
| 564 |
+
Compute the IIR state given an input `u` and log_poles of the modal system.
|
| 565 |
+
"""
|
| 566 |
+
bs = u.shape[0]
|
| 567 |
+
fft_size = 2 * L
|
| 568 |
+
U = torch.fft.rfft(u.to(torch.float32), n=fft_size)
|
| 569 |
+
fft_size = 2 * L
|
| 570 |
+
x = (log_poles * t).exp()
|
| 571 |
+
# [batch, hidden_size, state_dim, 2 * seqlen]
|
| 572 |
+
X = torch.fft.fft(x, n=fft_size).repeat(bs, 1, 1, 1)
|
| 573 |
+
state = torch.fft.ifft(U[..., None, :] * X, n=fft_size)[..., :L]
|
| 574 |
+
return state
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
# I don't think this class is used anywhere? Comment out
|
| 578 |
+
class HyenaFilter:
|
| 579 |
+
"""Handles Hyena filter computations including FFT and direct convolution."""
|
| 580 |
+
|
| 581 |
+
def __init__(self, use_flash_fft=False):
|
| 582 |
+
self.use_flash_fft = use_flash_fft
|
| 583 |
+
|
| 584 |
+
def fft_conv(self, u, k, D, **kwargs):
|
| 585 |
+
"""FFT-based convolution implementation."""
|
| 586 |
+
seqlen = u.shape[-1]
|
| 587 |
+
fft_size = 2 * seqlen
|
| 588 |
+
|
| 589 |
+
k_f = self._prepare_filter(k, u, fft_size)
|
| 590 |
+
y = self._compute_fft_conv(u, k_f, fft_size, seqlen, **kwargs)
|
| 591 |
+
|
| 592 |
+
return y + u * D.unsqueeze(-1)
|
| 593 |
+
|
| 594 |
+
def _prepare_filter(self, k, u, fft_size):
|
| 595 |
+
"""Prepare filter for FFT convolution."""
|
| 596 |
+
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
|
| 597 |
+
return adjust_filter_shape_for_broadcast(u, k_f)
|
generation.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied verbatim from vortex
|
| 2 |
+
# Copyright (c) 2024, Michael Poli.
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import sys
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from .sample import sample
|
| 11 |
+
from .tokenizer import CharLevelTokenizer
|
| 12 |
+
from .utils import print_rank_0
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Generator:
|
| 16 |
+
def __init__(self, model, tokenizer, top_k=50, top_p=0.7, temperature=1):
|
| 17 |
+
self.model = model
|
| 18 |
+
self.tokenizer = tokenizer
|
| 19 |
+
self.top_k = top_k
|
| 20 |
+
self.top_p = top_p
|
| 21 |
+
self.temperature = temperature
|
| 22 |
+
self.untils = ["\n\n"]
|
| 23 |
+
|
| 24 |
+
def generate(
|
| 25 |
+
self,
|
| 26 |
+
device: str,
|
| 27 |
+
input_string: str = None,
|
| 28 |
+
input_ids: torch.Tensor = None,
|
| 29 |
+
num_tokens: int = 32,
|
| 30 |
+
cached_generation: bool = True,
|
| 31 |
+
force_prompt_threshold: int = None,
|
| 32 |
+
max_seqlen: int = None,
|
| 33 |
+
print_generation: bool = True,
|
| 34 |
+
verbose: bool = False,
|
| 35 |
+
skip_special_tokens: bool = False,
|
| 36 |
+
stop_at_eos: bool = True,
|
| 37 |
+
inference_params_dict: dict = None,
|
| 38 |
+
token_callback=lambda i: None,
|
| 39 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 40 |
+
"""
|
| 41 |
+
Generates using the model with optional cached sampling replay.
|
| 42 |
+
|
| 43 |
+
This method enables passing in and returning the `inference_params_dict` for
|
| 44 |
+
replaying cached sampling from a given state, for example for beam search.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
device: The device to run the model on.
|
| 48 |
+
input_string: The input prompt to generate from.
|
| 49 |
+
input_ids: The input prompt token ids to generate from.
|
| 50 |
+
num_tokens: The number of tokens to generate.
|
| 51 |
+
cached_generation: Whether to use cached generation. Defaults to False.
|
| 52 |
+
force_prompt_threshold: Number of tokens to prefill in parallel before
|
| 53 |
+
switching to prompt forcing. Used to reduce peak memory usage and
|
| 54 |
+
support longer prompts. Defaults to None.
|
| 55 |
+
max_seqlen: Maximum sequence length to generate. Determines the max size
|
| 56 |
+
of the cache if larger. Otherwise automatically determined using
|
| 57 |
+
prompt length + max_tokens. Defaults to None.
|
| 58 |
+
print_generation: Whether to print generated tokens. Defaults to False.
|
| 59 |
+
verbose: Whether to print verbose output. Defaults to False.
|
| 60 |
+
skip_special_tokens: Whether to skip special tokens. Defaults to True.
|
| 61 |
+
stop_at_eos: Whether to stop generation at EOS token. Defaults to True.
|
| 62 |
+
inference_params_dict: Dictionary of inference parameters to use for
|
| 63 |
+
replaying cached sampling. Defaults to None.
|
| 64 |
+
token_callback: Optional callback function called after each token is
|
| 65 |
+
generated. Defaults to None.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
dict: The inference parameters dictionary used for generation, which can
|
| 69 |
+
be used to replay the exact same sampling sequence.
|
| 70 |
+
"""
|
| 71 |
+
if isinstance(self.tokenizer.eos, int):
|
| 72 |
+
eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
|
| 73 |
+
else:
|
| 74 |
+
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)
|
| 75 |
+
|
| 76 |
+
if input_ids is None:
|
| 77 |
+
input = self.tokenizer.tokenize(input_string)
|
| 78 |
+
if isinstance(input, list):
|
| 79 |
+
input = torch.LongTensor(input).unsqueeze(0).to(device)
|
| 80 |
+
else:
|
| 81 |
+
input = input.unsqueeze(0).to(device)
|
| 82 |
+
else:
|
| 83 |
+
input = input_ids
|
| 84 |
+
x = input
|
| 85 |
+
|
| 86 |
+
if max_seqlen is not None:
|
| 87 |
+
x = x[:, -max_seqlen:]
|
| 88 |
+
|
| 89 |
+
num_tokens = int(num_tokens)
|
| 90 |
+
batch_size = x.shape[0]
|
| 91 |
+
|
| 92 |
+
prompt_length = x.shape[1]
|
| 93 |
+
prompt_forcing = inference_params_dict is None and force_prompt_threshold is not None and prompt_length > force_prompt_threshold
|
| 94 |
+
if prompt_forcing:
|
| 95 |
+
forced_prompt_length = prompt_length - force_prompt_threshold
|
| 96 |
+
x_force = x[:, force_prompt_threshold:]
|
| 97 |
+
x = x[:, :force_prompt_threshold]
|
| 98 |
+
else:
|
| 99 |
+
forced_prompt_length = 0
|
| 100 |
+
tot_length = prompt_length + num_tokens
|
| 101 |
+
if max_seqlen is not None:
|
| 102 |
+
if max_seqlen > tot_length:
|
| 103 |
+
tot_length = max_seqlen
|
| 104 |
+
|
| 105 |
+
generation = torch.empty(
|
| 106 |
+
x.shape[0],
|
| 107 |
+
num_tokens,
|
| 108 |
+
dtype=torch.long,
|
| 109 |
+
device=x.device,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
scores = torch.empty(
|
| 113 |
+
x.shape[0],
|
| 114 |
+
num_tokens,
|
| 115 |
+
self.tokenizer.vocab_size,
|
| 116 |
+
dtype=torch.float,
|
| 117 |
+
device=x.device,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if inference_params_dict is not None:
|
| 121 |
+
cached_generation = True
|
| 122 |
+
prefilled = True
|
| 123 |
+
# Ensure that the cached data is loaded on the correct device.
|
| 124 |
+
if any(data.device != x.device for data in inference_params_dict["hcl"].fir_state_dict.values()):
|
| 125 |
+
for key, data in inference_params_dict["mha"].key_value_memory_dict.items():
|
| 126 |
+
inference_params_dict["mha"].key_value_memory_dict[key] = data.to(x.device)
|
| 127 |
+
for key, data in inference_params_dict["hcl"].fir_state_dict.items():
|
| 128 |
+
inference_params_dict["hcl"].fir_state_dict[key] = data.to(x.device)
|
| 129 |
+
for key, data in inference_params_dict["hcl"].state_dict.items():
|
| 130 |
+
inference_params_dict["hcl"].state_dict[key] = data.to(x.device)
|
| 131 |
+
for key, data in inference_params_dict["hcm"].fir_inner_state_dict.items():
|
| 132 |
+
inference_params_dict["hcm"].fir_inner_state_dict[key] = data.to(x.device)
|
| 133 |
+
for key, data in inference_params_dict["hcm"].fir_state_dict.items():
|
| 134 |
+
inference_params_dict["hcm"].fir_state_dict[key] = data.to(x.device)
|
| 135 |
+
for key, data in inference_params_dict["hcm"].state_dict.items():
|
| 136 |
+
inference_params_dict["hcm"].state_dict[key] = data.to(x.device)
|
| 137 |
+
for key, data in inference_params_dict["hcs"].fir_state_dict.items():
|
| 138 |
+
inference_params_dict["hcs"].fir_state_dict[key] = data.to(x.device)
|
| 139 |
+
for key, data in inference_params_dict["hcs"].fir_inner_state_dict.items():
|
| 140 |
+
inference_params_dict["hcs"].fir_inner_state_dict[key] = data.to(x.device)
|
| 141 |
+
for key, data in inference_params_dict["hcs"].state_dict.items():
|
| 142 |
+
inference_params_dict["hcs"].state_dict[key] = data.to(x.device)
|
| 143 |
+
inference_params_dict["mha"].max_batch_size = batch_size
|
| 144 |
+
elif cached_generation:
|
| 145 |
+
inference_params_dict = self.model.initialize_inference_params(max_seqlen=tot_length)
|
| 146 |
+
inference_params_dict["mha"].max_batch_size = batch_size
|
| 147 |
+
prefilled = False
|
| 148 |
+
else:
|
| 149 |
+
inference_params_dict = None
|
| 150 |
+
prefilled = False
|
| 151 |
+
|
| 152 |
+
if verbose:
|
| 153 |
+
mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9
|
| 154 |
+
print_rank_0(f"Memory after tokenization: {mem_after_tok} GB")
|
| 155 |
+
print_rank_0("Starting generation...")
|
| 156 |
+
if input_string is not None:
|
| 157 |
+
print_rank_0("Prompt: " + input_string)
|
| 158 |
+
else:
|
| 159 |
+
print_rank_0(f"Prompt ids: {input_ids} {input_ids.shape}")
|
| 160 |
+
|
| 161 |
+
i = 0
|
| 162 |
+
for i in range(forced_prompt_length + num_tokens):
|
| 163 |
+
post_prefill = prefilled or (cached_generation and i > 0)
|
| 164 |
+
|
| 165 |
+
# prefill then process only the last token
|
| 166 |
+
if post_prefill:
|
| 167 |
+
x = x[:, -1:]
|
| 168 |
+
seqlen_offset = inference_params_dict["mha"].seqlen_offset
|
| 169 |
+
|
| 170 |
+
if seqlen_offset == 0:
|
| 171 |
+
if prompt_forcing:
|
| 172 |
+
seqlen_offset = force_prompt_threshold
|
| 173 |
+
else:
|
| 174 |
+
seqlen_offset = input.shape[-1]
|
| 175 |
+
inference_params_dict["mha"].seqlen_offset = seqlen_offset
|
| 176 |
+
inference_params_dict["hcl"].seqlen_offset = seqlen_offset
|
| 177 |
+
inference_params_dict["hcm"].seqlen_offset = seqlen_offset
|
| 178 |
+
inference_params_dict["hcs"].seqlen_offset = seqlen_offset
|
| 179 |
+
else:
|
| 180 |
+
inference_params_dict["mha"].seqlen_offset += 1
|
| 181 |
+
inference_params_dict["hcl"].seqlen_offset += 1
|
| 182 |
+
inference_params_dict["hcm"].seqlen_offset += 1
|
| 183 |
+
inference_params_dict["hcs"].seqlen_offset += 1
|
| 184 |
+
|
| 185 |
+
# do forward pass with no gradient
|
| 186 |
+
with torch.inference_mode():
|
| 187 |
+
logits, inference_params_dict = self.model(
|
| 188 |
+
x,
|
| 189 |
+
inference_params_dict=inference_params_dict,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
token_callback(i)
|
| 193 |
+
|
| 194 |
+
last_logits = logits[:, -1]
|
| 195 |
+
|
| 196 |
+
if prompt_forcing and i < forced_prompt_length:
|
| 197 |
+
new_idx = x_force[:, i]
|
| 198 |
+
else:
|
| 199 |
+
new_idx = sample(
|
| 200 |
+
last_logits,
|
| 201 |
+
top_k=self.top_k,
|
| 202 |
+
top_p=self.top_p,
|
| 203 |
+
temperature=self.temperature,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
if stop_at_eos and (generation[0, -1:] == eos_token_ids).all():
|
| 207 |
+
print("Stopping generation at EOS")
|
| 208 |
+
|
| 209 |
+
if print_generation and verbose and batch_size == 1:
|
| 210 |
+
print(
|
| 211 |
+
f"{self.tokenizer.detokenize([new_idx.item()])}",
|
| 212 |
+
end=" ",
|
| 213 |
+
flush=True,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if prompt_forcing:
|
| 217 |
+
if i >= forced_prompt_length:
|
| 218 |
+
scores[:, i - forced_prompt_length] = last_logits
|
| 219 |
+
generation[:, i - forced_prompt_length] = new_idx
|
| 220 |
+
else:
|
| 221 |
+
scores[:, i] = last_logits
|
| 222 |
+
generation[:, i] = new_idx
|
| 223 |
+
|
| 224 |
+
if post_prefill:
|
| 225 |
+
x = new_idx[:, None]
|
| 226 |
+
else:
|
| 227 |
+
x = torch.cat([x, new_idx[:, None]], dim=-1)
|
| 228 |
+
|
| 229 |
+
if verbose:
|
| 230 |
+
y = self.tokenizer.detokenize_batch(generation[:, : i + 1])
|
| 231 |
+
|
| 232 |
+
for until in self.untils:
|
| 233 |
+
if until in y:
|
| 234 |
+
y = y.split(until)[0]
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
print(f"\nInput: {input_string}, Output: {y}")
|
| 238 |
+
|
| 239 |
+
mem_end = torch.cuda.memory_allocated(device=x.device) / 1e9
|
| 240 |
+
print(f"Memory after generation: {mem_end} GB")
|
| 241 |
+
|
| 242 |
+
return generation[:, : i + 1], scores[:, : i + 1], inference_params_dict
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def logits_to_logprobs(logits: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
|
| 246 |
+
"""Convert logits to log probabilities."""
|
| 247 |
+
probs = torch.log_softmax(logits, dim=-1)
|
| 248 |
+
return torch.gather(probs, -1, tokens.unsqueeze(-1)).squeeze(-1)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def prepare_batch(
|
| 252 |
+
seqs: list[str], tokenizer: CharLevelTokenizer, prepend_bos: bool = False, device: str = "cuda:0"
|
| 253 |
+
) -> tuple[torch.Tensor, list[int]]:
|
| 254 |
+
"""Prepare a batch of sequences for the model."""
|
| 255 |
+
if prepend_bos:
|
| 256 |
+
seqs = [tokenizer.bos + seq for seq in seqs]
|
| 257 |
+
|
| 258 |
+
tokens = [tokenizer.tokenize(seq) for seq in seqs]
|
| 259 |
+
if isinstance(tokens[0], list):
|
| 260 |
+
tokens = [torch.tensor(t, dtype=torch.long) for t in tokens]
|
| 261 |
+
|
| 262 |
+
max_len = max(len(t) for t in tokens)
|
| 263 |
+
batch = torch.zeros((len(tokens), max_len), dtype=torch.long)
|
| 264 |
+
|
| 265 |
+
for i, t in enumerate(tokens):
|
| 266 |
+
batch[i, : len(t)] = t
|
| 267 |
+
|
| 268 |
+
return batch.to(device), [len(t) for t in tokens]
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
@dataclass(kw_only=True)
|
| 272 |
+
class GenerationOutput:
|
| 273 |
+
sequences: list[str]
|
| 274 |
+
logits: list[torch.Tensor]
|
| 275 |
+
logprobs_mean: list[float]
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def generate(
|
| 279 |
+
*,
|
| 280 |
+
prompt_seqs: list[str],
|
| 281 |
+
model,
|
| 282 |
+
tokenizer: CharLevelTokenizer,
|
| 283 |
+
n_tokens: int = 100,
|
| 284 |
+
temperature: float = 0.0,
|
| 285 |
+
top_k: int = 1,
|
| 286 |
+
top_p: float = 1.0,
|
| 287 |
+
batched: bool = True,
|
| 288 |
+
prepend_bos: bool = False,
|
| 289 |
+
force_prompt_threshold: int = 1000,
|
| 290 |
+
cached_generation: bool = True,
|
| 291 |
+
verbose: int = 1,
|
| 292 |
+
device: str = "cuda:0",
|
| 293 |
+
**kwargs,
|
| 294 |
+
) -> GenerationOutput:
|
| 295 |
+
"""
|
| 296 |
+
Performs generation from a list of prompts.
|
| 297 |
+
If all prompts are the same length, this can do batched generation.
|
| 298 |
+
Also supports cached generation for efficient sampling.
|
| 299 |
+
"""
|
| 300 |
+
model.eval()
|
| 301 |
+
|
| 302 |
+
g = Generator(
|
| 303 |
+
model,
|
| 304 |
+
tokenizer,
|
| 305 |
+
top_k=top_k,
|
| 306 |
+
top_p=top_p,
|
| 307 |
+
temperature=temperature,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
uniform_lengths = all(len(s) == len(prompt_seqs[0]) for s in prompt_seqs)
|
| 311 |
+
|
| 312 |
+
if batched and uniform_lengths:
|
| 313 |
+
input_ids_list = [
|
| 314 |
+
prepare_batch(
|
| 315 |
+
prompt_seqs,
|
| 316 |
+
tokenizer,
|
| 317 |
+
prepend_bos=prepend_bos,
|
| 318 |
+
device=device,
|
| 319 |
+
)[0]
|
| 320 |
+
]
|
| 321 |
+
else:
|
| 322 |
+
sys.stderr.write("WARNING: Batched generation is turned off.\n")
|
| 323 |
+
input_ids_list = [
|
| 324 |
+
prepare_batch(
|
| 325 |
+
[prompt_seq],
|
| 326 |
+
tokenizer,
|
| 327 |
+
prepend_bos=prepend_bos,
|
| 328 |
+
device=device,
|
| 329 |
+
)[0]
|
| 330 |
+
for prompt_seq in prompt_seqs
|
| 331 |
+
]
|
| 332 |
+
|
| 333 |
+
generated_seqs, generated_scores, logitss = [], [], []
|
| 334 |
+
for input_ids in input_ids_list:
|
| 335 |
+
batch_size = input_ids.shape[0]
|
| 336 |
+
|
| 337 |
+
output_ids, logits, _ = g.generate(
|
| 338 |
+
input_ids=input_ids,
|
| 339 |
+
num_tokens=n_tokens,
|
| 340 |
+
device=device,
|
| 341 |
+
print_generation=(verbose > 1),
|
| 342 |
+
verbose=(verbose > 1),
|
| 343 |
+
stop_at_eos=False,
|
| 344 |
+
force_prompt_threshold=force_prompt_threshold,
|
| 345 |
+
cached_generation=cached_generation,
|
| 346 |
+
**kwargs,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if verbose > 1:
|
| 350 |
+
print("input_ids.shape", input_ids.shape)
|
| 351 |
+
print("output_ids.shape", output_ids.shape)
|
| 352 |
+
print("logits.shape", logits.shape)
|
| 353 |
+
|
| 354 |
+
generated_seqs_batch = list(tokenizer.detokenize_batch(output_ids))
|
| 355 |
+
assert len(generated_seqs_batch) == batch_size
|
| 356 |
+
generated_seqs += generated_seqs_batch
|
| 357 |
+
logitss.append(logits)
|
| 358 |
+
|
| 359 |
+
logprobs = logits_to_logprobs(logits, output_ids)
|
| 360 |
+
logprobs = logprobs.float().cpu().numpy()
|
| 361 |
+
|
| 362 |
+
generated_scores += [np.mean(logprobs[idx]) for idx in range(batch_size)]
|
| 363 |
+
|
| 364 |
+
assert len(generated_seqs) == len(generated_scores) == len(prompt_seqs)
|
| 365 |
+
if verbose:
|
| 366 |
+
for seq, score, prompt in zip(generated_seqs, generated_scores, prompt_seqs):
|
| 367 |
+
print(f'Prompt: "{prompt}",\tOutput: "{seq}",\tScore: {score}')
|
| 368 |
+
|
| 369 |
+
return GenerationOutput(
|
| 370 |
+
sequences=generated_seqs,
|
| 371 |
+
logits=logitss,
|
| 372 |
+
logprobs_mean=generated_scores,
|
| 373 |
+
)
|
layers.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied verbatim from vortex (minus the commented out code)
|
| 2 |
+
# Copyright (c) 2024, Michael Poli.
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from typing import Callable
|
| 9 |
+
from .utils import grab_first_if_tuple
|
| 10 |
+
|
| 11 |
+
from transformer_engine.pytorch import Linear
|
| 12 |
+
from transformer_engine.common.recipe import Format, DelayedScaling
|
| 13 |
+
import transformer_engine.pytorch as te
|
| 14 |
+
|
| 15 |
+
# Not bothering with ops right now (which is an interface with custom Triton
|
| 16 |
+
# kernels)
|
| 17 |
+
# try:
|
| 18 |
+
# from hyena_ops import hyena_se_fwd, hyena_mr_fwd, hyena_li_fwd
|
| 19 |
+
# except ImportError:
|
| 20 |
+
# hyena_se_fwd, hyena_mr_fwd, hyena_li_fwd = None, None, None
|
| 21 |
+
|
| 22 |
+
hyena_se_fwd, hyena_mr_fwd, hyena_li_fwd = None, None, None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def set_format_recipe():
|
| 26 |
+
fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass
|
| 27 |
+
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
|
| 28 |
+
return fp8_format, fp8_recipe
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TELinear(Linear):
|
| 32 |
+
"""
|
| 33 |
+
Wrapper for Transformer-Engine's `Linear` layer.
|
| 34 |
+
|
| 35 |
+
Note that if Megatron's parallel_state has not been initialized
|
| 36 |
+
yet, the tp_group passed to TE will be None and must be set later
|
| 37 |
+
via set_tensor_parallel_group().
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
input_size: int,
|
| 43 |
+
output_size: int,
|
| 44 |
+
init_method: Callable,
|
| 45 |
+
bias: bool = True,
|
| 46 |
+
skip_bias_add: bool = False,
|
| 47 |
+
use_fp8: bool = False,
|
| 48 |
+
**kwargs,
|
| 49 |
+
):
|
| 50 |
+
# Parameters are initialized at higher precision even if fp8
|
| 51 |
+
# is used
|
| 52 |
+
params_dtype = torch.bfloat16
|
| 53 |
+
|
| 54 |
+
# TE returns a zero length Tensor when bias=False and
|
| 55 |
+
# return_bias=True, but we prefer None. So in that case we
|
| 56 |
+
# tell TE to not return the bias, and return None
|
| 57 |
+
# ourselves. This way our forward always returns two values
|
| 58 |
+
# and we don't have to deal with the zero length Tensor.
|
| 59 |
+
self.te_return_bias = skip_bias_add and bias
|
| 60 |
+
|
| 61 |
+
self.use_fp8_input_projections = use_fp8
|
| 62 |
+
if use_fp8:
|
| 63 |
+
self.fp8_format, self.fp8_recipe = set_format_recipe()
|
| 64 |
+
|
| 65 |
+
super().__init__(
|
| 66 |
+
in_features=input_size,
|
| 67 |
+
out_features=output_size,
|
| 68 |
+
sequence_parallel=False,
|
| 69 |
+
fuse_wgrad_accumulation=False,
|
| 70 |
+
tp_group=None,
|
| 71 |
+
tp_size=1,
|
| 72 |
+
init_method=init_method,
|
| 73 |
+
params_dtype=params_dtype,
|
| 74 |
+
parallel_mode=None,
|
| 75 |
+
bias=bias,
|
| 76 |
+
return_bias=self.te_return_bias,
|
| 77 |
+
**kwargs,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
if self.use_fp8_input_projections:
|
| 82 |
+
with te.fp8_autocast(enabled=True, fp8_recipe=self.fp8_recipe):
|
| 83 |
+
out = super().forward(x)
|
| 84 |
+
else:
|
| 85 |
+
out = super().forward(x)
|
| 86 |
+
|
| 87 |
+
# TE only returns a tuple when return_bias is True, otherwise
|
| 88 |
+
# it returns a single Tensor, we always want to return two
|
| 89 |
+
# values regardless of the arguments.
|
| 90 |
+
if self.te_return_bias:
|
| 91 |
+
return out
|
| 92 |
+
return out, None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class FlexLinear:
|
| 96 |
+
"""
|
| 97 |
+
Megatron and Transformer Engine linear layer compatible with fp8, bf16, fp16 and fp32
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __new__(
|
| 101 |
+
self,
|
| 102 |
+
input_size,
|
| 103 |
+
output_size,
|
| 104 |
+
config,
|
| 105 |
+
parallel_mode: str,
|
| 106 |
+
bias: bool = False,
|
| 107 |
+
skip_bias_add: bool = True,
|
| 108 |
+
use_fp8: bool = False,
|
| 109 |
+
input_is_parallel=False, # for row parallel
|
| 110 |
+
gather_output: bool = True, # for column parallel
|
| 111 |
+
parallel_output: bool = False, # for row parallel
|
| 112 |
+
**kwargs,
|
| 113 |
+
):
|
| 114 |
+
# use_fp8 = config.use_fp8_linears
|
| 115 |
+
self.config = config
|
| 116 |
+
instance = None
|
| 117 |
+
|
| 118 |
+
if use_fp8:
|
| 119 |
+
instance = TELinear(
|
| 120 |
+
input_size=input_size,
|
| 121 |
+
output_size=output_size,
|
| 122 |
+
config=self.config,
|
| 123 |
+
parallel_mode=parallel_mode,
|
| 124 |
+
bias=bias,
|
| 125 |
+
skip_bias_add=skip_bias_add,
|
| 126 |
+
**kwargs,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return instance
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class RMSNorm(torch.nn.Module):
|
| 133 |
+
def __init__(self, config):
|
| 134 |
+
super(RMSNorm, self).__init__()
|
| 135 |
+
self.eps, self.hidden_size = config.eps, config.hidden_size
|
| 136 |
+
self.scale = torch.nn.Parameter(torch.ones(self.hidden_size, dtype=config.params_dtype))
|
| 137 |
+
self.register_parameter("scale", self.scale)
|
| 138 |
+
self.use_flash_rmsnorm = config.get("use_flash_rmsnorm", False)
|
| 139 |
+
|
| 140 |
+
if self.use_flash_rmsnorm:
|
| 141 |
+
from flash_attn.ops.rms_norm import rms_norm as rmsnorm_func
|
| 142 |
+
|
| 143 |
+
self.rmsnorm_func = rmsnorm_func
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
if self.use_flash_rmsnorm:
|
| 147 |
+
return self.rmsnorm_func(x, self.scale, self.eps)
|
| 148 |
+
else:
|
| 149 |
+
y = x / (x.norm(2, dim=-1, keepdim=True) * self.hidden_size ** (-1.0 / 2) + self.eps)
|
| 150 |
+
return self.scale * y
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class ParallelGatedMLP(nn.Module):
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
config,
|
| 157 |
+
layer_idx,
|
| 158 |
+
):
|
| 159 |
+
super().__init__()
|
| 160 |
+
|
| 161 |
+
self.layer_idx = layer_idx
|
| 162 |
+
multiple_of = config.get("inner_size_multiple_of", 64)
|
| 163 |
+
self.act_type = config.get("mlp_activation", "gelu")
|
| 164 |
+
if self.act_type == "gelu":
|
| 165 |
+
self.act = F.gelu
|
| 166 |
+
elif self.act_type == "silu":
|
| 167 |
+
self.act = F.silu
|
| 168 |
+
else:
|
| 169 |
+
raise NotImplementedError
|
| 170 |
+
|
| 171 |
+
if self.layer_idx > 0 and config.get("evo2_style_activations", False):
|
| 172 |
+
self.act = nn.Identity()
|
| 173 |
+
|
| 174 |
+
self.multiple_of = multiple_of * config.model_parallel_size
|
| 175 |
+
|
| 176 |
+
inner_size = int(2 * config.hidden_size * 4 / 3)
|
| 177 |
+
inner_size = self.multiple_of * ((inner_size + self.multiple_of - 1) // self.multiple_of)
|
| 178 |
+
inner_size = config.get("inner_mlp_size", inner_size)
|
| 179 |
+
|
| 180 |
+
self.l1 = nn.Linear(
|
| 181 |
+
in_features=config.hidden_size,
|
| 182 |
+
out_features=inner_size,
|
| 183 |
+
bias=False,
|
| 184 |
+
)
|
| 185 |
+
self.l2 = nn.Linear(
|
| 186 |
+
in_features=config.hidden_size,
|
| 187 |
+
out_features=inner_size,
|
| 188 |
+
bias=False,
|
| 189 |
+
)
|
| 190 |
+
self.l3 = nn.Linear(
|
| 191 |
+
in_features=inner_size,
|
| 192 |
+
out_features=config.hidden_size,
|
| 193 |
+
bias=False,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def forward(self, z):
|
| 197 |
+
z1, z2 = self.l1(z), self.l2(z)
|
| 198 |
+
z1, z2 = grab_first_if_tuple(z1), grab_first_if_tuple(z2)
|
| 199 |
+
y = self.l3(self.act(z1) * z2)
|
| 200 |
+
return grab_first_if_tuple(y)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class Embedding(nn.Module):
|
| 204 |
+
_train_dtype = "bf16"
|
| 205 |
+
|
| 206 |
+
def __init__(self, config):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
| 209 |
+
|
| 210 |
+
def embed(self, input_ids, position_ids=None, tokentype_ids=None):
|
| 211 |
+
embeddings = self.word_embeddings(input_ids)
|
| 212 |
+
return embeddings
|
| 213 |
+
|
| 214 |
+
def unembed(self, u):
|
| 215 |
+
weight = self.word_embeddings.weight
|
| 216 |
+
return torch.matmul(u, weight)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class VocabParallelEmbedding(nn.Embedding):
|
| 220 |
+
"Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py"
|
| 221 |
+
|
| 222 |
+
def __init__(self, config):
|
| 223 |
+
vocab_size, process_group, padding_idx = (
|
| 224 |
+
config.vocab_size,
|
| 225 |
+
config.get("process_group", None),
|
| 226 |
+
config.get("padding_idx", None),
|
| 227 |
+
)
|
| 228 |
+
self.process_group = process_group
|
| 229 |
+
if process_group is not None:
|
| 230 |
+
world_size = torch.distributed.get_world_size(process_group)
|
| 231 |
+
if vocab_size % world_size != 0:
|
| 232 |
+
raise ValueError(f"vocab_size ({vocab_size}) must be divisible by " f"world_size ({world_size})")
|
| 233 |
+
if world_size > 1 and padding_idx is not None:
|
| 234 |
+
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
| 235 |
+
else:
|
| 236 |
+
world_size = 1
|
| 237 |
+
super().__init__(
|
| 238 |
+
vocab_size // world_size,
|
| 239 |
+
embedding_dim=config.hidden_size,
|
| 240 |
+
padding_idx=padding_idx,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 244 |
+
if self.process_group is None:
|
| 245 |
+
return super().forward(input)
|
| 246 |
+
else:
|
| 247 |
+
rank = torch.distributed.get_rank(self.process_group)
|
| 248 |
+
vocab_size = self.num_embeddings
|
| 249 |
+
vocab_start_index, vocab_end_index = (
|
| 250 |
+
rank * vocab_size,
|
| 251 |
+
(rank + 1) * vocab_size,
|
| 252 |
+
)
|
| 253 |
+
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
| 254 |
+
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
| 255 |
+
input = input - vocab_start_index
|
| 256 |
+
input[input_ids_mask] = 0
|
| 257 |
+
embeddings = self.forward(input)
|
| 258 |
+
embeddings[input_ids_mask] = 0.0
|
| 259 |
+
# Reduce to the global process group
|
| 260 |
+
torch.distributed.all_reduce(embeddings, group=self.process_group)
|
| 261 |
+
return embeddings
|
| 262 |
+
|
| 263 |
+
def unembed(self, u: Tensor) -> Tensor:
|
| 264 |
+
if self.process_group is None:
|
| 265 |
+
return u @ self.weight.T
|
| 266 |
+
else:
|
| 267 |
+
raise NotImplementedError
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class VocabParallelUnembedding(VocabParallelEmbedding):
|
| 271 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 272 |
+
return self.unembed(input)
|
model.py
ADDED
|
@@ -0,0 +1,937 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied verbatim from vortex
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2024, Michael Poli.
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .cache import (
|
| 11 |
+
InferenceParams,
|
| 12 |
+
HyenaCascadeFIRInferenceParams,
|
| 13 |
+
HyenaCascadeIIRInferenceParams,
|
| 14 |
+
)
|
| 15 |
+
from .engine import HyenaInferenceEngine
|
| 16 |
+
from .layers import (
|
| 17 |
+
ParallelGatedMLP,
|
| 18 |
+
RMSNorm,
|
| 19 |
+
VocabParallelEmbedding,
|
| 20 |
+
VocabParallelUnembedding,
|
| 21 |
+
TELinear,
|
| 22 |
+
)
|
| 23 |
+
from .utils import (
|
| 24 |
+
Lambda,
|
| 25 |
+
column_split,
|
| 26 |
+
interleave,
|
| 27 |
+
print_rank_0,
|
| 28 |
+
move_to_device,
|
| 29 |
+
fixup_fp8_extra_states,
|
| 30 |
+
fixup_te_workspace,
|
| 31 |
+
)
|
| 32 |
+
from .rich_logging import activations_logger, enable_activations_logging
|
| 33 |
+
|
| 34 |
+
import logging
|
| 35 |
+
from tqdm import tqdm
|
| 36 |
+
|
| 37 |
+
from attention import MHA
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
from vortex.model.positional_embeddings import swap_mha_rope
|
| 41 |
+
except ImportError:
|
| 42 |
+
"could not import swap_mha_rope from src.positional_embeddings"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AttentionBlock(nn.Module):
|
| 46 |
+
def __init__(self, config, layer_idx) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.config = config
|
| 49 |
+
self.pre_norm, self.post_norm = RMSNorm(config), RMSNorm(config)
|
| 50 |
+
self.layer_idx = layer_idx
|
| 51 |
+
self.print_activations = config.get("print_activations", False)
|
| 52 |
+
self.proj_groups = config.get("proj_groups", 1)
|
| 53 |
+
dtype = config.get("attn_block_dtype", torch.bfloat16)
|
| 54 |
+
mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
|
| 55 |
+
self.num_attention_heads = config.num_attention_heads
|
| 56 |
+
self.hidden_size = config.hidden_size
|
| 57 |
+
self.hidden_size_per_attention_head = config.hidden_size // config.num_attention_heads
|
| 58 |
+
|
| 59 |
+
self.counter = 0
|
| 60 |
+
self.inner_mha_cls = MHA(
|
| 61 |
+
embed_dim=config.hidden_size,
|
| 62 |
+
num_heads=config.num_attention_heads,
|
| 63 |
+
num_heads_kv=config.num_attention_heads // self.proj_groups,
|
| 64 |
+
rotary_emb_dim=config.hidden_size // config.num_attention_heads,
|
| 65 |
+
qkv_proj_bias=config.get("qkv_proj_bias", True),
|
| 66 |
+
rotary_emb_base=config.get("rotary_emb_base", 1000000),
|
| 67 |
+
causal=True,
|
| 68 |
+
layer_idx=layer_idx,
|
| 69 |
+
out_proj_bias=config.get("mha_out_proj_bias", True),
|
| 70 |
+
use_flash_attn=self.config.use_flash_attn,
|
| 71 |
+
).to(dtype=dtype)
|
| 72 |
+
|
| 73 |
+
# check if using interpolated rotary pos emb from config, and swap the rope emb
|
| 74 |
+
if config.get("use_interpolated_rotary_pos_emb", False):
|
| 75 |
+
swap_mha_rope(
|
| 76 |
+
mha=self.inner_mha_cls,
|
| 77 |
+
kwargs_new_rope={"scaling_factor": config.get("rotary_emb_scaling_factor", 1.0)},
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
if self.config.get("smeared_gqa", False):
|
| 81 |
+
self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads
|
| 82 |
+
self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
|
| 83 |
+
|
| 84 |
+
self.mlp = ParallelGatedMLP(config, layer_idx).to(dtype=mlp_dtype)
|
| 85 |
+
|
| 86 |
+
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
|
| 87 |
+
if (
|
| 88 |
+
type(padding_mask) == torch.Tensor
|
| 89 |
+
): # workaround for masking bug in FA. This works because Wqkv does not have bias
|
| 90 |
+
# and attention scores will be also automatically zeroed.
|
| 91 |
+
u = u * padding_mask[..., None]
|
| 92 |
+
|
| 93 |
+
if self.print_activations:
|
| 94 |
+
activations_logger.info(f"pre mha: {u}")
|
| 95 |
+
|
| 96 |
+
u = (
|
| 97 |
+
self.inner_mha_cls(
|
| 98 |
+
self.pre_norm(u),
|
| 99 |
+
inference_params=inference_params,
|
| 100 |
+
)
|
| 101 |
+
+ u
|
| 102 |
+
)
|
| 103 |
+
if self.print_activations:
|
| 104 |
+
activations_logger.info(f"post mha: {u}")
|
| 105 |
+
|
| 106 |
+
if type(padding_mask) == torch.Tensor: # guard against bias
|
| 107 |
+
u = u * padding_mask[..., None]
|
| 108 |
+
|
| 109 |
+
if self.print_activations:
|
| 110 |
+
activations_logger.info(f"pre mlp: {u} {u.min()} {u.max()} {self.mlp.__class__}")
|
| 111 |
+
activations_logger.info(
|
| 112 |
+
f"post mlp norm: {self.post_norm(u)} {self.post_norm(u).min()} {self.post_norm(u).max()}"
|
| 113 |
+
)
|
| 114 |
+
activations_logger.info(
|
| 115 |
+
f"post mlp: {self.mlp(self.post_norm(u))} {self.mlp(self.post_norm(u)).min()} {self.mlp(self.post_norm(u)).max()}"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
u = self.mlp(self.post_norm(u)) + u
|
| 119 |
+
return u, None
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class HyenaCascade(nn.Module):
|
| 123 |
+
def __init__(self, config, layer_idx, hyena_filter_groups=None, fir_inner_filter_length=None) -> None:
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.config = config
|
| 126 |
+
self.layer_idx = layer_idx
|
| 127 |
+
self.hyena_filter_groups = hyena_filter_groups
|
| 128 |
+
self.print_activations = config.get("print_activations", False)
|
| 129 |
+
self.ground_truth_activations_path = config.get("ground_truth_activations_path", None)
|
| 130 |
+
|
| 131 |
+
self.use_flashfft = config.get("use_flashfft", False)
|
| 132 |
+
self.state_size = config.state_size
|
| 133 |
+
self.hidden_size = config.hidden_size
|
| 134 |
+
self.num_filters = config.num_filters
|
| 135 |
+
self.inference_mode = config.get("inference_mode", True)
|
| 136 |
+
self.counter = 0
|
| 137 |
+
self.column_split_hyena = config.get("column_split_hyena", True)
|
| 138 |
+
self.hyena_flip_x1x2 = config.get("hyena_flip_x1x2", False)
|
| 139 |
+
|
| 140 |
+
assert self.hidden_size % self.num_filters == 0 and self.num_filters <= self.hidden_size
|
| 141 |
+
|
| 142 |
+
# attention heads are not used except to split post short_filter
|
| 143 |
+
# projections in the same way as the checkpoint
|
| 144 |
+
self.num_attention_heads = config.num_attention_heads
|
| 145 |
+
self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
|
| 146 |
+
|
| 147 |
+
self.fir_inner_filter_length = fir_inner_filter_length
|
| 148 |
+
self.short_filter_length = config.short_filter_length
|
| 149 |
+
self.short_filter_weight = nn.Parameter(torch.randn(3 * config.hidden_size, 1, config.short_filter_length))
|
| 150 |
+
self.short_filter_bias = nn.Parameter(torch.randn(3 * config.hidden_size)) if config.short_filter_bias else None
|
| 151 |
+
|
| 152 |
+
self.engine = HyenaInferenceEngine(
|
| 153 |
+
layer_idx=layer_idx,
|
| 154 |
+
ground_truth_activations_path=self.ground_truth_activations_path,
|
| 155 |
+
print_activations=self.print_activations,
|
| 156 |
+
hyena_flip_x1x2=config.get("hyena_flip_x1x2", False),
|
| 157 |
+
)
|
| 158 |
+
self.use_flash_depthwise = config.get("use_flash_depthwise", False)
|
| 159 |
+
self.data_dtype = None
|
| 160 |
+
|
| 161 |
+
if self.use_flash_depthwise:
|
| 162 |
+
try:
|
| 163 |
+
from flashfftconv import FlashDepthwiseConv1d
|
| 164 |
+
|
| 165 |
+
self.fir_fn = FlashDepthwiseConv1d(
|
| 166 |
+
channels=3 * self.hidden_size,
|
| 167 |
+
kernel_size=self.short_filter_length,
|
| 168 |
+
padding=self.short_filter_length - 1,
|
| 169 |
+
weights=self.short_filter_weight,
|
| 170 |
+
bias=self.short_filter_bias,
|
| 171 |
+
device=None,
|
| 172 |
+
dtype=self.config.get("depthwise_dtype", torch.bfloat16),
|
| 173 |
+
)
|
| 174 |
+
except ImportError:
|
| 175 |
+
"flashfftconv not installed"
|
| 176 |
+
else:
|
| 177 |
+
self.fir_fn = F.conv1d
|
| 178 |
+
|
| 179 |
+
self.fir_inner_fn = F.conv1d
|
| 180 |
+
|
| 181 |
+
self.fftconv_fn = None
|
| 182 |
+
self.long_fir_threshold = config.get("long_fir_threshold", None)
|
| 183 |
+
if self.long_fir_threshold is not None:
|
| 184 |
+
assert self.use_flashfft is False, "long_fir_threshold not compatible with fused flashfft"
|
| 185 |
+
|
| 186 |
+
self.num_systems = self.hyena_filter_groups
|
| 187 |
+
self.channels_per_group = self.hidden_size // self.hyena_filter_groups
|
| 188 |
+
|
| 189 |
+
if self.fir_inner_filter_length:
|
| 190 |
+
self.h = nn.Parameter(torch.randn(self.hyena_filter_groups, 1, fir_inner_filter_length))
|
| 191 |
+
|
| 192 |
+
if fir_inner_filter_length >= 128:
|
| 193 |
+
self.D = nn.Parameter(torch.zeros(self.hidden_size))
|
| 194 |
+
|
| 195 |
+
if fir_inner_filter_length < 128:
|
| 196 |
+
self.D = None
|
| 197 |
+
|
| 198 |
+
else:
|
| 199 |
+
log_poles = torch.randn(self.num_systems, self.state_size, 1, dtype=torch.float32)
|
| 200 |
+
|
| 201 |
+
# TODO: bring over init from internals
|
| 202 |
+
# poles[..., 0] = 1e-2 * torch.randn(self.num_systems, self.state_size, 1)
|
| 203 |
+
# poles[..., 1] = 1e-3 * torch.randn(self.num_systems, self.state_size, 1)
|
| 204 |
+
|
| 205 |
+
self.log_poles = nn.Parameter(log_poles)
|
| 206 |
+
self.residues = nn.Parameter(torch.randn(self.num_systems, self.state_size, dtype=torch.float32))
|
| 207 |
+
self.D = nn.Parameter(torch.zeros(self.hidden_size))
|
| 208 |
+
self.h = None
|
| 209 |
+
self.t = None
|
| 210 |
+
|
| 211 |
+
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
|
| 212 |
+
if inference_params is not None and self.layer_idx in inference_params.fir_state_dict.keys():
|
| 213 |
+
return self.sequential_forward(u, inference_params)
|
| 214 |
+
|
| 215 |
+
else:
|
| 216 |
+
return self.parallel_forward(u, inference_params, padding_mask)
|
| 217 |
+
|
| 218 |
+
def parallel_forward(self, u, inference_params=None, padding_mask=None):
|
| 219 |
+
L = u.shape[1]
|
| 220 |
+
dims = (
|
| 221 |
+
self.hidden_size,
|
| 222 |
+
self.num_attention_heads,
|
| 223 |
+
self.hidden_size_per_attention_head,
|
| 224 |
+
self.state_size,
|
| 225 |
+
self.hyena_filter_groups,
|
| 226 |
+
)
|
| 227 |
+
if self.print_activations:
|
| 228 |
+
activations_logger.info(f"pre 1 parallel fir: {u}, {u.min()}, {u.max()}")
|
| 229 |
+
|
| 230 |
+
z_pre, fir_state = self.engine.parallel_fir(
|
| 231 |
+
self.fir_fn,
|
| 232 |
+
u,
|
| 233 |
+
self.short_filter_weight,
|
| 234 |
+
self.short_filter_bias,
|
| 235 |
+
L,
|
| 236 |
+
dims=dims,
|
| 237 |
+
gate=False,
|
| 238 |
+
column_split_hyena=self.column_split_hyena,
|
| 239 |
+
fir_length=self.short_filter_length,
|
| 240 |
+
inference_params=inference_params,
|
| 241 |
+
padding_mask=padding_mask,
|
| 242 |
+
dim_last=True,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
if inference_params:
|
| 246 |
+
inference_params.fir_state_dict[self.layer_idx] = fir_state
|
| 247 |
+
|
| 248 |
+
if self.config.interleave:
|
| 249 |
+
z_pre = interleave(z_pre)
|
| 250 |
+
|
| 251 |
+
if self.h is None:
|
| 252 |
+
h, _, _, _ = self.compute_filter(L, u.device)
|
| 253 |
+
else:
|
| 254 |
+
h = self.h
|
| 255 |
+
|
| 256 |
+
D = self.D
|
| 257 |
+
|
| 258 |
+
if self.hyena_filter_groups > 1:
|
| 259 |
+
h = h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 0)
|
| 260 |
+
|
| 261 |
+
# if inference_params is not None, we plan to perform generation:
|
| 262 |
+
# prefilling is handled by the engine.
|
| 263 |
+
if self.fir_inner_filter_length is not None:
|
| 264 |
+
if self.print_activations:
|
| 265 |
+
activations_logger.info(
|
| 266 |
+
f"pre 2 parallel fir: {z_pre}, {z_pre.min()}, {z_pre.max()}, {self.fir_inner_filter_length}"
|
| 267 |
+
)
|
| 268 |
+
y, fir_inner_state = self.engine.parallel_fir(
|
| 269 |
+
self.fir_inner_fn,
|
| 270 |
+
z_pre,
|
| 271 |
+
h,
|
| 272 |
+
D,
|
| 273 |
+
L,
|
| 274 |
+
dims=dims,
|
| 275 |
+
gate=True,
|
| 276 |
+
gated_bias=self.fir_inner_filter_length >= 128,
|
| 277 |
+
dim_last=False,
|
| 278 |
+
column_split_hyena=self.column_split_hyena,
|
| 279 |
+
fir_length=self.fir_inner_filter_length,
|
| 280 |
+
inference_params=inference_params,
|
| 281 |
+
padding_mask=padding_mask,
|
| 282 |
+
groups=self.hyena_filter_groups,
|
| 283 |
+
)
|
| 284 |
+
if self.print_activations:
|
| 285 |
+
activations_logger.info(f"post 2 parallel fir: {y}, {y.min()}, {y.max()}")
|
| 286 |
+
y = y.permute(0, 2, 1)
|
| 287 |
+
if inference_params:
|
| 288 |
+
inference_params.fir_inner_state_dict[self.layer_idx] = fir_inner_state
|
| 289 |
+
else:
|
| 290 |
+
if self.print_activations:
|
| 291 |
+
activations_logger.info(f"pre 2 parallel iir: {z_pre}, {z_pre.min()}, {z_pre.max()}")
|
| 292 |
+
y = self.engine.parallel_iir(
|
| 293 |
+
z_pre,
|
| 294 |
+
h,
|
| 295 |
+
D,
|
| 296 |
+
L,
|
| 297 |
+
t=self.t,
|
| 298 |
+
poles=self.log_poles,
|
| 299 |
+
residues=self.residues,
|
| 300 |
+
dims=dims,
|
| 301 |
+
inference_params=inference_params,
|
| 302 |
+
layer_idx=self.layer_idx,
|
| 303 |
+
prefill_style=self.config.get("prefill_style", "fft"),
|
| 304 |
+
use_flashfft=self.use_flashfft,
|
| 305 |
+
fftconv_fn=self.fftconv_fn,
|
| 306 |
+
column_split_hyena=self.column_split_hyena,
|
| 307 |
+
long_fir_threshold=self.long_fir_threshold,
|
| 308 |
+
padding_mask=padding_mask,
|
| 309 |
+
)
|
| 310 |
+
if self.print_activations:
|
| 311 |
+
activations_logger.info(f"post 2 parallel iir: {y}, {y.min()}, {y.max()}")
|
| 312 |
+
|
| 313 |
+
return y, inference_params
|
| 314 |
+
|
| 315 |
+
def sequential_forward(self, u, inference_params):
|
| 316 |
+
if self.data_dtype is None:
|
| 317 |
+
self.data_dtype = u.dtype
|
| 318 |
+
|
| 319 |
+
if len(u.shape) > 2:
|
| 320 |
+
u = u[:, -1]
|
| 321 |
+
|
| 322 |
+
z_pre, fir_state = self.engine.step_fir(
|
| 323 |
+
u,
|
| 324 |
+
inference_params.fir_state_dict[self.layer_idx],
|
| 325 |
+
weight=self.short_filter_weight,
|
| 326 |
+
bias=self.short_filter_bias,
|
| 327 |
+
)
|
| 328 |
+
inference_params.fir_state_dict[self.layer_idx] = fir_state
|
| 329 |
+
|
| 330 |
+
if self.config.interleave:
|
| 331 |
+
z_pre = interleave(z_pre)
|
| 332 |
+
|
| 333 |
+
x2, x1, v = (
|
| 334 |
+
column_split(z_pre, self.num_attention_heads, self.hidden_size_per_attention_head)
|
| 335 |
+
if self.column_split_hyena
|
| 336 |
+
else z_pre.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1)
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if self.hyena_flip_x1x2:
|
| 340 |
+
x1, x2 = x2, x1
|
| 341 |
+
|
| 342 |
+
if self.fir_inner_filter_length is not None:
|
| 343 |
+
if self.hyena_filter_groups > 1:
|
| 344 |
+
h = self.h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 0)
|
| 345 |
+
else:
|
| 346 |
+
h = self.h
|
| 347 |
+
|
| 348 |
+
y, fir_inner_state = self.engine.step_fir(
|
| 349 |
+
x1 * v,
|
| 350 |
+
inference_params.fir_inner_state_dict[self.layer_idx],
|
| 351 |
+
weight=h,
|
| 352 |
+
bias=self.D,
|
| 353 |
+
flip_filter=self.fir_inner_filter_length >= 128,
|
| 354 |
+
gated_bias=self.fir_inner_filter_length >= 128,
|
| 355 |
+
)
|
| 356 |
+
y = y * x2
|
| 357 |
+
inference_params.fir_inner_state_dict[self.layer_idx] = fir_inner_state
|
| 358 |
+
else:
|
| 359 |
+
y, iir_state = self.engine.step_iir(
|
| 360 |
+
x2,
|
| 361 |
+
x1,
|
| 362 |
+
v,
|
| 363 |
+
self.D,
|
| 364 |
+
self.residues,
|
| 365 |
+
self.log_poles,
|
| 366 |
+
inference_params.state_dict[self.layer_idx],
|
| 367 |
+
iir_groups=1,
|
| 368 |
+
)
|
| 369 |
+
inference_params.state_dict[self.layer_idx] = iir_state
|
| 370 |
+
|
| 371 |
+
y = y.to(dtype=self.data_dtype)
|
| 372 |
+
return y[:, None], inference_params
|
| 373 |
+
|
| 374 |
+
def update_time(self, L, device):
|
| 375 |
+
"""
|
| 376 |
+
Set [0, 1, ..., L-1] where L is the length of the current batch of inputs.
|
| 377 |
+
If L is greater than the length of the previous batch, then the time vector is
|
| 378 |
+
reinitialized. Otherwise, the time vector is truncated from cache.
|
| 379 |
+
"""
|
| 380 |
+
if self.t is None:
|
| 381 |
+
self.t = torch.arange(L, device=device)[None, None]
|
| 382 |
+
elif self.t.shape[-1] < L:
|
| 383 |
+
self.t = torch.arange(L, device=device)[None, None]
|
| 384 |
+
else:
|
| 385 |
+
self.t = self.t[..., :L]
|
| 386 |
+
|
| 387 |
+
def compute_filter(self, L, device):
|
| 388 |
+
self.update_time(L, device)
|
| 389 |
+
filter_dtype = torch.float32
|
| 390 |
+
residues, log_poles = (
|
| 391 |
+
self.residues.to(filter_dtype),
|
| 392 |
+
self.log_poles.to(filter_dtype),
|
| 393 |
+
)
|
| 394 |
+
h = (residues[..., None] * (log_poles * self.t).exp()).sum(1)[None] # B, D, L
|
| 395 |
+
return h, filter_dtype, log_poles, residues
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class ParallelGatedConvBlock(nn.Module):
|
| 399 |
+
def __init__(self, config, layer_idx, hyena_filter_groups=None, fir_inner_filter_length=None) -> None:
|
| 400 |
+
super().__init__()
|
| 401 |
+
self.config = config
|
| 402 |
+
self.layer_idx = layer_idx
|
| 403 |
+
self.print_activations = config.get("print_activations", False)
|
| 404 |
+
self.ground_truth_activations_path = config.get("ground_truth_activations_path", None)
|
| 405 |
+
self.low_mem_mode = config.get("low_mem_mode", False)
|
| 406 |
+
self.fir_inner_filter_length = fir_inner_filter_length
|
| 407 |
+
self.hyena_filter_groups = hyena_filter_groups if hyena_filter_groups is not None else config.hidden_size
|
| 408 |
+
dtype = config.get("hyena_block_dtype", torch.bfloat16)
|
| 409 |
+
mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
|
| 410 |
+
self.pre_norm, self.post_norm = (
|
| 411 |
+
RMSNorm(config).to(dtype=dtype),
|
| 412 |
+
RMSNorm(config).to(dtype=dtype),
|
| 413 |
+
)
|
| 414 |
+
self.filter = HyenaCascade(
|
| 415 |
+
config,
|
| 416 |
+
layer_idx,
|
| 417 |
+
hyena_filter_groups=self.hyena_filter_groups,
|
| 418 |
+
fir_inner_filter_length=fir_inner_filter_length,
|
| 419 |
+
).to(dtype=dtype)
|
| 420 |
+
|
| 421 |
+
# For posterity/debugging: TELinear can be easily replaced by
|
| 422 |
+
# nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.qkv_proj_bias).to(dtype=dtype)
|
| 423 |
+
# which sometimes is very useful when debugging FP8.
|
| 424 |
+
self.projections = TELinear(
|
| 425 |
+
config.hidden_size,
|
| 426 |
+
3 * config.hidden_size,
|
| 427 |
+
bias=config.qkv_proj_bias,
|
| 428 |
+
init_method=torch.nn.init.xavier_uniform_,
|
| 429 |
+
use_fp8=config.get("use_fp8_input_projections", False),
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
self.out_filter_dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.hyena_out_proj_bias).to(
|
| 433 |
+
dtype
|
| 434 |
+
)
|
| 435 |
+
self.mlp = ParallelGatedMLP(config, layer_idx).to(dtype=mlp_dtype)
|
| 436 |
+
|
| 437 |
+
# self.proj_norm_fn = self.proj_norm
|
| 438 |
+
# self.res_mlp_norm_fn = self.res_mlp_norm
|
| 439 |
+
|
| 440 |
+
if self.config.get("compile", False):
|
| 441 |
+
self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")
|
| 442 |
+
self.res_mlp_norm_fn = torch.compile(
|
| 443 |
+
self.res_mlp_norm, fullgraph=True, dynamic=False, mode="reduce-overhead"
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
def pad_to_multiple(self, x, multiple=16):
|
| 447 |
+
"""Pad input tensor to multiple of 16 only when FP8 is enabled"""
|
| 448 |
+
if not self.config.get("use_fp8_input_projections", False):
|
| 449 |
+
return x
|
| 450 |
+
|
| 451 |
+
batch_size, seq_len, hidden_dim = x.size()
|
| 452 |
+
pad_len = (multiple - (seq_len % multiple)) % multiple
|
| 453 |
+
if pad_len == 0:
|
| 454 |
+
return x
|
| 455 |
+
return F.pad(x, (0, 0, 0, pad_len))
|
| 456 |
+
|
| 457 |
+
def proj_norm(self, x):
|
| 458 |
+
if self.print_activations:
|
| 459 |
+
activations_logger.info(f"pre mixer norm: {x} {x.min()} {x.max()} {self.projections.__class__}")
|
| 460 |
+
activations_logger.info(
|
| 461 |
+
f"post mixer norm: {self.pre_norm(x)} {self.pre_norm(x).min()} {self.pre_norm(x).max()}"
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
if self.ground_truth_activations_path:
|
| 465 |
+
pre_norm_savanna = torch.load(
|
| 466 |
+
f"{self.ground_truth_activations_path}/pre_mixer_norm_{self.layer_idx}.pt"
|
| 467 |
+
)
|
| 468 |
+
post_norm_savanna = torch.load(
|
| 469 |
+
f"{self.ground_truth_activations_path}/post_mixer_norm_{self.layer_idx}.pt"
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
activation_diff = (x.squeeze() - pre_norm_savanna.squeeze()).abs()
|
| 473 |
+
activations_logger.info(
|
| 474 |
+
f"pre mixer norm activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
|
| 475 |
+
)
|
| 476 |
+
activation_diff = (self.pre_norm(x).squeeze() - post_norm_savanna.squeeze()).abs()
|
| 477 |
+
activations_logger.info(
|
| 478 |
+
f"post mixer norm activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
|
| 479 |
+
)
|
| 480 |
+
activations_logger.info(
|
| 481 |
+
f"pre norm scale: {self.pre_norm.scale}, {self.pre_norm.scale.min()}, {self.pre_norm.scale.max()}"
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
normalized = self.pre_norm(x)
|
| 485 |
+
normalized = self.pad_to_multiple(normalized)
|
| 486 |
+
with torch.cuda.device(x.device):
|
| 487 |
+
projected = self.projections(normalized)
|
| 488 |
+
|
| 489 |
+
if isinstance(projected, tuple):
|
| 490 |
+
projected = projected[0]
|
| 491 |
+
|
| 492 |
+
original_seq_len = x.size(1)
|
| 493 |
+
# Slice back to original sequence length if padding was added
|
| 494 |
+
if projected.size(1) > original_seq_len:
|
| 495 |
+
projected = projected[:, :original_seq_len, :]
|
| 496 |
+
|
| 497 |
+
return projected
|
| 498 |
+
|
| 499 |
+
def res_mlp_norm(self, x):
|
| 500 |
+
if self.print_activations:
|
| 501 |
+
activations_logger.info(f"pre mlp: {x} {x.min()} {x.max()} {self.mlp.__class__}")
|
| 502 |
+
activations_logger.info(
|
| 503 |
+
f"post mlp norm: {self.post_norm(x)} {self.post_norm(x).min()} {self.post_norm(x).max()}"
|
| 504 |
+
)
|
| 505 |
+
activations_logger.info(
|
| 506 |
+
f"post mlp: {self.mlp(self.post_norm(x))} {self.mlp(self.post_norm(x)).min()} {self.mlp(self.post_norm(x)).max()}"
|
| 507 |
+
)
|
| 508 |
+
if self.ground_truth_activations_path:
|
| 509 |
+
pre_mlp_savanna = torch.load(f"{self.ground_truth_activations_path}/pre_mlp_{self.layer_idx}.pt")
|
| 510 |
+
post_mlp_savanna = torch.load(f"{self.ground_truth_activations_path}/post_mlp_norm_{self.layer_idx}.pt")
|
| 511 |
+
|
| 512 |
+
activation_diff = (x.squeeze() - pre_mlp_savanna.squeeze()).abs()
|
| 513 |
+
activations_logger.info(f"pre mlp activation_diff: {activation_diff.max()}, {activation_diff.mean()}")
|
| 514 |
+
activation_diff = (self.post_norm(x).squeeze() - post_mlp_savanna.squeeze()).abs()
|
| 515 |
+
activations_logger.info(
|
| 516 |
+
f"post mlp norm activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
|
| 517 |
+
)
|
| 518 |
+
return self.mlp(self.post_norm(x)) + x
|
| 519 |
+
|
| 520 |
+
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
|
| 521 |
+
z = self.proj_norm(u)
|
| 522 |
+
|
| 523 |
+
if type(padding_mask) == torch.Tensor: # guard against bias
|
| 524 |
+
z = z * padding_mask[..., None]
|
| 525 |
+
|
| 526 |
+
if self.print_activations:
|
| 527 |
+
activations_logger.info(f"pre filter: {z} {z.min()} {z.max()} {self.filter.__class__}")
|
| 528 |
+
if self.ground_truth_activations_path:
|
| 529 |
+
z_savanna = torch.load(f"{self.ground_truth_activations_path}/pre_filter_{self.layer_idx}.pt")
|
| 530 |
+
activation_diff = (z - z_savanna.squeeze()).abs()
|
| 531 |
+
activations_logger.info(
|
| 532 |
+
f"pre filter activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
|
| 533 |
+
)
|
| 534 |
+
z, inference_params = self.filter(z, inference_params=inference_params, padding_mask=padding_mask)
|
| 535 |
+
|
| 536 |
+
if self.print_activations:
|
| 537 |
+
activations_logger.info(f"post postgate: {z} {z.min()} {z.max()} {self.filter.__class__}")
|
| 538 |
+
activations_logger.info(
|
| 539 |
+
f"post out proj: {self.out_filter_dense(z)} {self.out_filter_dense(z).min()} {self.out_filter_dense(z).max()} {self.out_filter_dense.__class__}"
|
| 540 |
+
)
|
| 541 |
+
activations_logger.info(
|
| 542 |
+
f"post mixer dense and residual: {self.out_filter_dense(z) + u} {(self.out_filter_dense(z) + u).min()} {(self.out_filter_dense(z) + u).max()}"
|
| 543 |
+
)
|
| 544 |
+
activations_logger.info(
|
| 545 |
+
f"post mixer dense: {self.out_filter_dense(z)} {self.out_filter_dense(z).min()} {self.out_filter_dense(z).max()}"
|
| 546 |
+
)
|
| 547 |
+
activations_logger.info(f"post mixer: {z} {z.min()} {z.max()}")
|
| 548 |
+
if self.ground_truth_activations_path:
|
| 549 |
+
z_savanna = torch.load(f"{self.ground_truth_activations_path}/post_filter_{self.layer_idx}.pt")
|
| 550 |
+
activation_diff = (z - z_savanna.squeeze()).abs()
|
| 551 |
+
activations_logger.info(
|
| 552 |
+
f"post filter activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
z_savanna = torch.load(f"{self.ground_truth_activations_path}/post_out_proj_{self.layer_idx}.pt")
|
| 556 |
+
z_ = F.linear(z, self.out_filter_dense.weight)
|
| 557 |
+
activation_diff = (z_ - z_savanna.squeeze()).abs()
|
| 558 |
+
activations_logger.info(
|
| 559 |
+
f"post out proj activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
z_in = self.out_filter_dense(z) + u
|
| 563 |
+
|
| 564 |
+
# if self.layer_idx == 0:
|
| 565 |
+
# z_in = z_savanna.squeeze() + u + self.out_filter_dense.bias
|
| 566 |
+
|
| 567 |
+
if type(padding_mask) == torch.Tensor: # guard against bias
|
| 568 |
+
z_in = z_in * padding_mask[..., None]
|
| 569 |
+
|
| 570 |
+
y = self.res_mlp_norm(z_in)
|
| 571 |
+
|
| 572 |
+
return y, inference_params
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def get_block(config, layer_idx, flash_fft=None):
|
| 576 |
+
if layer_idx in config.attn_layer_idxs:
|
| 577 |
+
return AttentionBlock(config, layer_idx)
|
| 578 |
+
elif layer_idx in config.hcl_layer_idxs:
|
| 579 |
+
block = ParallelGatedConvBlock(config, layer_idx)
|
| 580 |
+
if config.get("use_flashfft", "False"):
|
| 581 |
+
block.filter.fftconv_fn = flash_fft
|
| 582 |
+
return block
|
| 583 |
+
elif layer_idx in config.hcm_layer_idxs:
|
| 584 |
+
block = ParallelGatedConvBlock(
|
| 585 |
+
config,
|
| 586 |
+
layer_idx,
|
| 587 |
+
hyena_filter_groups=config.hcm_filter_groups,
|
| 588 |
+
fir_inner_filter_length=config.hcm_filter_length,
|
| 589 |
+
)
|
| 590 |
+
return block
|
| 591 |
+
elif layer_idx in config.hcs_layer_idxs:
|
| 592 |
+
block = ParallelGatedConvBlock(
|
| 593 |
+
config,
|
| 594 |
+
layer_idx,
|
| 595 |
+
hyena_filter_groups=config.hcs_filter_groups,
|
| 596 |
+
fir_inner_filter_length=config.hcs_filter_length,
|
| 597 |
+
)
|
| 598 |
+
return block
|
| 599 |
+
else:
|
| 600 |
+
raise NotImplementedError
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class StripedHyena(nn.Module):
|
| 604 |
+
def __init__(self, config):
|
| 605 |
+
super().__init__()
|
| 606 |
+
fixup_te_workspace() # Workaround global cublas workspaces in TE
|
| 607 |
+
|
| 608 |
+
self.config = config
|
| 609 |
+
self.print_activations = config.get("print_activations", False)
|
| 610 |
+
|
| 611 |
+
if self.print_activations:
|
| 612 |
+
enable_activations_logging()
|
| 613 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
| 614 |
+
|
| 615 |
+
self.ground_truth_activations_path = config.get("ground_truth_activations_path", None)
|
| 616 |
+
self.logger.info(f"Initializing StripedHyena with config: {config}")
|
| 617 |
+
|
| 618 |
+
with torch.device("cuda:0" if torch.cuda.is_available() else "cpu"):
|
| 619 |
+
self.embedding_layer = VocabParallelEmbedding(config)
|
| 620 |
+
|
| 621 |
+
if config.get("use_flashfft", "True"):
|
| 622 |
+
try:
|
| 623 |
+
from flashfftconv import FlashFFTConv
|
| 624 |
+
|
| 625 |
+
self.flash_fft = FlashFFTConv(config.seqlen, dtype=torch.bfloat16)
|
| 626 |
+
except ImportError:
|
| 627 |
+
"flashfftconv not installed"
|
| 628 |
+
else:
|
| 629 |
+
self.flash_fft = None
|
| 630 |
+
if not self.config.get('evo2_style_activations', False):
|
| 631 |
+
self.logger.warning(
|
| 632 |
+
"⚠️ Not using Evo2 style activations ⚠️\n"
|
| 633 |
+
"⚠️ Set 'evo2_style_activations: True' in config if you are using Evo 2 checkpoints ⚠️"
|
| 634 |
+
)
|
| 635 |
+
self.logger.info(f"Initializing {config.num_layers} blocks...")
|
| 636 |
+
self.blocks = nn.ModuleList()
|
| 637 |
+
self.block_idx_to_device = {}
|
| 638 |
+
|
| 639 |
+
# Calculate layers per GPU
|
| 640 |
+
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
| 641 |
+
layers_per_gpu = math.ceil(config.num_layers / num_gpus)
|
| 642 |
+
self.logger.info(f"Distributing across {num_gpus} GPUs, approximately {layers_per_gpu} layers per GPU")
|
| 643 |
+
|
| 644 |
+
for layer_idx in tqdm(range(config.num_layers)):
|
| 645 |
+
# Determine which GPU should handle this layer
|
| 646 |
+
device_idx = min(layer_idx // layers_per_gpu, num_gpus - 1)
|
| 647 |
+
device = f"cuda:{device_idx}" if torch.cuda.is_available() else "cpu"
|
| 648 |
+
|
| 649 |
+
with torch.device(device):
|
| 650 |
+
# TELinear uses `device="cuda"` device to allocate empty bias
|
| 651 |
+
# tensor. This makes sure that the empty tensor is allocated on the
|
| 652 |
+
# correct device. (torch.device(), unlike torch.cuda.device(),
|
| 653 |
+
# doesn't override current CUDA device.)
|
| 654 |
+
with torch.cuda.device(device):
|
| 655 |
+
block = get_block(config, layer_idx, flash_fft=self.flash_fft)
|
| 656 |
+
move_to_device(block, device)
|
| 657 |
+
|
| 658 |
+
self.blocks.append(block)
|
| 659 |
+
self.block_idx_to_device[layer_idx] = device
|
| 660 |
+
self.logger.info(f"Assigned {layer_idx=} to {device=}")
|
| 661 |
+
self.logger.info(
|
| 662 |
+
f"Parameter count for block {layer_idx}: {sum(p.numel() for p in self.blocks[-1].parameters())}"
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
with torch.device(self.block_idx_to_device[0]):
|
| 666 |
+
with torch.cuda.device(self.block_idx_to_device[0]):
|
| 667 |
+
self.norm = RMSNorm(config) if config.get("final_norm", True) else None
|
| 668 |
+
if config.tie_embeddings:
|
| 669 |
+
# Lambda usage is to be able to use forward() on caller side, which in
|
| 670 |
+
# turn is needed for PyTorch hooks to work properly.
|
| 671 |
+
self.unembed = Lambda(self.embedding_layer.unembed)
|
| 672 |
+
else:
|
| 673 |
+
if config.tie_embeddings:
|
| 674 |
+
# Technically we can support this mode, just need to
|
| 675 |
+
# copy tensors across GPUs then. But let's implement it
|
| 676 |
+
# once/if needed.
|
| 677 |
+
self.logger.info("Ignoring tie_embeddings for now.")
|
| 678 |
+
self.unembed = VocabParallelUnembedding(config)
|
| 679 |
+
|
| 680 |
+
self.logger.info("Initialized model")
|
| 681 |
+
|
| 682 |
+
def forward(self, x, inference_params_dict=None, padding_mask=None):
|
| 683 |
+
L = x.shape[1]
|
| 684 |
+
if self.print_activations:
|
| 685 |
+
activations_logger.info(f"pre embedding: {x}, {x.min()}, {x.max()}")
|
| 686 |
+
|
| 687 |
+
x = self.embedding_layer(x)
|
| 688 |
+
|
| 689 |
+
if self.print_activations:
|
| 690 |
+
activations_logger.info(f"post embedding: {x}, {x.min()}, {x.max()}")
|
| 691 |
+
|
| 692 |
+
if inference_params_dict is not None:
|
| 693 |
+
x, inference_params_dict_out = self.stateful_forward(
|
| 694 |
+
x,
|
| 695 |
+
inference_params_dict=inference_params_dict,
|
| 696 |
+
)
|
| 697 |
+
else:
|
| 698 |
+
x, inference_params_dict_out = self.stateless_forward(x, padding_mask=padding_mask)
|
| 699 |
+
|
| 700 |
+
if self.print_activations:
|
| 701 |
+
activations_logger.info(f"pre norm: {x}, {x.min()}, {x.max()}")
|
| 702 |
+
|
| 703 |
+
# By convention, we return results on the first device
|
| 704 |
+
x = x.to(self.block_idx_to_device[0])
|
| 705 |
+
x = self.norm(x)
|
| 706 |
+
|
| 707 |
+
if self.print_activations:
|
| 708 |
+
activations_logger.info(f"post norm: {x}, {x.min()}, {x.max(), {self.norm.scale}}")
|
| 709 |
+
|
| 710 |
+
x = self.unembed(x)
|
| 711 |
+
return x, inference_params_dict_out
|
| 712 |
+
|
| 713 |
+
def block_idx_to_name(self, block_idx):
|
| 714 |
+
if block_idx in self.config.attn_layer_idxs:
|
| 715 |
+
return "mha"
|
| 716 |
+
elif block_idx in self.config.hcl_layer_idxs:
|
| 717 |
+
return "hcl"
|
| 718 |
+
elif block_idx in self.config.hcm_layer_idxs:
|
| 719 |
+
return "hcm"
|
| 720 |
+
elif block_idx in self.config.hcs_layer_idxs:
|
| 721 |
+
return "hcs"
|
| 722 |
+
else:
|
| 723 |
+
raise ValueError(f"Block index {block_idx} not found")
|
| 724 |
+
|
| 725 |
+
def cross_device_transfer(self, x, block_idx):
|
| 726 |
+
if self.block_idx_to_device[max(block_idx - 1, 0)] != self.block_idx_to_device[block_idx]:
|
| 727 |
+
x = x.to(self.block_idx_to_device[block_idx])
|
| 728 |
+
return x
|
| 729 |
+
|
| 730 |
+
def stateful_forward(self, x, inference_params_dict=None):
|
| 731 |
+
for block_idx, block in enumerate(self.blocks):
|
| 732 |
+
inference_params = inference_params_dict[self.block_idx_to_name(block_idx)]
|
| 733 |
+
|
| 734 |
+
if self.print_activations:
|
| 735 |
+
activations_logger.info(f"pre block {block_idx}: {x}, {x.min()}, {x.max()} {block.__class__}")
|
| 736 |
+
if self.ground_truth_activations_path:
|
| 737 |
+
x_savanna = torch.load(f"{self.ground_truth_activations_path}/pre_block_{block_idx}.pt")
|
| 738 |
+
activation_diff = (x - x_savanna.squeeze()).abs()
|
| 739 |
+
activations_logger.info(
|
| 740 |
+
f"pre block {block_idx} activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
x = self.cross_device_transfer(x, block_idx)
|
| 744 |
+
x, _ = block(x, inference_params=inference_params)
|
| 745 |
+
|
| 746 |
+
if self.print_activations:
|
| 747 |
+
activations_logger.info(f"post block {block_idx}: {x}, {x.min()}, {x.max()}")
|
| 748 |
+
if self.ground_truth_activations_path:
|
| 749 |
+
x_savanna = torch.load(f"{self.ground_truth_activations_path}/post_block_{block_idx}.pt")
|
| 750 |
+
activation_diff = (x - x_savanna.squeeze()).abs()
|
| 751 |
+
activations_logger.info(
|
| 752 |
+
f"post block {block_idx} activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
return x, inference_params_dict
|
| 756 |
+
|
| 757 |
+
def stateless_forward(self, x, padding_mask=None):
|
| 758 |
+
if type(padding_mask) == torch.Tensor:
|
| 759 |
+
x = x * padding_mask[..., None]
|
| 760 |
+
|
| 761 |
+
for block_idx, block in enumerate(self.blocks):
|
| 762 |
+
if self.print_activations:
|
| 763 |
+
activations_logger.info(f"pre block {block_idx}: {x}, {x.min()}, {x.max()} {block.__class__}")
|
| 764 |
+
if self.ground_truth_activations_path:
|
| 765 |
+
x_savanna = torch.load(f"{self.ground_truth_activations_path}/pre_block_{block_idx}.pt")
|
| 766 |
+
activation_diff = (x - x_savanna.squeeze()).abs()
|
| 767 |
+
activations_logger.info(
|
| 768 |
+
f"pre block {block_idx} activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
x = self.cross_device_transfer(x, block_idx)
|
| 772 |
+
x, _ = block(x, inference_params=None, padding_mask=padding_mask)
|
| 773 |
+
|
| 774 |
+
if self.print_activations:
|
| 775 |
+
activations_logger.info(f"post block {block_idx}: {x}, {x.min()}, {x.max()}")
|
| 776 |
+
if self.ground_truth_activations_path:
|
| 777 |
+
x_savanna = torch.load(f"{self.ground_truth_activations_path}/post_block_{block_idx}.pt")
|
| 778 |
+
activation_diff = (x - x_savanna.squeeze()).abs()
|
| 779 |
+
activations_logger.info(
|
| 780 |
+
f"post block {block_idx} activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
return x, None
|
| 784 |
+
|
| 785 |
+
def initialize_inference_params(self, max_seqlen=None):
|
| 786 |
+
## Input seqlen takes priority over config!
|
| 787 |
+
## WARNING: This avoids potential errors but means the model can be used beyond length it was trained at
|
| 788 |
+
config_seqlen = self.config.get("max_seqlen", None)
|
| 789 |
+
if config_seqlen is None:
|
| 790 |
+
print("No max_seqlen found in config!!! using default value of 8192")
|
| 791 |
+
config_seqlen = 8192
|
| 792 |
+
new_max_seqlen = max_seqlen if max_seqlen != None else config_seqlen
|
| 793 |
+
# self.config["max_seqlen"] = new_max_seqlen
|
| 794 |
+
## Note: changing the stored config max_seqlen will change the max_seqlen used in flash attention, leading to minor logit differences
|
| 795 |
+
print(f"Initializing inference params with max_seqlen={new_max_seqlen}")
|
| 796 |
+
|
| 797 |
+
inference_params_dict = {
|
| 798 |
+
"mha": InferenceParams(
|
| 799 |
+
max_seqlen=new_max_seqlen,
|
| 800 |
+
max_batch_size=self.config.get("max_batch_size", 1),
|
| 801 |
+
seqlen_offset=0,
|
| 802 |
+
),
|
| 803 |
+
"hcl": HyenaCascadeIIRInferenceParams(
|
| 804 |
+
fir_filter_length=self.config.short_filter_length,
|
| 805 |
+
state_dim=self.config.state_size,
|
| 806 |
+
seqlen_offset=0,
|
| 807 |
+
),
|
| 808 |
+
"hcm": HyenaCascadeFIRInferenceParams(
|
| 809 |
+
fir_filter_length=self.config.short_filter_length,
|
| 810 |
+
fir_inner_filter_length=self.config.hcm_filter_length,
|
| 811 |
+
seqlen_offset=0,
|
| 812 |
+
),
|
| 813 |
+
"hcs": HyenaCascadeFIRInferenceParams(
|
| 814 |
+
fir_filter_length=self.config.short_filter_length,
|
| 815 |
+
fir_inner_filter_length=self.config.hcs_filter_length,
|
| 816 |
+
seqlen_offset=0,
|
| 817 |
+
),
|
| 818 |
+
}
|
| 819 |
+
return inference_params_dict
|
| 820 |
+
|
| 821 |
+
def precompute_filters(self, L, device):
|
| 822 |
+
for block_idx, block in enumerate(self.blocks):
|
| 823 |
+
if type(block) == ParallelGatedConvBlock:
|
| 824 |
+
if type(block.filter) == HyenaCascade:
|
| 825 |
+
L = block.filter.long_fir_threshold or L
|
| 826 |
+
print_rank_0(f"Precomputing filters, L={L}...")
|
| 827 |
+
|
| 828 |
+
filter_dtype = torch.float16 if L >= 2048 else torch.float32
|
| 829 |
+
|
| 830 |
+
block.filter._set_time(L, device)
|
| 831 |
+
residues, poles = (
|
| 832 |
+
block.filter.residues.to(torch.float16),
|
| 833 |
+
block.filter.poles.to(torch.float16),
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
block.filter.h = (residues * poles**block.filter.t).real.sum(1)[None]
|
| 837 |
+
block.filter.h = block.filter.h.to(dtype=filter_dtype)
|
| 838 |
+
|
| 839 |
+
def load_poles_residues(self, path):
|
| 840 |
+
"Load different poles and residues for each layer."
|
| 841 |
+
for block_idx, block in enumerate(self.blocks):
|
| 842 |
+
if type(block) == ParallelGatedConvBlock:
|
| 843 |
+
if type(block.filter) == HyenaCascade:
|
| 844 |
+
self.logger.info(f"Loading approximatepoles and residues for block {block_idx}")
|
| 845 |
+
poles = torch.load(path + f"/approx_poles_{block_idx+1}.pt", map_location="cpu")
|
| 846 |
+
poles = torch.view_as_real(poles)
|
| 847 |
+
residues = torch.load(path + f"/approx_residues_{block_idx+1}.pt", map_location="cpu")
|
| 848 |
+
residues = torch.view_as_real(residues)
|
| 849 |
+
poles = poles.permute(1, 0, 2).unsqueeze(-2)
|
| 850 |
+
residues = residues.permute(1, 0, 2).unsqueeze(-2)
|
| 851 |
+
|
| 852 |
+
block.filter.poles = nn.Parameter(poles)
|
| 853 |
+
block.filter.residues = nn.Parameter(residues)
|
| 854 |
+
|
| 855 |
+
def custom_load_state_dict(self, state_dict, strict=True):
|
| 856 |
+
"""
|
| 857 |
+
Post-processes the state_dict to convert savanna checkpoints to vortex checkpoints.
|
| 858 |
+
"""
|
| 859 |
+
self.logger.debug(f"Loading state dict: {state_dict}, (ignoring extra keys) with strict: {strict}")
|
| 860 |
+
model_dict = self.state_dict()
|
| 861 |
+
|
| 862 |
+
# Find keys that are in model_dict but not in state_dict
|
| 863 |
+
missing_in_state_dict = model_dict.keys() - state_dict.keys()
|
| 864 |
+
# Find keys that are in state_dict but not in model_dict
|
| 865 |
+
extra_in_state_dict = state_dict.keys() - model_dict.keys()
|
| 866 |
+
|
| 867 |
+
if missing_in_state_dict:
|
| 868 |
+
print(f"Keys missing in state_dict: {missing_in_state_dict}")
|
| 869 |
+
if extra_in_state_dict:
|
| 870 |
+
print(f"Extra keys in state_dict: {extra_in_state_dict}")
|
| 871 |
+
|
| 872 |
+
filtered_dict = {k: v for k, v in state_dict.items() if k in model_dict}
|
| 873 |
+
|
| 874 |
+
if all("._extra_state" in k for k in missing_in_state_dict):
|
| 875 |
+
self.logger.info("Checkpoint has no FP8 extra state, will be using initial state.")
|
| 876 |
+
for k in missing_in_state_dict:
|
| 877 |
+
filtered_dict[k] = None
|
| 878 |
+
|
| 879 |
+
self.load_state_dict(filtered_dict, strict=strict)
|
| 880 |
+
fixup_fp8_extra_states(self)
|
| 881 |
+
|
| 882 |
+
if self.config.get("column_split", True):
|
| 883 |
+
self.logger.info("Adjusting Wqkv for column split (permuting rows)")
|
| 884 |
+
for layer_idx, block in enumerate(self.blocks):
|
| 885 |
+
if type(block) == AttentionBlock:
|
| 886 |
+
target_device = block.inner_mha_cls.Wqkv.weight.device
|
| 887 |
+
|
| 888 |
+
Wqkv = state_dict[f"blocks.{layer_idx}.inner_mha_cls.Wqkv.weight"]
|
| 889 |
+
try:
|
| 890 |
+
bias = state_dict[f"blocks.{layer_idx}.inner_mha_cls.Wqkv.bias"]
|
| 891 |
+
except:
|
| 892 |
+
bias = None
|
| 893 |
+
|
| 894 |
+
size_att_head = block.hidden_size_per_attention_head
|
| 895 |
+
|
| 896 |
+
Wqkv = Wqkv.permute(1, 0)
|
| 897 |
+
Wqkv = Wqkv.reshape(block.hidden_size, block.num_attention_heads, 3, size_att_head)
|
| 898 |
+
Wq, Wk, Wv = Wqkv.unbind(dim=-2)
|
| 899 |
+
Wq = Wq.reshape(block.hidden_size, -1)
|
| 900 |
+
Wk = Wk.reshape(block.hidden_size, -1)
|
| 901 |
+
Wv = Wv.reshape(block.hidden_size, -1)
|
| 902 |
+
Wqkv = torch.cat([Wq, Wk, Wv], dim=-1)
|
| 903 |
+
Wqkv = Wqkv.permute(1, 0)
|
| 904 |
+
|
| 905 |
+
# Single device transfer at the end
|
| 906 |
+
block.inner_mha_cls.Wqkv.weight.data = Wqkv.to(target_device)
|
| 907 |
+
|
| 908 |
+
if bias is not None:
|
| 909 |
+
bias = bias.cpu() # Process on CPU
|
| 910 |
+
bias = bias.reshape(block.num_attention_heads, 3, size_att_head)
|
| 911 |
+
bias_q, bias_k, bias_v = bias.unbind(dim=-2)
|
| 912 |
+
bias_q = bias_q.reshape(block.hidden_size)
|
| 913 |
+
bias_k = bias_k.reshape(block.hidden_size)
|
| 914 |
+
bias_v = bias_v.reshape(block.hidden_size)
|
| 915 |
+
bias = torch.cat([bias_q, bias_k, bias_v], dim=0)
|
| 916 |
+
try:
|
| 917 |
+
block.inner_mha_cls.Wqkv.bias.data = bias.to(target_device)
|
| 918 |
+
except:
|
| 919 |
+
pass
|
| 920 |
+
|
| 921 |
+
def to_bfloat16_except_pr_lc(self, to_float32=False):
|
| 922 |
+
"""Convert all parameters to bfloat16 except for the poles and residues.
|
| 923 |
+
|
| 924 |
+
Particularly important for longer prompts.
|
| 925 |
+
"""
|
| 926 |
+
excluded_shapes = [(4096, 1, 128)]
|
| 927 |
+
for k, p in self.named_parameters():
|
| 928 |
+
if "projections" not in k: # avoid TE linears
|
| 929 |
+
if "log_poles" not in k and "residues" not in k and p.shape not in excluded_shapes:
|
| 930 |
+
p.data = p.data.to(torch.bfloat16)
|
| 931 |
+
else:
|
| 932 |
+
if to_float32:
|
| 933 |
+
p.data = p.data.to(torch.float32)
|
| 934 |
+
for k, b in self.named_buffers():
|
| 935 |
+
if "inv_freq" in k:
|
| 936 |
+
if to_float32:
|
| 937 |
+
b.data = b.data.to(torch.float32)
|
positional_embeddings.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied verbatim from vortex
|
| 2 |
+
"""
|
| 3 |
+
Armin Thomas, Jan 2023. Modified by Eric Nguyen.
|
| 4 |
+
|
| 5 |
+
Wrappers for linearly interpolated rope embeddings to use inside of MHA layers of Flash Attn.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from .rotary import RotaryEmbedding
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# simple wrapper for flash-attn RoPE with linear scaling:
|
| 15 |
+
class LinearlyScaledRotaryEmbedding(RotaryEmbedding):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim: int,
|
| 19 |
+
scaling_factor: float = 1.0,
|
| 20 |
+
base=10000.0,
|
| 21 |
+
interleaved=False,
|
| 22 |
+
scale_base=None,
|
| 23 |
+
pos_idx_in_fp32=True,
|
| 24 |
+
device=None,
|
| 25 |
+
):
|
| 26 |
+
super().__init__(
|
| 27 |
+
dim=dim,
|
| 28 |
+
base=base,
|
| 29 |
+
interleaved=interleaved,
|
| 30 |
+
scale_base=scale_base,
|
| 31 |
+
pos_idx_in_fp32=pos_idx_in_fp32,
|
| 32 |
+
device=device,
|
| 33 |
+
)
|
| 34 |
+
self._linear_scaling_factor = scaling_factor
|
| 35 |
+
|
| 36 |
+
# adpated from: https://github.com/Dao-AILab/flash-attention/blob/43ceab630bc6c27712428da5a33fc9cb5c369d91/flash_attn/layers/rotary.py#L368
|
| 37 |
+
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
| 38 |
+
# Reset the tables if the sequence length has changed,
|
| 39 |
+
# if we're on a new device (possibly due to tracing for instance),
|
| 40 |
+
# or if we're switching from inference mode to training
|
| 41 |
+
if (
|
| 42 |
+
seqlen > self._seq_len_cached
|
| 43 |
+
or self._cos_cached is None
|
| 44 |
+
or self._cos_cached.device != device
|
| 45 |
+
or self._cos_cached.dtype != dtype
|
| 46 |
+
or (self.training and self._cos_cached.is_inference())
|
| 47 |
+
):
|
| 48 |
+
self._seq_len_cached = seqlen
|
| 49 |
+
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
| 50 |
+
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
| 51 |
+
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
| 52 |
+
if self.pos_idx_in_fp32:
|
| 53 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 54 |
+
# linear scaling:
|
| 55 |
+
t = t / self._linear_scaling_factor
|
| 56 |
+
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
| 57 |
+
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
| 58 |
+
# cos & sin output to change significantly.
|
| 59 |
+
# We want to recompute self.inv_freq if it was not loaded in fp32
|
| 60 |
+
if self.inv_freq.dtype != torch.float32:
|
| 61 |
+
inv_freq = self._compute_inv_freq(device=device)
|
| 62 |
+
else:
|
| 63 |
+
inv_freq = self.inv_freq
|
| 64 |
+
else:
|
| 65 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
| 66 |
+
# linear scaling:
|
| 67 |
+
t = t / self._linear_scaling_factor
|
| 68 |
+
inv_freq = self.inv_freq
|
| 69 |
+
# Don't do einsum, it converts fp32 to fp16 under AMP
|
| 70 |
+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 71 |
+
freqs = torch.outer(t, inv_freq)
|
| 72 |
+
if self.scale is None:
|
| 73 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
| 74 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
| 75 |
+
else:
|
| 76 |
+
power = (
|
| 77 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
| 78 |
+
) / self.scale_base
|
| 79 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
| 80 |
+
# We want the multiplication by scale to happen in fp32
|
| 81 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 82 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
| 83 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
| 84 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# swap out RoPE of existing mha:
|
| 88 |
+
def swap_mha_rope(
|
| 89 |
+
mha,
|
| 90 |
+
new_rope: torch.nn.Module = LinearlyScaledRotaryEmbedding,
|
| 91 |
+
kwargs_new_rope: dict = None,
|
| 92 |
+
):
|
| 93 |
+
# determine mha dtype and device:
|
| 94 |
+
dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype
|
| 95 |
+
device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device
|
| 96 |
+
# determine RoPE settings:
|
| 97 |
+
kwargs_old_rope = dict(
|
| 98 |
+
dim=mha.rotary_emb.dim,
|
| 99 |
+
base=mha.rotary_emb.base,
|
| 100 |
+
interleaved=mha.rotary_emb.interleaved,
|
| 101 |
+
scale_base=mha.rotary_emb.scale_base,
|
| 102 |
+
pos_idx_in_fp32=mha.rotary_emb.pos_idx_in_fp32,
|
| 103 |
+
device=mha.rotary_emb.inv_freq.device,
|
| 104 |
+
)
|
| 105 |
+
# delete old RoPE:
|
| 106 |
+
del mha.rotary_emb
|
| 107 |
+
# create new RoPE:
|
| 108 |
+
kwargs_new_rope = kwargs_new_rope or {"scaling_factor": 1.0}
|
| 109 |
+
scaled_rope = new_rope(**kwargs_new_rope, **kwargs_old_rope).to(dtype)
|
| 110 |
+
# attach new RoPE to mha:
|
| 111 |
+
mha.rotary_emb = scaled_rope
|
| 112 |
+
# make new sure RoPE is correctly registered:
|
| 113 |
+
assert isinstance(mha.rotary_emb, new_rope)
|
| 114 |
+
return mha
|
sample.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied verbatim from vortex
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
| 6 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
| 7 |
+
def modify_logits_for_top_k_filtering(logits, top_k):
|
| 8 |
+
"""Set the logits for none top-k values to -inf. Done in-place."""
|
| 9 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 10 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
| 14 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
| 15 |
+
def modify_logits_for_top_p_filtering(logits, top_p):
|
| 16 |
+
"""Set the logits for none top-p values to -inf. Done in-place."""
|
| 17 |
+
if top_p <= 0.0 or top_p >= 1.0:
|
| 18 |
+
return
|
| 19 |
+
|
| 20 |
+
# First sort and calculate cumulative sum of probabilities.
|
| 21 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
| 22 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
| 23 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
| 24 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
| 25 |
+
# scatter sorted tensors to original indexing
|
| 26 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 27 |
+
logits.masked_fill_(indices_to_remove, float("-inf"))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py
|
| 31 |
+
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
| 32 |
+
"""Sample from top-k logits.
|
| 33 |
+
Arguments:
|
| 34 |
+
logits: Tensor of shape (batch_size, vocab_size)
|
| 35 |
+
"""
|
| 36 |
+
logits = torch.nan_to_num(logits)
|
| 37 |
+
logits = torch.where(logits == float("-inf"), 0, logits)
|
| 38 |
+
logits = torch.where(logits == float("inf"), 0, logits)
|
| 39 |
+
|
| 40 |
+
if top_k == 1: # Short-circuit for greedy decoding
|
| 41 |
+
return logits.argmax(dim=-1)
|
| 42 |
+
else:
|
| 43 |
+
if top_p > 0.0:
|
| 44 |
+
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
| 45 |
+
if top_k > 0:
|
| 46 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
| 47 |
+
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
| 48 |
+
if temperature != 1.0:
|
| 49 |
+
logits_top /= temperature
|
| 50 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
| 51 |
+
|
| 52 |
+
return indices[
|
| 53 |
+
torch.arange(indices.shape[0], device=indices.device),
|
| 54 |
+
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
|
| 55 |
+
]
|
| 56 |
+
else:
|
| 57 |
+
# Clone so that when we modify for top_p we don't change the original logits
|
| 58 |
+
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
| 59 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
| 60 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
utils.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied veratim from vortex
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
log = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
|
| 9 |
+
"""Get the dim for the local rank derived from splitting dim on world_size processes.
|
| 10 |
+
|
| 11 |
+
The split may not be even across the world_size processes.
|
| 12 |
+
"""
|
| 13 |
+
multiple = dim // multiple_of
|
| 14 |
+
div = multiple // world_size
|
| 15 |
+
mod = multiple % world_size
|
| 16 |
+
local_multiple = div + int(local_rank < mod)
|
| 17 |
+
return local_multiple * multiple_of
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def grab_first_if_tuple(x):
|
| 21 |
+
if x.__class__.__name__ == "tuple":
|
| 22 |
+
return x[0]
|
| 23 |
+
else:
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def interleave(z_pre):
|
| 28 |
+
if len(z_pre.shape) == 3: # non-cached
|
| 29 |
+
x1 = z_pre[:, 0::3, :]
|
| 30 |
+
x2 = z_pre[:, 1::3, :]
|
| 31 |
+
v = z_pre[:, 2::3, :]
|
| 32 |
+
z_pre = torch.cat([x1, x2, v], dim=1)
|
| 33 |
+
return z_pre
|
| 34 |
+
else:
|
| 35 |
+
x1 = z_pre[..., 0::3]
|
| 36 |
+
x2 = z_pre[..., 1::3]
|
| 37 |
+
v = z_pre[..., 2::3]
|
| 38 |
+
z_pre = torch.concat([x1, x2, v], dim=-1)
|
| 39 |
+
return z_pre
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def column_split(x, num_heads, head_size):
|
| 43 |
+
"""Split a tensor with `num_heads` alongside the head dimension, instead of
|
| 44 |
+
across heads. Fixed to three projections
|
| 45 |
+
"""
|
| 46 |
+
# FIXME: merge cases
|
| 47 |
+
if len(x.shape) == 2:
|
| 48 |
+
x_reshaped = x.reshape(
|
| 49 |
+
x.shape[0],
|
| 50 |
+
num_heads,
|
| 51 |
+
3 * head_size,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
x2, x1, v = (
|
| 55 |
+
x_reshaped[..., :head_size],
|
| 56 |
+
x_reshaped[..., head_size : 2 * head_size],
|
| 57 |
+
x_reshaped[..., 2 * head_size :],
|
| 58 |
+
)
|
| 59 |
+
x2, x1, v = (
|
| 60 |
+
x2.reshape(x2.shape[0], -1),
|
| 61 |
+
x1.reshape(x1.shape[0], -1),
|
| 62 |
+
v.reshape(v.shape[0], -1),
|
| 63 |
+
)
|
| 64 |
+
return x2, x1, v
|
| 65 |
+
else:
|
| 66 |
+
x = x.reshape(
|
| 67 |
+
x.shape[0],
|
| 68 |
+
num_heads,
|
| 69 |
+
3 * head_size,
|
| 70 |
+
x.shape[2],
|
| 71 |
+
)
|
| 72 |
+
x2, x1, v = (
|
| 73 |
+
x[:, :, :head_size],
|
| 74 |
+
x[
|
| 75 |
+
:,
|
| 76 |
+
:,
|
| 77 |
+
head_size : 2 * head_size,
|
| 78 |
+
],
|
| 79 |
+
x[:, :, 2 * head_size :],
|
| 80 |
+
)
|
| 81 |
+
x2, x1, v = (
|
| 82 |
+
x2.reshape(x2.shape[0], -1, x2.shape[-1]),
|
| 83 |
+
x1.reshape(x1.shape[0], -1, x1.shape[-1]),
|
| 84 |
+
v.reshape(v.shape[0], -1, v.shape[-1]),
|
| 85 |
+
)
|
| 86 |
+
return x2, x1, v
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def load_checkpoint(model, checkpoint_path):
|
| 90 |
+
if checkpoint_path is None:
|
| 91 |
+
log.warning("Using random weights (dry-run)")
|
| 92 |
+
return
|
| 93 |
+
log.info(f"Loading {checkpoint_path}")
|
| 94 |
+
|
| 95 |
+
# We must allowlist BytesIO, as fp8-enabled checkpoints store this type
|
| 96 |
+
# in Transformer Engine layers' _extra keys. If not, weights_only=True
|
| 97 |
+
# will not be happy.
|
| 98 |
+
import io
|
| 99 |
+
|
| 100 |
+
torch.serialization.add_safe_globals([io.BytesIO])
|
| 101 |
+
|
| 102 |
+
with torch.inference_mode():
|
| 103 |
+
state = torch.load(
|
| 104 |
+
checkpoint_path,
|
| 105 |
+
# Make sure we override device location that is specified in the
|
| 106 |
+
# checkpoint dictionary (e.g. checkpoints may have "cuda:0"
|
| 107 |
+
# as a location for all layers, which then wouldn't work for
|
| 108 |
+
# multi-GPU case.)
|
| 109 |
+
map_location="cpu",
|
| 110 |
+
# This is an optimization: with that, we don't actually read
|
| 111 |
+
# whole checkpoints dictionary from disk to CPU memory in one
|
| 112 |
+
# go; instead, pytorch would only load relevant layers to CPU
|
| 113 |
+
# memory when we are about to copy them to GPU.
|
| 114 |
+
mmap=True,
|
| 115 |
+
# Make sure PyTorch is not issuing a warning regarding potential
|
| 116 |
+
# security issues.
|
| 117 |
+
weights_only=True,
|
| 118 |
+
)
|
| 119 |
+
model.to_bfloat16_except_pr_lc(to_float32=True)
|
| 120 |
+
|
| 121 |
+
model.custom_load_state_dict(state)
|
| 122 |
+
|
| 123 |
+
model.to_bfloat16_except_pr_lc()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def move_to_device(module, device):
|
| 127 |
+
"""Recursively moves all parameters and buffers to the specified device."""
|
| 128 |
+
for child in module.children():
|
| 129 |
+
move_to_device(child, device)
|
| 130 |
+
|
| 131 |
+
for param in module.parameters(recurse=False):
|
| 132 |
+
if param.device != device:
|
| 133 |
+
param.data = param.data.to(device)
|
| 134 |
+
|
| 135 |
+
for buf in module.buffers(recurse=False):
|
| 136 |
+
if buf.device != device:
|
| 137 |
+
buf.data = buf.data.to(device)
|
| 138 |
+
|
| 139 |
+
module.to(device)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def fixup_fp8_extra_states(module):
|
| 143 |
+
"""Recursively fixes device location of TE's Linear fp8 extra states."""
|
| 144 |
+
for child in module.children():
|
| 145 |
+
fixup_fp8_extra_states(child)
|
| 146 |
+
|
| 147 |
+
# TE Linear uses default "cuda" device to load extra state, which causes
|
| 148 |
+
# trouble when the layer is moved to another GPU. Instead, this is how
|
| 149 |
+
# TE Linear should load extra_state: using parameters' device.
|
| 150 |
+
torch_load = torch.load
|
| 151 |
+
|
| 152 |
+
def overriden_load(state, map_location):
|
| 153 |
+
device = next(module.parameters()).device
|
| 154 |
+
return torch_load(state, map_location=device)
|
| 155 |
+
|
| 156 |
+
if hasattr(module, "fp8_meta"):
|
| 157 |
+
log.debug(f"Reloading fp8 extra state to a proper device for {module}")
|
| 158 |
+
from unittest.mock import patch
|
| 159 |
+
|
| 160 |
+
with patch("torch.load", new=overriden_load):
|
| 161 |
+
module.set_extra_state(module.get_extra_state())
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def fixup_te_workspace():
|
| 165 |
+
"""TE uses single workspace tensor for all calls, disregarding that inputs
|
| 166 |
+
may be on separate GPUs. This patches TE's Linear module to use per-device
|
| 167 |
+
workspaces."""
|
| 168 |
+
from functools import lru_cache
|
| 169 |
+
|
| 170 |
+
@lru_cache
|
| 171 |
+
def te_cublas_get_workspace_per_device(device):
|
| 172 |
+
log.info(f"Fixup applied: Allocating cublas workspace for {device=}")
|
| 173 |
+
import transformer_engine.pytorch.module.base as tebase
|
| 174 |
+
|
| 175 |
+
with torch.cuda.device(device):
|
| 176 |
+
tebase._cublas_workspace = None # Force get_workspace() to reallocate tensor
|
| 177 |
+
return tebase.get_workspace()
|
| 178 |
+
|
| 179 |
+
def get_workspace():
|
| 180 |
+
return te_cublas_get_workspace_per_device(torch.cuda.current_device())
|
| 181 |
+
|
| 182 |
+
import transformer_engine.pytorch.module.linear as telinear
|
| 183 |
+
|
| 184 |
+
telinear.get_workspace = get_workspace
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_init_from_string(init_str):
|
| 188 |
+
if type(init_str) == str:
|
| 189 |
+
if init_str == "torch.nn.init.zeros_":
|
| 190 |
+
return torch.nn.init.zeros_
|
| 191 |
+
elif init_str == "torch.nn.init.xavier_uniform_":
|
| 192 |
+
return torch.nn.init.xavier_uniform_
|
| 193 |
+
elif init_str == "torch.nn.init.xavier_normal_":
|
| 194 |
+
return torch.nn.init.xavier_normal_
|
| 195 |
+
else:
|
| 196 |
+
raise ValueError(f"Unrecognized init {init_str}")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def print_rank_0(message, debug=False, end="\n"):
|
| 200 |
+
"""Print from rank 0 only."""
|
| 201 |
+
if torch.distributed.is_initialized():
|
| 202 |
+
if torch.distributed.get_rank() == 0:
|
| 203 |
+
print(message, flush=True, end=end)
|
| 204 |
+
else:
|
| 205 |
+
print(message, flush=True, end=end)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class dotdict(dict):
|
| 209 |
+
"""dot.notation access to dictionary attributes"""
|
| 210 |
+
|
| 211 |
+
__getattr__ = dict.get
|
| 212 |
+
__setattr__ = dict.__setitem__
|
| 213 |
+
__delattr__ = dict.__delitem__
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def ensure_divisibility(numerator, denominator):
|
| 217 |
+
"""Ensure that numerator is divisible by the denominator."""
|
| 218 |
+
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def divide(numerator, denominator):
|
| 222 |
+
"""Ensure that numerator is divisible by the denominator and return
|
| 223 |
+
the division value."""
|
| 224 |
+
ensure_divisibility(numerator, denominator)
|
| 225 |
+
return numerator // denominator
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class Lambda(torch.nn.Module):
|
| 229 |
+
def __init__(self, func):
|
| 230 |
+
super().__init__()
|
| 231 |
+
self.func = func
|
| 232 |
+
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
return self.func(x)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class VocabUtility:
|
| 238 |
+
"""Split the vocabulary into `world_size` chunks amd return the
|
| 239 |
+
first and last index of the vocabulary belonging to the `rank`
|
| 240 |
+
partition: Note that indices in [first, last]"""
|
| 241 |
+
|
| 242 |
+
@staticmethod
|
| 243 |
+
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
|
| 244 |
+
index_f = rank * per_partition_vocab_size
|
| 245 |
+
index_l = index_f + per_partition_vocab_size
|
| 246 |
+
return index_f, index_l
|
| 247 |
+
|
| 248 |
+
@staticmethod
|
| 249 |
+
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
|
| 250 |
+
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
| 251 |
+
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
|