Upload FastSLMForCausalLM
Browse files- config.json +2 -2
- delta_net.py +1 -2
- model.safetensors.index.json +1 -0
- modeling_fast_slm.py +130 -30
config.json
CHANGED
|
@@ -13,6 +13,7 @@
|
|
| 13 |
"bos_token_id": 1,
|
| 14 |
"calc_logits_for_entire_prompt": false,
|
| 15 |
"d_conv": 4,
|
|
|
|
| 16 |
"eos_token_id": 2,
|
| 17 |
"ffn_expand_ratio": 3,
|
| 18 |
"global_attn_idx": [],
|
|
@@ -127,8 +128,7 @@
|
|
| 127 |
"router_aux_loss_coef": 0.001,
|
| 128 |
"sliding_window": null,
|
| 129 |
"tie_word_embeddings": true,
|
| 130 |
-
"
|
| 131 |
-
"transformers_version": "4.48.2",
|
| 132 |
"use_cache": false,
|
| 133 |
"use_mamba_kernels": true,
|
| 134 |
"v_head_dim": -1,
|
|
|
|
| 13 |
"bos_token_id": 1,
|
| 14 |
"calc_logits_for_entire_prompt": false,
|
| 15 |
"d_conv": 4,
|
| 16 |
+
"dtype": "bfloat16",
|
| 17 |
"eos_token_id": 2,
|
| 18 |
"ffn_expand_ratio": 3,
|
| 19 |
"global_attn_idx": [],
|
|
|
|
| 128 |
"router_aux_loss_coef": 0.001,
|
| 129 |
"sliding_window": null,
|
| 130 |
"tie_word_embeddings": true,
|
| 131 |
+
"transformers_version": "4.56.2",
|
|
|
|
| 132 |
"use_cache": false,
|
| 133 |
"use_mamba_kernels": true,
|
| 134 |
"v_head_dim": -1,
|
delta_net.py
CHANGED
|
@@ -10,7 +10,6 @@ import torch.nn as nn
|
|
| 10 |
from einops import rearrange
|
| 11 |
from torch.nn import functional as F
|
| 12 |
|
| 13 |
-
import fla
|
| 14 |
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
|
| 15 |
from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
|
| 16 |
|
|
@@ -331,7 +330,7 @@ class Cache(transformers.cache_utils.Cache):
|
|
| 331 |
self,
|
| 332 |
seen_tokens: int = 0
|
| 333 |
) -> Cache:
|
| 334 |
-
super().__init__()
|
| 335 |
|
| 336 |
self.states: List[Dict[str, Any]] = []
|
| 337 |
|
|
|
|
| 10 |
from einops import rearrange
|
| 11 |
from torch.nn import functional as F
|
| 12 |
|
|
|
|
| 13 |
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
|
| 14 |
from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
|
| 15 |
|
|
|
|
| 330 |
self,
|
| 331 |
seen_tokens: int = 0
|
| 332 |
) -> Cache:
|
| 333 |
+
super().__init__(layers=[0])
|
| 334 |
|
| 335 |
self.states: List[Dict[str, Any]] = []
|
| 336 |
|
model.safetensors.index.json
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"metadata": {
|
|
|
|
| 3 |
"total_size": 5500034112
|
| 4 |
},
|
| 5 |
"weight_map": {
|
|
|
|
| 1 |
{
|
| 2 |
"metadata": {
|
| 3 |
+
"total_parameters": 2750017056,
|
| 4 |
"total_size": 5500034112
|
| 5 |
},
|
| 6 |
"weight_map": {
|
modeling_fast_slm.py
CHANGED
|
@@ -46,6 +46,12 @@ from transformers.modeling_outputs import (
|
|
| 46 |
SequenceClassifierOutputWithPast,
|
| 47 |
)
|
| 48 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
|
| 50 |
from transformers.utils import (
|
| 51 |
add_start_docstrings,
|
|
@@ -280,7 +286,6 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
| 280 |
def __init__(self, config, batch_size, dtype=torch.float16, device=None, layer_type=None):
|
| 281 |
self.dtype = dtype
|
| 282 |
# self.layers_block_type = config.layers_block_type
|
| 283 |
-
self.has_previous_state = False
|
| 284 |
intermediate_size = config.mamba_expand * config.hidden_size
|
| 285 |
ssm_state_size = config.mamba_d_state
|
| 286 |
conv_kernel_size = config.mamba_d_conv
|
|
@@ -804,6 +809,75 @@ class FastSLMFlashAttention2(FastSLMAttention):
|
|
| 804 |
)
|
| 805 |
|
| 806 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
class FastSLMFused_MHA(FastSLMAttention):
|
| 808 |
"""
|
| 809 |
FastSLM flash attention module. This module inherits from `FastSLMAttention` as the weights of the module stays
|
|
@@ -938,9 +1012,6 @@ class FastSLMFused_MHA(FastSLMAttention):
|
|
| 938 |
v_dim = query_states.shape[-2] * value_states.shape[-1]
|
| 939 |
attn_output = attn_output.reshape(bsz, q_len, v_dim).contiguous()
|
| 940 |
|
| 941 |
-
if past_key_value is not None:
|
| 942 |
-
past_key_value.has_previous_state = True
|
| 943 |
-
|
| 944 |
attn_output = self.o_proj(attn_output)
|
| 945 |
|
| 946 |
if not output_attentions:
|
|
@@ -952,6 +1023,7 @@ class FastSLMFused_MHA(FastSLMAttention):
|
|
| 952 |
JAMBA_ATTENTION_CLASSES = {
|
| 953 |
"flash_attention_2": FastSLMFlashAttention2,
|
| 954 |
"fused_mha": FastSLMFused_MHA,
|
|
|
|
| 955 |
}
|
| 956 |
|
| 957 |
class FastSLMMLP(nn.Module):
|
|
@@ -1633,6 +1705,8 @@ class FastSLMModel(FastSLMPreTrainedModel):
|
|
| 1633 |
# Initialize weights and apply final processing
|
| 1634 |
self.post_init()
|
| 1635 |
|
|
|
|
|
|
|
| 1636 |
|
| 1637 |
def get_input_embeddings(self):
|
| 1638 |
return self.embed_tokens
|
|
@@ -1684,21 +1758,13 @@ class FastSLMModel(FastSLMPreTrainedModel):
|
|
| 1684 |
)
|
| 1685 |
use_cache = False
|
| 1686 |
|
| 1687 |
-
past_key_values_length = 0
|
| 1688 |
-
if use_cache:
|
| 1689 |
-
if past_key_values is not None:
|
| 1690 |
-
past_key_values_length = past_key_values.get_usable_length(seq_length, 0)
|
| 1691 |
-
else:
|
| 1692 |
-
use_cache = False
|
| 1693 |
-
|
| 1694 |
if position_ids is None:
|
| 1695 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 1696 |
-
position_ids = torch.arange(
|
| 1697 |
-
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
| 1698 |
)
|
| 1699 |
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 1700 |
else:
|
| 1701 |
-
if self.config.num_memory_tokens > 0 and past_key_values is not None and not
|
| 1702 |
position_ids = position_ids.view(-1, seq_length + self.config.num_memory_tokens).long()
|
| 1703 |
else:
|
| 1704 |
position_ids = position_ids.view(-1, seq_length).long()
|
|
@@ -1708,7 +1774,7 @@ class FastSLMModel(FastSLMPreTrainedModel):
|
|
| 1708 |
|
| 1709 |
ori_b, ori_n = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
| 1710 |
|
| 1711 |
-
if self.config.num_memory_tokens > 0 and (past_key_values is None or not
|
| 1712 |
mem = repeat(self.memory_tokens, 'n d -> b n d', b = inputs_embeds.shape[0]) # prepend the memory to every segment of m by repeating the memory tokens
|
| 1713 |
inputs_embeds, mem_packed_shape = pack((mem, inputs_embeds), 'b * d')
|
| 1714 |
|
|
@@ -1718,6 +1784,7 @@ class FastSLMModel(FastSLMPreTrainedModel):
|
|
| 1718 |
if attention_mask is not None and attention_mask.shape[1] < inputs_embeds.shape[1]:
|
| 1719 |
assert attention_mask.shape[1] + self.config.num_memory_tokens == inputs_embeds.shape[1]
|
| 1720 |
attention_mask = torch.cat([torch.ones(inputs_embeds.shape[0], self.config.num_memory_tokens, device=attention_mask.device), attention_mask], dim=1)
|
|
|
|
| 1721 |
|
| 1722 |
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
| 1723 |
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
|
@@ -1784,21 +1851,12 @@ class FastSLMModel(FastSLMPreTrainedModel):
|
|
| 1784 |
if output_hidden_states:
|
| 1785 |
all_hidden_states += (hidden_states,)
|
| 1786 |
|
| 1787 |
-
if self.config.num_memory_tokens > 0 and (past_key_values is None or not
|
| 1788 |
mem, hidden_states = unpack(hidden_states, mem_packed_shape, 'b * d')
|
| 1789 |
hidden_states = hidden_states[:, :ori_n, :]
|
| 1790 |
|
| 1791 |
-
if past_key_values and not
|
| 1792 |
-
|
| 1793 |
-
if past_key_values.get_seq_length(layer_idx_) > 0:
|
| 1794 |
-
past_key_values.has_previous_state = True
|
| 1795 |
-
break
|
| 1796 |
-
|
| 1797 |
-
if mamba_inference_params is not None and mamba_inference_params.seqlen_offset > 0:
|
| 1798 |
-
past_key_values.has_previous_state = True
|
| 1799 |
-
|
| 1800 |
-
if fla_past_key_values is not None and len(fla_past_key_values.states) > 0:
|
| 1801 |
-
past_key_values.has_previous_state = True
|
| 1802 |
|
| 1803 |
next_cache = None
|
| 1804 |
if use_cache:
|
|
@@ -2011,7 +2069,7 @@ class FastSLMForCausalLM(FastSLMPreTrainedModel):
|
|
| 2011 |
static_logits = torch.zeros((batch_size, self.config.vocab_size), device=device)
|
| 2012 |
|
| 2013 |
# Set up for graph capture
|
| 2014 |
-
|
| 2015 |
if mamba_inference_params is not None:
|
| 2016 |
mamba_inference_params.seqlen_offset = 1
|
| 2017 |
|
|
@@ -2055,7 +2113,7 @@ class FastSLMForCausalLM(FastSLMPreTrainedModel):
|
|
| 2055 |
if hasattr(module, 'reset_kv_cache'):
|
| 2056 |
module.reset_kv_cache()
|
| 2057 |
|
| 2058 |
-
|
| 2059 |
|
| 2060 |
# Return generation state
|
| 2061 |
generation_state = {
|
|
@@ -2134,7 +2192,7 @@ class FastSLMForCausalLM(FastSLMPreTrainedModel):
|
|
| 2134 |
if hasattr(module, 'reset_kv_cache'):
|
| 2135 |
module.reset_kv_cache()
|
| 2136 |
|
| 2137 |
-
|
| 2138 |
|
| 2139 |
# Prefill phase - process input sequence
|
| 2140 |
position_ids = torch.arange(
|
|
@@ -2215,6 +2273,48 @@ class FastSLMForCausalLM(FastSLMPreTrainedModel):
|
|
| 2215 |
return generated_ids
|
| 2216 |
|
| 2217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2218 |
def sample_token(logits, temperature=1.0, top_k=0, top_p=0.9):
|
| 2219 |
"""
|
| 2220 |
Sample a token from logits with temperature, top-k, and top-p filtering.
|
|
|
|
| 46 |
SequenceClassifierOutputWithPast,
|
| 47 |
)
|
| 48 |
from transformers.modeling_utils import PreTrainedModel
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
| 52 |
+
except ImportError:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
|
| 56 |
from transformers.utils import (
|
| 57 |
add_start_docstrings,
|
|
|
|
| 286 |
def __init__(self, config, batch_size, dtype=torch.float16, device=None, layer_type=None):
|
| 287 |
self.dtype = dtype
|
| 288 |
# self.layers_block_type = config.layers_block_type
|
|
|
|
| 289 |
intermediate_size = config.mamba_expand * config.hidden_size
|
| 290 |
ssm_state_size = config.mamba_d_state
|
| 291 |
conv_kernel_size = config.mamba_d_conv
|
|
|
|
| 809 |
)
|
| 810 |
|
| 811 |
|
| 812 |
+
|
| 813 |
+
class FastSLMSDPAAttention(nn.Module):
|
| 814 |
+
|
| 815 |
+
def __init__(self, config, layer_idx: int, reuse_kv=False):
|
| 816 |
+
super().__init__()
|
| 817 |
+
self.config = config
|
| 818 |
+
self.layer_idx = layer_idx
|
| 819 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 820 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 821 |
+
self.scaling = self.head_dim**-0.5
|
| 822 |
+
self.attention_dropout = config.attention_dropout
|
| 823 |
+
self.is_causal = True
|
| 824 |
+
|
| 825 |
+
self.q_proj = nn.Linear(
|
| 826 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=False
|
| 827 |
+
)
|
| 828 |
+
self.k_proj = nn.Linear(
|
| 829 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
|
| 830 |
+
)
|
| 831 |
+
self.v_proj = nn.Linear(
|
| 832 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
|
| 833 |
+
)
|
| 834 |
+
self.o_proj = nn.Linear(
|
| 835 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
self.sliding_window = self.config.sliding_window if self.layer_idx not in self.config.global_attn_idx else None
|
| 839 |
+
|
| 840 |
+
def forward(
|
| 841 |
+
self,
|
| 842 |
+
hidden_states: torch.Tensor,
|
| 843 |
+
# position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 844 |
+
attention_mask: Optional[torch.Tensor],
|
| 845 |
+
past_key_value: Optional[Cache] = None,
|
| 846 |
+
**kwargs,
|
| 847 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 848 |
+
input_shape = hidden_states.shape[:-1]
|
| 849 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 850 |
+
|
| 851 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 852 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 853 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 854 |
+
|
| 855 |
+
# cos, sin = position_embeddings
|
| 856 |
+
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 857 |
+
|
| 858 |
+
if past_key_value is not None:
|
| 859 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) # , cache_kwargs)
|
| 860 |
+
|
| 861 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS['flash_attention_2']
|
| 862 |
+
|
| 863 |
+
attn_output, attn_weights = attention_interface(
|
| 864 |
+
self,
|
| 865 |
+
query_states,
|
| 866 |
+
key_states,
|
| 867 |
+
value_states,
|
| 868 |
+
attention_mask,
|
| 869 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 870 |
+
scaling=self.scaling,
|
| 871 |
+
sliding_window=self.sliding_window, # diff with Llama
|
| 872 |
+
**kwargs,
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 876 |
+
attn_output = self.o_proj(attn_output)
|
| 877 |
+
|
| 878 |
+
return attn_output, attn_weights, past_key_value, (key_states, value_states)
|
| 879 |
+
|
| 880 |
+
|
| 881 |
class FastSLMFused_MHA(FastSLMAttention):
|
| 882 |
"""
|
| 883 |
FastSLM flash attention module. This module inherits from `FastSLMAttention` as the weights of the module stays
|
|
|
|
| 1012 |
v_dim = query_states.shape[-2] * value_states.shape[-1]
|
| 1013 |
attn_output = attn_output.reshape(bsz, q_len, v_dim).contiguous()
|
| 1014 |
|
|
|
|
|
|
|
|
|
|
| 1015 |
attn_output = self.o_proj(attn_output)
|
| 1016 |
|
| 1017 |
if not output_attentions:
|
|
|
|
| 1023 |
JAMBA_ATTENTION_CLASSES = {
|
| 1024 |
"flash_attention_2": FastSLMFlashAttention2,
|
| 1025 |
"fused_mha": FastSLMFused_MHA,
|
| 1026 |
+
"sdpa": FastSLMSDPAAttention,
|
| 1027 |
}
|
| 1028 |
|
| 1029 |
class FastSLMMLP(nn.Module):
|
|
|
|
| 1705 |
# Initialize weights and apply final processing
|
| 1706 |
self.post_init()
|
| 1707 |
|
| 1708 |
+
self.has_previous_state = False
|
| 1709 |
+
|
| 1710 |
|
| 1711 |
def get_input_embeddings(self):
|
| 1712 |
return self.embed_tokens
|
|
|
|
| 1758 |
)
|
| 1759 |
use_cache = False
|
| 1760 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1761 |
if position_ids is None:
|
| 1762 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 1763 |
+
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device
|
|
|
|
| 1764 |
)
|
| 1765 |
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 1766 |
else:
|
| 1767 |
+
if self.config.num_memory_tokens > 0 and past_key_values is not None and not self.has_previous_state:
|
| 1768 |
position_ids = position_ids.view(-1, seq_length + self.config.num_memory_tokens).long()
|
| 1769 |
else:
|
| 1770 |
position_ids = position_ids.view(-1, seq_length).long()
|
|
|
|
| 1774 |
|
| 1775 |
ori_b, ori_n = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
| 1776 |
|
| 1777 |
+
if self.config.num_memory_tokens > 0 and (past_key_values is None or not self.has_previous_state):
|
| 1778 |
mem = repeat(self.memory_tokens, 'n d -> b n d', b = inputs_embeds.shape[0]) # prepend the memory to every segment of m by repeating the memory tokens
|
| 1779 |
inputs_embeds, mem_packed_shape = pack((mem, inputs_embeds), 'b * d')
|
| 1780 |
|
|
|
|
| 1784 |
if attention_mask is not None and attention_mask.shape[1] < inputs_embeds.shape[1]:
|
| 1785 |
assert attention_mask.shape[1] + self.config.num_memory_tokens == inputs_embeds.shape[1]
|
| 1786 |
attention_mask = torch.cat([torch.ones(inputs_embeds.shape[0], self.config.num_memory_tokens, device=attention_mask.device), attention_mask], dim=1)
|
| 1787 |
+
|
| 1788 |
|
| 1789 |
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
| 1790 |
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
|
|
|
| 1851 |
if output_hidden_states:
|
| 1852 |
all_hidden_states += (hidden_states,)
|
| 1853 |
|
| 1854 |
+
if self.config.num_memory_tokens > 0 and (past_key_values is None or not self.has_previous_state):
|
| 1855 |
mem, hidden_states = unpack(hidden_states, mem_packed_shape, 'b * d')
|
| 1856 |
hidden_states = hidden_states[:, :ori_n, :]
|
| 1857 |
|
| 1858 |
+
if past_key_values is not None and not self.has_previous_state:
|
| 1859 |
+
self.has_previous_state = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1860 |
|
| 1861 |
next_cache = None
|
| 1862 |
if use_cache:
|
|
|
|
| 2069 |
static_logits = torch.zeros((batch_size, self.config.vocab_size), device=device)
|
| 2070 |
|
| 2071 |
# Set up for graph capture
|
| 2072 |
+
self.model.has_previous_state = True
|
| 2073 |
if mamba_inference_params is not None:
|
| 2074 |
mamba_inference_params.seqlen_offset = 1
|
| 2075 |
|
|
|
|
| 2113 |
if hasattr(module, 'reset_kv_cache'):
|
| 2114 |
module.reset_kv_cache()
|
| 2115 |
|
| 2116 |
+
self.model.has_previous_state = False
|
| 2117 |
|
| 2118 |
# Return generation state
|
| 2119 |
generation_state = {
|
|
|
|
| 2192 |
if hasattr(module, 'reset_kv_cache'):
|
| 2193 |
module.reset_kv_cache()
|
| 2194 |
|
| 2195 |
+
self.model.has_previous_state = False
|
| 2196 |
|
| 2197 |
# Prefill phase - process input sequence
|
| 2198 |
position_ids = torch.arange(
|
|
|
|
| 2273 |
return generated_ids
|
| 2274 |
|
| 2275 |
|
| 2276 |
+
def prepare_inputs_for_generation(
|
| 2277 |
+
self,
|
| 2278 |
+
input_ids,
|
| 2279 |
+
past_key_values=None,
|
| 2280 |
+
attention_mask=None,
|
| 2281 |
+
inputs_embeds=None,
|
| 2282 |
+
output_router_logits=False,
|
| 2283 |
+
**kwargs,
|
| 2284 |
+
):
|
| 2285 |
+
if self.config.num_memory_tokens > 0:
|
| 2286 |
+
attention_mask = torch.cat([torch.ones(input_ids.shape[0], self.config.num_memory_tokens, device=attention_mask.device), attention_mask], dim=1)
|
| 2287 |
+
|
| 2288 |
+
past_key_values = None # Disable cache for now
|
| 2289 |
+
|
| 2290 |
+
position_ids = kwargs.get("position_ids", None)
|
| 2291 |
+
if attention_mask is not None and position_ids is None:
|
| 2292 |
+
# create position_ids on the fly for batch generation
|
| 2293 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 2294 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 2295 |
+
position_ids = position_ids[:, -input_ids.shape[1]:]
|
| 2296 |
+
|
| 2297 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 2298 |
+
if inputs_embeds is not None:
|
| 2299 |
+
if input_ids.shape[1] == 0:
|
| 2300 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 2301 |
+
else:
|
| 2302 |
+
inputs_embeds_new = self.model.embed_tokens(input_ids)
|
| 2303 |
+
model_inputs = {"inputs_embeds": torch.cat([inputs_embeds, inputs_embeds_new], dim=1)}
|
| 2304 |
+
else:
|
| 2305 |
+
model_inputs = {"input_ids": input_ids}
|
| 2306 |
+
|
| 2307 |
+
model_inputs.update(
|
| 2308 |
+
{
|
| 2309 |
+
"position_ids": position_ids,
|
| 2310 |
+
"past_key_values": past_key_values,
|
| 2311 |
+
"use_cache": kwargs.get("use_cache"),
|
| 2312 |
+
"attention_mask": attention_mask,
|
| 2313 |
+
}
|
| 2314 |
+
)
|
| 2315 |
+
return model_inputs
|
| 2316 |
+
|
| 2317 |
+
|
| 2318 |
def sample_token(logits, temperature=1.0, top_k=0, top_p=0.9):
|
| 2319 |
"""
|
| 2320 |
Sample a token from logits with temperature, top-k, and top-p filtering.
|