arch-btw commited on
Commit
d844dcb
·
verified ·
1 Parent(s): 862b31e

Upload 8 files

Browse files
config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "KimiLinearForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_kimi.KimiLinearConfig",
7
+ "AutoModelForCausalLM": "modeling_kimi.KimiLinearForCausalLM",
8
+ "AutoTokenizer": [
9
+ "tokenization_kimi.TikTokenTokenizer",
10
+ null
11
+ ]
12
+ },
13
+ "bos_token_id": 163584,
14
+ "dtype": "float32",
15
+ "eos_token_id": 163585,
16
+ "first_k_dense_replace": 0,
17
+ "head_dim": 64,
18
+ "hidden_act": "silu",
19
+ "hidden_size": 256,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 512,
22
+ "kv_lora_rank": 32,
23
+ "linear_attn_config": {
24
+ "full_attn_layers": [
25
+ 2,
26
+ 4
27
+ ],
28
+ "head_dim": 64,
29
+ "kda_layers": [
30
+ 1,
31
+ 3
32
+ ],
33
+ "num_heads": 4,
34
+ "short_conv_kernel_size": 4
35
+ },
36
+ "mla_use_nope": true,
37
+ "model_type": "kimi_linear",
38
+ "moe_intermediate_size": 128,
39
+ "moe_layer_freq": 1,
40
+ "moe_renormalize": true,
41
+ "moe_router_activation_func": "sigmoid",
42
+ "num_attention_heads": 4,
43
+ "num_expert_group": 1,
44
+ "num_experts": 4,
45
+ "num_experts_per_token": 2,
46
+ "num_hidden_layers": 4,
47
+ "num_key_value_heads": 4,
48
+ "num_nextn_predict_layers": 0,
49
+ "num_shared_experts": 1,
50
+ "pad_token_id": 163839,
51
+ "q_lora_rank": null,
52
+ "qk_nope_head_dim": 32,
53
+ "qk_rope_head_dim": 32,
54
+ "rms_norm_eps": 1e-06,
55
+ "rope_scaling": null,
56
+ "rope_theta": 10000.0,
57
+ "routed_scaling_factor": 1.0,
58
+ "tie_word_embeddings": false,
59
+ "topk_group": 1,
60
+ "transformers_version": "4.57.3",
61
+ "use_cache": true,
62
+ "use_grouped_topk": true,
63
+ "v_head_dim": 64,
64
+ "vocab_size": 200000
65
+ }
configuration_kimi.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from typing import Optional
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+
7
+ class KimiLinearConfig(PretrainedConfig):
8
+ model_type = "kimi_linear"
9
+ keys_to_ignore_at_inference = ["past_key_values"]
10
+
11
+ def __init__(
12
+ self,
13
+ model_type="kimi_linear",
14
+ vocab_size=163840,
15
+ hidden_size=4096,
16
+ head_dim=None,
17
+ intermediate_size=11008,
18
+ num_hidden_layers=32,
19
+ num_attention_heads=32,
20
+ num_key_value_heads=None,
21
+ hidden_act="silu",
22
+ initializer_range=0.02,
23
+ rms_norm_eps=1e-6,
24
+ use_cache=True,
25
+ pad_token_id=0,
26
+ bos_token_id=1,
27
+ eos_token_id=2,
28
+ rope_theta=10000.0,
29
+ rope_scaling=None,
30
+ tie_word_embeddings=False,
31
+ moe_intermediate_size: Optional[int] = None,
32
+ moe_renormalize: bool = True,
33
+ moe_router_activation_func: str = "sigmoid",
34
+ num_experts: Optional[int] = None,
35
+ num_experts_per_token: Optional[int] = None,
36
+ num_shared_experts: int = 0,
37
+ routed_scaling_factor: float = 1.0,
38
+ first_k_dense_replace: int = 0,
39
+ moe_layer_freq: int = 1,
40
+ use_grouped_topk: bool = True,
41
+ num_expert_group: int = 1,
42
+ topk_group: int = 1,
43
+ q_lora_rank: Optional[int] = None,
44
+ kv_lora_rank: Optional[int] = None,
45
+ qk_nope_head_dim: Optional[int] = None,
46
+ qk_rope_head_dim: Optional[int] = None,
47
+ v_head_dim: Optional[int] = None,
48
+ mla_use_nope: Optional[bool] = False,
49
+ num_nextn_predict_layers: int = 0,
50
+ linear_attn_config: Optional[dict] = None,
51
+ **kwargs,
52
+ ):
53
+ self.model_type = model_type
54
+ self.vocab_size = vocab_size
55
+ self.hidden_size = hidden_size
56
+ self.head_dim = (
57
+ head_dim if head_dim is not None else hidden_size // num_attention_heads
58
+ )
59
+ self.intermediate_size = intermediate_size
60
+ self.num_hidden_layers = num_hidden_layers
61
+ self.num_attention_heads = num_attention_heads
62
+
63
+ # for backward compatibility
64
+ if num_key_value_heads is None:
65
+ num_key_value_heads = num_attention_heads
66
+
67
+ self.num_key_value_heads = num_key_value_heads
68
+ self.hidden_act = hidden_act
69
+ self.initializer_range = initializer_range
70
+ self.rms_norm_eps = rms_norm_eps
71
+ self.use_cache = use_cache
72
+ self.rope_theta = rope_theta
73
+ self.rope_scaling = rope_scaling
74
+
75
+ self.q_lora_rank = q_lora_rank
76
+ self.kv_lora_rank = kv_lora_rank
77
+ self.qk_nope_head_dim = qk_nope_head_dim
78
+ self.qk_rope_head_dim = qk_rope_head_dim
79
+ self.v_head_dim = v_head_dim
80
+ self.mla_use_nope = mla_use_nope
81
+ # moe config
82
+ self.num_experts = num_experts
83
+ self.num_experts_per_token = num_experts_per_token
84
+ self.moe_renormalize = moe_renormalize
85
+ self.num_shared_experts = num_shared_experts
86
+ self.routed_scaling_factor = routed_scaling_factor
87
+ self.moe_router_activation_func = moe_router_activation_func
88
+ assert self.moe_router_activation_func in ("softmax", "sigmoid")
89
+ self.moe_intermediate_size = moe_intermediate_size
90
+ self.first_k_dense_replace = first_k_dense_replace
91
+ self.moe_layer_freq = moe_layer_freq
92
+ self.use_grouped_topk = use_grouped_topk
93
+ self.num_expert_group = num_expert_group
94
+ self.topk_group = topk_group
95
+ self.num_nextn_predict_layers = num_nextn_predict_layers
96
+
97
+ if linear_attn_config is not None:
98
+ assert linear_attn_config["kda_layers"] is not None
99
+ assert linear_attn_config["full_attn_layers"] is not None
100
+ self.linear_attn_config = linear_attn_config
101
+
102
+ super().__init__(
103
+ pad_token_id=pad_token_id,
104
+ bos_token_id=bos_token_id,
105
+ eos_token_id=eos_token_id,
106
+ tie_word_embeddings=tie_word_embeddings,
107
+ **kwargs,
108
+ )
109
+
110
+ @property
111
+ def is_mla(self):
112
+ return (
113
+ self.q_lora_rank is not None
114
+ or self.kv_lora_rank is not None
115
+ or self.qk_nope_head_dim is not None
116
+ or self.qk_rope_head_dim is not None
117
+ or self.v_head_dim is not None
118
+ or self.mla_use_nope is True
119
+ )
120
+
121
+ @property
122
+ def is_moe(self):
123
+ return self.num_experts is not None
124
+
125
+ @property
126
+ def is_linear_attn(self) -> bool:
127
+ return not (
128
+ self.linear_attn_config is None
129
+ or (
130
+ isinstance(self.linear_attn_config, dict)
131
+ and self.linear_attn_config["kda_layers"] is not None
132
+ and len(self.linear_attn_config["kda_layers"]) == 0
133
+ )
134
+ )
135
+
136
+ def is_kda_layer(self, layer_idx: int):
137
+ return (
138
+ self.linear_attn_config is not None
139
+ and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
140
+ )
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 163584,
4
+ "eos_token_id": 163585,
5
+ "pad_token_id": 163839,
6
+ "transformers_version": "4.57.3"
7
+ }
modeling_kimi.py ADDED
@@ -0,0 +1,1102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Callable
3
+ from typing import Any
4
+ from typing import Any, Optional, Tuple, Dict, List, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import transformers
9
+ from einops import rearrange, repeat
10
+ from packaging import version
11
+ from torch import nn
12
+ from transformers.activations import ACT2FN
13
+ from transformers.cache_utils import Cache
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.masking_utils import create_causal_mask
16
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
17
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
18
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
19
+ from transformers.processing_utils import Unpack
20
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
21
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
22
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
23
+
24
+ try:
25
+ from fla.modules import FusedRMSNormGated, ShortConvolution
26
+ from fla.ops.kda import chunk_kda, fused_recurrent_kda
27
+ from fla.ops.kda.gate import fused_kda_gate
28
+ from fla.ops.utils.index import prepare_cu_seqlens_from_mask, prepare_lens_from_mask
29
+ from fla.utils import tensor_cache
30
+ except ImportError:
31
+ raise ImportError("Plese run `pip install -U fla-core`")
32
+
33
+ from .configuration_kimi import KimiLinearConfig
34
+
35
+ assert version.parse(transformers.__version__) >= version.parse("4.56.0"), \
36
+ "Please upgrade transformers to >= 4.56.0"
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+
41
+ def index_first_axis(x, indices):
42
+ other_shape = x.shape[1:]
43
+ second_dim = other_shape.numel()
44
+ return torch.gather(
45
+ rearrange(x, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim),
46
+ ).reshape(-1, *other_shape)
47
+
48
+
49
+ def index_put_first_axis(x, indices, first_axis_dim):
50
+ y = torch.zeros(first_axis_dim, *x.shape[1:], device=x.device, dtype=x.dtype)
51
+ # TODO [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
52
+ y[indices] = x
53
+ # y.scatter_(0, repeat(indices, 'z -> z d', d=x.shape[1]), x)
54
+ return y
55
+
56
+
57
+ @tensor_cache
58
+ def get_unpad_data(
59
+ attention_mask: torch.Tensor,
60
+ ) -> tuple[torch.Tensor, torch.Tensor, int]:
61
+ lens = prepare_lens_from_mask(attention_mask)
62
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
63
+ max_seqlen_in_batch = lens.max().item()
64
+ cu_seqlens = prepare_cu_seqlens_from_mask(attention_mask)
65
+ return indices, cu_seqlens, max_seqlen_in_batch
66
+
67
+
68
+ def unpad_input(
69
+ q: torch.Tensor,
70
+ states: tuple[torch.Tensor],
71
+ attention_mask: torch.Tensor,
72
+ q_len: int,
73
+ keepdim: bool = False,
74
+ ):
75
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask)
76
+ batch_size, seq_len, *_ = states[0].shape
77
+
78
+ state = tuple(
79
+ index_first_axis(rearrange(s, "b s ... -> (b s) ..."), indices_k)
80
+ for s in states
81
+ )
82
+
83
+ if q_len == seq_len:
84
+ q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k)
85
+ cu_seqlens_q = cu_seqlens_k
86
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
87
+ indices_q = indices_k
88
+ elif q_len == 1:
89
+ max_seqlen_in_batch_q = 1
90
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
91
+ indices_q = cu_seqlens_q[:-1]
92
+ q = q.squeeze(1)
93
+ else:
94
+ raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)")
95
+
96
+ if keepdim:
97
+ q = q.unsqueeze(0)
98
+ state = tuple(s.unsqueeze(0) for s in state)
99
+
100
+ return (
101
+ q,
102
+ state,
103
+ indices_q,
104
+ (cu_seqlens_q, cu_seqlens_k),
105
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
106
+ )
107
+
108
+
109
+ def pad_input(
110
+ hidden_states: torch.Tensor,
111
+ indices: torch.LongTensor,
112
+ batch_size: int,
113
+ seq_len: int,
114
+ ) -> torch.Tensor:
115
+ output = index_put_first_axis(hidden_states, indices, batch_size * seq_len)
116
+ return rearrange(output, "(b s) ... -> b s ...", b=batch_size)
117
+
118
+
119
+ class KimiDynamicCache:
120
+ """
121
+ Dynamic cache for Kimi model.
122
+ Inspired by Qwen3-Next
123
+ """
124
+ is_compileable = False
125
+
126
+ def __init__(self, config: KimiLinearConfig):
127
+ super().__init__()
128
+ self.config = config
129
+
130
+ if config.linear_attn_config is not None:
131
+ self.layer_types = []
132
+ for i in range(config.num_hidden_layers):
133
+ if config.is_kda_layer(i):
134
+ self.layer_types.append("linear_attention")
135
+ else:
136
+ self.layer_types.append("full_attention")
137
+ else:
138
+ self.layer_types = ["full_attention"] * config.num_hidden_layers
139
+
140
+ self.transformer_layers = [
141
+ i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention"
142
+ ]
143
+
144
+ linear_layers = [i for i in range(
145
+ config.num_hidden_layers) if self.layer_types[i] == "linear_attention"]
146
+ self.last_linear_layer = linear_layers[-1] if linear_layers else -1
147
+
148
+ self.conv_states = [None for _ in range(config.num_hidden_layers)]
149
+ self.recurrent_states = [None for _ in range(config.num_hidden_layers)]
150
+ self.key_cache = [None for _ in range(config.num_hidden_layers)]
151
+ self.value_cache = [None for _ in range(config.num_hidden_layers)]
152
+
153
+ def __len__(self):
154
+ return len(self.layer_types)
155
+
156
+ def update(
157
+ self,
158
+ key_states: torch.Tensor,
159
+ value_states: torch.Tensor,
160
+ layer_idx: int,
161
+ cache_kwargs: Optional[Dict[str, Any]] = None,
162
+ ) -> tuple[torch.Tensor, torch.Tensor]:
163
+ if self.key_cache[layer_idx] is None:
164
+ self.key_cache[layer_idx] = key_states
165
+ self.value_cache[layer_idx] = value_states
166
+ else:
167
+ self.key_cache[layer_idx] = torch.cat(
168
+ [self.key_cache[layer_idx], key_states], dim=2)
169
+ self.value_cache[layer_idx] = torch.cat(
170
+ [self.value_cache[layer_idx], value_states], dim=2)
171
+
172
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
173
+
174
+ def reorder_cache(self, beam_idx: torch.LongTensor):
175
+ """Reorders the cache for beam search, given the selected beam indices."""
176
+ for layer_idx in range(len(self.key_cache)):
177
+ if self.key_cache[layer_idx] is not None:
178
+ device = self.key_cache[layer_idx].device
179
+ beam_idx = beam_idx.to(device)
180
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
181
+ 0, beam_idx)
182
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
183
+ 0, beam_idx)
184
+
185
+ if self.conv_states[layer_idx] is not None:
186
+ device = self.conv_states[layer_idx][0].device
187
+ beam_idx = beam_idx.to(device)
188
+ q_conv, k_conv, v_conv = self.conv_states[layer_idx]
189
+ self.conv_states[layer_idx] = (
190
+ q_conv.index_select(0, beam_idx),
191
+ k_conv.index_select(0, beam_idx),
192
+ v_conv.index_select(0, beam_idx),
193
+ )
194
+ self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(
195
+ 0, beam_idx)
196
+
197
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
198
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
199
+ # take any layer that contains cache and not empty tensor
200
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
201
+ if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:
202
+ return 0
203
+ return self.key_cache[layer_idx].shape[-2]
204
+
205
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
206
+ """
207
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
208
+ the given layer at `layer_idx`.
209
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
210
+ """
211
+ kv_offset = 0
212
+ query_length = cache_position.shape[0]
213
+ past_seen_tokens = self.get_seq_length(layer_idx)
214
+ kv_length = query_length + past_seen_tokens
215
+ return kv_length, kv_offset
216
+
217
+ @property
218
+ def has_previous_state(self):
219
+ """We have a previous state if the last linear (conv) layer was already updated."""
220
+ if self.last_linear_layer == -1:
221
+ return False
222
+ return self.conv_states[self.last_linear_layer] is not None
223
+
224
+
225
+ class KimiRMSNorm(nn.Module):
226
+ def __init__(self, hidden_size, eps=1e-6):
227
+ """
228
+ KimiRMSNorm is equivalent to T5LayerNorm
229
+ """
230
+ super().__init__()
231
+ self.weight = nn.Parameter(torch.ones(hidden_size))
232
+ self.variance_epsilon = eps
233
+
234
+ def forward(self, hidden_states):
235
+ input_dtype = hidden_states.dtype
236
+ hidden_states = hidden_states.to(torch.float32)
237
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
238
+ hidden_states = hidden_states * \
239
+ torch.rsqrt(variance + self.variance_epsilon)
240
+ return self.weight * hidden_states.to(input_dtype)
241
+
242
+
243
+ ALL_LAYERNORM_LAYERS.append(KimiRMSNorm)
244
+
245
+
246
+ class KimiBlockSparseMLP(nn.Module):
247
+ def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
248
+ super().__init__()
249
+ self.config = config
250
+ self.ffn_dim = config.intermediate_size if intermediate_size is None else intermediate_size
251
+ self.hidden_dim = config.hidden_size if hidden_size is None else hidden_size
252
+
253
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # gate
254
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) # down
255
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # up
256
+
257
+ self.act_fn = ACT2FN[config.hidden_act]
258
+
259
+ def forward(self, hidden_states):
260
+ current_hidden_states = self.act_fn(
261
+ self.w1(hidden_states)) * self.w3(hidden_states)
262
+ current_hidden_states = self.w2(current_hidden_states)
263
+ return current_hidden_states
264
+
265
+
266
+ class KimiMLP(nn.Module):
267
+ def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
268
+ super().__init__()
269
+ self.config = config
270
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
271
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
272
+ self.gate_proj = nn.Linear(
273
+ self.hidden_size, self.intermediate_size, bias=False)
274
+ self.up_proj = nn.Linear(
275
+ self.hidden_size, self.intermediate_size, bias=False)
276
+ self.down_proj = nn.Linear(
277
+ self.intermediate_size, self.hidden_size, bias=False)
278
+ self.act_fn = ACT2FN[config.hidden_act]
279
+
280
+ def forward(self, x):
281
+ down_proj = self.down_proj(self.act_fn(
282
+ self.gate_proj(x)) * self.up_proj(x))
283
+ return down_proj
284
+
285
+
286
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
287
+ """
288
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
289
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
290
+ """
291
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
292
+ if n_rep == 1:
293
+ return hidden_states
294
+ hidden_states = hidden_states[:, :, None, :, :].expand(
295
+ batch, num_key_value_heads, n_rep, slen, head_dim)
296
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
297
+
298
+
299
+ def eager_attention_forward(
300
+ module: nn.Module,
301
+ query: torch.Tensor,
302
+ key: torch.Tensor,
303
+ value: torch.Tensor,
304
+ attention_mask: Optional[torch.Tensor],
305
+ scaling: float,
306
+ dropout: float = 0.0,
307
+ **kwargs: Unpack[TransformersKwargs],
308
+ ):
309
+ key_states = repeat_kv(key, module.num_key_value_groups)
310
+ value_states = repeat_kv(value, module.num_key_value_groups)
311
+
312
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
313
+ if attention_mask is not None:
314
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
315
+ attn_weights = attn_weights + causal_mask
316
+
317
+ attn_weights = nn.functional.softmax(
318
+ attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
319
+ attn_weights = nn.functional.dropout(
320
+ attn_weights, p=dropout, training=module.training)
321
+ attn_output = torch.matmul(attn_weights, value_states)
322
+ attn_output = attn_output.transpose(1, 2).contiguous()
323
+
324
+ return attn_output, attn_weights
325
+
326
+
327
+ class KimiMLAAttention(nn.Module):
328
+ """
329
+ Multi-Latent Attention adapted from deepseek-v3
330
+ """
331
+
332
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
333
+ nn.Module.__init__(self)
334
+ self.config = config
335
+ self.layer_idx = layer_idx
336
+ self.hidden_size = config.hidden_size
337
+ self.num_heads = config.num_attention_heads
338
+ self.num_key_value_heads = config.num_key_value_heads
339
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
340
+
341
+ self.rope_theta = config.rope_theta
342
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
343
+
344
+ try:
345
+ self.q_lora_rank = config.q_lora_rank
346
+ self.qk_rope_head_dim = config.qk_rope_head_dim
347
+ self.kv_lora_rank = config.kv_lora_rank
348
+ self.v_head_dim = config.v_head_dim
349
+ self.qk_nope_head_dim = config.qk_nope_head_dim
350
+ self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
351
+ self.use_nope = config.mla_use_nope
352
+ self.scaling = self.q_head_dim ** (-0.5)
353
+ except Exception as e:
354
+ raise ValueError(
355
+ f"Kimi MLA config is not found or not properly formatted: {e}")
356
+
357
+ assert self.q_lora_rank is None
358
+ self.q_proj = nn.Linear(
359
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False,
360
+ )
361
+ self.kv_a_proj_with_mqa = nn.Linear(
362
+ self.hidden_size,
363
+ self.kv_lora_rank + self.qk_rope_head_dim,
364
+ bias=False,
365
+ )
366
+ self.kv_a_layernorm = KimiRMSNorm(self.kv_lora_rank)
367
+ self.kv_b_proj = nn.Linear(
368
+ self.kv_lora_rank,
369
+ self.num_heads
370
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
371
+ bias=False,
372
+ )
373
+ self.o_proj = nn.Linear(
374
+ self.num_heads * self.v_head_dim,
375
+ self.hidden_size,
376
+ bias=False,
377
+ )
378
+ self.is_causal = True
379
+ assert self.use_nope
380
+
381
+ def forward(
382
+ self,
383
+ hidden_states: torch.Tensor,
384
+ attention_mask: Optional[torch.Tensor] = None,
385
+ past_key_values: Optional[Cache] = None,
386
+ **kwargs,
387
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, ...]]]:
388
+ batch_size, seq_length = hidden_states.shape[:-1]
389
+ query_shape = (batch_size, seq_length, -1, self.q_head_dim)
390
+ key_shape = (batch_size, seq_length, -1,
391
+ self.qk_nope_head_dim + self.v_head_dim)
392
+
393
+ q_states = self.q_proj(hidden_states)
394
+ q_states = q_states.view(query_shape).transpose(1, 2)
395
+ q_pass, q_rot = torch.split(
396
+ q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
397
+
398
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
399
+ k_pass, k_rot = torch.split(
400
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
401
+
402
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(
403
+ k_pass)).view(key_shape).transpose(1, 2)
404
+ k_pass, value_states = torch.split(
405
+ k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
406
+
407
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
408
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
409
+
410
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
411
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
412
+
413
+ if past_key_values is not None:
414
+ key_states, value_states = past_key_values.update(
415
+ key_states, value_states, self.layer_idx)
416
+
417
+ if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
418
+ value_states = F.pad(
419
+ value_states, [0, self.q_head_dim - self.v_head_dim])
420
+
421
+ attention_interface: Callable = eager_attention_forward
422
+ if self.config._attn_implementation != "eager":
423
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
424
+
425
+ attn_output, _ = attention_interface(
426
+ self,
427
+ query_states,
428
+ key_states,
429
+ value_states,
430
+ attention_mask,
431
+ dropout=0.0 if not self.training else self.attention_dropout,
432
+ scaling=self.scaling,
433
+ **kwargs,
434
+ )
435
+
436
+ if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
437
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
438
+
439
+ attn_output = attn_output.reshape(
440
+ batch_size, seq_length, -1).contiguous()
441
+ attn_output = self.o_proj(attn_output)
442
+ return attn_output
443
+
444
+
445
+ class KimiDeltaAttention(nn.Module):
446
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
447
+ super().__init__()
448
+ self.config = config
449
+ self.mode = "chunk"
450
+
451
+ self.hidden_size = config.hidden_size
452
+ self.conv_size = config.linear_attn_config["short_conv_kernel_size"]
453
+ self.head_dim = config.linear_attn_config["head_dim"]
454
+ self.num_heads = config.linear_attn_config["num_heads"]
455
+ self.head_k_dim = self.head_dim
456
+ self.num_k_heads = self.num_heads
457
+
458
+ self.layer_idx = layer_idx
459
+
460
+ assert self.mode in [
461
+ 'chunk', 'fused_recurrent'], f"Not suppoerted mode `{self.mode}`."
462
+
463
+ projection_k_size = self.head_k_dim * self.num_k_heads
464
+ projection_size = self.head_dim * self.num_heads
465
+
466
+ self.q_proj = nn.Linear(
467
+ self.hidden_size, projection_k_size, bias=False)
468
+ self.k_proj = nn.Linear(
469
+ self.hidden_size, projection_k_size, bias=False)
470
+ self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False)
471
+
472
+ self.q_conv1d = ShortConvolution(
473
+ hidden_size=projection_k_size,
474
+ kernel_size=self.conv_size,
475
+ activation='silu',
476
+ )
477
+ self.k_conv1d = ShortConvolution(
478
+ hidden_size=projection_k_size,
479
+ kernel_size=self.conv_size,
480
+ activation='silu',
481
+ )
482
+ self.v_conv1d = ShortConvolution(
483
+ hidden_size=projection_size,
484
+ kernel_size=self.conv_size,
485
+ activation='silu',
486
+ )
487
+
488
+ self.A_log = torch.nn.Parameter(torch.log(torch.empty(
489
+ self.num_heads, dtype=torch.float32).uniform_(1, 16)).view(1, 1, -1, 1))
490
+
491
+ self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
492
+ self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
493
+
494
+ self.dt_bias = nn.Parameter(
495
+ torch.empty(projection_size, dtype=torch.float32))
496
+
497
+ self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
498
+
499
+ self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
500
+ self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
501
+
502
+ self.o_norm = FusedRMSNormGated(
503
+ self.head_dim, eps=config.rms_norm_eps, activation='sigmoid')
504
+ self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False)
505
+
506
+ def forward(
507
+ self,
508
+ hidden_states: torch.Tensor,
509
+ attention_mask: Optional[torch.Tensor] = None,
510
+ cache_params: Optional[KimiDynamicCache] = None,
511
+ **kwargs: Unpack[dict],
512
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
513
+ if attention_mask is not None:
514
+ if attention_mask.dim() != 2:
515
+ attention_mask = kwargs.get("padding_mask")
516
+
517
+ if attention_mask is not None and attention_mask.dim() != 2:
518
+ raise ValueError(
519
+ "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] "
520
+ "(0 = padding). 3D masks are not supported here.",
521
+ )
522
+ use_cache = cache_params is not None
523
+ batch_size, q_len, _ = hidden_states.shape
524
+ mode = 'fused_recurrent' if q_len <= 64 else self.mode
525
+ if self.training:
526
+ assert mode == 'chunk', "Only chunk mode is supported in training."
527
+
528
+ cu_seqlens = kwargs.get('cu_seqlens')
529
+ indices = None
530
+ if attention_mask is not None:
531
+ indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
532
+ hidden_states = index_first_axis(
533
+ rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
534
+
535
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
536
+ recurrent_state = None
537
+ if cache_params is not None:
538
+ if cache_params.conv_states[self.layer_idx] is not None:
539
+ conv_state_q, conv_state_k, conv_state_v = cache_params.conv_states[
540
+ self.layer_idx]
541
+ recurrent_state = cache_params.recurrent_states[self.layer_idx]
542
+ q, conv_state_q = self.q_conv1d(
543
+ x=self.q_proj(hidden_states),
544
+ cache=conv_state_q,
545
+ output_final_state=use_cache,
546
+ cu_seqlens=cu_seqlens,
547
+ )
548
+ k, conv_state_k = self.k_conv1d(
549
+ x=self.k_proj(hidden_states),
550
+ cache=conv_state_k,
551
+ output_final_state=use_cache,
552
+ cu_seqlens=cu_seqlens,
553
+ )
554
+ v, conv_state_v = self.v_conv1d(
555
+ x=self.v_proj(hidden_states),
556
+ cache=conv_state_v,
557
+ output_final_state=use_cache,
558
+ cu_seqlens=cu_seqlens,
559
+ )
560
+ g = self.f_b_proj(self.f_a_proj(hidden_states))
561
+ g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
562
+ beta = self.b_proj(hidden_states).float().sigmoid()
563
+
564
+ q, k = map(lambda x: rearrange(
565
+ x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
566
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
567
+
568
+ if mode == 'chunk':
569
+ o, recurrent_state = chunk_kda(
570
+ q=q,
571
+ k=k,
572
+ v=v,
573
+ g=g,
574
+ beta=beta,
575
+ initial_state=recurrent_state,
576
+ output_final_state=True,
577
+ use_qk_l2norm_in_kernel=True,
578
+ cu_seqlens=cu_seqlens,
579
+ )
580
+ else:
581
+ o, recurrent_state = fused_recurrent_kda(
582
+ q=q,
583
+ k=k,
584
+ v=v,
585
+ g=g,
586
+ beta=beta,
587
+ initial_state=recurrent_state,
588
+ output_final_state=True,
589
+ use_qk_l2norm_in_kernel=True,
590
+ cu_seqlens=cu_seqlens,
591
+ )
592
+ if cache_params is not None:
593
+ cache_params.recurrent_states[self.layer_idx] = recurrent_state
594
+ cache_params.conv_states[self.layer_idx] = (
595
+ conv_state_q, conv_state_k, conv_state_v)
596
+
597
+ g = self.g_b_proj(self.g_a_proj(hidden_states))
598
+ g = rearrange(g, '... (h d) -> ... h d', d=self.head_dim)
599
+ o = self.o_norm(o, g)
600
+
601
+ o = rearrange(o, 'b t h d -> b t (h d)')
602
+ o = self.o_proj(o)
603
+ if attention_mask is not None:
604
+ o = pad_input(o.squeeze(0), indices, batch_size, q_len)
605
+
606
+ return o
607
+
608
+
609
+ class KimiMoEGate(nn.Module):
610
+ """
611
+ MoEGate adapted from Deepseek-V3.
612
+ Parameter correspondences:
613
+ num_experts -> n_routed_experts
614
+ num_experts_per_token -> num_experts_per_tok
615
+ num_expert_group -> n_group
616
+ moe_router_activation_func -> scoring_func
617
+ """
618
+
619
+ def __init__(self, config: KimiLinearConfig):
620
+ super().__init__()
621
+ self.config = config
622
+ self.top_k = config.num_experts_per_token
623
+ self.num_experts = config.num_experts
624
+ self.routed_scaling_factor = config.routed_scaling_factor
625
+ self.moe_router_activation_func = config.moe_router_activation_func
626
+ self.num_expert_group = getattr(config, "num_expert_group", 1)
627
+ self.topk_group = getattr(config, "topk_group", 1)
628
+
629
+ # topk selection algorithm
630
+ self.moe_renormalize = config.moe_renormalize
631
+ self.gating_dim = config.hidden_size
632
+ self.weight = nn.Parameter(
633
+ torch.empty((self.num_experts, self.gating_dim)),
634
+ )
635
+
636
+ self.e_score_correction_bias = nn.Parameter(
637
+ torch.empty(self.num_experts),
638
+ )
639
+ self.reset_parameters()
640
+
641
+ def reset_parameters(self) -> None:
642
+ import torch.nn.init as init
643
+
644
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
645
+
646
+ def forward(self, hidden_states):
647
+ bsz, seq_len, h = hidden_states.shape
648
+ # compute gating score
649
+ hidden_states = hidden_states.view(-1, h)
650
+ logits = F.linear(
651
+ hidden_states.type(torch.float32), self.weight.type(
652
+ torch.float32), None,
653
+ )
654
+ if self.moe_router_activation_func == "sigmoid":
655
+ scores = logits.sigmoid()
656
+ elif self.moe_router_activation_func == "softmax":
657
+ scores = logits.softmax(dim=1)
658
+ else:
659
+ raise NotImplementedError(
660
+ f"insupportable scoring function for MoE gating: {self.moe_router_activation_func}",
661
+ )
662
+
663
+ # select top-k experts
664
+ assert not self.training
665
+ scores_for_choice = scores.view(bsz * seq_len, -1)
666
+ scores_for_choice += self.e_score_correction_bias.unsqueeze(0)
667
+ group_scores = (
668
+ scores_for_choice.view(
669
+ bsz * seq_len, self.num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
670
+ ) # [n, num_expert_group]
671
+ group_idx = torch.topk(
672
+ group_scores, k=self.topk_group, dim=-1, sorted=False,
673
+ )[
674
+ 1
675
+ ] # [n, top_k_group]
676
+ group_mask = torch.zeros_like(group_scores) # [n, num_expert_group]
677
+ group_mask.scatter_(1, group_idx, 1) # [n, num_expert_group]
678
+ score_mask = (
679
+ group_mask.unsqueeze(-1)
680
+ .expand(
681
+ bsz * seq_len, self.num_expert_group, self.num_experts // self.num_expert_group,
682
+ )
683
+ .reshape(bsz * seq_len, -1)
684
+ ) # [n, e]
685
+ tmp_scores = scores_for_choice.masked_fill(
686
+ ~score_mask.bool(), 0.0) # [n, e]
687
+ _, topk_idx = torch.topk(
688
+ tmp_scores, k=self.top_k, dim=-1, sorted=False,
689
+ )
690
+ topk_weight = scores.gather(1, topk_idx)
691
+
692
+ # norm gate to sum 1
693
+ if self.top_k > 1 and self.moe_renormalize:
694
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
695
+ topk_weight = topk_weight / denominator
696
+ # must multiply the scaling factor
697
+ topk_weight = topk_weight * self.routed_scaling_factor
698
+
699
+ return topk_idx, topk_weight
700
+
701
+
702
+ class KimiSparseMoeBlock(nn.Module):
703
+ """
704
+ Adapted from Deepseek-V3's MOE implementation
705
+ The namings are consistent with Kimi's version.
706
+ """
707
+
708
+ def __init__(self, config: KimiLinearConfig):
709
+ super().__init__()
710
+ self.config = config
711
+ self.hidden_dim = config.hidden_size
712
+ self.num_experts = config.num_experts
713
+ self.top_k = config.num_experts_per_token
714
+ self.moe_renormalize = config.moe_renormalize
715
+
716
+ self.ep_size = 1
717
+ self.experts_per_rank = config.num_experts
718
+ self.ep_rank = 0
719
+ self.experts = nn.ModuleList(
720
+ [
721
+ KimiBlockSparseMLP(
722
+ config, intermediate_size=config.moe_intermediate_size,
723
+ )
724
+ for _ in range(config.num_experts)
725
+ ],
726
+ )
727
+ self.gate = KimiMoEGate(config)
728
+ if config.num_shared_experts is not None:
729
+ intermediate_size = config.moe_intermediate_size * config.num_shared_experts
730
+ self.shared_experts = KimiMLP(
731
+ config=config, intermediate_size=intermediate_size,
732
+ )
733
+
734
+ def forward(self, hidden_states):
735
+ identity = hidden_states
736
+ orig_shape = hidden_states.shape
737
+ topk_idx, topk_weight = self.gate(hidden_states)
738
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
739
+ if not self.training:
740
+ y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
741
+ else:
742
+ raise NotImplementedError("Training mode is not supported in KimiSparseMoeBlock")
743
+ if self.config.num_shared_experts is not None:
744
+ y = y + self.shared_experts(identity)
745
+ return y
746
+
747
+ @torch.no_grad()
748
+ def moe_infer(self, x, topk_ids, topk_weight):
749
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
750
+ cnts.scatter_(1, topk_ids, 1)
751
+ tokens_per_expert = cnts.sum(dim=0)
752
+ idxs = topk_ids.view(-1).argsort()
753
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
754
+
755
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
756
+
757
+ outputs = []
758
+ start_idx = 0
759
+ for i, num_tokens in enumerate(tokens_per_expert):
760
+ end_idx = start_idx + num_tokens
761
+ if num_tokens == 0:
762
+ continue
763
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
764
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
765
+ expert_out = expert(tokens_for_this_expert)
766
+ outputs.append(expert_out)
767
+ start_idx = end_idx
768
+
769
+ outs = torch.cat(outputs, dim=0) if len(
770
+ outputs) else sorted_tokens.new_empty(0)
771
+
772
+ new_x = torch.empty_like(outs)
773
+ new_x[idxs] = outs
774
+ final_out = (
775
+ new_x.view(*topk_ids.shape, -1)
776
+ .type(topk_weight.dtype)
777
+ .mul_(topk_weight.unsqueeze(dim=-1))
778
+ .sum(dim=1)
779
+ .type(new_x.dtype)
780
+ )
781
+ return final_out
782
+
783
+
784
+ class KimiDecoderLayer(nn.Module):
785
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
786
+ super().__init__()
787
+ self.hidden_size = config.hidden_size
788
+ self.config = config
789
+ if config.is_kda_layer(layer_idx):
790
+ self.is_linear_attn = True
791
+ self.self_attn = KimiDeltaAttention(
792
+ config=config, layer_idx=layer_idx)
793
+ elif config.is_mla:
794
+ self.is_linear_attn = False
795
+ self.self_attn = KimiMLAAttention(
796
+ config=config, layer_idx=layer_idx)
797
+ else:
798
+ raise NotImplementedError
799
+ if (
800
+ config.num_experts is not None
801
+ and layer_idx >= config.first_k_dense_replace
802
+ and layer_idx % getattr(config, "moe_layer_freq", 1) == 0
803
+ ):
804
+ self.block_sparse_moe = KimiSparseMoeBlock(config)
805
+ else:
806
+ self.mlp = KimiMLP(config)
807
+ self.input_layernorm = KimiRMSNorm(
808
+ config.hidden_size, eps=config.rms_norm_eps)
809
+ self.post_attention_layernorm = KimiRMSNorm(
810
+ config.hidden_size, eps=config.rms_norm_eps)
811
+
812
+ def forward(
813
+ self,
814
+ hidden_states: torch.Tensor,
815
+ attention_mask: Optional[torch.Tensor] = None,
816
+ position_ids: Optional[torch.LongTensor] = None,
817
+ past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
818
+ output_attentions: Optional[bool] = False,
819
+ use_cache: Optional[bool] = False,
820
+ **kwargs: Unpack[FlashAttentionKwargs],
821
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
822
+ """
823
+ Args:
824
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
825
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
826
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
827
+ output_attentions (`bool`, *optional*):
828
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
829
+ returned tensors for more detail.
830
+ use_cache (`bool`, *optional*):
831
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
832
+ (see `past_key_values`).
833
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
834
+ """
835
+
836
+ residual = hidden_states
837
+
838
+ hidden_states = self.input_layernorm(hidden_states)
839
+
840
+ # Self Attention
841
+ if self.is_linear_attn is False:
842
+ hidden_states = self.self_attn(
843
+ hidden_states=hidden_states,
844
+ attention_mask=attention_mask,
845
+ position_ids=position_ids,
846
+ past_key_values=past_key_values,
847
+ output_attentions=output_attentions,
848
+ use_cache=use_cache,
849
+ **kwargs,
850
+ )
851
+ else:
852
+ hidden_states = self.self_attn(
853
+ hidden_states=hidden_states,
854
+ attention_mask=attention_mask,
855
+ cache_params=past_key_values,
856
+ output_attentions=output_attentions,
857
+ use_cache=use_cache,
858
+ **kwargs,
859
+ )
860
+ hidden_states = residual + hidden_states
861
+
862
+ # Fully Connected
863
+ residual = hidden_states
864
+ hidden_states = self.post_attention_layernorm(hidden_states)
865
+ if hasattr(self, "block_sparse_moe"):
866
+ hidden_states = self.block_sparse_moe(hidden_states)
867
+ else:
868
+ hidden_states = self.mlp(hidden_states)
869
+ hidden_states = residual + hidden_states
870
+
871
+ return hidden_states
872
+
873
+
874
+ class KimiPreTrainedModel(PreTrainedModel):
875
+ config_class = KimiLinearConfig
876
+ base_model_prefix = "model"
877
+ supports_gradient_checkpointing = True
878
+ _no_split_modules = ["KimiDecoderLayer"]
879
+ _skip_keys_device_placement = "past_key_values"
880
+ _supports_flash_attn_2 = True
881
+ _can_record_outputs = {
882
+ "router_logits": OutputRecorder(KimiBlockSparseMLP, index=1),
883
+ "hidden_states": KimiDecoderLayer,
884
+ "attentions": KimiMLAAttention,
885
+ }
886
+ _is_stateful = True
887
+
888
+ def _init_weights(self, module):
889
+ std = self.config.initializer_range
890
+ if isinstance(module, nn.Linear):
891
+ module.weight.data.normal_(mean=0.0, std=std)
892
+ if module.bias is not None:
893
+ module.bias.data.zero_()
894
+ elif isinstance(module, nn.Embedding):
895
+ module.weight.data.normal_(mean=0.0, std=std)
896
+ if module.padding_idx is not None:
897
+ module.weight.data[module.padding_idx].zero_()
898
+
899
+
900
+ class KimiLinearModel(KimiPreTrainedModel):
901
+ def __init__(self, config: KimiLinearConfig):
902
+ super().__init__(config)
903
+ self.padding_idx = config.pad_token_id
904
+ self.vocab_size = config.vocab_size
905
+
906
+ self.embed_tokens = nn.Embedding(
907
+ config.vocab_size, config.hidden_size, self.padding_idx)
908
+ self.layers = nn.ModuleList([KimiDecoderLayer(
909
+ config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
910
+ self.norm = KimiRMSNorm(
911
+ config.hidden_size, eps=config.rms_norm_eps)
912
+
913
+ if getattr(config, "_attn_implementation", None) is not None:
914
+ if config._attn_implementation != "flash_attention_2":
915
+ logger.warning_once(
916
+ f"Ignoring the provided attention implementation {config._attn_implementation}")
917
+ logger.warning_once("Using flash_attention_2 backend instead.")
918
+ pass # config._attn_implementation = "flash_attention_2"
919
+ else:
920
+ pass # config._attn_implementation = "flash_attention_2"
921
+
922
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
923
+ self.gradient_checkpointing = False
924
+ # Initialize weights and apply final processing
925
+ self.post_init()
926
+
927
+ def _update_linear_attn_mask(self, attention_mask, cache_position):
928
+ """
929
+ NOTE: Left-padding is used for linear attention mask.
930
+ No need for zeroing states when
931
+ 1. Cached forward
932
+ 2. Attending to all inputs
933
+ """
934
+ linear_attn_mask = attention_mask
935
+ if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
936
+ linear_attn_mask = None
937
+ return linear_attn_mask
938
+
939
+ @check_model_inputs
940
+ @auto_docstring
941
+ def forward(
942
+ self,
943
+ input_ids: torch.LongTensor = None,
944
+ attention_mask: Optional[torch.Tensor] = None,
945
+ position_ids: Optional[torch.LongTensor] = None,
946
+ past_key_values: Optional[Cache] = None,
947
+ inputs_embeds: Optional[torch.FloatTensor] = None,
948
+ cache_position: Optional[torch.LongTensor] = None,
949
+ use_cache: Optional[bool] = None,
950
+ **kwargs: Unpack[TransformersKwargs],
951
+ ) -> tuple | BaseModelOutputWithPast:
952
+
953
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
954
+
955
+ if (input_ids is None) and (inputs_embeds is None):
956
+ raise ValueError(
957
+ "You must specify exactly one of input_ids or inputs_embeds")
958
+
959
+ # Get inputs_embeds
960
+ if inputs_embeds is None:
961
+ inputs_embeds = self.embed_tokens(input_ids)
962
+
963
+ if use_cache and past_key_values is None:
964
+ past_key_values = KimiDynamicCache(config=self.config)
965
+
966
+ if cache_position is None:
967
+ past_seen_tokens = past_key_values.get_seq_length(
968
+ ) if past_key_values is not None else 0
969
+ cache_position: torch.Tensor = torch.arange(
970
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device,
971
+ )
972
+
973
+ if position_ids is None:
974
+ position_ids = cache_position.unsqueeze(0)
975
+
976
+ causal_mask = create_causal_mask(
977
+ config=self.config,
978
+ input_embeds=inputs_embeds,
979
+ attention_mask=attention_mask,
980
+ cache_position=cache_position,
981
+ past_key_values=past_key_values,
982
+ position_ids=position_ids,
983
+ )
984
+ linear_attn_mask = self._update_linear_attn_mask(
985
+ attention_mask, cache_position)
986
+
987
+ hidden_states = inputs_embeds
988
+ if past_key_values is not None:
989
+ assert isinstance(past_key_values, KimiDynamicCache)
990
+
991
+ for decoder_layer in self.layers:
992
+ layer_mask = linear_attn_mask if decoder_layer.is_linear_attn else causal_mask
993
+
994
+ hidden_states = decoder_layer(
995
+ hidden_states,
996
+ attention_mask=layer_mask,
997
+ past_key_values=past_key_values,
998
+ cache_position=cache_position,
999
+ **kwargs,
1000
+ )
1001
+
1002
+ hidden_states = self.norm(hidden_states)
1003
+
1004
+ return BaseModelOutputWithPast(
1005
+ last_hidden_state=hidden_states,
1006
+ past_key_values=past_key_values,
1007
+ )
1008
+
1009
+
1010
+ class KimiLinearForCausalLM(KimiPreTrainedModel, GenerationMixin):
1011
+ _tied_weights_keys = ["lm_head.weight"]
1012
+
1013
+ def __init__(self, config):
1014
+ super().__init__(config)
1015
+ self.model = KimiLinearModel(config)
1016
+ self.vocab_size = config.vocab_size
1017
+ self.lm_head = nn.Linear(
1018
+ config.hidden_size, config.vocab_size, bias=False)
1019
+
1020
+ # Initialize weights and apply final processing
1021
+ self.post_init()
1022
+
1023
+ @can_return_tuple
1024
+ @auto_docstring
1025
+ def forward(
1026
+ self,
1027
+ input_ids: torch.LongTensor = None,
1028
+ attention_mask: Optional[torch.Tensor] = None,
1029
+ position_ids: Optional[torch.LongTensor] = None,
1030
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1031
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1032
+ labels: Optional[torch.LongTensor] = None,
1033
+ use_cache: Optional[bool] = None,
1034
+ output_attentions: Optional[bool] = None,
1035
+ output_hidden_states: Optional[bool] = None,
1036
+ generation_mode: Optional[bool] = None,
1037
+ return_dict: Optional[bool] = None,
1038
+ cache_position: Optional[torch.LongTensor] = None,
1039
+ **kwargs: Unpack[TransformersKwargs],
1040
+ ) -> tuple | CausalLMOutputWithPast:
1041
+ r"""
1042
+ Args:
1043
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1044
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1045
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1046
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1047
+
1048
+ Returns:
1049
+
1050
+ Example:
1051
+
1052
+ ```python
1053
+ >>> from transformers import AutoTokenizer, KimiLinearForCausalLM
1054
+
1055
+ >>> model = KimiLinearForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1056
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1057
+
1058
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1059
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1060
+
1061
+ >>> # Generate
1062
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1063
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1064
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1065
+ ```"""
1066
+
1067
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1068
+ output_hidden_states = (
1069
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1070
+ )
1071
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1072
+
1073
+ outputs = self.model(
1074
+ input_ids=input_ids,
1075
+ attention_mask=attention_mask,
1076
+ position_ids=position_ids,
1077
+ past_key_values=past_key_values,
1078
+ inputs_embeds=inputs_embeds,
1079
+ use_cache=use_cache,
1080
+ output_attentions=output_attentions,
1081
+ output_hidden_states=output_hidden_states,
1082
+ return_dict=return_dict,
1083
+ cache_position=cache_position,
1084
+ )
1085
+
1086
+ logits = outputs[0]
1087
+ if generation_mode:
1088
+ logits = logits[:, -1:]
1089
+ logits = self.lm_head(logits)
1090
+
1091
+ loss = None
1092
+ if labels is not None:
1093
+ loss = self.loss_function(
1094
+ logits, labels, self.vocab_size, **kwargs)
1095
+
1096
+ return CausalLMOutputWithPast(
1097
+ loss=loss,
1098
+ logits=logits,
1099
+ past_key_values=outputs.past_key_values,
1100
+ hidden_states=outputs.hidden_states,
1101
+ attentions=outputs.attentions,
1102
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "[extra_id_0]",
4
+ "[extra_id_1]",
5
+ "[extra_id_2]",
6
+ "[extra_id_3]",
7
+ "[start_header_id]",
8
+ "[end_header_id]",
9
+ "[extra_id_4]",
10
+ "[EOT]",
11
+ "[extra_id_5]",
12
+ "[extra_id_6]",
13
+ "[extra_id_7]",
14
+ "[extra_id_8]",
15
+ "[extra_id_9]",
16
+ "[extra_id_10]",
17
+ "[extra_id_11]",
18
+ "[extra_id_12]",
19
+ "[extra_id_13]",
20
+ "[extra_id_14]",
21
+ "[extra_id_15]",
22
+ "[extra_id_16]",
23
+ "[extra_id_17]",
24
+ "[extra_id_18]",
25
+ "[extra_id_19]",
26
+ "[extra_id_20]",
27
+ "[extra_id_21]",
28
+ "[extra_id_22]",
29
+ "[extra_id_23]",
30
+ "[extra_id_24]",
31
+ "[extra_id_25]",
32
+ "[extra_id_26]",
33
+ "[extra_id_27]",
34
+ "[extra_id_28]",
35
+ "[extra_id_29]",
36
+ "[extra_id_30]",
37
+ "[extra_id_31]",
38
+ "[extra_id_32]",
39
+ "[extra_id_33]",
40
+ "[extra_id_34]",
41
+ "[extra_id_35]",
42
+ "[extra_id_36]",
43
+ "[extra_id_37]",
44
+ "[extra_id_38]",
45
+ "[extra_id_39]",
46
+ "[extra_id_40]",
47
+ "[extra_id_41]",
48
+ "[extra_id_42]",
49
+ "[extra_id_43]",
50
+ "[extra_id_44]",
51
+ "[extra_id_45]",
52
+ "[extra_id_46]",
53
+ "[extra_id_47]",
54
+ "[extra_id_48]",
55
+ "[extra_id_49]",
56
+ "[extra_id_50]",
57
+ "[extra_id_51]",
58
+ "[extra_id_52]",
59
+ "[extra_id_53]",
60
+ "[extra_id_54]",
61
+ "[extra_id_55]",
62
+ "[extra_id_56]",
63
+ "[extra_id_57]",
64
+ "[extra_id_58]",
65
+ "[extra_id_59]",
66
+ "[extra_id_60]",
67
+ "[extra_id_61]",
68
+ "[extra_id_62]",
69
+ "[extra_id_63]",
70
+ "[extra_id_64]",
71
+ "[extra_id_65]",
72
+ "[extra_id_66]",
73
+ "[extra_id_67]",
74
+ "[extra_id_68]",
75
+ "[extra_id_69]",
76
+ "[extra_id_70]",
77
+ "[extra_id_71]",
78
+ "[extra_id_72]",
79
+ "[extra_id_73]",
80
+ "[extra_id_74]",
81
+ "[extra_id_75]",
82
+ "[extra_id_76]",
83
+ "[extra_id_77]",
84
+ "[extra_id_78]",
85
+ "[extra_id_79]",
86
+ "[extra_id_80]",
87
+ "[extra_id_81]",
88
+ "[extra_id_82]",
89
+ "[extra_id_83]",
90
+ "[extra_id_84]",
91
+ "[extra_id_85]",
92
+ "[extra_id_86]",
93
+ "[extra_id_87]",
94
+ "[extra_id_88]",
95
+ "[extra_id_89]",
96
+ "[extra_id_90]",
97
+ "[extra_id_91]",
98
+ "[extra_id_92]",
99
+ "[extra_id_93]",
100
+ "[extra_id_94]",
101
+ "[extra_id_95]",
102
+ "[extra_id_96]",
103
+ "[extra_id_97]",
104
+ "[extra_id_98]",
105
+ "[extra_id_99]",
106
+ "[extra_id_100]",
107
+ "[extra_id_101]",
108
+ "[extra_id_102]",
109
+ "[extra_id_103]",
110
+ "[extra_id_104]",
111
+ "[extra_id_105]",
112
+ "[extra_id_106]",
113
+ "[extra_id_107]",
114
+ "[extra_id_108]",
115
+ "[extra_id_109]",
116
+ "[extra_id_110]",
117
+ "[extra_id_111]",
118
+ "[extra_id_112]",
119
+ "[extra_id_113]",
120
+ "[extra_id_114]",
121
+ "[extra_id_115]",
122
+ "[extra_id_116]",
123
+ "[extra_id_117]",
124
+ "[extra_id_118]",
125
+ "[extra_id_119]",
126
+ "[extra_id_120]",
127
+ "[extra_id_121]",
128
+ "[extra_id_122]",
129
+ "[extra_id_123]",
130
+ "[extra_id_124]",
131
+ "[extra_id_125]",
132
+ "[extra_id_126]",
133
+ "[extra_id_127]",
134
+ "[extra_id_128]",
135
+ "[extra_id_129]",
136
+ "[extra_id_130]",
137
+ "[extra_id_131]",
138
+ "[extra_id_132]",
139
+ "[extra_id_133]",
140
+ "[extra_id_134]",
141
+ "[extra_id_135]",
142
+ "[extra_id_136]",
143
+ "[extra_id_137]",
144
+ "[extra_id_138]",
145
+ "[extra_id_139]",
146
+ "[extra_id_140]",
147
+ "[extra_id_141]",
148
+ "[extra_id_142]",
149
+ "[extra_id_143]",
150
+ "[extra_id_144]",
151
+ "[extra_id_145]",
152
+ "[extra_id_146]",
153
+ "[extra_id_147]",
154
+ "[extra_id_148]",
155
+ "[extra_id_149]",
156
+ "[extra_id_150]",
157
+ "[extra_id_151]",
158
+ "[extra_id_152]",
159
+ "[extra_id_153]",
160
+ "[extra_id_154]",
161
+ "[extra_id_155]",
162
+ "[extra_id_156]",
163
+ "[extra_id_157]",
164
+ "[extra_id_158]",
165
+ "[extra_id_159]",
166
+ "[extra_id_160]",
167
+ "[extra_id_161]",
168
+ "[extra_id_162]",
169
+ "[extra_id_163]",
170
+ "[extra_id_164]",
171
+ "[extra_id_165]",
172
+ "[extra_id_166]",
173
+ "[extra_id_167]",
174
+ "[extra_id_168]",
175
+ "[extra_id_169]",
176
+ "[extra_id_170]",
177
+ "[extra_id_171]",
178
+ "[extra_id_172]",
179
+ "[extra_id_173]",
180
+ "[extra_id_174]",
181
+ "[extra_id_175]",
182
+ "[extra_id_176]",
183
+ "[extra_id_177]",
184
+ "[extra_id_178]",
185
+ "[extra_id_179]",
186
+ "[extra_id_180]",
187
+ "[extra_id_181]",
188
+ "[extra_id_182]",
189
+ "[extra_id_183]",
190
+ "[extra_id_184]",
191
+ "[extra_id_185]",
192
+ "[extra_id_186]",
193
+ "[extra_id_187]",
194
+ "[extra_id_188]",
195
+ "[extra_id_189]",
196
+ "[extra_id_190]",
197
+ "[extra_id_191]",
198
+ "[extra_id_192]",
199
+ "[extra_id_193]",
200
+ "[extra_id_194]",
201
+ "[extra_id_195]",
202
+ "[extra_id_196]",
203
+ "[extra_id_197]",
204
+ "[extra_id_198]",
205
+ "[extra_id_199]",
206
+ "[extra_id_200]",
207
+ "[extra_id_201]",
208
+ "[extra_id_202]",
209
+ "[extra_id_203]",
210
+ "[extra_id_204]",
211
+ "[extra_id_205]",
212
+ "[extra_id_206]",
213
+ "[extra_id_207]",
214
+ "[extra_id_208]",
215
+ "[extra_id_209]",
216
+ "[extra_id_210]",
217
+ "[extra_id_211]",
218
+ "[extra_id_212]",
219
+ "[extra_id_213]",
220
+ "[extra_id_214]",
221
+ "[extra_id_215]",
222
+ "[extra_id_216]",
223
+ "[extra_id_217]",
224
+ "[extra_id_218]",
225
+ "[extra_id_219]",
226
+ "[extra_id_220]",
227
+ "[extra_id_221]",
228
+ "[extra_id_222]",
229
+ "[extra_id_223]",
230
+ "[extra_id_224]",
231
+ "[extra_id_225]",
232
+ "[extra_id_226]",
233
+ "[extra_id_227]",
234
+ "[extra_id_228]",
235
+ "[extra_id_229]",
236
+ "[extra_id_230]",
237
+ "[extra_id_231]",
238
+ "[extra_id_232]",
239
+ "[extra_id_233]",
240
+ "[extra_id_234]",
241
+ "[extra_id_235]",
242
+ "[extra_id_236]",
243
+ "[extra_id_237]",
244
+ "[extra_id_238]",
245
+ "[extra_id_239]",
246
+ "[extra_id_240]",
247
+ "[extra_id_241]",
248
+ "[extra_id_242]",
249
+ "[extra_id_243]",
250
+ "[extra_id_244]",
251
+ "[extra_id_245]",
252
+ "[extra_id_246]",
253
+ "[extra_id_247]",
254
+ "[extra_id_248]"
255
+ ],
256
+ "bos_token": "[BOS]",
257
+ "eos_token": "[EOS]",
258
+ "pad_token": "[extra_id_250]",
259
+ "unk_token": "[extra_id_249]"
260
+ }
tiktoken.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6c497a7469b33ced9c38afb1ad6e47f03f5e5dc05f15930799210ec050c5103
3
+ size 2795286
tokenization_kimi.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import base64
3
+ def load_tiktoken_bpe(tiktoken_bpe_file):
4
+ with open(tiktoken_bpe_file, "rb") as f:
5
+ contents = f.read()
6
+ return {
7
+ base64.b64decode(token): int(rank)
8
+ for token, rank in (line.split() for line in contents.splitlines() if line)
9
+ }
10
+ import os
11
+ import tiktoken
12
+
13
+ from logging import getLogger
14
+ from pathlib import Path
15
+ from typing import (
16
+ cast,
17
+ Tuple,
18
+ Dict,
19
+ Iterator,
20
+ List,
21
+ Union,
22
+ Optional,
23
+ )
24
+ from shutil import copyfile
25
+
26
+ from tokenizers import AddedToken, pre_tokenizers, Regex
27
+ from transformers.tokenization_utils import PreTrainedTokenizer
28
+ from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
29
+ from typing import Any
30
+
31
+
32
+ logger = getLogger(__name__)
33
+ VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
34
+
35
+
36
+ class TikTokenTokenizer(PreTrainedTokenizer):
37
+ """
38
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
39
+
40
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
41
+ this superclass for more information regarding those methods.
42
+
43
+ Args:
44
+ vocab_file (`str`):
45
+ The path to the Tiktoken model file.
46
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
47
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
48
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
49
+ The end of sequence token.
50
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
51
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
52
+ token instead. The second to last item in special_tokens.
53
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
54
+ The token used for padding, for example when batching sequences of different lengths.
55
+ additional_special_tokens (list of `str`, *optional*):
56
+ A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
57
+ skipped when decoding if `skip_special_tokens` is set to `True`.
58
+ """
59
+
60
+ vocab_files_names = VOCAB_FILES_NAMES
61
+
62
+ model_input_names = ["input_ids", "attention_mask"]
63
+
64
+ special_tokens: Dict[str, int]
65
+
66
+ num_reserved_special_tokens = 256
67
+
68
+ pat_str = "|".join(
69
+ [
70
+ r"""[\p{Han}]+""",
71
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
72
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
73
+ r"""\p{N}{1,3}""",
74
+ r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
75
+ r"""\s*[\r\n]+""",
76
+ r"""\s+(?!\S)""",
77
+ r"""\s+""",
78
+ ]
79
+ )
80
+
81
+ def __init__(
82
+ self,
83
+ vocab_file,
84
+ bos_token: Union[str, AddedToken]="[BOS]",
85
+ eos_token: Union[str, AddedToken]="[EOS]",
86
+ unk_token: Union[str, AddedToken, None]=None,
87
+ pad_token: Union[str, AddedToken, None]=None,
88
+ additional_special_tokens: List[str]=None,
89
+ added_tokens_decoder: Optional[dict] = None,
90
+ **kwargs,
91
+ ):
92
+ assert os.path.isfile(vocab_file), vocab_file
93
+
94
+ if additional_special_tokens is None:
95
+ additional_special_tokens = [
96
+ "<|im_end|>",
97
+ "<|im_user|>",
98
+ "<|im_assistant|>",
99
+ "<|start_header_id|>",
100
+ "<|end_header_id|>",
101
+ "[EOT]",
102
+ "<|im_system|>",
103
+ "<|im_middle|>",
104
+ ]
105
+
106
+ special_tokens_mapping = {
107
+ i: added_tokens_decoder[i].content for i in added_tokens_decoder
108
+ }
109
+
110
+ self.vocab_file = vocab_file
111
+ mergeable_ranks = load_tiktoken_bpe(vocab_file)
112
+ num_base_tokens = len(mergeable_ranks)
113
+ self.special_tokens = {
114
+ special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
115
+ for i in range(
116
+ num_base_tokens, num_base_tokens + self.num_reserved_special_tokens + 2
117
+ )
118
+ }
119
+
120
+
121
+
122
+ self.model = tiktoken.Encoding(
123
+ name=Path(vocab_file).name,
124
+ pat_str=self.pat_str,
125
+ mergeable_ranks=mergeable_ranks,
126
+ special_tokens=self.special_tokens,
127
+ )
128
+ logger.info(f"Reloaded tiktoken model from {vocab_file}")
129
+
130
+ self.n_words: int = self.model.n_vocab
131
+ # BOS / EOS token IDs
132
+ self.bos_id: int = self.special_tokens[str(bos_token)]
133
+ self.eos_id: int = self.special_tokens[str(eos_token)]
134
+ logger.info(
135
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
136
+ )
137
+
138
+ self.pad_id: int = self.special_tokens[str(pad_token)]
139
+ self.unk_id: int = self.special_tokens[str(unk_token)]
140
+
141
+ self.byte_encoder = bytes_to_unicode()
142
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
143
+
144
+ self.decoder = {}
145
+ for i in range(self.n_words):
146
+ # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
147
+ decoding = ''.join([
148
+ self.byte_encoder[ord(char)] for char in
149
+ self.model.decode_single_token_bytes(i).decode('latin-1')
150
+ ])
151
+ self.decoder[i] = decoding
152
+
153
+ self.encoder = {}
154
+ for i in range(self.n_words):
155
+ if i in self.decoder:
156
+ self.encoder[self.decoder[i]] = i
157
+
158
+ super().__init__(
159
+ bos_token=bos_token,
160
+ eos_token=eos_token,
161
+ unk_token=unk_token,
162
+ pad_token=pad_token,
163
+ additional_special_tokens=additional_special_tokens,
164
+ **kwargs,
165
+ )
166
+ self.all_special_ids_set = set(self.all_special_ids)
167
+
168
+ def encode(
169
+ self,
170
+ text: str,
171
+ allow_special_tokens: bool = True,
172
+ **kwargs
173
+ ) -> List[int]:
174
+ """
175
+ Encodes a string into a list of token IDs.
176
+
177
+ Args:
178
+ text (str): The input string to be encoded.
179
+
180
+ Returns:
181
+ list[int]: A list of token IDs.
182
+ """
183
+ # If there are other args, we should call super().encode because there are a lot of code
184
+ # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
185
+ # NOTE: our encode method is not compatible with the super().encode method,
186
+ # e.g. split_special_tokens' default is True in our encode method.
187
+ if len(kwargs) > 0:
188
+ logger.warning( f"Calling super().encode with {kwargs}" )
189
+ return super().encode(text, **kwargs)
190
+
191
+ assert type(text) is str
192
+
193
+ # The tiktoken tokenizer can handle <=400k chars without
194
+ # pyo3_runtime.PanicException.
195
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
196
+
197
+ # https://github.com/openai/tiktoken/issues/195
198
+ # Here we iterate over subsequences and split if we exceed the limit
199
+ # of max consecutive non-whitespace or whitespace characters.
200
+ MAX_NO_WHITESPACES_CHARS = 25_000
201
+
202
+ texts = self.pre_tokenizer_process(text)
203
+
204
+ all_substrs = []
205
+ for text in texts:
206
+ substrs = (
207
+ substr
208
+ for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
209
+ for substr in self._split_whitespaces_or_nonwhitespaces(
210
+ text[i: i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
211
+ )
212
+ )
213
+ all_substrs.extend(substrs)
214
+
215
+ t: List[int] = []
216
+ for substr in all_substrs:
217
+ if allow_special_tokens:
218
+ t.extend(
219
+ # we should consider special token as a common token
220
+ self.model.encode(
221
+ substr,
222
+ allowed_special="all",
223
+ )
224
+ )
225
+ else:
226
+ t.extend(
227
+ # we should consider special token as a common token
228
+ self.model.encode(
229
+ substr,
230
+ disallowed_special=(),
231
+ )
232
+ )
233
+
234
+ return t
235
+
236
+ def decode(
237
+ self,
238
+ token_ids: Union[int, List[int]],
239
+ **kwargs
240
+ ) -> str:
241
+ """
242
+ Decodes a list of token IDs into a string.
243
+
244
+ Args:
245
+ token_ids (List[int]): The list of token IDs to be decoded.
246
+
247
+ Returns:
248
+ str: The decoded string.
249
+ """
250
+ # If there are other args, we should call super().decode because there are a lot of code
251
+ # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
252
+ if len(kwargs) > 0:
253
+ return super().decode(token_ids, **kwargs)
254
+
255
+ if type(token_ids) is int:
256
+ token_ids = [token_ids]
257
+
258
+ return self.model.decode(cast(List[int], token_ids))
259
+
260
+ @staticmethod
261
+ def _split_whitespaces_or_nonwhitespaces(
262
+ s: str, max_consecutive_slice_len: int
263
+ ) -> Iterator[str]:
264
+ """
265
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
266
+ consecutive whitespaces or consecutive non-whitespaces.
267
+ """
268
+ current_slice_len = 0
269
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
270
+ slice_start = 0
271
+
272
+ for i in range(len(s)):
273
+ is_now_space = s[i].isspace()
274
+
275
+ if current_slice_is_space ^ is_now_space:
276
+ current_slice_len = 1
277
+ current_slice_is_space = is_now_space
278
+ else:
279
+ current_slice_len += 1
280
+ if current_slice_len > max_consecutive_slice_len:
281
+ yield s[slice_start:i]
282
+ slice_start = i
283
+ current_slice_len = 1
284
+ yield s[slice_start:]
285
+
286
+ def pre_tokenizer_process(self, text: str) -> List[str]:
287
+ """
288
+ pre-tokenizes the input text into a list of tokens.
289
+ This method is used to split the input text into smaller chunks for internal processing.
290
+ """
291
+ return [text]
292
+
293
+
294
+ """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
295
+ @property
296
+ def vocab_size(self) -> int:
297
+ return self.n_words
298
+
299
+ def get_vocab(self) -> Dict[str, int]:
300
+ return self.encoder
301
+
302
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
303
+ return [
304
+ self.decoder[t]
305
+ for t in self.encode(text)
306
+ ]
307
+
308
+ def _convert_token_to_id(self, token: str) -> int:
309
+ return self.encoder.get(token, self.unk_id)
310
+
311
+ def _convert_id_to_token(self, index: int) -> str:
312
+ return self.decoder.get(index)
313
+
314
+ @staticmethod
315
+ def clean_up_tokenization(out_string: str) -> str:
316
+ return out_string
317
+
318
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
319
+ text = ''.join(tokens)
320
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', 'replace')
321
+ return text
322
+
323
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
324
+ if not os.path.isdir(save_directory):
325
+ raise ValueError(f"vocabulary path ({save_directory}) should be a directory")
326
+ out_vocab_file = os.path.join(
327
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
328
+ )
329
+
330
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
331
+ copyfile(self.vocab_file, out_vocab_file)
332
+
333
+ return (out_vocab_file,)
334
+
335
+
336
+
337
+ def apply_chat_template(
338
+ self, conversation, tools: Optional[list[dict]] = None,
339
+ tokenize: bool = False,
340
+ add_generation_prompt: bool = True,
341
+ **kwargs
342
+ ):
343
+ tools = deep_sort_dict(tools)
344
+ return super().apply_chat_template(conversation,
345
+ tools=tools,
346
+ tokenize=tokenize,
347
+ add_generation_prompt=add_generation_prompt,
348
+ **kwargs)
349
+
350
+
351
+ def deep_sort_dict(obj: Any) -> Any:
352
+ if isinstance(obj, dict):
353
+ return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
354
+ if isinstance(obj, list):
355
+ return [deep_sort_dict(item) for item in obj]
356
+ return obj
tokenizer_config.json ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "163584": {
4
+ "content": "[BOS]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "163585": {
12
+ "content": "[EOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "163586": {
20
+ "content": "<|im_end|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "163587": {
28
+ "content": "<|im_user|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "163588": {
36
+ "content": "<|im_assistant|>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "163590": {
44
+ "content": "<|start_header_id|>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "163591": {
52
+ "content": "<|end_header_id|>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "163593": {
60
+ "content": "[EOT]",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "163594": {
68
+ "content": "<|im_system|>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "163595": {
76
+ "content": "<|tool_calls_section_begin|>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": false
82
+ },
83
+ "163596": {
84
+ "content": "<|tool_calls_section_end|>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": false
90
+ },
91
+ "163597": {
92
+ "content": "<|tool_call_begin|>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": false
98
+ },
99
+ "163598": {
100
+ "content": "<|tool_call_argument_begin|>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": false
106
+ },
107
+ "163599": {
108
+ "content": "<|tool_call_end|>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": false
114
+ },
115
+ "163601": {
116
+ "content": "<|im_middle|>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "163838": {
124
+ "content": "[UNK]",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "163839": {
132
+ "content": "[PAD]",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ }
139
+ },
140
+ "additional_special_tokens": [
141
+ "<|im_end|>",
142
+ "<|im_user|>",
143
+ "<|im_assistant|>",
144
+ "<|start_header_id|>",
145
+ "<|end_header_id|>",
146
+ "[EOT]",
147
+ "<|im_system|>",
148
+ "<|im_middle|>"
149
+ ],
150
+ "bos_token": "[BOS]",
151
+ "clean_up_tokenization_spaces": false,
152
+ "eos_token": "[EOS]",
153
+ "extra_special_tokens": {},
154
+ "model_max_length": 1000000000000000019884624838656,
155
+ "pad_token": "[PAD]",
156
+ "tokenizer_class": "TikTokenTokenizer",
157
+ "unk_token": "[UNK]",
158
+ "auto_map": {
159
+ "AutoTokenizer": [
160
+ "tokenization_kimi.TikTokenTokenizer",
161
+ null
162
+ ]
163
+ }
164
+ }