Boojum commited on
Commit
fb9b425
·
verified ·
1 Parent(s): 19ad981

Upload folder using huggingface_hub

Browse files
chat_template.jinja ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_content(msg) -%}
2
+ {%- set c = msg.get('content') -%}
3
+ {%- if c is string -%}
4
+ {{ c }}
5
+ {%- elif c is not none -%}
6
+ {% for content in c -%}
7
+ {% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
8
+ <|media_start|>image<|media_content|><|media_pad|><|media_end|>
9
+ {% else -%}
10
+ {{ content['text'] }}
11
+ {%- endif -%}
12
+ {%- endfor -%}
13
+ {%- endif -%}
14
+ {%- endmacro %}
15
+
16
+
17
+ {%- if tools -%}
18
+ <|im_system|>tool_declare<|im_middle|>{{ tools | tojson(separators=(',', ':')) }}<|im_end|>
19
+ {%- endif -%}
20
+ {% for message in messages %}
21
+ {%- set role_name = message.get('name') or message['role'] -%}
22
+ {%- if message['role'] == 'user' -%}
23
+ <|im_user|>{{role_name}}<|im_middle|>
24
+ {%- elif message['role'] == 'assistant' -%}
25
+ <|im_assistant|>{{role_name}}<|im_middle|>
26
+ {%- else -%}
27
+ <|im_system|>{{role_name}}<|im_middle|>
28
+ {%- endif -%}
29
+
30
+ {%- if message['role'] == 'assistant' and message.get('tool_calls') -%}
31
+ {{render_content(message)}}<|tool_calls_section_begin|>
32
+ {%- for tool_call in message['tool_calls'] -%}
33
+ {%- set formatted_id = tool_call['id'] -%}
34
+ <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>
35
+ {%- endfor -%}
36
+ <|tool_calls_section_end|>
37
+ {%- elif message['role'] == 'tool' -%}
38
+ {%- set tool_call_id = message.tool_call_id -%}
39
+ ## Return of {{ tool_call_id }}
40
+ {{render_content(message)}}
41
+ {%- elif message['content'] is not none -%}
42
+ {{render_content(message)}}
43
+ {%- endif -%}
44
+ <|im_end|>
45
+ {%- endfor -%}
46
+ {%- if add_generation_prompt -%}
47
+ <|im_assistant|>assistant<|im_middle|>
48
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "KimiLinearForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_kimi.KimiLinearConfig",
7
+ "AutoModel": "modeling_kimi.KimiLinearModel",
8
+ "AutoModelForCausalLM": "modeling_kimi.KimiLinearForCausalLM"
9
+ },
10
+ "bos_token_id": 163584,
11
+ "dtype": "bfloat16",
12
+ "eos_token_id": 200002,
13
+ "first_k_dense_replace": 1,
14
+ "head_dim": 128,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 1024,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 3072,
19
+ "kv_lora_rank": 512,
20
+ "linear_attn_config": {
21
+ "full_attn_layers": [
22
+ 4,
23
+ 8,
24
+ 12,
25
+ 16
26
+ ],
27
+ "head_dim": 128,
28
+ "kda_layers": [
29
+ 1,
30
+ 2,
31
+ 3,
32
+ 5,
33
+ 6,
34
+ 7,
35
+ 9,
36
+ 10,
37
+ 11,
38
+ 13,
39
+ 14,
40
+ 15
41
+ ],
42
+ "num_heads": 32,
43
+ "short_conv_kernel_size": 4
44
+ },
45
+ "mla_use_nope": true,
46
+ "model_max_length": 1048576,
47
+ "model_type": "kimi_linear",
48
+ "moe_intermediate_size": 2048,
49
+ "moe_layer_freq": 1,
50
+ "moe_renormalize": true,
51
+ "moe_router_activation_func": "sigmoid",
52
+ "num_attention_heads": 16,
53
+ "num_expert_group": 1,
54
+ "num_experts": 16,
55
+ "num_experts_per_token": 2,
56
+ "num_hidden_layers": 18,
57
+ "num_key_value_heads": 2,
58
+ "num_nextn_predict_layers": 0,
59
+ "num_shared_experts": 1,
60
+ "pad_token_id": 199999,
61
+ "q_lora_rank": null,
62
+ "qk_nope_head_dim": 128,
63
+ "qk_rope_head_dim": 64,
64
+ "rms_norm_eps": 1e-05,
65
+ "rope_scaling": null,
66
+ "rope_theta": 10000.0,
67
+ "routed_scaling_factor": 2.446,
68
+ "tie_word_embeddings": false,
69
+ "topk_group": 1,
70
+ "transformers_version": "4.57.1",
71
+ "use_cache": true,
72
+ "use_grouped_topk": true,
73
+ "v_head_dim": 128,
74
+ "vocab_size": 201088
75
+ }
configuration_kimi.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from typing import Optional
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+
7
+ class KimiLinearConfig(PretrainedConfig):
8
+ model_type = "kimi_linear"
9
+ keys_to_ignore_at_inference = ["past_key_values"]
10
+
11
+ def __init__(
12
+ self,
13
+ model_type="kimi_linear",
14
+ vocab_size=163840,
15
+ hidden_size=4096,
16
+ head_dim=None,
17
+ intermediate_size=11008,
18
+ num_hidden_layers=32,
19
+ num_attention_heads=32,
20
+ num_key_value_heads=None,
21
+ hidden_act="silu",
22
+ initializer_range=0.02,
23
+ rms_norm_eps=1e-6,
24
+ use_cache=True,
25
+ pad_token_id=0,
26
+ bos_token_id=1,
27
+ eos_token_id=2,
28
+ rope_theta=10000.0,
29
+ rope_scaling=None,
30
+ tie_word_embeddings=False,
31
+ moe_intermediate_size: Optional[int] = None,
32
+ moe_renormalize: bool = True,
33
+ moe_router_activation_func: str = "sigmoid",
34
+ num_experts: Optional[int] = None,
35
+ num_experts_per_token: Optional[int] = None,
36
+ num_shared_experts: int = 0,
37
+ routed_scaling_factor: float = 1.0,
38
+ first_k_dense_replace: int = 0,
39
+ moe_layer_freq: int = 1,
40
+ use_grouped_topk: bool = True,
41
+ num_expert_group: int = 1,
42
+ topk_group: int = 1,
43
+ q_lora_rank: Optional[int] = None,
44
+ kv_lora_rank: Optional[int] = None,
45
+ qk_nope_head_dim: Optional[int] = None,
46
+ qk_rope_head_dim: Optional[int] = None,
47
+ v_head_dim: Optional[int] = None,
48
+ mla_use_nope: Optional[bool] = False,
49
+ num_nextn_predict_layers: int = 0,
50
+ linear_attn_config: Optional[dict] = None,
51
+ **kwargs,
52
+ ):
53
+ self.model_type = model_type
54
+ self.vocab_size = vocab_size
55
+ self.hidden_size = hidden_size
56
+ self.head_dim = (
57
+ head_dim if head_dim is not None else hidden_size // num_attention_heads
58
+ )
59
+ self.intermediate_size = intermediate_size
60
+ self.num_hidden_layers = num_hidden_layers
61
+ self.num_attention_heads = num_attention_heads
62
+
63
+ # for backward compatibility
64
+ if num_key_value_heads is None:
65
+ num_key_value_heads = num_attention_heads
66
+
67
+ self.num_key_value_heads = num_key_value_heads
68
+ self.hidden_act = hidden_act
69
+ self.initializer_range = initializer_range
70
+ self.rms_norm_eps = rms_norm_eps
71
+ self.use_cache = use_cache
72
+ self.rope_theta = rope_theta
73
+ self.rope_scaling = rope_scaling
74
+
75
+ self.q_lora_rank = q_lora_rank
76
+ self.kv_lora_rank = kv_lora_rank
77
+ self.qk_nope_head_dim = qk_nope_head_dim
78
+ self.qk_rope_head_dim = qk_rope_head_dim
79
+ self.v_head_dim = v_head_dim
80
+ self.mla_use_nope = mla_use_nope
81
+ # moe config
82
+ self.num_experts = num_experts
83
+ self.num_experts_per_token = num_experts_per_token
84
+ self.moe_renormalize = moe_renormalize
85
+ self.num_shared_experts = num_shared_experts
86
+ self.routed_scaling_factor = routed_scaling_factor
87
+ self.moe_router_activation_func = moe_router_activation_func
88
+ assert self.moe_router_activation_func in ("softmax", "sigmoid")
89
+ self.moe_intermediate_size = moe_intermediate_size
90
+ self.first_k_dense_replace = first_k_dense_replace
91
+ self.moe_layer_freq = moe_layer_freq
92
+ self.use_grouped_topk = use_grouped_topk
93
+ self.num_expert_group = num_expert_group
94
+ self.topk_group = topk_group
95
+ self.num_nextn_predict_layers = num_nextn_predict_layers
96
+
97
+ if linear_attn_config is not None:
98
+ assert linear_attn_config["kda_layers"] is not None
99
+ assert linear_attn_config["full_attn_layers"] is not None
100
+ self.linear_attn_config = linear_attn_config
101
+
102
+ super().__init__(
103
+ pad_token_id=pad_token_id,
104
+ bos_token_id=bos_token_id,
105
+ eos_token_id=eos_token_id,
106
+ tie_word_embeddings=tie_word_embeddings,
107
+ **kwargs,
108
+ )
109
+
110
+ @property
111
+ def is_mla(self):
112
+ return (
113
+ self.q_lora_rank is not None
114
+ or self.kv_lora_rank is not None
115
+ or self.qk_nope_head_dim is not None
116
+ or self.qk_rope_head_dim is not None
117
+ or self.v_head_dim is not None
118
+ or self.mla_use_nope is True
119
+ )
120
+
121
+ @property
122
+ def is_moe(self):
123
+ return self.num_experts is not None
124
+
125
+ @property
126
+ def is_linear_attn(self) -> bool:
127
+ return not (
128
+ self.linear_attn_config is None
129
+ or (
130
+ isinstance(self.linear_attn_config, dict)
131
+ and self.linear_attn_config["kda_layers"] is not None
132
+ and len(self.linear_attn_config["kda_layers"]) == 0
133
+ )
134
+ )
135
+
136
+ def is_kda_layer(self, layer_idx: int):
137
+ return (
138
+ self.linear_attn_config is not None
139
+ and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
140
+ )
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 163584,
4
+ "eos_token_id": 200002,
5
+ "pad_token_id": 199999,
6
+ "transformers_version": "4.57.1"
7
+ }
model-00001-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b46dce368ffa84238ec6704b4aa4af05d919f4f999b7df347d9e1fedf4ddfc0
3
+ size 996133416
model-00002-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb8d14d02a6a03b684d1b21faa157a2c381a76ea4dc378999863cf840a53c93d
3
+ size 997527352
model-00003-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c088e0d0ff87e21d36a78b5e4f2c263cd69a3142cf9c978607f0d9cbe48afe31
3
+ size 997527440
model-00004-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d97415e8d3a757bb58f92e3a8de8b1088c5b9a429373e176741b22f9f20b01f5
3
+ size 997527624
model-00005-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c505e76bf66a850d78c70cbe007fbf2def8ff8a8dc57b5afe3f9624e1fd2ffc
3
+ size 610668816
model-00006-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f2c821308069fd4f4b7b9a1df4604fcd22cc52f0ce3460e828d2127f3da34c7
3
+ size 411828352
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_kimi.py ADDED
@@ -0,0 +1,1099 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Callable
3
+ from typing import Any
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import transformers
8
+ from einops import rearrange, repeat
9
+ from packaging import version
10
+ from torch import nn
11
+ from transformers.activations import ACT2FN
12
+ from transformers.cache_utils import Cache
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.masking_utils import create_causal_mask
15
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
17
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
18
+ from transformers.processing_utils import Unpack
19
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
20
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
21
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
22
+
23
+ try:
24
+ from fla.modules import FusedRMSNormGated, ShortConvolution
25
+ from fla.ops.kda import chunk_kda, fused_recurrent_kda
26
+ from fla.ops.kda.gate import fused_kda_gate
27
+ from fla.ops.utils.index import prepare_cu_seqlens_from_mask, prepare_lens_from_mask
28
+ from fla.utils import tensor_cache
29
+ except ImportError:
30
+ raise ImportError("Plese run `pip install -U fla-core`")
31
+
32
+ from .configuration_kimi import KimiLinearConfig
33
+
34
+ assert version.parse(transformers.__version__) >= version.parse("4.56.0"), \
35
+ "Please upgrade transformers to >= 4.56.0"
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ def index_first_axis(x, indices):
41
+ other_shape = x.shape[1:]
42
+ second_dim = other_shape.numel()
43
+ return torch.gather(
44
+ rearrange(x, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim),
45
+ ).reshape(-1, *other_shape)
46
+
47
+
48
+ def index_put_first_axis(x, indices, first_axis_dim):
49
+ y = torch.zeros(first_axis_dim, *x.shape[1:], device=x.device, dtype=x.dtype)
50
+ # TODO [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
51
+ y[indices] = x
52
+ # y.scatter_(0, repeat(indices, 'z -> z d', d=x.shape[1]), x)
53
+ return y
54
+
55
+
56
+ @tensor_cache
57
+ def get_unpad_data(
58
+ attention_mask: torch.Tensor,
59
+ ) -> tuple[torch.Tensor, torch.Tensor, int]:
60
+ lens = prepare_lens_from_mask(attention_mask)
61
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
62
+ max_seqlen_in_batch = lens.max().item()
63
+ cu_seqlens = prepare_cu_seqlens_from_mask(attention_mask)
64
+ return indices, cu_seqlens, max_seqlen_in_batch
65
+
66
+
67
+ def unpad_input(
68
+ q: torch.Tensor,
69
+ states: tuple[torch.Tensor],
70
+ attention_mask: torch.Tensor,
71
+ q_len: int,
72
+ keepdim: bool = False,
73
+ ):
74
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask)
75
+ batch_size, seq_len, *_ = states[0].shape
76
+
77
+ state = tuple(
78
+ index_first_axis(rearrange(s, "b s ... -> (b s) ..."), indices_k)
79
+ for s in states
80
+ )
81
+
82
+ if q_len == seq_len:
83
+ q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k)
84
+ cu_seqlens_q = cu_seqlens_k
85
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
86
+ indices_q = indices_k
87
+ elif q_len == 1:
88
+ max_seqlen_in_batch_q = 1
89
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
90
+ indices_q = cu_seqlens_q[:-1]
91
+ q = q.squeeze(1)
92
+ else:
93
+ raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)")
94
+
95
+ if keepdim:
96
+ q = q.unsqueeze(0)
97
+ state = tuple(s.unsqueeze(0) for s in state)
98
+
99
+ return (
100
+ q,
101
+ state,
102
+ indices_q,
103
+ (cu_seqlens_q, cu_seqlens_k),
104
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
105
+ )
106
+
107
+
108
+ def pad_input(
109
+ hidden_states: torch.Tensor,
110
+ indices: torch.LongTensor,
111
+ batch_size: int,
112
+ seq_len: int,
113
+ ) -> torch.Tensor:
114
+ output = index_put_first_axis(hidden_states, indices, batch_size * seq_len)
115
+ return rearrange(output, "(b s) ... -> b s ...", b=batch_size)
116
+
117
+
118
+ class KimiDynamicCache:
119
+ """
120
+ Dynamic cache for Kimi model.
121
+ Inspired by Qwen3-Next
122
+ """
123
+ is_compileable = False
124
+
125
+ def __init__(self, config: KimiLinearConfig):
126
+ super().__init__()
127
+ self.config = config
128
+
129
+ if config.linear_attn_config is not None:
130
+ self.layer_types = []
131
+ for i in range(config.num_hidden_layers):
132
+ if config.is_kda_layer(i):
133
+ self.layer_types.append("linear_attention")
134
+ else:
135
+ self.layer_types.append("full_attention")
136
+ else:
137
+ self.layer_types = ["full_attention"] * config.num_hidden_layers
138
+
139
+ self.transformer_layers = [
140
+ i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention"
141
+ ]
142
+
143
+ linear_layers = [i for i in range(
144
+ config.num_hidden_layers) if self.layer_types[i] == "linear_attention"]
145
+ self.last_linear_layer = linear_layers[-1] if linear_layers else -1
146
+
147
+ self.conv_states = [None for _ in range(config.num_hidden_layers)]
148
+ self.recurrent_states = [None for _ in range(config.num_hidden_layers)]
149
+ self.key_cache = [None for _ in range(config.num_hidden_layers)]
150
+ self.value_cache = [None for _ in range(config.num_hidden_layers)]
151
+
152
+ def __len__(self):
153
+ return len(self.layer_types)
154
+
155
+ def update(
156
+ self,
157
+ key_states: torch.Tensor,
158
+ value_states: torch.Tensor,
159
+ layer_idx: int,
160
+ cache_kwargs: dict[str, Any] | None = None,
161
+ ) -> tuple[torch.Tensor, torch.Tensor]:
162
+ if self.key_cache[layer_idx] is None:
163
+ self.key_cache[layer_idx] = key_states
164
+ self.value_cache[layer_idx] = value_states
165
+ else:
166
+ self.key_cache[layer_idx] = torch.cat(
167
+ [self.key_cache[layer_idx], key_states], dim=2)
168
+ self.value_cache[layer_idx] = torch.cat(
169
+ [self.value_cache[layer_idx], value_states], dim=2)
170
+
171
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
172
+
173
+ def reorder_cache(self, beam_idx: torch.LongTensor):
174
+ """Reorders the cache for beam search, given the selected beam indices."""
175
+ for layer_idx in range(len(self.key_cache)):
176
+ if self.key_cache[layer_idx] is not None:
177
+ device = self.key_cache[layer_idx].device
178
+ beam_idx = beam_idx.to(device)
179
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
180
+ 0, beam_idx)
181
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
182
+ 0, beam_idx)
183
+
184
+ if self.conv_states[layer_idx] is not None:
185
+ device = self.conv_states[layer_idx][0].device
186
+ beam_idx = beam_idx.to(device)
187
+ q_conv, k_conv, v_conv = self.conv_states[layer_idx]
188
+ self.conv_states[layer_idx] = (
189
+ q_conv.index_select(0, beam_idx),
190
+ k_conv.index_select(0, beam_idx),
191
+ v_conv.index_select(0, beam_idx),
192
+ )
193
+ self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(
194
+ 0, beam_idx)
195
+
196
+ def get_seq_length(self, layer_idx: int | None = 0) -> int:
197
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
198
+ # take any layer that contains cache and not empty tensor
199
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
200
+ if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:
201
+ return 0
202
+ return self.key_cache[layer_idx].shape[-2]
203
+
204
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
205
+ """
206
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
207
+ the given layer at `layer_idx`.
208
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
209
+ """
210
+ kv_offset = 0
211
+ query_length = cache_position.shape[0]
212
+ past_seen_tokens = self.get_seq_length(layer_idx)
213
+ kv_length = query_length + past_seen_tokens
214
+ return kv_length, kv_offset
215
+
216
+ @property
217
+ def has_previous_state(self):
218
+ """We have a previous state if the last linear (conv) layer was already updated."""
219
+ if self.last_linear_layer == -1:
220
+ return False
221
+ return self.conv_states[self.last_linear_layer] is not None
222
+
223
+
224
+ class KimiRMSNorm(nn.Module):
225
+ def __init__(self, hidden_size, eps=1e-6):
226
+ """
227
+ KimiRMSNorm is equivalent to T5LayerNorm
228
+ """
229
+ super().__init__()
230
+ self.weight = nn.Parameter(torch.ones(hidden_size))
231
+ self.variance_epsilon = eps
232
+
233
+ def forward(self, hidden_states):
234
+ input_dtype = hidden_states.dtype
235
+ hidden_states = hidden_states.to(torch.float32)
236
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
237
+ hidden_states = hidden_states * \
238
+ torch.rsqrt(variance + self.variance_epsilon)
239
+ return self.weight * hidden_states.to(input_dtype)
240
+
241
+
242
+ ALL_LAYERNORM_LAYERS.append(KimiRMSNorm)
243
+
244
+
245
+ class KimiBlockSparseMLP(nn.Module):
246
+ def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
247
+ super().__init__()
248
+ self.config = config
249
+ self.ffn_dim = config.intermediate_size if intermediate_size is None else intermediate_size
250
+ self.hidden_dim = config.hidden_size if hidden_size is None else hidden_size
251
+
252
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # gate
253
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) # down
254
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # up
255
+
256
+ self.act_fn = ACT2FN[config.hidden_act]
257
+
258
+ def forward(self, hidden_states):
259
+ current_hidden_states = self.act_fn(
260
+ self.w1(hidden_states)) * self.w3(hidden_states)
261
+ current_hidden_states = self.w2(current_hidden_states)
262
+ return current_hidden_states
263
+
264
+
265
+ class KimiMLP(nn.Module):
266
+ def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
267
+ super().__init__()
268
+ self.config = config
269
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
270
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
271
+ self.gate_proj = nn.Linear(
272
+ self.hidden_size, self.intermediate_size, bias=False)
273
+ self.up_proj = nn.Linear(
274
+ self.hidden_size, self.intermediate_size, bias=False)
275
+ self.down_proj = nn.Linear(
276
+ self.intermediate_size, self.hidden_size, bias=False)
277
+ self.act_fn = ACT2FN[config.hidden_act]
278
+
279
+ def forward(self, x):
280
+ down_proj = self.down_proj(self.act_fn(
281
+ self.gate_proj(x)) * self.up_proj(x))
282
+ return down_proj
283
+
284
+
285
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
286
+ """
287
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
288
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
289
+ """
290
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
291
+ if n_rep == 1:
292
+ return hidden_states
293
+ hidden_states = hidden_states[:, :, None, :, :].expand(
294
+ batch, num_key_value_heads, n_rep, slen, head_dim)
295
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
296
+
297
+
298
+ def eager_attention_forward(
299
+ module: nn.Module,
300
+ query: torch.Tensor,
301
+ key: torch.Tensor,
302
+ value: torch.Tensor,
303
+ attention_mask: torch.Tensor | None,
304
+ scaling: float,
305
+ dropout: float = 0.0,
306
+ **kwargs: Unpack[TransformersKwargs],
307
+ ):
308
+ key_states = repeat_kv(key, module.num_key_value_groups)
309
+ value_states = repeat_kv(value, module.num_key_value_groups)
310
+
311
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
312
+ if attention_mask is not None:
313
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
314
+ attn_weights = attn_weights + causal_mask
315
+
316
+ attn_weights = nn.functional.softmax(
317
+ attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
318
+ attn_weights = nn.functional.dropout(
319
+ attn_weights, p=dropout, training=module.training)
320
+ attn_output = torch.matmul(attn_weights, value_states)
321
+ attn_output = attn_output.transpose(1, 2).contiguous()
322
+
323
+ return attn_output, attn_weights
324
+
325
+
326
+ class KimiMLAAttention(nn.Module):
327
+ """
328
+ Multi-Latent Attention adapted from deepseek-v3
329
+ """
330
+
331
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
332
+ nn.Module.__init__(self)
333
+ self.config = config
334
+ self.layer_idx = layer_idx
335
+ self.hidden_size = config.hidden_size
336
+ self.num_heads = config.num_attention_heads
337
+ self.num_key_value_heads = config.num_key_value_heads
338
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
339
+
340
+ self.rope_theta = config.rope_theta
341
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
342
+
343
+ try:
344
+ self.q_lora_rank = config.q_lora_rank
345
+ self.qk_rope_head_dim = config.qk_rope_head_dim
346
+ self.kv_lora_rank = config.kv_lora_rank
347
+ self.v_head_dim = config.v_head_dim
348
+ self.qk_nope_head_dim = config.qk_nope_head_dim
349
+ self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
350
+ self.use_nope = config.mla_use_nope
351
+ self.scaling = self.q_head_dim ** (-0.5)
352
+ except Exception as e:
353
+ raise ValueError(
354
+ f"Kimi MLA config is not found or not properly formatted: {e}")
355
+
356
+ assert self.q_lora_rank is None
357
+ self.q_proj = nn.Linear(
358
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False,
359
+ )
360
+ self.kv_a_proj_with_mqa = nn.Linear(
361
+ self.hidden_size,
362
+ self.kv_lora_rank + self.qk_rope_head_dim,
363
+ bias=False,
364
+ )
365
+ self.kv_a_layernorm = KimiRMSNorm(self.kv_lora_rank)
366
+ self.kv_b_proj = nn.Linear(
367
+ self.kv_lora_rank,
368
+ self.num_heads
369
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
370
+ bias=False,
371
+ )
372
+ self.o_proj = nn.Linear(
373
+ self.num_heads * self.v_head_dim,
374
+ self.hidden_size,
375
+ bias=False,
376
+ )
377
+ self.is_causal = True
378
+ assert self.use_nope
379
+
380
+ def forward(
381
+ self,
382
+ hidden_states: torch.Tensor,
383
+ attention_mask: torch.Tensor | None = None,
384
+ past_key_values: Cache | None = None,
385
+ **kwargs,
386
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
387
+ batch_size, seq_length = hidden_states.shape[:-1]
388
+ query_shape = (batch_size, seq_length, -1, self.q_head_dim)
389
+ key_shape = (batch_size, seq_length, -1,
390
+ self.qk_nope_head_dim + self.v_head_dim)
391
+
392
+ q_states = self.q_proj(hidden_states)
393
+ q_states = q_states.view(query_shape).transpose(1, 2)
394
+ q_pass, q_rot = torch.split(
395
+ q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
396
+
397
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
398
+ k_pass, k_rot = torch.split(
399
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
400
+
401
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(
402
+ k_pass)).view(key_shape).transpose(1, 2)
403
+ k_pass, value_states = torch.split(
404
+ k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
405
+
406
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
407
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
408
+
409
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
410
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
411
+
412
+ if past_key_values is not None:
413
+ key_states, value_states = past_key_values.update(
414
+ key_states, value_states, self.layer_idx)
415
+
416
+ if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
417
+ value_states = F.pad(
418
+ value_states, [0, self.q_head_dim - self.v_head_dim])
419
+
420
+ attention_interface: Callable = eager_attention_forward
421
+ if self.config._attn_implementation != "eager":
422
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
423
+
424
+ attn_output, _ = attention_interface(
425
+ self,
426
+ query_states,
427
+ key_states,
428
+ value_states,
429
+ attention_mask,
430
+ dropout=0.0 if not self.training else self.attention_dropout,
431
+ scaling=self.scaling,
432
+ **kwargs,
433
+ )
434
+
435
+ if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
436
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
437
+
438
+ attn_output = attn_output.reshape(
439
+ batch_size, seq_length, -1).contiguous()
440
+ attn_output = self.o_proj(attn_output)
441
+ return attn_output
442
+
443
+
444
+ class KimiDeltaAttention(nn.Module):
445
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
446
+ super().__init__()
447
+ self.config = config
448
+ self.mode = "chunk"
449
+
450
+ self.hidden_size = config.hidden_size
451
+ self.conv_size = config.linear_attn_config["short_conv_kernel_size"]
452
+ self.head_dim = config.linear_attn_config["head_dim"]
453
+ self.num_heads = config.linear_attn_config["num_heads"]
454
+ self.head_k_dim = self.head_dim
455
+ self.num_k_heads = self.num_heads
456
+
457
+ self.layer_idx = layer_idx
458
+
459
+ assert self.mode in [
460
+ 'chunk', 'fused_recurrent'], f"Not suppoerted mode `{self.mode}`."
461
+
462
+ projection_k_size = self.head_k_dim * self.num_k_heads
463
+ projection_size = self.head_dim * self.num_heads
464
+
465
+ self.q_proj = nn.Linear(
466
+ self.hidden_size, projection_k_size, bias=False)
467
+ self.k_proj = nn.Linear(
468
+ self.hidden_size, projection_k_size, bias=False)
469
+ self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False)
470
+
471
+ self.q_conv1d = ShortConvolution(
472
+ hidden_size=projection_k_size,
473
+ kernel_size=self.conv_size,
474
+ activation='silu',
475
+ )
476
+ self.k_conv1d = ShortConvolution(
477
+ hidden_size=projection_k_size,
478
+ kernel_size=self.conv_size,
479
+ activation='silu',
480
+ )
481
+ self.v_conv1d = ShortConvolution(
482
+ hidden_size=projection_size,
483
+ kernel_size=self.conv_size,
484
+ activation='silu',
485
+ )
486
+
487
+ self.A_log = torch.nn.Parameter(torch.log(torch.empty(
488
+ self.num_heads, dtype=torch.float32).uniform_(1, 16)).view(1, 1, -1, 1))
489
+
490
+ self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
491
+ self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
492
+
493
+ self.dt_bias = nn.Parameter(
494
+ torch.empty(projection_size, dtype=torch.float32))
495
+
496
+ self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
497
+
498
+ self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
499
+ self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
500
+
501
+ self.o_norm = FusedRMSNormGated(
502
+ self.head_dim, eps=config.rms_norm_eps, activation='sigmoid')
503
+ self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False)
504
+
505
+ def forward(
506
+ self,
507
+ hidden_states: torch.Tensor,
508
+ attention_mask: torch.Tensor | None = None,
509
+ cache_params: KimiDynamicCache | None = None,
510
+ **kwargs: Unpack[dict],
511
+ ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
512
+ if attention_mask is not None:
513
+ if attention_mask.dim() != 2:
514
+ attention_mask = kwargs.get("padding_mask")
515
+
516
+ if attention_mask is not None and attention_mask.dim() != 2:
517
+ raise ValueError(
518
+ "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] "
519
+ "(0 = padding). 3D masks are not supported here.",
520
+ )
521
+ use_cache = cache_params is not None
522
+ batch_size, q_len, _ = hidden_states.shape
523
+ mode = 'fused_recurrent' if q_len <= 64 else self.mode
524
+ if self.training:
525
+ assert mode == 'chunk', "Only chunk mode is supported in training."
526
+
527
+ cu_seqlens = kwargs.get('cu_seqlens')
528
+ indices = None
529
+ if attention_mask is not None:
530
+ indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
531
+ hidden_states = index_first_axis(
532
+ rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
533
+
534
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
535
+ recurrent_state = None
536
+ if cache_params is not None:
537
+ if cache_params.conv_states[self.layer_idx] is not None:
538
+ conv_state_q, conv_state_k, conv_state_v = cache_params.conv_states[
539
+ self.layer_idx]
540
+ recurrent_state = cache_params.recurrent_states[self.layer_idx]
541
+ q, conv_state_q = self.q_conv1d(
542
+ x=self.q_proj(hidden_states),
543
+ cache=conv_state_q,
544
+ output_final_state=use_cache,
545
+ cu_seqlens=cu_seqlens,
546
+ )
547
+ k, conv_state_k = self.k_conv1d(
548
+ x=self.k_proj(hidden_states),
549
+ cache=conv_state_k,
550
+ output_final_state=use_cache,
551
+ cu_seqlens=cu_seqlens,
552
+ )
553
+ v, conv_state_v = self.v_conv1d(
554
+ x=self.v_proj(hidden_states),
555
+ cache=conv_state_v,
556
+ output_final_state=use_cache,
557
+ cu_seqlens=cu_seqlens,
558
+ )
559
+ g = self.f_b_proj(self.f_a_proj(hidden_states))
560
+ g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
561
+ beta = self.b_proj(hidden_states).float().sigmoid()
562
+
563
+ q, k = map(lambda x: rearrange(
564
+ x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
565
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
566
+
567
+ if mode == 'chunk':
568
+ o, recurrent_state = chunk_kda(
569
+ q=q,
570
+ k=k,
571
+ v=v,
572
+ g=g,
573
+ beta=beta,
574
+ initial_state=recurrent_state,
575
+ output_final_state=True,
576
+ use_qk_l2norm_in_kernel=True,
577
+ cu_seqlens=cu_seqlens,
578
+ )
579
+ else:
580
+ o, recurrent_state = fused_recurrent_kda(
581
+ q=q,
582
+ k=k,
583
+ v=v,
584
+ g=g,
585
+ beta=beta,
586
+ initial_state=recurrent_state,
587
+ output_final_state=True,
588
+ use_qk_l2norm_in_kernel=True,
589
+ cu_seqlens=cu_seqlens,
590
+ )
591
+ if cache_params is not None:
592
+ cache_params.recurrent_states[self.layer_idx] = recurrent_state
593
+ cache_params.conv_states[self.layer_idx] = (
594
+ conv_state_q, conv_state_k, conv_state_v)
595
+
596
+ g = self.g_b_proj(self.g_a_proj(hidden_states))
597
+ g = rearrange(g, '... (h d) -> ... h d', d=self.head_dim)
598
+ o = self.o_norm(o, g)
599
+
600
+ o = rearrange(o, 'b t h d -> b t (h d)')
601
+ o = self.o_proj(o)
602
+ if attention_mask is not None:
603
+ o = pad_input(o.squeeze(0), indices, batch_size, q_len)
604
+
605
+ return o
606
+
607
+
608
+ class KimiMoEGate(nn.Module):
609
+ """
610
+ MoEGate adapted from Deepseek-V3.
611
+ Parameter correspondences:
612
+ num_experts -> n_routed_experts
613
+ num_experts_per_token -> num_experts_per_tok
614
+ num_expert_group -> n_group
615
+ moe_router_activation_func -> scoring_func
616
+ """
617
+
618
+ def __init__(self, config: KimiLinearConfig):
619
+ super().__init__()
620
+ self.config = config
621
+ self.top_k = config.num_experts_per_token
622
+ self.num_experts = config.num_experts
623
+ self.routed_scaling_factor = config.routed_scaling_factor
624
+ self.moe_router_activation_func = config.moe_router_activation_func
625
+ self.num_expert_group = getattr(config, "num_expert_group", 1)
626
+ self.topk_group = getattr(config, "topk_group", 1)
627
+
628
+ # topk selection algorithm
629
+ self.moe_renormalize = config.moe_renormalize
630
+ self.gating_dim = config.hidden_size
631
+ self.weight = nn.Parameter(
632
+ torch.empty((self.num_experts, self.gating_dim)),
633
+ )
634
+
635
+ self.e_score_correction_bias = nn.Parameter(
636
+ torch.empty(self.num_experts),
637
+ )
638
+ self.reset_parameters()
639
+
640
+ def reset_parameters(self) -> None:
641
+ import torch.nn.init as init
642
+
643
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
644
+
645
+ def forward(self, hidden_states):
646
+ bsz, seq_len, h = hidden_states.shape
647
+ # compute gating score
648
+ hidden_states = hidden_states.view(-1, h)
649
+ logits = F.linear(
650
+ hidden_states.type(torch.float32), self.weight.type(
651
+ torch.float32), None,
652
+ )
653
+ if self.moe_router_activation_func == "sigmoid":
654
+ scores = logits.sigmoid()
655
+ elif self.moe_router_activation_func == "softmax":
656
+ scores = logits.softmax(dim=1)
657
+ else:
658
+ raise NotImplementedError(
659
+ f"insupportable scoring function for MoE gating: {self.moe_router_activation_func}",
660
+ )
661
+
662
+ # select top-k experts
663
+ assert not self.training
664
+ scores_for_choice = scores.view(bsz * seq_len, -1)
665
+ scores_for_choice += self.e_score_correction_bias.unsqueeze(0)
666
+ group_scores = (
667
+ scores_for_choice.view(
668
+ bsz * seq_len, self.num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
669
+ ) # [n, num_expert_group]
670
+ group_idx = torch.topk(
671
+ group_scores, k=self.topk_group, dim=-1, sorted=False,
672
+ )[
673
+ 1
674
+ ] # [n, top_k_group]
675
+ group_mask = torch.zeros_like(group_scores) # [n, num_expert_group]
676
+ group_mask.scatter_(1, group_idx, 1) # [n, num_expert_group]
677
+ score_mask = (
678
+ group_mask.unsqueeze(-1)
679
+ .expand(
680
+ bsz * seq_len, self.num_expert_group, self.num_experts // self.num_expert_group,
681
+ )
682
+ .reshape(bsz * seq_len, -1)
683
+ ) # [n, e]
684
+ tmp_scores = scores_for_choice.masked_fill(
685
+ ~score_mask.bool(), 0.0) # [n, e]
686
+ _, topk_idx = torch.topk(
687
+ tmp_scores, k=self.top_k, dim=-1, sorted=False,
688
+ )
689
+ topk_weight = scores.gather(1, topk_idx)
690
+
691
+ # norm gate to sum 1
692
+ if self.top_k > 1 and self.moe_renormalize:
693
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
694
+ topk_weight = topk_weight / denominator
695
+ # must multiply the scaling factor
696
+ topk_weight = topk_weight * self.routed_scaling_factor
697
+
698
+ return topk_idx, topk_weight
699
+
700
+
701
+ class KimiSparseMoeBlock(nn.Module):
702
+ """
703
+ Adapted from Deepseek-V3's MOE implementation
704
+ The namings are consistent with Kimi's version.
705
+ """
706
+
707
+ def __init__(self, config: KimiLinearConfig):
708
+ super().__init__()
709
+ self.config = config
710
+ self.hidden_dim = config.hidden_size
711
+ self.num_experts = config.num_experts
712
+ self.top_k = config.num_experts_per_token
713
+ self.moe_renormalize = config.moe_renormalize
714
+
715
+ self.ep_size = 1
716
+ self.experts_per_rank = config.num_experts
717
+ self.ep_rank = 0
718
+ self.experts = nn.ModuleList(
719
+ [
720
+ KimiBlockSparseMLP(
721
+ config, intermediate_size=config.moe_intermediate_size,
722
+ )
723
+ for _ in range(config.num_experts)
724
+ ],
725
+ )
726
+ self.gate = KimiMoEGate(config)
727
+ if config.num_shared_experts is not None:
728
+ intermediate_size = config.moe_intermediate_size * config.num_shared_experts
729
+ self.shared_experts = KimiMLP(
730
+ config=config, intermediate_size=intermediate_size,
731
+ )
732
+
733
+ def forward(self, hidden_states):
734
+ identity = hidden_states
735
+ orig_shape = hidden_states.shape
736
+ topk_idx, topk_weight = self.gate(hidden_states)
737
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
738
+ if not self.training:
739
+ y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
740
+ else:
741
+ raise NotImplementedError("Training mode is not supported in KimiSparseMoeBlock")
742
+ if self.config.num_shared_experts is not None:
743
+ y = y + self.shared_experts(identity)
744
+ return y
745
+
746
+ @torch.no_grad()
747
+ def moe_infer(self, x, topk_ids, topk_weight):
748
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
749
+ cnts.scatter_(1, topk_ids, 1)
750
+ tokens_per_expert = cnts.sum(dim=0)
751
+ idxs = topk_ids.view(-1).argsort()
752
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
753
+
754
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
755
+
756
+ outputs = []
757
+ start_idx = 0
758
+ for i, num_tokens in enumerate(tokens_per_expert):
759
+ end_idx = start_idx + num_tokens
760
+ if num_tokens == 0:
761
+ continue
762
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
763
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
764
+ expert_out = expert(tokens_for_this_expert)
765
+ outputs.append(expert_out)
766
+ start_idx = end_idx
767
+
768
+ outs = torch.cat(outputs, dim=0) if len(
769
+ outputs) else sorted_tokens.new_empty(0)
770
+
771
+ new_x = torch.empty_like(outs)
772
+ new_x[idxs] = outs
773
+ final_out = (
774
+ new_x.view(*topk_ids.shape, -1)
775
+ .type(topk_weight.dtype)
776
+ .mul_(topk_weight.unsqueeze(dim=-1))
777
+ .sum(dim=1)
778
+ .type(new_x.dtype)
779
+ )
780
+ return final_out
781
+
782
+
783
+ class KimiDecoderLayer(nn.Module):
784
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
785
+ super().__init__()
786
+ self.hidden_size = config.hidden_size
787
+ self.config = config
788
+ if config.is_kda_layer(layer_idx):
789
+ self.is_linear_attn = True
790
+ self.self_attn = KimiDeltaAttention(
791
+ config=config, layer_idx=layer_idx)
792
+ elif config.is_mla:
793
+ self.is_linear_attn = False
794
+ self.self_attn = KimiMLAAttention(
795
+ config=config, layer_idx=layer_idx)
796
+ else:
797
+ raise NotImplementedError
798
+ if (
799
+ config.num_experts is not None
800
+ and layer_idx >= config.first_k_dense_replace
801
+ and layer_idx % getattr(config, "moe_layer_freq", 1) == 0
802
+ ):
803
+ self.block_sparse_moe = KimiSparseMoeBlock(config)
804
+ else:
805
+ self.mlp = KimiMLP(config)
806
+ self.input_layernorm = KimiRMSNorm(
807
+ config.hidden_size, eps=config.rms_norm_eps)
808
+ self.post_attention_layernorm = KimiRMSNorm(
809
+ config.hidden_size, eps=config.rms_norm_eps)
810
+
811
+ def forward(
812
+ self,
813
+ hidden_states: torch.Tensor,
814
+ attention_mask: torch.Tensor | None = None,
815
+ position_ids: torch.LongTensor | None = None,
816
+ past_key_values: tuple[torch.Tensor] | None = None,
817
+ output_attentions: bool | None = False,
818
+ use_cache: bool | None = False,
819
+ **kwargs: Unpack[FlashAttentionKwargs],
820
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
821
+ """
822
+ Args:
823
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
824
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
825
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
826
+ output_attentions (`bool`, *optional*):
827
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
828
+ returned tensors for more detail.
829
+ use_cache (`bool`, *optional*):
830
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
831
+ (see `past_key_values`).
832
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
833
+ """
834
+
835
+ residual = hidden_states
836
+
837
+ hidden_states = self.input_layernorm(hidden_states)
838
+
839
+ # Self Attention
840
+ if self.is_linear_attn is False:
841
+ hidden_states = self.self_attn(
842
+ hidden_states=hidden_states,
843
+ attention_mask=attention_mask,
844
+ position_ids=position_ids,
845
+ past_key_values=past_key_values,
846
+ output_attentions=output_attentions,
847
+ use_cache=use_cache,
848
+ **kwargs,
849
+ )
850
+ else:
851
+ hidden_states = self.self_attn(
852
+ hidden_states=hidden_states,
853
+ attention_mask=attention_mask,
854
+ cache_params=past_key_values,
855
+ output_attentions=output_attentions,
856
+ use_cache=use_cache,
857
+ **kwargs,
858
+ )
859
+ hidden_states = residual + hidden_states
860
+
861
+ # Fully Connected
862
+ residual = hidden_states
863
+ hidden_states = self.post_attention_layernorm(hidden_states)
864
+ if hasattr(self, "block_sparse_moe"):
865
+ hidden_states = self.block_sparse_moe(hidden_states)
866
+ else:
867
+ hidden_states = self.mlp(hidden_states)
868
+ hidden_states = residual + hidden_states
869
+
870
+ return hidden_states
871
+
872
+
873
+ class KimiPreTrainedModel(PreTrainedModel):
874
+ config_class = KimiLinearConfig
875
+ base_model_prefix = "model"
876
+ supports_gradient_checkpointing = True
877
+ _no_split_modules = ["KimiDecoderLayer"]
878
+ _skip_keys_device_placement = "past_key_values"
879
+ _supports_flash_attn_2 = True
880
+ _can_record_outputs = {
881
+ "router_logits": OutputRecorder(KimiBlockSparseMLP, index=1),
882
+ "hidden_states": KimiDecoderLayer,
883
+ "attentions": KimiMLAAttention,
884
+ }
885
+ _is_stateful = True
886
+
887
+ def _init_weights(self, module):
888
+ std = self.config.initializer_range
889
+ if isinstance(module, nn.Linear):
890
+ module.weight.data.normal_(mean=0.0, std=std)
891
+ if module.bias is not None:
892
+ module.bias.data.zero_()
893
+ elif isinstance(module, nn.Embedding):
894
+ module.weight.data.normal_(mean=0.0, std=std)
895
+ if module.padding_idx is not None:
896
+ module.weight.data[module.padding_idx].zero_()
897
+
898
+
899
+ class KimiLinearModel(KimiPreTrainedModel):
900
+ def __init__(self, config: KimiLinearConfig):
901
+ super().__init__(config)
902
+ self.padding_idx = config.pad_token_id
903
+ self.vocab_size = config.vocab_size
904
+
905
+ self.embed_tokens = nn.Embedding(
906
+ config.vocab_size, config.hidden_size, self.padding_idx)
907
+ self.layers = nn.ModuleList([KimiDecoderLayer(
908
+ config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
909
+ self.norm = KimiRMSNorm(
910
+ config.hidden_size, eps=config.rms_norm_eps)
911
+
912
+ if getattr(config, "_attn_implementation", None) is not None:
913
+ if config._attn_implementation != "flash_attention_2":
914
+ logger.warning_once(
915
+ f"Ignoring the provided attention implementation {config._attn_implementation}")
916
+ logger.warning_once("Using flash_attention_2 backend instead.")
917
+ config._attn_implementation = "flash_attention_2"
918
+ else:
919
+ config._attn_implementation = "flash_attention_2"
920
+
921
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
922
+ self.gradient_checkpointing = False
923
+ # Initialize weights and apply final processing
924
+ self.post_init()
925
+
926
+ def _update_linear_attn_mask(self, attention_mask, cache_position):
927
+ """
928
+ NOTE: Left-padding is used for linear attention mask.
929
+ No need for zeroing states when
930
+ 1. Cached forward
931
+ 2. Attending to all inputs
932
+ """
933
+ linear_attn_mask = attention_mask
934
+ if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
935
+ linear_attn_mask = None
936
+ return linear_attn_mask
937
+
938
+ @check_model_inputs
939
+ def forward(
940
+ self,
941
+ input_ids: torch.LongTensor = None,
942
+ attention_mask: torch.Tensor | None = None,
943
+ position_ids: torch.LongTensor | None = None,
944
+ past_key_values: Cache | None = None,
945
+ inputs_embeds: torch.FloatTensor | None = None,
946
+ cache_position: torch.LongTensor | None = None,
947
+ use_cache: bool | None = None,
948
+ **kwargs: Unpack[TransformersKwargs],
949
+ ) -> tuple | BaseModelOutputWithPast:
950
+
951
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
952
+
953
+ if (input_ids is None) and (inputs_embeds is None):
954
+ raise ValueError(
955
+ "You must specify exactly one of input_ids or inputs_embeds")
956
+
957
+ # Get inputs_embeds
958
+ if inputs_embeds is None:
959
+ inputs_embeds = self.embed_tokens(input_ids)
960
+
961
+ if use_cache and past_key_values is None:
962
+ past_key_values = KimiDynamicCache(config=self.config)
963
+
964
+ if cache_position is None:
965
+ past_seen_tokens = past_key_values.get_seq_length(
966
+ ) if past_key_values is not None else 0
967
+ cache_position: torch.Tensor = torch.arange(
968
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device,
969
+ )
970
+
971
+ if position_ids is None:
972
+ position_ids = cache_position.unsqueeze(0)
973
+
974
+ causal_mask = create_causal_mask(
975
+ config=self.config,
976
+ input_embeds=inputs_embeds,
977
+ attention_mask=attention_mask,
978
+ cache_position=cache_position,
979
+ past_key_values=past_key_values,
980
+ position_ids=position_ids,
981
+ )
982
+ linear_attn_mask = self._update_linear_attn_mask(
983
+ attention_mask, cache_position)
984
+
985
+ hidden_states = inputs_embeds
986
+ if past_key_values is not None:
987
+ assert isinstance(past_key_values, KimiDynamicCache)
988
+
989
+ for decoder_layer in self.layers:
990
+ layer_mask = linear_attn_mask if decoder_layer.is_linear_attn else causal_mask
991
+
992
+ hidden_states = decoder_layer(
993
+ hidden_states,
994
+ attention_mask=layer_mask,
995
+ past_key_values=past_key_values,
996
+ cache_position=cache_position,
997
+ **kwargs,
998
+ )
999
+
1000
+ hidden_states = self.norm(hidden_states)
1001
+
1002
+ return BaseModelOutputWithPast(
1003
+ last_hidden_state=hidden_states,
1004
+ past_key_values=past_key_values,
1005
+ )
1006
+
1007
+
1008
+ class KimiLinearForCausalLM(KimiPreTrainedModel, GenerationMixin):
1009
+ _tied_weights_keys = ["lm_head.weight"]
1010
+
1011
+ def __init__(self, config):
1012
+ super().__init__(config)
1013
+ self.model = KimiLinearModel(config)
1014
+ self.vocab_size = config.vocab_size
1015
+ self.lm_head = nn.Linear(
1016
+ config.hidden_size, config.vocab_size, bias=False)
1017
+
1018
+ # Initialize weights and apply final processing
1019
+ self.post_init()
1020
+
1021
+ @can_return_tuple
1022
+ def forward(
1023
+ self,
1024
+ input_ids: torch.LongTensor = None,
1025
+ attention_mask: torch.Tensor | None = None,
1026
+ position_ids: torch.LongTensor | None = None,
1027
+ past_key_values: list[torch.FloatTensor] | None = None,
1028
+ inputs_embeds: torch.FloatTensor | None = None,
1029
+ labels: torch.LongTensor | None = None,
1030
+ use_cache: bool | None = None,
1031
+ output_attentions: bool | None = None,
1032
+ output_hidden_states: bool | None = None,
1033
+ generation_mode: bool | None = None,
1034
+ return_dict: bool | None = None,
1035
+ cache_position: torch.LongTensor | None = None,
1036
+ **kwargs: Unpack[TransformersKwargs],
1037
+ ) -> tuple | CausalLMOutputWithPast:
1038
+ r"""
1039
+ Args:
1040
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1041
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1042
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1043
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1044
+
1045
+ Returns:
1046
+
1047
+ Example:
1048
+
1049
+ ```python
1050
+ >>> from transformers import AutoTokenizer, KimiLinearForCausalLM
1051
+
1052
+ >>> model = KimiLinearForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1053
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1054
+
1055
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1056
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1057
+
1058
+ >>> # Generate
1059
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1060
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1061
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1062
+ ```"""
1063
+
1064
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1065
+ output_hidden_states = (
1066
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1067
+ )
1068
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1069
+
1070
+ outputs = self.model(
1071
+ input_ids=input_ids,
1072
+ attention_mask=attention_mask,
1073
+ position_ids=position_ids,
1074
+ past_key_values=past_key_values,
1075
+ inputs_embeds=inputs_embeds,
1076
+ use_cache=use_cache,
1077
+ output_attentions=output_attentions,
1078
+ output_hidden_states=output_hidden_states,
1079
+ return_dict=return_dict,
1080
+ cache_position=cache_position,
1081
+ )
1082
+
1083
+ logits = outputs[0]
1084
+ if generation_mode:
1085
+ logits = logits[:, -1:]
1086
+ logits = self.lm_head(logits)
1087
+
1088
+ loss = None
1089
+ if labels is not None:
1090
+ loss = self.loss_function(
1091
+ logits, labels, self.vocab_size, **kwargs)
1092
+
1093
+ return CausalLMOutputWithPast(
1094
+ loss=loss,
1095
+ logits=logits,
1096
+ past_key_values=outputs.past_key_values,
1097
+ hidden_states=outputs.hidden_states,
1098
+ attentions=outputs.attentions,
1099
+ )
tokenization_kimi_py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tiktoken
3
+
4
+ from logging import getLogger
5
+ from pathlib import Path
6
+ from typing import (
7
+ cast,
8
+ Tuple,
9
+ Dict,
10
+ Iterator,
11
+ List,
12
+ Union,
13
+ Optional,
14
+ )
15
+ from shutil import copyfile
16
+ from tiktoken.load import load_tiktoken_bpe
17
+ from tokenizers import AddedToken, pre_tokenizers, Regex
18
+ from transformers.tokenization_utils import PreTrainedTokenizer
19
+ from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
20
+ from typing import Any
21
+
22
+
23
+ logger = getLogger(__name__)
24
+ VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
25
+
26
+
27
+ class TikTokenTokenizer(PreTrainedTokenizer):
28
+ """
29
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
30
+
31
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
32
+ this superclass for more information regarding those methods.
33
+
34
+ Args:
35
+ vocab_file (`str`):
36
+ The path to the Tiktoken model file.
37
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
38
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
39
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
40
+ The end of sequence token.
41
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
42
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
43
+ token instead. The second to last item in special_tokens.
44
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
45
+ The token used for padding, for example when batching sequences of different lengths.
46
+ additional_special_tokens (list of `str`, *optional*):
47
+ A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
48
+ skipped when decoding if `skip_special_tokens` is set to `True`.
49
+ """
50
+
51
+ vocab_files_names = VOCAB_FILES_NAMES
52
+
53
+ model_input_names = ["input_ids", "attention_mask"]
54
+
55
+ special_tokens: Dict[str, int]
56
+
57
+ num_reserved_special_tokens = 256
58
+
59
+ pat_str = "|".join(
60
+ [
61
+ r"""[\p{Han}]+""",
62
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
63
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
64
+ r"""\p{N}{1,3}""",
65
+ r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
66
+ r"""\s*[\r\n]+""",
67
+ r"""\s+(?!\S)""",
68
+ r"""\s+""",
69
+ ]
70
+ )
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_file,
75
+ bos_token: Union[str, AddedToken]="[BOS]",
76
+ eos_token: Union[str, AddedToken]="[EOS]",
77
+ unk_token: Union[str, AddedToken, None]=None,
78
+ pad_token: Union[str, AddedToken, None]=None,
79
+ additional_special_tokens: List[str]=None,
80
+ added_tokens_decoder: Optional[dict] = None,
81
+ **kwargs,
82
+ ):
83
+ assert os.path.isfile(vocab_file), vocab_file
84
+
85
+ if additional_special_tokens is None:
86
+ additional_special_tokens = [
87
+ "<|im_end|>",
88
+ "<|im_user|>",
89
+ "<|im_assistant|>",
90
+ "<|start_header_id|>",
91
+ "<|end_header_id|>",
92
+ "[EOT]",
93
+ "<|im_system|>",
94
+ "<|im_middle|>",
95
+ ]
96
+
97
+ special_tokens_mapping = {
98
+ i: added_tokens_decoder[i].content for i in added_tokens_decoder
99
+ }
100
+
101
+ self.vocab_file = vocab_file
102
+ mergeable_ranks = load_tiktoken_bpe(vocab_file)
103
+ num_base_tokens = len(mergeable_ranks)
104
+ self.special_tokens = {
105
+ special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
106
+ for i in range(
107
+ num_base_tokens, num_base_tokens + self.num_reserved_special_tokens + 2
108
+ )
109
+ }
110
+
111
+
112
+
113
+ self.model = tiktoken.Encoding(
114
+ name=Path(vocab_file).name,
115
+ pat_str=self.pat_str,
116
+ mergeable_ranks=mergeable_ranks,
117
+ special_tokens=self.special_tokens,
118
+ )
119
+ logger.info(f"Reloaded tiktoken model from {vocab_file}")
120
+
121
+ self.n_words: int = self.model.n_vocab
122
+ # BOS / EOS token IDs
123
+ self.bos_id: int = self.special_tokens[str(bos_token)]
124
+ self.eos_id: int = self.special_tokens[str(eos_token)]
125
+ logger.info(
126
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
127
+ )
128
+
129
+ self.pad_id: int = self.special_tokens[str(pad_token)]
130
+ self.unk_id: int = self.special_tokens[str(unk_token)]
131
+
132
+ self.byte_encoder = bytes_to_unicode()
133
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
134
+
135
+ self.decoder = {}
136
+ for i in range(self.n_words):
137
+ # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
138
+ decoding = ''.join([
139
+ self.byte_encoder[ord(char)] for char in
140
+ self.model.decode_single_token_bytes(i).decode('latin-1')
141
+ ])
142
+ self.decoder[i] = decoding
143
+
144
+ self.encoder = {}
145
+ for i in range(self.n_words):
146
+ if i in self.decoder:
147
+ self.encoder[self.decoder[i]] = i
148
+
149
+ super().__init__(
150
+ bos_token=bos_token,
151
+ eos_token=eos_token,
152
+ unk_token=unk_token,
153
+ pad_token=pad_token,
154
+ additional_special_tokens=additional_special_tokens,
155
+ **kwargs,
156
+ )
157
+ self.all_special_ids_set = set(self.all_special_ids)
158
+
159
+ def encode(
160
+ self,
161
+ text: str,
162
+ allow_special_tokens: bool = True,
163
+ **kwargs
164
+ ) -> List[int]:
165
+ """
166
+ Encodes a string into a list of token IDs.
167
+
168
+ Args:
169
+ text (str): The input string to be encoded.
170
+
171
+ Returns:
172
+ list[int]: A list of token IDs.
173
+ """
174
+ # If there are other args, we should call super().encode because there are a lot of code
175
+ # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
176
+ # NOTE: our encode method is not compatible with the super().encode method,
177
+ # e.g. split_special_tokens' default is True in our encode method.
178
+ if len(kwargs) > 0:
179
+ logger.warning( f"Calling super().encode with {kwargs}" )
180
+ return super().encode(text, **kwargs)
181
+
182
+ assert type(text) is str
183
+
184
+ # The tiktoken tokenizer can handle <=400k chars without
185
+ # pyo3_runtime.PanicException.
186
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
187
+
188
+ # https://github.com/openai/tiktoken/issues/195
189
+ # Here we iterate over subsequences and split if we exceed the limit
190
+ # of max consecutive non-whitespace or whitespace characters.
191
+ MAX_NO_WHITESPACES_CHARS = 25_000
192
+
193
+ texts = self.pre_tokenizer_process(text)
194
+
195
+ all_substrs = []
196
+ for text in texts:
197
+ substrs = (
198
+ substr
199
+ for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
200
+ for substr in self._split_whitespaces_or_nonwhitespaces(
201
+ text[i: i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
202
+ )
203
+ )
204
+ all_substrs.extend(substrs)
205
+
206
+ t: List[int] = []
207
+ for substr in all_substrs:
208
+ if allow_special_tokens:
209
+ t.extend(
210
+ # we should consider special token as a common token
211
+ self.model.encode(
212
+ substr,
213
+ allowed_special="all",
214
+ )
215
+ )
216
+ else:
217
+ t.extend(
218
+ # we should consider special token as a common token
219
+ self.model.encode(
220
+ substr,
221
+ disallowed_special=(),
222
+ )
223
+ )
224
+
225
+ return t
226
+
227
+ def decode(
228
+ self,
229
+ token_ids: Union[int, List[int]],
230
+ **kwargs
231
+ ) -> str:
232
+ """
233
+ Decodes a list of token IDs into a string.
234
+
235
+ Args:
236
+ token_ids (List[int]): The list of token IDs to be decoded.
237
+
238
+ Returns:
239
+ str: The decoded string.
240
+ """
241
+ # If there are other args, we should call super().decode because there are a lot of code
242
+ # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
243
+ if len(kwargs) > 0:
244
+ return super().decode(token_ids, **kwargs)
245
+
246
+ if type(token_ids) is int:
247
+ token_ids = [token_ids]
248
+
249
+ return self.model.decode(cast(List[int], token_ids))
250
+
251
+ @staticmethod
252
+ def _split_whitespaces_or_nonwhitespaces(
253
+ s: str, max_consecutive_slice_len: int
254
+ ) -> Iterator[str]:
255
+ """
256
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
257
+ consecutive whitespaces or consecutive non-whitespaces.
258
+ """
259
+ current_slice_len = 0
260
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
261
+ slice_start = 0
262
+
263
+ for i in range(len(s)):
264
+ is_now_space = s[i].isspace()
265
+
266
+ if current_slice_is_space ^ is_now_space:
267
+ current_slice_len = 1
268
+ current_slice_is_space = is_now_space
269
+ else:
270
+ current_slice_len += 1
271
+ if current_slice_len > max_consecutive_slice_len:
272
+ yield s[slice_start:i]
273
+ slice_start = i
274
+ current_slice_len = 1
275
+ yield s[slice_start:]
276
+
277
+ def pre_tokenizer_process(self, text: str) -> List[str]:
278
+ """
279
+ pre-tokenizes the input text into a list of tokens.
280
+ This method is used to split the input text into smaller chunks for internal processing.
281
+ """
282
+ return [text]
283
+
284
+
285
+ """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
286
+ @property
287
+ def vocab_size(self) -> int:
288
+ return self.n_words
289
+
290
+ def get_vocab(self) -> Dict[str, int]:
291
+ return self.encoder
292
+
293
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
294
+ return [
295
+ self.decoder[t]
296
+ for t in self.encode(text)
297
+ ]
298
+
299
+ def _convert_token_to_id(self, token: str) -> int:
300
+ return self.encoder.get(token, self.unk_id)
301
+
302
+ def _convert_id_to_token(self, index: int) -> str:
303
+ return self.decoder.get(index)
304
+
305
+ @staticmethod
306
+ def clean_up_tokenization(out_string: str) -> str:
307
+ return out_string
308
+
309
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
310
+ text = ''.join(tokens)
311
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', 'replace')
312
+ return text
313
+
314
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
315
+ if not os.path.isdir(save_directory):
316
+ raise ValueError(f"vocabulary path ({save_directory}) should be a directory")
317
+ out_vocab_file = os.path.join(
318
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
319
+ )
320
+
321
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
322
+ copyfile(self.vocab_file, out_vocab_file)
323
+
324
+ return (out_vocab_file,)
325
+
326
+
327
+
328
+ def apply_chat_template(
329
+ self, conversation, tools: Optional[list[dict]] = None,
330
+ tokenize: bool = False,
331
+ add_generation_prompt: bool = True,
332
+ **kwargs
333
+ ):
334
+ tools = deep_sort_dict(tools)
335
+ return super().apply_chat_template(conversation,
336
+ tools=tools,
337
+ tokenize=tokenize,
338
+ add_generation_prompt=add_generation_prompt,
339
+ **kwargs)
340
+
341
+
342
+ def deep_sort_dict(obj: Any) -> Any:
343
+ if isinstance(obj, dict):
344
+ return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
345
+ if isinstance(obj, list):
346
+ return [deep_sort_dict(item) for item in obj]
347
+ return obj
tokenizer_config.json ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "163584": {
4
+ "content": "[BOS]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "163585": {
12
+ "content": "[EOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "163586": {
20
+ "content": "<|im_end|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "163587": {
28
+ "content": "<|im_user|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "163588": {
36
+ "content": "<|im_assistant|>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "163590": {
44
+ "content": "<|start_header_id|>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "163591": {
52
+ "content": "<|end_header_id|>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "163593": {
60
+ "content": "[EOT]",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "163594": {
68
+ "content": "<|im_system|>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "163595": {
76
+ "content": "<|tool_calls_section_begin|>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": false
82
+ },
83
+ "163596": {
84
+ "content": "<|tool_calls_section_end|>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": false
90
+ },
91
+ "163597": {
92
+ "content": "<|tool_call_begin|>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": false
98
+ },
99
+ "163598": {
100
+ "content": "<|tool_call_argument_begin|>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": false
106
+ },
107
+ "163599": {
108
+ "content": "<|tool_call_end|>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": false
114
+ },
115
+ "163601": {
116
+ "content": "<|im_middle|>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "163838": {
124
+ "content": "[UNK]",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "163839": {
132
+ "content": "[PAD]",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ }
139
+ },
140
+ "additional_special_tokens": [
141
+ "<|im_end|>",
142
+ "<|im_user|>",
143
+ "<|im_assistant|>",
144
+ "<|start_header_id|>",
145
+ "<|end_header_id|>",
146
+ "[EOT]",
147
+ "<|im_system|>",
148
+ "<|im_middle|>"
149
+ ],
150
+ "bos_token": "[BOS]",
151
+ "clean_up_tokenization_spaces": false,
152
+ "eos_token": "[EOS]",
153
+ "extra_special_tokens": {},
154
+ "model_max_length": 1000000000000000019884624838656,
155
+ "pad_token": "[PAD]",
156
+ "tokenizer_class": "TikTokenTokenizer",
157
+ "unk_token": "[UNK]",
158
+ "auto_map": {
159
+ "AutoTokenizer": [
160
+ "tokenization_kimi.TikTokenTokenizer",
161
+ null
162
+ ]
163
+ }
164
+ }