OpenMOSE commited on
Commit
411272a
·
1 Parent(s): 8f164da

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
__pycache__/test_openai_api.cpython-312.pyc ADDED
Binary file (8.7 kB). View file
 
chat_template.jinja ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [gMASK]<sop>
2
+ {%- if tools -%}
3
+ <|system|>
4
+ # Tools
5
+
6
+ You may call one or more functions to assist with the user query.
7
+
8
+ You are provided with function signatures within <tools></tools> XML tags:
9
+ <tools>
10
+ {% for tool in tools %}
11
+ {{ tool | tojson(ensure_ascii=False) }}
12
+ {% endfor %}
13
+ </tools>
14
+
15
+ For each function call, output the function name and arguments within the following XML format:
16
+ <tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>{%- endif -%}
17
+ {%- macro visible_text(content) -%}
18
+ {%- if content is string -%}
19
+ {{- content }}
20
+ {%- elif content is iterable and content is not mapping -%}
21
+ {%- for item in content -%}
22
+ {%- if item is mapping and item.type == 'text' -%}
23
+ {{- item.text }}
24
+ {%- elif item is string -%}
25
+ {{- item }}
26
+ {%- endif -%}
27
+ {%- endfor -%}
28
+ {%- else -%}
29
+ {{- content }}
30
+ {%- endif -%}
31
+ {%- endmacro -%}
32
+ {%- set ns = namespace(last_user_index=-1) %}
33
+ {%- for m in messages %}
34
+ {%- if m.role == 'user' %}
35
+ {% set ns.last_user_index = loop.index0 -%}
36
+ {%- endif %}
37
+ {%- endfor %}
38
+ {% for m in messages %}
39
+ {%- if m.role == 'user' -%}<|user|>{{ visible_text(m.content) }}
40
+ {%- elif m.role == 'assistant' -%}
41
+ <|assistant|>
42
+ {%- set reasoning_content = '' %}
43
+ {%- set content = visible_text(m.content) %}
44
+ {%- if m.reasoning_content is string %}
45
+ {%- set reasoning_content = m.reasoning_content %}
46
+ {%- else %}
47
+ {%- if '</think>' in content %}
48
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
49
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
50
+ {%- endif %}
51
+ {%- endif %}
52
+ {%- if ((clear_thinking is defined and not clear_thinking) or loop.index0 > ns.last_user_index) and reasoning_content -%}
53
+ {{ '<think>' + reasoning_content.strip() + '</think>'}}
54
+ {%- else -%}
55
+ {{ '</think>' }}
56
+ {%- endif -%}
57
+ {%- if content.strip() -%}
58
+ {{ content.strip() }}
59
+ {%- endif -%}
60
+ {% if m.tool_calls %}
61
+ {% for tc in m.tool_calls %}
62
+ {%- if tc.function %}
63
+ {%- set tc = tc.function %}
64
+ {%- endif %}
65
+ {{- '<tool_call>' + tc.name -}}
66
+ {% set _args = tc.arguments %}{% for k, v in _args.items() %}<arg_key>{{ k }}</arg_key><arg_value>{{ v | tojson(ensure_ascii=False) if v is not string else v }}</arg_value>{% endfor %}</tool_call>{% endfor %}
67
+ {% endif %}
68
+ {%- elif m.role == 'tool' -%}
69
+ {%- if m.content is string -%}
70
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
71
+ {{- '<|observation|>' }}
72
+ {%- endif %}
73
+ {{- '<tool_response>' }}
74
+ {{- m.content }}
75
+ {{- '</tool_response>' }}
76
+ {%- else -%}
77
+ <|observation|>{% for tr in m.content %}
78
+ <tool_response>{{ tr.output if tr.output is defined else tr }}</tool_response>{% endfor -%}
79
+ {% endif -%}
80
+ {%- elif m.role == 'system' -%}
81
+ <|system|>{{ visible_text(m.content) }}
82
+ {%- endif -%}
83
+ {%- endfor -%}
84
+ {%- if add_generation_prompt -%}
85
+ <|assistant|>{{- '</think>' if (enable_thinking is defined and not enable_thinking) else '<think>' -}}
86
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RWKV07IMoEForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_rwkv07i.RWKV07IConfig",
7
+ "AutoModelForCausalLM": "modeling_rwkv07i.RWKV07IMoEForCausalLM"
8
+ },
9
+ "description": "Prime-RWKV TICA (Tiny Infused Causal Attention)",
10
+ "base_model": "GLM4.7-Flash",
11
+ "model_revision": "alpha",
12
+ "transformer_layers": [-1],
13
+
14
+ "rwkv_layers": [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46],
15
+ "tiny_attention_layers": [21, 23, 25, 27, 28, 30, 33, 37, 38, 41, 43],
16
+ "rwkv_architecture": "hxa07i",
17
+ "enable_qk_norm": true,
18
+ "nope_in_transformer": true,
19
+ "nope_in_rwkv": false,
20
+ "tiny_head_dim": 128,
21
+ "tiny_n_heads": 4,
22
+ "tiny_kv_heads": 2,
23
+ "lora_rank_decay": 512,
24
+ "lora_rank_iclr":256,
25
+ "lora_rank_value_residual_mix":-1,
26
+ "lora_rank_key_residual_mix":-1,
27
+ "lora_rank_gate":384,
28
+ "attention_bias": false,
29
+ "attention_dropout": 0.0,
30
+ "pad_token_id": 154820,
31
+ "eos_token_id": [
32
+ 154820,
33
+ 154827,
34
+ 154829
35
+ ],
36
+ "hidden_act": "silu",
37
+ "hidden_size": 2048,
38
+ "intermediate_size": 10240,
39
+ "max_position_embeddings": 202752,
40
+ "model_type": "rwkv07i_moe",
41
+ "moe_intermediate_size": 1536,
42
+ "topk_method": "noaux_tc",
43
+ "norm_topk_prob": true,
44
+ "num_attention_heads": 20,
45
+ "n_group": 1,
46
+ "topk_group": 1,
47
+ "n_routed_experts": 64,
48
+ "n_shared_experts": 1,
49
+ "routed_scaling_factor": 1.8,
50
+ "num_experts_per_tok": 4,
51
+ "first_k_dense_replace": 1,
52
+ "num_hidden_layers": 47,
53
+ "num_key_value_heads": 20,
54
+ "num_nextn_predict_layers": 1,
55
+ "partial_rotary_factor": 1.0,
56
+ "rms_norm_eps": 1e-05,
57
+ "rope_scaling": null,
58
+ "rope_theta": 1000000,
59
+ "tie_word_embeddings": false,
60
+ "dtype": "bfloat16",
61
+ "transformers_version": "5.0.0",
62
+ "q_lora_rank": 768,
63
+ "kv_lora_rank": 512,
64
+ "qk_nope_head_dim": 192,
65
+ "qk_rope_head_dim": 64,
66
+ "v_head_dim": 256,
67
+ "vocab_size": 154880
68
+ }
configuration_rwkv07i.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """RWKV07DQwen3 model configuration"""
16
+
17
+ #Never gonna give you up
18
+
19
+ from transformers.configuration_utils import PretrainedConfig, layer_type_validation
20
+ from transformers.modeling_rope_utils import rope_config_validation
21
+ from transformers.utils import logging
22
+
23
+
24
+ from transformers.configuration_utils import PretrainedConfig, layer_type_validation
25
+ #from transformers.modeling_rope_utils import RopeParameters
26
+ from typing import Optional, TypedDict
27
+ #from transformers.modeling_rope_utils import RopeParameters
28
+ class RopeParameters(TypedDict):
29
+ """
30
+ Args:
31
+ rope_theta (`float`):
32
+ The base period of the RoPE embeddings.
33
+ rope_type (`str`, *optional*, defaults to "default"):
34
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
35
+ 'llama3'], with 'default' being the original RoPE implementation.
36
+ factor (`float`, *optional*):
37
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
38
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
39
+ original maximum pre-trained length.
40
+ original_max_position_embeddings (`int`, *optional*):
41
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
42
+ pretraining.
43
+ attention_factor (`float`, *optional*):
44
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
45
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
46
+ `factor` field to infer the suggested value.
47
+ beta_fast (`float`, *optional*):
48
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
49
+ ramp function. If unspecified, it defaults to 32.
50
+ beta_slow (`float`, *optional*):
51
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
52
+ ramp function. If unspecified, it defaults to 1.
53
+ short_factor (`list[float]`, *optional*):
54
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
55
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
56
+ size divided by the number of attention heads divided by 2
57
+ long_factor (`list[float]`, *optional*):
58
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
59
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
60
+ size divided by the number of attention heads divided by 2
61
+ low_freq_factor (`float`, *optional*):
62
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
63
+ high_freq_factor (`float`, *optional*):
64
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
65
+ """
66
+
67
+ rope_theta: float
68
+ rope_type: Optional[str]
69
+ factor: Optional[float]
70
+ original_max_position_embeddings: Optional[int]
71
+ attention_factor: Optional[float]
72
+ beta_fast: Optional[float]
73
+ beta_slow: Optional[float]
74
+ short_factor: Optional[list[float]]
75
+ long_factor: Optional[list[float]]
76
+ low_freq_factor: Optional[float]
77
+ high_freq_factor: Optional[float]
78
+ logger = logging.get_logger(__name__)
79
+
80
+
81
+ class RWKV07IConfig(PretrainedConfig):
82
+ r"""
83
+ This is the configuration class to store the configuration of a [`RWKV07BModel`]. It is used to instantiate a
84
+ RWKV079Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
85
+ with the defaults will yield a similar configuration to that of
86
+ Qwen3-7B-beta [Qwen/Qwen3-7B-beta](https://huggingface.co/Qwen/Qwen3-7B-beta).
87
+
88
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
89
+ documentation from [`PretrainedConfig`] for more information.
90
+
91
+
92
+ Args:
93
+ vocab_size (`int`, *optional*, defaults to 151936):
94
+ Vocabulary size of the RWKV079Qwen3 model. Defines the number of different tokens that can be represented by the
95
+ `inputs_ids` passed when calling [`RWKV07BModel`]
96
+ hidden_size (`int`, *optional*, defaults to 4096):
97
+ Dimension of the hidden representations.
98
+ intermediate_size (`int`, *optional*, defaults to 22016):
99
+ Dimension of the MLP representations.
100
+ num_hidden_layers (`int`, *optional*, defaults to 32):
101
+ Number of hidden layers in the Transformer encoder.
102
+ num_attention_heads (`int`, *optional*, defaults to 32):
103
+ Number of attention heads for each attention layer in the Transformer encoder.
104
+ num_key_value_heads (`int`, *optional*, defaults to 32):
105
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
106
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
107
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
108
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
109
+ by meanpooling all the original heads within that group. For more details checkout [this
110
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
111
+ lora_rank_decay (`int`, *optional*):
112
+ The rank of the lora used to generate decay.
113
+ lora_rank_iclr (`int`, *optional*):
114
+ The rank of the lora used to generate the in-context learning rate.
115
+ lora_rank_value_residual_mix (`int`, *optional*):
116
+ The rank of the lora used to generate the value residual mix amount.
117
+ lora_rank_value_gate (`int`, *optional*):
118
+ The rank of the lora used to generate the gate.
119
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
120
+ The non-linear activation function (function or string) in the decoder.
121
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
122
+ The maximum sequence length that this model might ever be used with.
123
+ initializer_range (`float`, *optional*, defaults to 0.02):
124
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
125
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
126
+ The epsilon used by the rms normalization layers.
127
+ use_cache (`bool`, *optional*, defaults to `True`):
128
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
129
+ relevant if `config.is_decoder=True`.
130
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
131
+ Whether the model's input and output word embeddings should be tied.
132
+ rope_theta (`float`, *optional*, defaults to 10000.0):
133
+ The base period of the RoPE embeddings.
134
+ rope_scaling (`Dict`, *optional*):
135
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
136
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
137
+ accordingly.
138
+ Expected contents:
139
+ `rope_type` (`str`):
140
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
141
+ 'llama3'], with 'default' being the original RoPE implementation.
142
+ `factor` (`float`, *optional*):
143
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
144
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
145
+ original maximum pre-trained length.
146
+ `original_max_position_embeddings` (`int`, *optional*):
147
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
148
+ pretraining.
149
+ `attention_factor` (`float`, *optional*):
150
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
151
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
152
+ `factor` field to infer the suggested value.
153
+ `beta_fast` (`float`, *optional*):
154
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
155
+ ramp function. If unspecified, it defaults to 32.
156
+ `beta_slow` (`float`, *optional*):
157
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
158
+ ramp function. If unspecified, it defaults to 1.
159
+ `short_factor` (`List[float]`, *optional*):
160
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
161
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
162
+ size divided by the number of attention heads divided by 2
163
+ `long_factor` (`List[float]`, *optional*):
164
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
165
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
166
+ size divided by the number of attention heads divided by 2
167
+ `low_freq_factor` (`float`, *optional*):
168
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
169
+ `high_freq_factor` (`float`, *optional*):
170
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
171
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
172
+ Whether to use sliding window attention.
173
+ sliding_window (`int`, *optional*, defaults to 4096):
174
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
175
+ max_window_layers (`int`, *optional*, defaults to 28):
176
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
177
+ attention_dropout (`float`, *optional*, defaults to 0.0):
178
+ The dropout ratio for the attention probabilities.
179
+
180
+ ```python
181
+ >>> from transformers import RWKV07BModel, RWKV079Qwen3Config
182
+
183
+ >>> # Initializing a RWKV079Qwen3 style configuration
184
+ >>> configuration = RWKV079Qwen3Config()
185
+
186
+ >>> # Initializing a model from the RWKV079Qwen3-7B style configuration
187
+ >>> model = RWKV07BModel(configuration)
188
+
189
+ >>> # Accessing the model configuration
190
+ >>> configuration = model.config
191
+ ```"""
192
+
193
+ model_type = "rwkv07i_moe"
194
+
195
+ keys_to_ignore_at_inference = ["past_key_values"]
196
+ base_model_tp_plan = {
197
+ "layers.*.self_attn.o_proj": "rowwise",
198
+ "layers.*.mlp.experts.gate_up_proj": "local_rowwise",
199
+ "layers.*.mlp.experts.down_proj": "local_rowwise",
200
+ "layers.*.mlp.experts": "gather",
201
+ "layers.*.mlp.gate_proj": "colwise",
202
+ "layers.*.mlp.up_proj": "colwise",
203
+ "layers.*.mlp.down_proj": "rowwise",
204
+ }
205
+ base_model_pp_plan = {
206
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
207
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
208
+ "norm": (["hidden_states"], ["hidden_states"]),
209
+ }
210
+ attribute_map = {
211
+ "num_local_experts": "n_routed_experts",
212
+ }
213
+
214
+ def __init__(
215
+ self,
216
+ lora_rank_tokenshift=None,
217
+ lora_rank_decay=None,
218
+ lora_rank_iclr=None,
219
+ lora_rank_value_residual_mix=None,
220
+ lora_rank_value_key_mix=None,
221
+ lora_rank_gate=None,
222
+
223
+ vocab_size: int | None = 154880,
224
+ hidden_size: int | None = 2048,
225
+ intermediate_size: int | None = 10240,
226
+ moe_intermediate_size: int | None = 1536,
227
+ num_hidden_layers: int | None = 47,
228
+ num_attention_heads: int | None = 20,
229
+ num_key_value_heads: int | None = 20,
230
+ n_shared_experts: int | None = 1,
231
+ n_routed_experts: int | None = 64,
232
+ routed_scaling_factor: float | None = 1.8,
233
+ kv_lora_rank: int | None = 512,
234
+ q_lora_rank: int | None = 768,
235
+ qk_rope_head_dim: int | None = 64,
236
+ v_head_dim: int | None = 256,
237
+ qk_nope_head_dim: int | None = 192,
238
+ n_group: int | None = 1,
239
+ topk_group: int | None = 1,
240
+ num_experts_per_tok: int | None = 4,
241
+ norm_topk_prob: bool | None = True,
242
+ hidden_act: str | None = "silu",
243
+ max_position_embeddings: int | None = 202752,
244
+ initializer_range: float | None = 0.02,
245
+ rms_norm_eps: int | None = 1e-5,
246
+ use_cache: bool | None = True,
247
+ pad_token_id: int | None = None,
248
+ bos_token_id: int | None = 0,
249
+ eos_token_id: int | None = 1,
250
+ pretraining_tp: int | None = 1,
251
+ tie_word_embeddings: bool | None = False,
252
+ rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
253
+ rope_interleave: bool | None = True,
254
+ mlp_layer_types=None,
255
+ attention_bias: bool | None = False,
256
+ attention_dropout: float | None = 0.0,
257
+
258
+ **kwargs,
259
+ ):
260
+
261
+
262
+ self.num_key_value_heads = num_key_value_heads
263
+ self.lora_rank_tokenshift = lora_rank_tokenshift
264
+ self.lora_rank_decay = lora_rank_decay
265
+ self.lora_rank_iclr = lora_rank_iclr
266
+ self.lora_rank_value_residual_mix = lora_rank_value_residual_mix
267
+ self.lora_rank_gate = lora_rank_gate
268
+
269
+ self.vocab_size = vocab_size
270
+ self.max_position_embeddings = max_position_embeddings
271
+ self.hidden_size = hidden_size
272
+ self.intermediate_size = intermediate_size
273
+ self.num_hidden_layers = num_hidden_layers
274
+
275
+ # Default to MoE from the second layer and on
276
+ self.mlp_layer_types = mlp_layer_types
277
+ if self.mlp_layer_types is None:
278
+ self.mlp_layer_types = ["dense"] + ["sparse"] * (self.num_hidden_layers - 1)
279
+ layer_type_validation(self.mlp_layer_types, self.num_hidden_layers, attention=False)
280
+
281
+ self.layer_types = None
282
+ self.sliding_window = None
283
+ if self.layer_types is None:
284
+ self.layer_types = [
285
+ "sliding_attention"
286
+ if self.sliding_window is not None and i >= self.max_window_layers
287
+ else "full_attention"
288
+ for i in range(self.num_hidden_layers)
289
+ ]
290
+
291
+ self.moe_intermediate_size = moe_intermediate_size
292
+ self.num_attention_heads = num_attention_heads
293
+ self.n_shared_experts = n_shared_experts
294
+ self.n_routed_experts = n_routed_experts
295
+ self.routed_scaling_factor = routed_scaling_factor
296
+ self.kv_lora_rank = kv_lora_rank
297
+ self.q_lora_rank = q_lora_rank
298
+ self.qk_rope_head_dim = qk_rope_head_dim
299
+ self.v_head_dim = v_head_dim
300
+ self.qk_nope_head_dim = qk_nope_head_dim
301
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
302
+ self.head_dim = qk_rope_head_dim
303
+ self.n_group = n_group
304
+ self.topk_group = topk_group
305
+ self.num_experts_per_tok = num_experts_per_tok
306
+ self.norm_topk_prob = norm_topk_prob
307
+ self.rope_interleave = rope_interleave
308
+ self.num_key_value_heads = num_key_value_heads
309
+ self.hidden_act = hidden_act
310
+ self.initializer_range = initializer_range
311
+ self.rms_norm_eps = rms_norm_eps
312
+ self.pretraining_tp = pretraining_tp
313
+ self.use_cache = use_cache
314
+ self.attention_bias = attention_bias
315
+ self.attention_dropout = attention_dropout
316
+ self.rope_parameters = rope_parameters
317
+ self.pad_token_id = pad_token_id
318
+ self.bos_token_id = bos_token_id
319
+ self.eos_token_id = eos_token_id
320
+ self.tie_word_embeddings = tie_word_embeddings
321
+
322
+ super().__init__(**kwargs)
323
+
324
+ __all__ = ["RWKV07IConfig"]
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_rwkv07i.py ADDED
@@ -0,0 +1,1528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch RWKV07I model.
3
+ base code from SmerkyG @ recursal.ai, featherless.ai
4
+ hxa07i implementation RWKV07I Few-Head-Gated-Attention
5
+
6
+ """
7
+
8
+ import math
9
+ import inspect
10
+ from typing import List, Optional, Tuple, Union, Dict, Any
11
+
12
+ import torch
13
+ import torch.utils.checkpoint
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
17
+
18
+ from transformers import initialization as init
19
+ from transformers.activations import ACT2FN
20
+ from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin
21
+ from transformers.generation import GenerationMixin
22
+ from transformers.integrations import use_kernel_forward_from_hub
23
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
24
+ from transformers.integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernel_func_from_hub
25
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
26
+
27
+ from transformers.modeling_layers import (
28
+ GenericForQuestionAnswering,
29
+ GenericForSequenceClassification,
30
+ GenericForTokenClassification,
31
+ GradientCheckpointingLayer,
32
+ )
33
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
34
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from transformers.processing_utils import Unpack
37
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
38
+ from transformers.utils.generic import check_model_inputs
39
+ from transformers.utils.generic import check_model_inputs, is_flash_attention_requested, maybe_autocast
40
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
41
+
42
+ from .configuration_rwkv07i import RWKV07IConfig
43
+
44
+ import torch
45
+ from dataclasses import dataclass, field
46
+ from typing import Any, Dict, List, Optional, Tuple
47
+
48
+
49
+ @dataclass
50
+ class LayerCache:
51
+ """
52
+ Cache storage for a single layer.
53
+
54
+ Supports two independent cache types per layer:
55
+ - rwkv_state / shift_state: overwrite semantics (copy_)
56
+ - attn_key_cache / attn_value_cache: append semantics (cat along seq dim)
57
+
58
+ A pure RWKV layer uses only rwkv_state/shift_state.
59
+ A pure Attention layer uses only attn_key/value_cache.
60
+ A hybrid RWKV+small-Attention layer uses all four.
61
+ """
62
+ rwkv_state: Optional[torch.Tensor] = None
63
+ shift_state: Optional[torch.Tensor] = None
64
+ attn_key_cache: Optional[torch.Tensor] = None
65
+ attn_value_cache: Optional[torch.Tensor] = None
66
+
67
+ @property
68
+ def has_rwkv_state(self) -> bool:
69
+ return self.rwkv_state is not None
70
+
71
+ @property
72
+ def has_kv_cache(self) -> bool:
73
+ return self.attn_key_cache is not None
74
+
75
+ @property
76
+ def kv_seq_length(self) -> int:
77
+ """Current sequence length stored in the KV cache."""
78
+ if self.attn_key_cache is not None:
79
+ return self.attn_key_cache.size(-2)
80
+ return 0
81
+
82
+ def reset_kv(self):
83
+ """Clear KV cache (e.g. for sliding window truncation)."""
84
+ self.attn_key_cache = None
85
+ self.attn_value_cache = None
86
+
87
+ def reset_rwkv(self):
88
+ """Clear RWKV state."""
89
+ self.rwkv_state = None
90
+ self.shift_state = None
91
+
92
+ def reset(self):
93
+ """Clear all cached data for this layer."""
94
+ self.reset_kv()
95
+ self.reset_rwkv()
96
+
97
+
98
+ class RWKV07IState:
99
+ """
100
+ Cache manager for RWKV-Attention hybrid models.
101
+
102
+ Each layer gets an independent LayerCache that can hold:
103
+ - RWKV recurrent state (overwrite on update)
104
+ - Token shift state (overwrite on update)
105
+ - Attention K cache (append on update)
106
+ - Attention V cache (append on update)
107
+
108
+ Usage in model forward:
109
+ # Pure RWKV layer
110
+ cache.update_rwkv_state(layer_idx, new_state, new_shift, token_count=seq_len)
111
+
112
+ # Pure Attention layer
113
+ cache.update_kv_cache(layer_idx, key, value, token_count=seq_len)
114
+
115
+ # Hybrid RWKV layer (RWKV + small Attention)
116
+ cache.update_rwkv_state(layer_idx, new_state, new_shift, token_count=seq_len)
117
+ cache.update_kv_cache(layer_idx, key, value) # token_count=0, already counted
118
+ """
119
+
120
+ def __init__(self) -> None:
121
+ self._seen_tokens: int = 0
122
+ self._layers: List[LayerCache] = []
123
+ # RoPE cache (shared across layers)
124
+ self.sin: List[torch.Tensor] = []
125
+ self.cos: List[torch.Tensor] = []
126
+ self.cumulative_scores: List[torch.Tensor] = []
127
+
128
+ # ------------------------------------------------------------------ #
129
+ # Internal helpers
130
+ # ------------------------------------------------------------------ #
131
+
132
+ def _ensure_layer(self, layer_idx: int) -> LayerCache:
133
+ """Ensure LayerCache exists for the given index, padding with empties if needed."""
134
+ while layer_idx >= len(self._layers):
135
+ self._layers.append(LayerCache())
136
+ return self._layers[layer_idx]
137
+
138
+ # ------------------------------------------------------------------ #
139
+ # Core update methods (separated by semantics & timing)
140
+ # ------------------------------------------------------------------ #
141
+
142
+ @torch.no_grad()
143
+ def update_rwkv_state(
144
+ self,
145
+ layer_idx: int,
146
+ rwkv_state: torch.Tensor,
147
+ shift_state: torch.Tensor,
148
+ token_count: int = 0,
149
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
150
+ """
151
+ Update RWKV recurrent state — **overwrite** semantics.
152
+
153
+ Called AFTER the RWKV state recurrence computation.
154
+
155
+ Args:
156
+ layer_idx: Index of the layer.
157
+ rwkv_state: New recurrent state tensor.
158
+ shift_state: New token-shift state tensor.
159
+ token_count: Number of new tokens processed.
160
+ Pass seq_len for the first cache-counting call per step,
161
+ pass 0 if another update already counted this step.
162
+
163
+ Returns:
164
+ (rwkv_state, shift_state) stored in cache.
165
+ """
166
+ if layer_idx == 0:
167
+ self._seen_tokens += token_count
168
+
169
+ cache = self._ensure_layer(layer_idx)
170
+
171
+ if cache.rwkv_state is None:
172
+ # First call — store directly
173
+ cache.rwkv_state = rwkv_state
174
+ cache.shift_state = shift_state
175
+ else:
176
+ # Subsequent calls — overwrite in-place
177
+ cache.rwkv_state.copy_(rwkv_state)
178
+ cache.shift_state.copy_(shift_state)
179
+
180
+ return cache.rwkv_state, cache.shift_state
181
+
182
+ @torch.no_grad()
183
+ def update_kv_cache(
184
+ self,
185
+ layer_idx: int,
186
+ key: torch.Tensor,
187
+ value: torch.Tensor,
188
+ token_count: int = 0,
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ """
191
+ Update Attention KV cache — **append** semantics.
192
+
193
+ Called AFTER K, V projection.
194
+
195
+ Args:
196
+ layer_idx: Index of the layer.
197
+ key: New key tensor, shape (..., new_seq_len, head_dim).
198
+ value: New value tensor, shape (..., new_seq_len, head_dim).
199
+ token_count: Number of new tokens processed.
200
+ Pass seq_len for the first cache-counting call per step,
201
+ pass 0 if another update already counted this step
202
+ (e.g. hybrid layer where update_rwkv_state was called first).
203
+
204
+ Returns:
205
+ (full_key_cache, full_value_cache) after concatenation.
206
+ """
207
+ if layer_idx == 0:
208
+ self._seen_tokens += token_count
209
+
210
+ cache = self._ensure_layer(layer_idx)
211
+
212
+ if cache.attn_key_cache is None:
213
+ # First call — store directly
214
+ cache.attn_key_cache = key
215
+ cache.attn_value_cache = value
216
+ else:
217
+ # Subsequent calls — append along sequence dimension
218
+ cache.attn_key_cache = torch.cat(
219
+ [cache.attn_key_cache, key], dim=-2
220
+ )
221
+ cache.attn_value_cache = torch.cat(
222
+ [cache.attn_value_cache, value], dim=-2
223
+ )
224
+
225
+ return cache.attn_key_cache, cache.attn_value_cache
226
+
227
+ @torch.no_grad()
228
+ def update(
229
+ self,
230
+ kv_state: torch.Tensor,
231
+ shift_state: torch.Tensor,
232
+ layer_idx: int,
233
+ token_count: int = 0,
234
+ is_attention_layer: bool = True,
235
+ cache_kwargs: Optional[Dict[str, Any]] = None,
236
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
237
+ """
238
+ Legacy unified update for backward compatibility.
239
+
240
+ Dispatches to update_kv_cache or update_rwkv_state based on
241
+ is_attention_layer flag.
242
+
243
+ New code should prefer the explicit update_rwkv_state / update_kv_cache
244
+ methods directly.
245
+ """
246
+ if is_attention_layer:
247
+ return self.update_kv_cache(
248
+ layer_idx=layer_idx,
249
+ key=kv_state,
250
+ value=shift_state,
251
+ token_count=token_count,
252
+ )
253
+ else:
254
+ return self.update_rwkv_state(
255
+ layer_idx=layer_idx,
256
+ rwkv_state=kv_state,
257
+ shift_state=shift_state,
258
+ token_count=token_count,
259
+ )
260
+
261
+ # ------------------------------------------------------------------ #
262
+ # Accessors
263
+ # ------------------------------------------------------------------ #
264
+
265
+ def get_layer_cache(self, layer_idx: int) -> LayerCache:
266
+ """Get the full LayerCache for a given layer."""
267
+ if layer_idx < len(self._layers):
268
+ return self._layers[layer_idx]
269
+ raise KeyError(
270
+ f"Cache only has {len(self._layers)} layers, "
271
+ f"attempted to access layer with index {layer_idx}"
272
+ )
273
+
274
+ def get_rwkv_state(self, layer_idx: int) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
275
+ """Get (rwkv_state, shift_state) for a layer. Returns (None, None) if not set."""
276
+ if layer_idx < len(self._layers):
277
+ c = self._layers[layer_idx]
278
+ return c.rwkv_state, c.shift_state
279
+ return None, None
280
+
281
+ def get_kv_cache(self, layer_idx: int) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
282
+ """Get (key_cache, value_cache) for a layer. Returns (None, None) if not set."""
283
+ if layer_idx < len(self._layers):
284
+ c = self._layers[layer_idx]
285
+ return c.attn_key_cache, c.attn_value_cache
286
+ return None, None
287
+
288
+ # ------------------------------------------------------------------ #
289
+ # Backward-compatible dict-like interface
290
+ # ------------------------------------------------------------------ #
291
+
292
+ def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
293
+ """
294
+ Backward-compatible indexing: cache[layer_idx] returns
295
+ (layer_kv_states, layer_shift_states).
296
+
297
+ For Attention layers this returns (key_cache, value_cache).
298
+ For RWKV layers this returns (rwkv_state, shift_state).
299
+ For hybrid layers, this returns KV cache if present, else RWKV state.
300
+
301
+ Prefer get_layer_cache / get_rwkv_state / get_kv_cache for new code.
302
+ """
303
+ if layer_idx >= len(self._layers):
304
+ raise KeyError(
305
+ f"Cache only has {len(self._layers)} layers, "
306
+ f"attempted to access layer with index {layer_idx}"
307
+ )
308
+ c = self._layers[layer_idx]
309
+ # Priority: KV cache (for seq length queries etc.), fallback to RWKV state
310
+ if c.has_kv_cache:
311
+ return c.attn_key_cache, c.attn_value_cache
312
+ return c.rwkv_state, c.shift_state
313
+
314
+ def __iter__(self):
315
+ for layer_idx in range(len(self._layers)):
316
+ yield self[layer_idx]
317
+
318
+ def __len__(self) -> int:
319
+ return len(self._layers)
320
+
321
+ # ------------------------------------------------------------------ #
322
+ # Sequence / shape queries
323
+ # ------------------------------------------------------------------ #
324
+
325
+ def get_seq_length(self, layer_idx: int = 0) -> int:
326
+ """Returns the total number of tokens seen so far."""
327
+ return self._seen_tokens
328
+
329
+ def get_kv_seq_length(self, layer_idx: int = 0) -> int:
330
+ """Returns the sequence length stored in KV cache for a specific layer."""
331
+ if layer_idx < len(self._layers):
332
+ return self._layers[layer_idx].kv_seq_length
333
+ return 0
334
+
335
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
336
+ """Given the sequence length of new inputs, returns the usable length.
337
+ Linear attention / RWKV layers have no maximum length constraint."""
338
+ return new_seq_length
339
+
340
+ def get_max_cache_shape(self) -> Optional[int]:
341
+ return None
342
+
343
+ def get_max_length(self) -> Optional[int]:
344
+ return None
345
+
346
+ def get_mask_sizes(
347
+ self, cache_position: torch.Tensor, layer_idx: int
348
+ ) -> Tuple[int, int]:
349
+ """Return (kv_length, kv_offset) used to generate the attention mask."""
350
+ kv_offset = 0
351
+ query_length = cache_position.shape[0]
352
+ past_seen_tokens = self.get_seq_length()
353
+ kv_length = query_length + past_seen_tokens
354
+ return kv_length, kv_offset
355
+
356
+ # ------------------------------------------------------------------ #
357
+ # Housekeeping
358
+ # ------------------------------------------------------------------ #
359
+
360
+ def crop(self, max_length: int):
361
+ """Crop KV caches to max_length. RWKV state is unaffected."""
362
+ for cache in self._layers:
363
+ if cache.has_kv_cache and cache.kv_seq_length > max_length:
364
+ cache.attn_key_cache = cache.attn_key_cache[..., -max_length:, :]
365
+ cache.attn_value_cache = cache.attn_value_cache[..., -max_length:, :]
366
+
367
+ def reset_layer(self, layer_idx: int):
368
+ """Clear all cached data for a specific layer."""
369
+ if layer_idx < len(self._layers):
370
+ self._layers[layer_idx].reset()
371
+
372
+ def reset(self):
373
+ """Clear the entire cache."""
374
+ self._layers.clear()
375
+ self._seen_tokens = 0
376
+ self.sin.clear()
377
+ self.cos.clear()
378
+ self.cumulative_scores.clear()
379
+
380
+ def reorder_cache(self, beam_idx: torch.LongTensor):
381
+ """
382
+ Reorder cache for beam search.
383
+ KV caches can be reordered; RWKV states are reordered along batch dim.
384
+ """
385
+ for cache in self._layers:
386
+ if cache.has_kv_cache:
387
+ cache.attn_key_cache = cache.attn_key_cache.index_select(0, beam_idx)
388
+ cache.attn_value_cache = cache.attn_value_cache.index_select(0, beam_idx)
389
+ if cache.has_rwkv_state:
390
+ cache.rwkv_state = cache.rwkv_state.index_select(0, beam_idx)
391
+ cache.shift_state = cache.shift_state.index_select(0, beam_idx)
392
+
393
+ @property
394
+ def is_compileable(self) -> bool:
395
+ return True
396
+
397
+ # ------------------------------------------------------------------ #
398
+ # Debug / inspection
399
+ # ------------------------------------------------------------------ #
400
+
401
+ def summary(self) -> str:
402
+ """Human-readable summary of cache contents."""
403
+ lines = [f"HybridCache: {len(self._layers)} layers, {self._seen_tokens} tokens seen"]
404
+ for i, cache in enumerate(self._layers):
405
+ parts = []
406
+ if cache.has_rwkv_state:
407
+ parts.append(f"rwkv_state={list(cache.rwkv_state.shape)}")
408
+ parts.append(f"shift_state={list(cache.shift_state.shape)}")
409
+ if cache.has_kv_cache:
410
+ parts.append(f"key_cache={list(cache.attn_key_cache.shape)}")
411
+ parts.append(f"value_cache={list(cache.attn_value_cache.shape)}")
412
+ if not parts:
413
+ parts.append("empty")
414
+ lines.append(f" layer {i:3d}: {', '.join(parts)}")
415
+ return "\n".join(lines)
416
+
417
+ try:
418
+ from fla.ops.rwkv7.chunk import chunk_rwkv7
419
+ from fla.ops.rwkv7.fused_recurrent import fused_recurrent_rwkv7
420
+ except ImportError:
421
+ print("Required module is not installed. Please install it using the following commands:")
422
+ print("pip install --no-use-pep517 flash-linear-attention")
423
+ print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
424
+ print("pip install triton>=2.2.0")
425
+
426
+
427
+
428
+ def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
429
+ r"""
430
+ TODO let's just use the original freqcis computation to not have the view
431
+ transpose + reshape! This is not optimized!
432
+ Applies Rotary Position Embedding to the query and key tensors.
433
+
434
+ Args:
435
+ q (`torch.Tensor`): The query tensor.
436
+ k (`torch.Tensor`): The key tensor.
437
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
438
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
439
+ position_ids (`torch.Tensor`):
440
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
441
+ used to pass offsetted position ids when working with a KV-cache.
442
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
443
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
444
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
445
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
446
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
447
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
448
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
449
+ Returns:
450
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
451
+ """
452
+ cos = cos.unsqueeze(unsqueeze_dim)
453
+ sin = sin.unsqueeze(unsqueeze_dim)
454
+
455
+ b, h, s, d = q.shape
456
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
457
+
458
+ b, h, s, d = k.shape
459
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
460
+
461
+ q_embed = (q * cos) + (rotate_half(q) * sin)
462
+ k_embed = (k * cos) + (rotate_half(k) * sin)
463
+ return q_embed, k_embed
464
+
465
+
466
+ def yarn_get_mscale(scale=1, mscale=1):
467
+ if scale <= 1:
468
+ return 1.0
469
+ return 0.1 * mscale * math.log(scale) + 1.0
470
+
471
+ # def is_layer_attention(config, layer_id):
472
+ # return layer_id >= config.first_attention_layer and layer_id < config.first_post_attention_layer and (layer_id > min(config.num_hidden_layers, config.last_striping_layer) or (min(config.num_hidden_layers-1, config.last_striping_layer) - layer_id) % config.attention_striping == 0)
473
+
474
+ def is_layer_attention(config, layer_id):
475
+ return layer_id in config.transformer_layers
476
+
477
+ def repeat_kv_rwkv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
478
+ """
479
+ Repeat KV heads along the head dimension (GQA).
480
+ Input: (B, T, H_kv, D)
481
+ Output: (B, T, H_kv * n_rep, D)
482
+ """
483
+ B, T, H_kv, D = hidden_states.shape
484
+ if n_rep == 1:
485
+ return hidden_states
486
+ # Expand head dim
487
+ hidden_states = hidden_states[:, :, :, None, :] # (B, T, H_kv, 1, D)
488
+ hidden_states = hidden_states.expand(B, T, H_kv, n_rep, D) # (B, T, H_kv, n_rep, D)
489
+ return hidden_states.reshape(B, T, H_kv * n_rep, D).contiguous()
490
+
491
+
492
+
493
+ class Glm4MoeLiteRotaryEmbedding(nn.Module):
494
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
495
+
496
+ def __init__(self, config: RWKV07IConfig, device=None):
497
+ super().__init__()
498
+ self.max_seq_len_cached = config.max_position_embeddings
499
+ self.original_max_seq_len = config.max_position_embeddings
500
+
501
+ self.config = config
502
+
503
+ self.rope_type = self.config.rope_parameters["rope_type"]
504
+ rope_init_fn: Callable = self.compute_default_rope_parameters
505
+ if self.rope_type != "default":
506
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
507
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
508
+
509
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
510
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
511
+
512
+ @staticmethod
513
+ def compute_default_rope_parameters(
514
+ config: RWKV07IConfig | None = None,
515
+ device: Optional["torch.device"] = None,
516
+ seq_len: int | None = None,
517
+ ) -> tuple["torch.Tensor", float]:
518
+ """
519
+ Computes the inverse frequencies according to the original RoPE implementation
520
+ Args:
521
+ config ([`~transformers.PreTrainedConfig`]):
522
+ The model configuration.
523
+ device (`torch.device`):
524
+ The device to use for initialization of the inverse frequencies.
525
+ seq_len (`int`, *optional*):
526
+ The current sequence length. Unused for this type of RoPE.
527
+ Returns:
528
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
529
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
530
+ """
531
+ base = config.rope_parameters["rope_theta"]
532
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
533
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
534
+ dim = int(head_dim * partial_rotary_factor)
535
+
536
+ attention_factor = 1.0 # Unused in this type of RoPE
537
+
538
+ # Compute the inverse frequencies
539
+ inv_freq = 1.0 / (
540
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
541
+ )
542
+ return inv_freq, attention_factor
543
+
544
+ @torch.no_grad()
545
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
546
+ def forward(self, x, position_ids):
547
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
548
+ position_ids_expanded = position_ids[:, None, :].float()
549
+
550
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
551
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
552
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
553
+ emb = torch.cat((freqs, freqs), dim=-1)
554
+ cos = emb.cos() * self.attention_scaling
555
+ sin = emb.sin() * self.attention_scaling
556
+
557
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
558
+
559
+ class Glm4MoeLiteMLP(nn.Module):
560
+ def __init__(self, config, intermediate_size=None):
561
+ super().__init__()
562
+ self.config = config
563
+ self.hidden_size = config.hidden_size
564
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
565
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
566
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
567
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
568
+ self.act_fn = ACT2FN[config.hidden_act]
569
+
570
+ def forward(self, x):
571
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
572
+ return down_proj
573
+
574
+
575
+ class Glm4MoeLiteTopkRouter(nn.Module):
576
+ def __init__(self, config: RWKV07IConfig):
577
+ super().__init__()
578
+ self.config = config
579
+ self.top_k = config.num_experts_per_tok
580
+ self.n_routed_experts = config.n_routed_experts
581
+ self.routed_scaling_factor = config.routed_scaling_factor
582
+ self.n_group = config.n_group
583
+ self.topk_group = config.topk_group
584
+ self.norm_topk_prob = config.norm_topk_prob
585
+
586
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
587
+ self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32))
588
+
589
+ def forward(self, hidden_states):
590
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
591
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
592
+ return router_logits
593
+
594
+
595
+ @use_kernel_forward_from_hub("RMSNorm")
596
+ class Glm4MoeLiteRMSNorm(nn.Module):
597
+ def __init__(self, hidden_size, eps=1e-5):
598
+ """
599
+ Glm4MoeLiteRMSNorm is equivalent to T5LayerNorm
600
+ """
601
+ super().__init__()
602
+ self.weight = nn.Parameter(torch.ones(hidden_size))
603
+ self.variance_epsilon = eps
604
+
605
+ def forward(self, hidden_states):
606
+ input_dtype = hidden_states.dtype
607
+ hidden_states = hidden_states.to(torch.float32)
608
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
609
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
610
+ return self.weight * hidden_states.to(input_dtype)
611
+
612
+ def extra_repr(self):
613
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
614
+
615
+
616
+ @use_experts_implementation
617
+ class Glm4MoeLiteNaiveMoe(nn.Module):
618
+ """Collection of expert weights stored as 3D tensors."""
619
+
620
+ def __init__(self, config):
621
+ super().__init__()
622
+ self.num_experts = config.num_local_experts
623
+ self.hidden_dim = config.hidden_size
624
+ self.intermediate_dim = config.moe_intermediate_size
625
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
626
+ self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
627
+ self.act_fn = ACT2FN[config.hidden_act]
628
+
629
+ def forward(
630
+ self,
631
+ hidden_states: torch.Tensor,
632
+ top_k_index: torch.Tensor,
633
+ top_k_weights: torch.Tensor,
634
+ ) -> torch.Tensor:
635
+ final_hidden_states = torch.zeros_like(hidden_states)
636
+ with torch.no_grad():
637
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
638
+ expert_mask = expert_mask.permute(2, 1, 0)
639
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
640
+
641
+ for expert_idx in expert_hit:
642
+ expert_idx = expert_idx[0]
643
+ if expert_idx == self.num_experts:
644
+ continue
645
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
646
+ current_state = hidden_states[token_idx]
647
+ gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
648
+ current_hidden_states = self.act_fn(gate) * up
649
+ current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
650
+ current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
651
+ final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
652
+
653
+ return final_hidden_states
654
+
655
+
656
+ class Glm4MoeLiteMoE(nn.Module):
657
+ """
658
+ A mixed expert module containing shared experts.
659
+ """
660
+
661
+ def __init__(self, config):
662
+ super().__init__()
663
+ self.config = config
664
+ self.experts = Glm4MoeLiteNaiveMoe(config)
665
+ self.gate = Glm4MoeLiteTopkRouter(config)
666
+ self.shared_experts = Glm4MoeLiteMLP(
667
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
668
+ )
669
+ self.n_routed_experts = config.n_routed_experts
670
+ self.n_group = config.n_group
671
+ self.topk_group = config.topk_group
672
+ self.norm_topk_prob = config.norm_topk_prob
673
+ self.routed_scaling_factor = config.routed_scaling_factor
674
+ self.top_k = config.num_experts_per_tok
675
+
676
+ def route_tokens_to_experts(self, router_logits):
677
+ router_logits = router_logits.sigmoid()
678
+ router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
679
+ group_scores = (
680
+ router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
681
+ .topk(2, dim=-1)[0]
682
+ .sum(dim=-1)
683
+ )
684
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
685
+ group_mask = torch.zeros_like(group_scores)
686
+ group_mask.scatter_(1, group_idx, 1)
687
+ score_mask = (
688
+ group_mask.unsqueeze(-1)
689
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
690
+ .reshape(-1, self.n_routed_experts)
691
+ )
692
+ scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
693
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
694
+ topk_weights = router_logits.gather(1, topk_indices)
695
+ if self.norm_topk_prob:
696
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
697
+ topk_weights /= denominator
698
+ topk_weights = topk_weights * self.routed_scaling_factor
699
+ return topk_indices, topk_weights
700
+
701
+ def forward(self, hidden_states):
702
+ residuals = hidden_states
703
+ orig_shape = hidden_states.shape
704
+ router_logits = self.gate(hidden_states)
705
+ topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
706
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
707
+ hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
708
+ hidden_states = hidden_states + self.shared_experts(residuals)
709
+ return hidden_states
710
+
711
+
712
+
713
+ def rms_norm(hidden_states, eps = 1e-6):
714
+ #print('ugyuugyu')
715
+ input_dtype = hidden_states.dtype
716
+ hidden_states = hidden_states.to(torch.float32)
717
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
718
+ hidden_states = hidden_states * torch.rsqrt(variance + eps)
719
+ return hidden_states.to(input_dtype)
720
+
721
+ def generate_rotary_embedding(max_seqlen:int, dim:int, theta:float = 10000.0, scale:float = 1):
722
+ #inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float).to(device) / dim))
723
+
724
+ angular_velocity = theta ** -(torch.arange(0, dim, 2, dtype=torch.float) / dim) / scale # frequencies from 1.0 ... 1/theta
725
+ angles = torch.outer(torch.arange(max_seqlen), angular_velocity)
726
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
727
+ emb = torch.cat((angles, angles), dim=-1)
728
+ return torch.stack([emb.cos(), emb.sin()], dim=0)
729
+ #return torch.polar(torch.ones_like(angles), angles)
730
+
731
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
732
+ def rotate_half(x):
733
+ """Rotates half the hidden dims of the input."""
734
+ x1 = x[..., : x.shape[-1] // 2]
735
+ x2 = x[..., x.shape[-1] // 2 :]
736
+ return torch.cat((-x2, x1), dim=-1)
737
+
738
+
739
+ def apply_rotary_pos_emb_single(x, cos, sin, unsqueeze_dim=1):
740
+ return (x * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(x) * sin.unsqueeze(unsqueeze_dim))
741
+
742
+ from typing import Callable, Optional, Tuple, Union
743
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
744
+ from transformers.processing_utils import Unpack
745
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
746
+
747
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
748
+ """
749
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
750
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
751
+ """
752
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
753
+ if n_rep == 1:
754
+ return hidden_states
755
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
756
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
757
+
758
+
759
+
760
+
761
+
762
+
763
+
764
+ class RWKV07I_Attention(nn.Module):
765
+ # this supports
766
+ # Prime-RWKV mode(FHGA Assist)
767
+ # Effecient-RWKV mode
768
+
769
+ def __init__(self, config: RWKV07IConfig, layer_idx: int):
770
+ super().__init__()
771
+ self.config = config
772
+ self.layer_idx = layer_idx
773
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
774
+ self.attention_dropout = config.attention_dropout
775
+ self.num_heads = config.num_attention_heads
776
+
777
+ self.q_lora_rank = config.q_lora_rank
778
+ self.qk_rope_head_dim = config.qk_rope_head_dim
779
+ self.kv_lora_rank = config.kv_lora_rank
780
+ self.v_head_dim = config.v_head_dim
781
+ self.qk_nope_head_dim = config.qk_nope_head_dim
782
+ self.qk_head_dim = config.qk_head_dim
783
+
784
+
785
+ self.is_causal = True
786
+ if self.q_lora_rank is None:
787
+ self.receptance_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
788
+ else:
789
+ self.receptance_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
790
+ self.receptance_a_layernorm = Glm4MoeLiteRMSNorm(config.q_lora_rank,eps=config.rms_norm_eps)
791
+ self.receptance_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
792
+
793
+ self.keyvalue_a_proj_with_mqa = nn.Linear(
794
+ config.hidden_size,
795
+ self.kv_lora_rank + self.qk_rope_head_dim,
796
+ bias=config.attention_bias,
797
+ )
798
+ self.keyvalue_a_layernorm = Glm4MoeLiteRMSNorm(self.kv_lora_rank,eps=config.rms_norm_eps)
799
+ self.keyvalue_b_proj = nn.Linear(
800
+ self.kv_lora_rank,
801
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
802
+ bias=False,
803
+ )
804
+
805
+ self.output_proj = nn.Linear(
806
+ self.num_heads * self.v_head_dim,
807
+ config.hidden_size,
808
+ bias=config.attention_bias,
809
+ )
810
+
811
+ self.scaling = self.qk_head_dim ** (-0.5)
812
+
813
+
814
+ lora_rank_decay = config.lora_rank_decay
815
+ lora_rank_iclr = config.lora_rank_iclr
816
+ lora_rank_gate = config.lora_rank_gate
817
+
818
+ H = self.num_heads * 2
819
+ N = self.qk_head_dim // 2
820
+
821
+ self.w0 = nn.Parameter(torch.empty(1,1,H*N))
822
+ self.w1 = nn.Parameter(torch.empty(config.hidden_size, lora_rank_decay))
823
+ self.w2 = nn.Parameter(torch.empty(lora_rank_decay, H*N))
824
+
825
+ self.a0 = nn.Parameter(torch.empty(1,1,H*N))
826
+ self.a1 = nn.Parameter(torch.empty(config.hidden_size, lora_rank_iclr))
827
+ self.a2 = nn.Parameter(torch.empty(lora_rank_iclr, H*N))
828
+
829
+ self.g1 = nn.Parameter(torch.empty(config.hidden_size, lora_rank_gate))
830
+ self.g2 = nn.Parameter(torch.empty(lora_rank_gate, H*N))
831
+
832
+ # ---- Tiny Attention (optional per layer) ----
833
+ self.TinyAttention = layer_idx in config.tiny_attention_layers
834
+ if self.TinyAttention:
835
+ #print(f"[Layer {layer_idx}] tiny Attention enabled")
836
+ self.tiny_n_heads = config.tiny_n_heads#4
837
+ self.tiny_head_dim = config.tiny_head_dim#128
838
+ self.tiny_kv_heads = config.tiny_kv_heads#2
839
+ self.tiny_kv_groups = self.tiny_n_heads // self.tiny_kv_heads
840
+
841
+ self.tiny_q_proj = nn.Linear(config.hidden_size, self.tiny_n_heads * self.tiny_head_dim, bias=False)
842
+ self.tiny_k_proj = nn.Linear(config.hidden_size, self.tiny_kv_heads * self.tiny_head_dim, bias=False)
843
+ self.tiny_v_proj = nn.Linear(config.hidden_size, self.tiny_kv_heads * self.tiny_head_dim, bias=False)
844
+ self.tiny_o_proj = nn.Linear(self.tiny_n_heads * self.tiny_head_dim, config.hidden_size, bias=False)
845
+
846
+ self.tiny_q_norm = Glm4MoeLiteRMSNorm(self.tiny_head_dim, eps=config.rms_norm_eps)
847
+ self.tiny_k_norm = Glm4MoeLiteRMSNorm(self.tiny_head_dim, eps=config.rms_norm_eps)
848
+
849
+ # near-zero init so tiny-attention starts as small perturbation
850
+ # nn.init.zeros_(self.tiny_o_proj.weight)
851
+ # nn.init.xavier_uniform_(self.tiny_q_proj.weight)
852
+ # nn.init.xavier_uniform_(self.tiny_k_proj.weight)
853
+ # nn.init.xavier_uniform_(self.tiny_v_proj.weight)
854
+
855
+ # LoRA gate for mixing
856
+ D_tiny_GATE = 128
857
+ self.tiny_g1 = nn.Parameter(torch.zeros(config.hidden_size, D_tiny_GATE))
858
+ self.tiny_g2 = nn.Parameter(torch.randn(D_tiny_GATE, self.tiny_n_heads * self.tiny_head_dim) * 0.01)
859
+
860
+
861
+ def forward(
862
+ self,
863
+ hidden_states: torch.Tensor,
864
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
865
+ attention_mask: torch.Tensor | None,
866
+ past_key_values: Cache | None = None,
867
+ cache_position: torch.LongTensor | None = None,
868
+ output_attentions: Optional[bool] = False,
869
+ use_cache: bool = True,
870
+ v_first: Optional[torch.Tensor] = None,
871
+ k_first: Optional[torch.Tensor] = None,
872
+ **kwargs: Unpack[FlashAttentionKwargs],
873
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
874
+ batch_size, seq_length = hidden_states.shape[:-1]
875
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
876
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
877
+
878
+ self.head_dim = self.qk_head_dim
879
+
880
+ output_shift_state = hidden_states[:, -1:].detach().clone()
881
+
882
+ x = hidden_states
883
+
884
+ B, T, C = hidden_states.shape
885
+
886
+ #RWKV Block uses 128
887
+ H = self.num_heads * 2
888
+ N = self.head_dim // 2
889
+
890
+ # ============================================================ #
891
+ # 1. Retrieve cached RWKV state (if any)
892
+ # ============================================================ #
893
+
894
+ input_vk_state = None
895
+ input_shift_state = None
896
+
897
+ if use_cache and past_key_values is not None and len(past_key_values) > self.layer_idx:
898
+ input_vk_state, input_shift_state = past_key_values.get_rwkv_state(self.layer_idx)
899
+
900
+ if input_vk_state is None:
901
+ input_vk_state = torch.zeros(
902
+ B, H, N, N,
903
+ dtype=torch.bfloat16, device=hidden_states.device
904
+ )
905
+ if input_shift_state is None:
906
+ input_shift_state = torch.zeros_like(hidden_states[:, -1:])
907
+
908
+ xr = xw = xk = xa = xg = xe = x
909
+
910
+ if self.q_lora_rank is None:
911
+ q_states = self.receptance_proj(hidden_states)
912
+ else:
913
+ q_states = self.receptance_b_proj(self.receptance_a_layernorm(self.receptance_a_proj(xr)))
914
+ q_states = q_states.view(query_shape).transpose(1, 2)
915
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
916
+
917
+ compressed_kv = self.keyvalue_a_proj_with_mqa(xk)
918
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
919
+
920
+ k_pass = self.keyvalue_b_proj(self.keyvalue_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
921
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
922
+
923
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
924
+
925
+ cos, sin = position_embeddings
926
+ if self.config.nope_in_rwkv == False:
927
+ if self.config.rope_interleave: # support using interleaved weights for efficiency
928
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
929
+ else:
930
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
931
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
932
+
933
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
934
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
935
+
936
+ query_states = query_states.transpose(1, 2).contiguous()
937
+ key_states = key_states.transpose(1, 2).contiguous()
938
+ value_states = value_states.transpose(1, 2).contiguous()
939
+
940
+
941
+ q_len = T
942
+
943
+ r = query_states.reshape(B,T,-1)
944
+ k = key_states.reshape(B,T,-1)
945
+ v = value_states.reshape(B,T,-1)
946
+
947
+ log_neglog_w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) -0.5
948
+ a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2)
949
+ g = torch.sigmoid(xg @ self.g1) @ self.g2
950
+
951
+
952
+
953
+ kk = (k).view(B,T,H,-1).float()
954
+ kk = (kk / (torch.norm(kk, dim=-1, keepdim=True) + 1e-12)).view(B,T,-1).to(k.dtype)
955
+
956
+ w = (-log_neglog_w.float().exp()).exp()
957
+ k = k * (1.0 - w + a).to(dtype=torch.bfloat16)
958
+
959
+ aa = -kk
960
+ bb = kk * a
961
+ w = -log_neglog_w.float().exp()
962
+
963
+ r_,w_,k_,v_,aa_,bb_ = [i.view(B,T,H,N) for i in [r,w,k,v,aa,bb]]
964
+
965
+ if attention_mask is not None:
966
+ if attention_mask is not None:
967
+ if attention_mask.ndim == 2:
968
+ # [B, S]
969
+ mask = attention_mask[:, -T:] # [B, T]
970
+ v_ = v_ * mask[:, :, None, None] # → [B, T, 1, 1] に拡張して掛け算
971
+ elif attention_mask.ndim == 4:
972
+ # [B, 1, L, S]
973
+ mask = attention_mask[:, 0, -1, -T:] # [B, T]
974
+ v_ = v_ * mask[:, :, None, None] # 同上
975
+
976
+ x, output_vk_state = fused_recurrent_rwkv7(r_, w_, k_, v_, aa_, bb_, scale=1.0, initial_state=input_vk_state, output_final_state=True, head_first=False)
977
+
978
+ x = x.view(B,T,-1) * (float(N) ** -0.5)
979
+
980
+ if past_key_values is not None:
981
+ past_key_values.update_rwkv_state(
982
+ layer_idx=self.layer_idx,
983
+ rwkv_state=output_vk_state,
984
+ shift_state=output_shift_state,
985
+ token_count=T, # count tokens here
986
+ )
987
+
988
+ tiny_out = None
989
+ if self.TinyAttention:
990
+ # --- Q projection (always from current hidden_states) ---
991
+ mq = self.tiny_q_norm(
992
+ self.tiny_q_proj(hidden_states).view(B, T, self.tiny_n_heads, self.tiny_head_dim)
993
+ ).transpose(1, 2) # (B, n_heads, T, head_dim)
994
+
995
+ # --- K, V projection (current step) ---
996
+ mk_new = self.tiny_k_norm(
997
+ self.tiny_k_proj(hidden_states).view(B, T, self.tiny_kv_heads, self.tiny_head_dim)
998
+ ).transpose(1, 2) # (B, kv_heads, T, head_dim)
999
+
1000
+ mv_new = self.tiny_v_proj(hidden_states).view(
1001
+ B, T, self.tiny_kv_heads, self.tiny_head_dim
1002
+ ).transpose(1, 2) # (B, kv_heads, T, head_dim)
1003
+
1004
+ # ---- Update KV cache (append, AFTER projection) ----
1005
+ if past_key_values is not None:
1006
+ mk_full, mv_full = past_key_values.update_kv_cache(
1007
+ layer_idx=self.layer_idx,
1008
+ key=mk_new,
1009
+ value=mv_new,
1010
+ token_count=0, # already counted in update_rwkv_state
1011
+ )
1012
+ else:
1013
+ mk_full, mv_full = mk_new, mv_new
1014
+
1015
+ # GQA expand: (B, kv_heads, S, D) → (B, n_heads, S, D)
1016
+ S = mk_full.size(2) # full sequence length including past
1017
+ mk_expanded = mk_full[:, :, None, :, :].expand(
1018
+ B, self.tiny_kv_heads, self.tiny_kv_groups, S, self.tiny_head_dim
1019
+ ).reshape(B, self.tiny_n_heads, S, self.tiny_head_dim)
1020
+
1021
+ mv_expanded = mv_full[:, :, None, :, :].expand(
1022
+ B, self.tiny_kv_heads, self.tiny_kv_groups, S, self.tiny_head_dim
1023
+ ).reshape(B, self.tiny_n_heads, S, self.tiny_head_dim)
1024
+
1025
+ # SDPA — use is_causal only during prefill (T > 1)
1026
+ # During generation (T == 1), no causal mask needed
1027
+ # When past KV exists and T > 1, need explicit causal mask
1028
+ if T == 1:
1029
+ # Single-token generation: attend to all cached positions
1030
+ tiny_out = F.scaled_dot_product_attention(
1031
+ mq, mk_expanded, mv_expanded,
1032
+ is_causal=False,
1033
+ dropout_p=self.attention_dropout if self.training else 0.0,
1034
+ )
1035
+ elif past_key_values is None or S == T:
1036
+ # Prefill (no past cache, or first forward): standard causal
1037
+ tiny_out = F.scaled_dot_product_attention(
1038
+ mq, mk_expanded, mv_expanded,
1039
+ is_causal=True,
1040
+ dropout_p=self.attention_dropout if self.training else 0.0,
1041
+ )
1042
+ else:
1043
+ # Chunked prefill with existing cache:
1044
+ # build explicit causal mask allowing full attention to past
1045
+ # Q positions: [past_len, past_len + T)
1046
+ # K positions: [0, past_len + T)
1047
+ past_len = S - T
1048
+ causal_mask = torch.ones(T, S, dtype=torch.bool, device=hidden_states.device).tril(diagonal=past_len)
1049
+ # SDPA expects (B, n_heads, T, S) or broadcastable
1050
+ attn_mask = causal_mask.unsqueeze(0).unsqueeze(0) # (1, 1, T, S)
1051
+ tiny_out = F.scaled_dot_product_attention(
1052
+ mq, mk_expanded, mv_expanded,
1053
+ attn_mask=attn_mask,
1054
+ dropout_p=self.attention_dropout if self.training else 0.0,
1055
+ )
1056
+
1057
+ # (B, n_heads, T, head_dim) → (B, T, n_heads * head_dim)
1058
+ tiny_out = tiny_out.transpose(1, 2).contiguous().view(B, T, -1)
1059
+
1060
+ # Gated mixing via LoRA
1061
+ tiny_gate = torch.sigmoid(hidden_states @ self.tiny_g1) @ self.tiny_g2
1062
+ tiny_out = self.tiny_o_proj(tiny_out * tiny_gate)
1063
+
1064
+ if tiny_out is not None:
1065
+ x = self.output_proj(x * g) + tiny_out
1066
+ else:
1067
+ x = self.output_proj(x * g)
1068
+
1069
+ return x, v_first, k_first
1070
+
1071
+
1072
+
1073
+ class RWKV07IDecoderLayer(nn.Module):
1074
+ def __init__(self, config: RWKV07IConfig, layer_idx: int):
1075
+ super().__init__()
1076
+ self.hidden_size = config.hidden_size
1077
+
1078
+ self.layer_idx = layer_idx
1079
+
1080
+
1081
+ if layer_idx in config.tiny_attention_layers:
1082
+ print(f'layer {layer_idx} : Prime-RWKV')
1083
+ else:
1084
+ print(f'layer {layer_idx} : Effecient-RWKV')
1085
+
1086
+ att_fn = RWKV07I_Attention
1087
+
1088
+ self.self_attn = att_fn(config, layer_idx)
1089
+
1090
+ #Qwen Variant
1091
+ # if (layer_idx not in config.mlp_only_layers) and (
1092
+ # config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
1093
+ # ):
1094
+ # self.mlp = Qwen3MoeSparseMoeBlock(config)
1095
+ # else:
1096
+ # self.mlp = Qwen3MLP(config)
1097
+
1098
+ #GLM4-Lite variant
1099
+ if config.mlp_layer_types[layer_idx] == "sparse":
1100
+ self.mlp = Glm4MoeLiteMoE(config)
1101
+ print(f'Sparse MoE Mode = {layer_idx}')
1102
+ else:
1103
+ self.mlp = Glm4MoeLiteMLP(config)
1104
+ print(f'Dense MoE Mode = {layer_idx}')
1105
+
1106
+ self.input_layernorm = Glm4MoeLiteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1107
+ self.post_attention_layernorm = Glm4MoeLiteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1108
+ self.attention_type = config.layer_types[layer_idx]
1109
+
1110
+ def forward(
1111
+ self,
1112
+ hidden_states: torch.Tensor,
1113
+ frozen_residual: torch.Tensor,
1114
+ v_first: Optional[torch.Tensor],
1115
+ k_first: Optional[torch.Tensor],
1116
+ attention_mask: Optional[torch.Tensor] = None,
1117
+ position_ids: Optional[torch.LongTensor] = None,
1118
+ past_key_values: Optional[Cache] = None,
1119
+ output_attentions: Optional[bool] = False,
1120
+ use_cache: Optional[bool] = True,
1121
+ cache_position: Optional[torch.LongTensor] = None,
1122
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1123
+ **kwargs,
1124
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1125
+ """
1126
+ Args:
1127
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1128
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1129
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1130
+ output_attentions (`bool`, *optional*):
1131
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1132
+ returned tensors for more detail.
1133
+ output_router_logits (`bool`, *optional*):
1134
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
1135
+ and should not be returned during inference.
1136
+ use_cache (`bool`, *optional*):
1137
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1138
+ (see `past_key_values`).
1139
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
1140
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1141
+ Indices depicting the position of the input sequence tokens in the sequence.
1142
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
1143
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
1144
+ with `head_dim` being the embedding dimension of each attention head.
1145
+ kwargs (`dict`, *optional*):
1146
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
1147
+ into the model
1148
+ """
1149
+ residual = hidden_states
1150
+
1151
+ hidden_states = self.input_layernorm(hidden_states)
1152
+
1153
+ # Self Attention
1154
+ hidden_states, v_first, k_first = self.self_attn(
1155
+ hidden_states=hidden_states,
1156
+ frozen_residual=frozen_residual,
1157
+ v_first=v_first,
1158
+ k_first=k_first,
1159
+ attention_mask=attention_mask,
1160
+ position_ids=position_ids,
1161
+ past_key_values=past_key_values,
1162
+ output_attentions=output_attentions,
1163
+ use_cache=use_cache,
1164
+ cache_position=cache_position,
1165
+ position_embeddings=position_embeddings,
1166
+ **kwargs,
1167
+ #is_causal=True,
1168
+ )
1169
+ hidden_states = residual + hidden_states
1170
+
1171
+ # Fully Connected
1172
+ residual = hidden_states
1173
+ hidden_states = self.post_attention_layernorm(hidden_states)
1174
+ hidden_states = self.mlp(hidden_states)
1175
+ # For the MoE layers, we need to unpack
1176
+ if isinstance(hidden_states, tuple):
1177
+ hidden_states, _ = hidden_states
1178
+ hidden_states = residual + hidden_states
1179
+
1180
+ outputs = (hidden_states, v_first,k_first,)
1181
+
1182
+ return outputs
1183
+
1184
+
1185
+ @auto_docstring
1186
+ class RWKV07IPreTrainedModel(PreTrainedModel):
1187
+ config: RWKV07IConfig
1188
+ config_class = RWKV07IConfig
1189
+ base_model_prefix = "model"
1190
+ supports_gradient_checkpointing = False
1191
+ _no_split_modules = ["RWKV07IDecoderLayer"]
1192
+ _skip_keys_device_placement = "past_key_values"
1193
+ _supports_flash_attn_2 = True
1194
+ _supports_sdpa = True
1195
+ _supports_flex_attn = True
1196
+
1197
+ _supports_cache_class = True
1198
+ _supports_quantized_cache = True
1199
+ _supports_static_cache = True
1200
+
1201
+ _can_compile_fullgraph = (
1202
+ is_grouped_mm_available()
1203
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
1204
+ _supports_attention_backend = True
1205
+ _can_record_outputs = {
1206
+ "hidden_states": RWKV07IDecoderLayer,
1207
+ #"attentions": Glm4MoeLiteAttention,
1208
+ }
1209
+ _keep_in_fp32_modules_strict = ["e_score_correction_bias"]
1210
+
1211
+ # @torch.no_grad()
1212
+ # def _init_weights(self, module):
1213
+ # super()._init_weights(module)
1214
+ # if isinstance(module, Glm4MoeLiteTopkRouter):
1215
+ # init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
1216
+ # init.zeros_(module.e_score_correction_bias)
1217
+ # elif isinstance(module, Glm4MoeLiteNaiveMoe):
1218
+ # init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
1219
+ # init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
1220
+
1221
+ # def _init_weights(self, module):
1222
+ # std = self.config.initializer_range
1223
+ # if isinstance(module, nn.Linear):
1224
+ # module.weight.data.normal_(mean=0.0, std=std)
1225
+ # if module.bias is not None:
1226
+ # module.bias.data.zero_()
1227
+ # elif isinstance(module, nn.Embedding):
1228
+ # module.weight.data.normal_(mean=0.0, std=std)
1229
+ # if module.padding_idx is not None:
1230
+ # module.weight.data[module.padding_idx].zero_()
1231
+
1232
+
1233
+
1234
+ @auto_docstring
1235
+ class RWKV07IModel(RWKV07IPreTrainedModel):
1236
+ """
1237
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`]
1238
+
1239
+ Args:
1240
+ config: RWKV07EConfig
1241
+ """
1242
+
1243
+ def __init__(self, config: RWKV07IConfig):
1244
+ super().__init__(config)
1245
+ self.padding_idx = config.pad_token_id
1246
+ self.vocab_size = config.vocab_size
1247
+
1248
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1249
+ self.layers = nn.ModuleList(
1250
+ [RWKV07IDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1251
+ )
1252
+ self.norm = Glm4MoeLiteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1253
+ self.rotary_emb = Glm4MoeLiteRotaryEmbedding(config=config)
1254
+ self.gradient_checkpointing = False
1255
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
1256
+
1257
+ # Initialize weights and apply final processing
1258
+ self.post_init()
1259
+
1260
+ @check_model_inputs
1261
+ @auto_docstring
1262
+ def forward(
1263
+ self,
1264
+ input_ids: Optional[torch.LongTensor] = None,
1265
+ attention_mask: Optional[torch.Tensor] = None,
1266
+ position_ids: Optional[torch.LongTensor] = None,
1267
+ past_key_values: Optional[Cache] = None,
1268
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1269
+ use_cache: Optional[bool] = None,
1270
+ output_attentions: Optional[bool] = None,
1271
+ output_hidden_states: Optional[bool] = None,
1272
+ cache_position: Optional[torch.LongTensor] = None,
1273
+ **kwargs: Unpack[TransformersKwargs],
1274
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1275
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1276
+ output_hidden_states = (
1277
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1278
+ )
1279
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1280
+
1281
+ if (input_ids is None) ^ (inputs_embeds is not None):
1282
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1283
+
1284
+ if self.gradient_checkpointing and self.training and use_cache:
1285
+ logger.warning_once(
1286
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1287
+ )
1288
+ use_cache = False
1289
+
1290
+ if inputs_embeds is None:
1291
+ inputs_embeds = self.embed_tokens(input_ids)
1292
+
1293
+ if use_cache and not isinstance(past_key_values, RWKV07IState):
1294
+ past_key_values = RWKV07IState()
1295
+
1296
+ if cache_position is None:
1297
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1298
+ cache_position = torch.arange(
1299
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1300
+ )
1301
+
1302
+ if position_ids is None:
1303
+ position_ids = cache_position.unsqueeze(0)
1304
+
1305
+ # It may already have been prepared by e.g. `generate`
1306
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
1307
+ # Prepare mask arguments
1308
+ mask_kwargs = {
1309
+ "config": self.config,
1310
+ "input_embeds": inputs_embeds,
1311
+ "attention_mask": attention_mask,
1312
+ "cache_position": cache_position,
1313
+ "past_key_values": past_key_values,
1314
+ "position_ids": position_ids,
1315
+ }
1316
+ # Create the masks
1317
+ causal_mask_mapping = {
1318
+ "full_attention": create_causal_mask(**mask_kwargs),
1319
+ }
1320
+ # The sliding window alternating layers are not always activated depending on the config
1321
+ if self.has_sliding_layers:
1322
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
1323
+
1324
+ hidden_states = inputs_embeds
1325
+
1326
+ # create position embeddings to be shared across the decoder layers
1327
+ #if self.config.use_rope:
1328
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1329
+ # else:
1330
+ # position_embeddings = None
1331
+
1332
+ # decoder layers
1333
+ all_hidden_states = () if output_hidden_states else None
1334
+ all_self_attns = () if output_attentions else None
1335
+ next_decoder_cache = None
1336
+ v_first = None
1337
+ k_first = None
1338
+ frozen_residual = None
1339
+
1340
+ for decoder_layer in self.layers:
1341
+ if not is_layer_attention(self.config, decoder_layer.layer_idx):
1342
+ frozen_residual = hidden_states#rms_norm(hidden_states)
1343
+ if output_hidden_states:
1344
+ all_hidden_states += (hidden_states,)
1345
+
1346
+ attention_mask = causal_mask_mapping[decoder_layer.attention_type]
1347
+ if attention_mask is not None and attention_mask.ndim == 1:
1348
+ attention_mask = None
1349
+ #attention_mask = None
1350
+
1351
+ layer_outputs = decoder_layer(
1352
+ hidden_states,
1353
+ frozen_residual=frozen_residual,
1354
+ attention_mask=attention_mask,
1355
+ position_ids=position_ids,
1356
+ past_key_values=past_key_values,
1357
+ output_attentions=output_attentions,
1358
+ use_cache=use_cache,
1359
+ cache_position=cache_position,
1360
+ position_embeddings=position_embeddings,
1361
+ v_first=v_first,
1362
+ k_first=k_first
1363
+ )
1364
+
1365
+ hidden_states = layer_outputs[0]
1366
+ v_first = layer_outputs[1]
1367
+ k_first = layer_outputs[2]
1368
+
1369
+ if output_attentions:
1370
+ all_self_attns += (layer_outputs[2],)
1371
+
1372
+ hidden_states = self.norm(hidden_states)
1373
+
1374
+ # add hidden states from the last decoder layer
1375
+ if output_hidden_states:
1376
+ all_hidden_states += (hidden_states,)
1377
+
1378
+ #if return_legacy_cache:
1379
+ # next_cache = next_cache.to_legacy_cache()
1380
+
1381
+ return BaseModelOutputWithPast(
1382
+ last_hidden_state=hidden_states,
1383
+ past_key_values=past_key_values if use_cache else None,
1384
+ hidden_states=all_hidden_states,
1385
+ attentions=all_self_attns,
1386
+ )
1387
+
1388
+ class RWKV07IMoEForCausalLM(RWKV07IPreTrainedModel, GenerationMixin):
1389
+
1390
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
1391
+ _tp_plan = {"lm_head": "colwise_rep"}
1392
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1393
+
1394
+ def __init__(self, config):
1395
+ super().__init__(config)
1396
+ self.model = RWKV07IModel(config)
1397
+ self.vocab_size = config.vocab_size
1398
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1399
+
1400
+ # Initialize weights and apply final processing
1401
+ self.post_init()
1402
+
1403
+ @can_return_tuple
1404
+ @auto_docstring
1405
+ def forward(
1406
+ self,
1407
+ input_ids: torch.LongTensor = None,
1408
+ attention_mask: Optional[torch.Tensor] = None,
1409
+ position_ids: Optional[torch.LongTensor] = None,
1410
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1411
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1412
+ labels: Optional[torch.LongTensor] = None,
1413
+ use_cache: Optional[bool] = True,
1414
+ output_attentions: Optional[bool] = None,
1415
+ output_hidden_states: Optional[bool] = None,
1416
+ cache_position: Optional[torch.LongTensor] = None,
1417
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1418
+ **loss_kwargs,
1419
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1420
+ r"""
1421
+ Args:
1422
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1423
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1424
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1425
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1426
+
1427
+ num_logits_to_keep (`int`, *optional*):
1428
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1429
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1430
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1431
+
1432
+ Returns:
1433
+
1434
+ Example:
1435
+
1436
+ ```python
1437
+ >>> from transformers import AutoTokenizer, RWKV07EQwen3ForCausalLM
1438
+
1439
+ >>> model = RWKV07EQwen3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1440
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1441
+
1442
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1443
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1444
+
1445
+ >>> # Generate
1446
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1447
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1448
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1449
+ ```"""
1450
+
1451
+ # # run the prefill only up to the last token, then run one more for the actual result
1452
+ # # we do this so that called code doesn't have to handle the dichotomy specially and can just check for L==1
1453
+ # for i in range(2):
1454
+ # all_but_one = max(1, input_ids.size(-1)-1)
1455
+ # iid = input_ids[..., i*all_but_one:(i+1)*all_but_one]
1456
+ # if iid.size(-1) == 0:
1457
+ # continue
1458
+ # pids = position_ids
1459
+ # if pids is not None:
1460
+ # pids = position_ids[..., i*all_but_one:(i+1)*all_but_one]
1461
+ # cp = cache_position
1462
+ # if cp is not None:
1463
+ # cp = cache_position[..., i*all_but_one:(i+1)*all_but_one]
1464
+ # rv = self.forward_inner(iid, attention_mask=attention_mask, position_ids=pids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cp, num_logits_to_keep=num_logits_to_keep, **loss_kwargs)
1465
+ # past_key_values = rv.past_key_values
1466
+ # return rv
1467
+
1468
+ # def forward_inner(
1469
+ # self,
1470
+ # input_ids: torch.LongTensor = None,
1471
+ # attention_mask: Optional[torch.Tensor] = None,
1472
+ # position_ids: Optional[torch.LongTensor] = None,
1473
+ # past_key_values: Optional[List[torch.FloatTensor]] = None,
1474
+ # inputs_embeds: Optional[torch.FloatTensor] = None,
1475
+ # labels: Optional[torch.LongTensor] = None,
1476
+ # use_cache: Optional[bool] = None,
1477
+ # output_attentions: Optional[bool] = None,
1478
+ # output_hidden_states: Optional[bool] = None,
1479
+ # cache_position: Optional[torch.LongTensor] = None,
1480
+ # num_logits_to_keep: int = 0,
1481
+ # **loss_kwargs,
1482
+ # ) -> Union[Tuple, CausalLMOutputWithPast]:
1483
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1484
+ output_hidden_states = (
1485
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1486
+ )
1487
+
1488
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1489
+ outputs = self.model(
1490
+ input_ids=input_ids,
1491
+ attention_mask=attention_mask,
1492
+ position_ids=position_ids,
1493
+ past_key_values=past_key_values,
1494
+ inputs_embeds=inputs_embeds,
1495
+ use_cache=use_cache,
1496
+ output_attentions=output_attentions,
1497
+ output_hidden_states=output_hidden_states,
1498
+ cache_position=cache_position,
1499
+ )
1500
+
1501
+ hidden_states = outputs.last_hidden_state
1502
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1503
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1504
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1505
+
1506
+ loss = None
1507
+ if labels is not None:
1508
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **loss_kwargs)
1509
+
1510
+ return CausalLMOutputWithPast(
1511
+ loss=loss,
1512
+ logits=logits,
1513
+ past_key_values=outputs.past_key_values,
1514
+ hidden_states=outputs.hidden_states,
1515
+ attentions=outputs.attentions,
1516
+ )
1517
+
1518
+ @auto_docstring
1519
+ class RWKV07IQwen3ForSequenceClassification(RWKV07IPreTrainedModel):
1520
+ pass
1521
+
1522
+ @auto_docstring
1523
+ class RWKV07IQwen3ForTokenClassification(RWKV07IPreTrainedModel):
1524
+ pass
1525
+
1526
+ @auto_docstring
1527
+ class RWKV07IQwen3ForQuestionAnswering(RWKV07IPreTrainedModel):
1528
+ base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers==5.0.0
2
+ flash-linear-attention
3
+ fastapi
4
+ uvicorn
teacher.txt ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hf (pretrained=/workspace/llm/GLM-4.7-Flash,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 32
2
+ | Tasks |Version|Filter|n-shot|Metric| |Value | |Stderr|
3
+ |---------------------------------------|------:|------|-----:|------|---|-----:|---|-----:|
4
+ |mmlu | 2|none | |acc |↑ |0.7072|± |0.0035|
5
+ | - humanities | 2|none | |acc |↑ |0.6113|± |0.0063|
6
+ | - formal_logic | 1|none | 0|acc |↑ |0.5873|± |0.0440|
7
+ | - high_school_european_history | 1|none | 0|acc |↑ |0.7939|± |0.0316|
8
+ | - high_school_us_history | 1|none | 0|acc |↑ |0.8725|± |0.0234|
9
+ | - high_school_world_history | 1|none | 0|acc |↑ |0.8692|± |0.0219|
10
+ | - international_law | 1|none | 0|acc |↑ |0.8512|± |0.0325|
11
+ | - jurisprudence | 1|none | 0|acc |↑ |0.7870|± |0.0396|
12
+ | - logical_fallacies | 1|none | 0|acc |↑ |0.8160|± |0.0304|
13
+ | - moral_disputes | 1|none | 0|acc |↑ |0.7630|± |0.0229|
14
+ | - moral_scenarios | 1|none | 0|acc |↑ |0.2380|± |0.0142|
15
+ | - philosophy | 1|none | 0|acc |↑ |0.8006|± |0.0227|
16
+ | - prehistory | 1|none | 0|acc |↑ |0.8210|± |0.0213|
17
+ | - professional_law | 1|none | 0|acc |↑ |0.5385|± |0.0127|
18
+ | - world_religions | 1|none | 0|acc |↑ |0.8655|± |0.0262|
19
+ | - other | 2|none | |acc |↑ |0.7757|± |0.0071|
20
+ | - business_ethics | 1|none | 0|acc |↑ |0.7300|± |0.0446|
21
+ | - clinical_knowledge | 1|none | 0|acc |↑ |0.7585|± |0.0263|
22
+ | - college_medicine | 1|none | 0|acc |↑ |0.7514|± |0.0330|
23
+ | - global_facts | 1|none | 0|acc |↑ |0.3700|± |0.0485|
24
+ | - human_aging | 1|none | 0|acc |↑ |0.7444|± |0.0293|
25
+ | - management | 1|none | 0|acc |↑ |0.8641|± |0.0339|
26
+ | - marketing | 1|none | 0|acc |↑ |0.9017|± |0.0195|
27
+ | - medical_genetics | 1|none | 0|acc |↑ |0.8600|± |0.0349|
28
+ | - miscellaneous | 1|none | 0|acc |↑ |0.8710|± |0.0120|
29
+ | - nutrition | 1|none | 0|acc |↑ |0.8072|± |0.0226|
30
+ | - professional_accounting | 1|none | 0|acc |↑ |0.5780|± |0.0295|
31
+ | - professional_medicine | 1|none | 0|acc |↑ |0.8493|± |0.0217|
32
+ | - virology | 1|none | 0|acc |↑ |0.5663|± |0.0386|
33
+ | - social sciences | 2|none | |acc |↑ |0.8203|± |0.0068|
34
+ | - econometrics | 1|none | 0|acc |↑ |0.6228|± |0.0456|
35
+ | - high_school_geography | 1|none | 0|acc |↑ |0.8333|± |0.0266|
36
+ | - high_school_government_and_politics| 1|none | 0|acc |↑ |0.9223|± |0.0193|
37
+ | - high_school_macroeconomics | 1|none | 0|acc |↑ |0.7641|± |0.0215|
38
+ | - high_school_microeconomics | 1|none | 0|acc |↑ |0.8782|± |0.0212|
39
+ | - high_school_psychology | 1|none | 0|acc |↑ |0.9064|± |0.0125|
40
+ | - human_sexuality | 1|none | 0|acc |↑ |0.8550|± |0.0309|
41
+ | - professional_psychology | 1|none | 0|acc |↑ |0.7778|± |0.0168|
42
+ | - public_relations | 1|none | 0|acc |↑ |0.7091|± |0.0435|
43
+ | - security_studies | 1|none | 0|acc |↑ |0.7510|± |0.0277|
44
+ | - sociology | 1|none | 0|acc |↑ |0.8507|± |0.0252|
45
+ | - us_foreign_policy | 1|none | 0|acc |↑ |0.8800|± |0.0327|
46
+ | - stem | 2|none | |acc |↑ |0.6727|± |0.0080|
47
+ | - abstract_algebra | 1|none | 0|acc |↑ |0.5400|± |0.0501|
48
+ | - anatomy | 1|none | 0|acc |↑ |0.7037|± |0.0394|
49
+ | - astronomy | 1|none | 0|acc |↑ |0.8289|± |0.0306|
50
+ | - college_biology | 1|none | 0|acc |↑ |0.8264|± |0.0317|
51
+ | - college_chemistry | 1|none | 0|acc |↑ |0.5600|± |0.0499|
52
+ | - college_computer_science | 1|none | 0|acc |↑ |0.6400|± |0.0482|
53
+ | - college_mathematics | 1|none | 0|acc |↑ |0.4700|± |0.0502|
54
+ | - college_physics | 1|none | 0|acc |↑ |0.5490|± |0.0495|
55
+ | - computer_security | 1|none | 0|acc |↑ |0.7900|± |0.0409|
56
+ | - conceptual_physics | 1|none | 0|acc |↑ |0.8255|± |0.0248|
57
+ | - electrical_engineering | 1|none | 0|acc |↑ |0.7310|± |0.0370|
58
+ | - elementary_mathematics | 1|none | 0|acc |↑ |0.6270|± |0.0249|
59
+ | - high_school_biology | 1|none | 0|acc |↑ |0.8677|± |0.0193|
60
+ | - high_school_chemistry | 1|none | 0|acc |↑ |0.6946|± |0.0324|
61
+ | - high_school_computer_science | 1|none | 0|acc |↑ |0.8300|± |0.0378|
62
+ | - high_school_mathematics | 1|none | 0|acc |↑ |0.4148|± |0.0300|
63
+ | - high_school_physics | 1|none | 0|acc |↑ |0.5894|± |0.0402|
64
+ | - high_school_statistics | 1|none | 0|acc |↑ |0.6111|± |0.0332|
65
+ | - machine_learning | 1|none | 0|acc |↑ |0.5536|± |0.0472|
66
+
67
+ | Groups |Version|Filter|n-shot|Metric| |Value | |Stderr|
68
+ |------------------|------:|------|------|------|---|-----:|---|-----:|
69
+ |mmlu | 2|none | |acc |↑ |0.7072|± |0.0035|
70
+ | - humanities | 2|none | |acc |↑ |0.6113|± |0.0063|
71
+ | - other | 2|none | |acc |↑ |0.7757|± |0.0071|
72
+ | - social sciences| 2|none | |acc |↑ |0.8203|± |0.0068|
73
+ | - stem | 2|none | |acc |↑ |0.6727|± |0.0080|
74
+
75
+ hf (pretrained=/workspace/llm/GLM-4.7-Flash,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 32
76
+ |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
77
+ |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
78
+ |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8264|± |0.0104|
79
+ | | |strict-match | 5|exact_match|↑ |0.8271|± |0.0104|
test.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
+ #quantization_config = BitsAndBytesConfig(load_in_8bit=True)
4
+
5
+ MODEL_PATH = "/workspace/output/glm4_7_30b/hf_temp_07i/"
6
+ #MODEL_PATH = "/workspace/llm/GLM-4.7-Flash/"
7
+
8
+ messages = [{"role": "user", "content": "who is rick astley?"}]
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH
10
+ ,torch_dtype="auto",
11
+ device_map="auto",
12
+ trust_remote_code=True,
13
+ #quantization_config=quantization_config
14
+ )
15
+ inputs = tokenizer.apply_chat_template(
16
+ messages,
17
+ tokenize=True,
18
+ add_generation_prompt=True,
19
+ return_dict=True,
20
+ enable_thinking=False,
21
+
22
+ return_tensors="pt",
23
+ )
24
+
25
+ print(type(tokenizer))
26
+ print("chat_template is None?", tokenizer.chat_template is None)
27
+ print("chat_template head:\n", (tokenizer.chat_template or "")[:400])
28
+
29
+ print(inputs)
30
+
31
+
32
+ print('---------------------------')
33
+ print(tokenizer.decode(inputs['input_ids']))
34
+ #exit()
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ pretrained_model_name_or_path=MODEL_PATH,
37
+ torch_dtype=torch.bfloat16,
38
+ device_map="auto",
39
+ trust_remote_code=True,
40
+ )
41
+ inputs = inputs.to(model.device)
42
+ generated_ids = model.generate(**inputs, max_new_tokens=256,use_cache=True, do_sample=True)
43
+ output_text = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[1]:])
44
+
45
+ print('--------------------------------------------------------------------------------------')
46
+ print(output_text)
47
+
test_client_api.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import time
4
+ import sys
5
+
6
+ BASE_URL = "http://localhost:8000/v1"
7
+ MODEL_NAME = "RWKV-GLM-4.7-Flash-Preview-v0.1"
8
+
9
+ # ==========================================================
10
+ # Utility
11
+ # ==========================================================
12
+ def print_section(title):
13
+ print("\n" + "=" * 60)
14
+ print(title)
15
+ print("=" * 60)
16
+
17
+
18
+ def safe_json(resp):
19
+ try:
20
+ return resp.json()
21
+ except:
22
+ print("❌ JSON decode failed")
23
+ print(resp.text)
24
+ sys.exit(1)
25
+
26
+
27
+ # ==========================================================
28
+ # 1️⃣ Models API
29
+ # ==========================================================
30
+ def test_models():
31
+ print_section("TEST: /v1/models")
32
+
33
+ resp = requests.get(f"{BASE_URL}/models")
34
+ assert resp.status_code == 200, "Models API failed"
35
+
36
+ data = safe_json(resp)
37
+
38
+ assert "data" in data, "No model list returned"
39
+ assert len(data["data"]) > 0, "Empty model list"
40
+
41
+ print("✅ Models endpoint OK")
42
+ print("Available models:", [m["id"] for m in data["data"]])
43
+
44
+
45
+ # ==========================================================
46
+ # 2️⃣ Non-stream basic
47
+ # ==========================================================
48
+ def test_basic_completion():
49
+ print_section("TEST: Basic Non-Streaming Completion")
50
+
51
+ payload = {
52
+ "model": MODEL_NAME,
53
+ "messages": [{"role": "user", "content": "Say hello."}],
54
+ "max_tokens": 30,
55
+ "stream": False
56
+ }
57
+
58
+ resp = requests.post(
59
+ f"{BASE_URL}/chat/completions",
60
+ headers={"Content-Type": "application/json"},
61
+ data=json.dumps(payload)
62
+ )
63
+
64
+ assert resp.status_code == 200, "Completion failed"
65
+
66
+ data = safe_json(resp)
67
+
68
+ assert "choices" in data, "No choices returned"
69
+ assert "usage" in data, "No usage returned"
70
+
71
+ print("Assistant:", data["choices"][0]["message"]["content"])
72
+ print("Usage:", data["usage"])
73
+ print("✅ Basic completion OK")
74
+
75
+
76
+ # ==========================================================
77
+ # 3️⃣ Streaming
78
+ # ==========================================================
79
+ def test_streaming():
80
+ print_section("TEST: Streaming Completion")
81
+
82
+ payload = {
83
+ "model": MODEL_NAME,
84
+ "messages": [{"role": "user", "content": "Count from 1 to 5."}],
85
+ "max_tokens": 50,
86
+ "stream": True
87
+ }
88
+
89
+ full_text = ""
90
+
91
+ with requests.post(
92
+ f"{BASE_URL}/chat/completions",
93
+ headers={"Content-Type": "application/json"},
94
+ data=json.dumps(payload),
95
+ stream=True
96
+ ) as resp:
97
+
98
+ assert resp.status_code == 200, "Streaming failed"
99
+
100
+ for line in resp.iter_lines():
101
+ if line:
102
+ decoded = line.decode("utf-8")
103
+
104
+ if decoded.startswith("data: "):
105
+ content = decoded[len("data: "):]
106
+
107
+ if content == "[DONE]":
108
+ break
109
+
110
+ chunk = json.loads(content)
111
+ delta = chunk["choices"][0]["delta"]
112
+
113
+ if "content" in delta:
114
+ print(delta["content"], end="", flush=True)
115
+ full_text += delta["content"]
116
+
117
+ print("\n\n✅ Streaming OK")
118
+ assert len(full_text) > 0, "Streaming returned empty"
119
+
120
+
121
+ # ==========================================================
122
+ # 4️⃣ Sampling Variations
123
+ # ==========================================================
124
+ def test_sampling_variations():
125
+ print_section("TEST: Sampling Variations")
126
+
127
+ base_payload = {
128
+ "model": MODEL_NAME,
129
+ "messages": [{"role": "user", "content": "Write a creative sentence about AI."}],
130
+ "max_tokens": 50,
131
+ "stream": False
132
+ }
133
+
134
+ configs = [
135
+ {"temperature": 0.0},
136
+ {"temperature": 0.7},
137
+ {"top_p": 0.8},
138
+ {"top_k": 20},
139
+ {"repetition_penalty": 1.2},
140
+ {"presence_penalty": 0.5},
141
+ {"frequency_penalty": 0.5}
142
+ ]
143
+
144
+ for cfg in configs:
145
+ payload = base_payload.copy()
146
+ payload.update(cfg)
147
+
148
+ resp = requests.post(
149
+ f"{BASE_URL}/chat/completions",
150
+ headers={"Content-Type": "application/json"},
151
+ data=json.dumps(payload)
152
+ )
153
+
154
+ assert resp.status_code == 200, f"Sampling failed: {cfg}"
155
+
156
+ data = safe_json(resp)
157
+
158
+ text = data["choices"][0]["message"]["content"]
159
+
160
+ print(f"\nConfig: {cfg}")
161
+ print("Output:", text[:120], "...")
162
+
163
+ print("\n✅ Sampling parameter variations OK")
164
+
165
+
166
+ # ==========================================================
167
+ # 5️⃣ Deterministic Check (temperature=0)
168
+ # ==========================================================
169
+ def test_deterministic():
170
+ print_section("TEST: Deterministic Mode (temperature=0)")
171
+
172
+ payload = {
173
+ "model": MODEL_NAME,
174
+ "messages": [{"role": "user", "content": "Define gravity in one sentence."}],
175
+ "temperature": 0.0,
176
+ "max_tokens": 50,
177
+ "stream": False
178
+ }
179
+
180
+ resp1 = requests.post(f"{BASE_URL}/chat/completions",
181
+ headers={"Content-Type": "application/json"},
182
+ data=json.dumps(payload))
183
+ resp2 = requests.post(f"{BASE_URL}/chat/completions",
184
+ headers={"Content-Type": "application/json"},
185
+ data=json.dumps(payload))
186
+
187
+ out1 = safe_json(resp1)["choices"][0]["message"]["content"]
188
+ out2 = safe_json(resp2)["choices"][0]["message"]["content"]
189
+
190
+ print("Run1:", out1)
191
+ print("Run2:", out2)
192
+
193
+ assert out1 == out2, "❌ Deterministic mode not deterministic"
194
+ print("✅ Deterministic check OK")
195
+
196
+
197
+ # ==========================================================
198
+ # 6️⃣ Error Handling
199
+ # ==========================================================
200
+ def test_error_handling():
201
+ print_section("TEST: Error Handling")
202
+
203
+ payload = {
204
+ "model": MODEL_NAME,
205
+ # missing messages intentionally
206
+ }
207
+
208
+ resp = requests.post(
209
+ f"{BASE_URL}/chat/completions",
210
+ headers={"Content-Type": "application/json"},
211
+ data=json.dumps(payload)
212
+ )
213
+
214
+ if resp.status_code != 200:
215
+ print("✅ Server correctly handled bad request")
216
+ else:
217
+ print("⚠️ Warning: server did not reject bad request")
218
+
219
+
220
+ # ==========================================================
221
+ # Main
222
+ # ==========================================================
223
+ if __name__ == "__main__":
224
+ start = time.time()
225
+
226
+ test_models()
227
+ test_basic_completion()
228
+ test_streaming()
229
+ test_sampling_variations()
230
+ test_deterministic()
231
+ test_error_handling()
232
+
233
+ print_section("ALL TESTS PASSED")
234
+ print(f"Total time: {round(time.time() - start, 2)} sec")
test_openai_api.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ import uuid
4
+ import torch
5
+ from threading import Thread, Event
6
+ from fastapi import FastAPI, Request
7
+ from fastapi.responses import StreamingResponse
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ TextIteratorStreamer,
12
+ LogitsProcessor,
13
+ LogitsProcessorList,
14
+ StoppingCriteria,
15
+ StoppingCriteriaList,
16
+ )
17
+
18
+ # ==========================================================
19
+ # 設定
20
+ # ==========================================================
21
+ MODEL_ID = "/workspace/output/glm4_7_30b/hf_temp_07i"
22
+ VIEW_NAME = "RWKV-GLM-4.7-Flash"
23
+ HOST = "0.0.0.0"
24
+ PORT = 8000
25
+
26
+ # ==========================================================
27
+ # モデルロード
28
+ # ==========================================================
29
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ MODEL_ID,
32
+ torch_dtype=torch.bfloat16,
33
+ device_map="auto",
34
+ trust_remote_code=True,
35
+ )
36
+
37
+ app = FastAPI()
38
+
39
+
40
+ # ==========================================================
41
+ # Logits Processors
42
+ # ==========================================================
43
+ class PresencePenaltyProcessor(LogitsProcessor):
44
+ def __init__(self, penalty):
45
+ self.penalty = penalty
46
+
47
+ def __call__(self, input_ids, scores):
48
+ for batch_idx in range(input_ids.shape[0]):
49
+ unique_tokens = torch.unique(input_ids[batch_idx])
50
+ scores[batch_idx, unique_tokens] -= self.penalty
51
+ return scores
52
+
53
+
54
+ class FrequencyPenaltyProcessor(LogitsProcessor):
55
+ def __init__(self, penalty):
56
+ self.penalty = penalty
57
+
58
+ def __call__(self, input_ids, scores):
59
+ for batch_idx in range(input_ids.shape[0]):
60
+ token_counts = torch.bincount(
61
+ input_ids[batch_idx], minlength=scores.shape[-1]
62
+ )
63
+ scores[batch_idx] -= token_counts * self.penalty
64
+ return scores
65
+
66
+
67
+ # ==========================================================
68
+ # Cancellable Stopping Criteria
69
+ # ==========================================================
70
+ class CancelledStoppingCriteria(StoppingCriteria):
71
+ """threading.Event がセットされたら生成を打ち切る"""
72
+
73
+ def __init__(self, stop_event: Event):
74
+ self.stop_event = stop_event
75
+
76
+ def __call__(self, input_ids, scores, **kwargs):
77
+ return self.stop_event.is_set()
78
+
79
+
80
+ # ==========================================================
81
+ # Models Endpoint
82
+ # ==========================================================
83
+ @app.get("/v1/models")
84
+ async def list_models():
85
+ return {
86
+ "object": "list",
87
+ "data": [
88
+ {
89
+ "id": VIEW_NAME,
90
+ "object": "model",
91
+ "created": int(time.time()),
92
+ "owned_by": "local",
93
+ }
94
+ ],
95
+ }
96
+
97
+
98
+ # ==========================================================
99
+ # Chat Completions Endpoint
100
+ # ==========================================================
101
+ @app.post("/v1/chat/completions")
102
+ async def chat_completions(request: Request):
103
+ body = await request.json()
104
+
105
+ model_name = body.get("model", MODEL_ID)
106
+ messages = body["messages"]
107
+ stream = body.get("stream", False)
108
+
109
+ temperature = body.get("temperature", 1.0)
110
+ top_p = body.get("top_p", 1.0)
111
+ top_k = body.get("top_k", 50)
112
+ repetition_penalty = body.get("repetition_penalty", 1.0)
113
+ presence_penalty = body.get("presence_penalty", 0.0)
114
+ frequency_penalty = body.get("frequency_penalty", 0.0)
115
+ max_tokens = body.get("max_tokens", 2048)
116
+
117
+ prompt = tokenizer.apply_chat_template(
118
+ messages, tokenize=False, add_generation_prompt=True#,enable_thinking=False
119
+ )
120
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
121
+
122
+ processors = LogitsProcessorList()
123
+ if presence_penalty > 0:
124
+ processors.append(PresencePenaltyProcessor(presence_penalty))
125
+ if frequency_penalty > 0:
126
+ processors.append(FrequencyPenaltyProcessor(frequency_penalty))
127
+
128
+ generate_kwargs = dict(
129
+ **inputs,
130
+ max_new_tokens=max_tokens,
131
+ temperature=temperature,
132
+ top_p=top_p,
133
+ top_k=top_k,
134
+ repetition_penalty=repetition_penalty,
135
+ logits_processor=processors,
136
+ do_sample=temperature > 0,
137
+ use_cache=True, # 生成時に明示的に有効化
138
+ )
139
+
140
+ # ================= Non-stream =================
141
+ if not stream:
142
+ outputs = model.generate(**generate_kwargs)
143
+ completion_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
144
+ generated_text = tokenizer.decode(
145
+ outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False
146
+ )
147
+ return {
148
+ "id": f"chatcmpl-{uuid.uuid4().hex}",
149
+ "object": "chat.completion",
150
+ "created": int(time.time()),
151
+ "model": model_name,
152
+ "choices": [
153
+ {
154
+ "index": 0,
155
+ "message": {"role": "assistant", "content": generated_text},
156
+ "finish_reason": "stop",
157
+ }
158
+ ],
159
+ "usage": {
160
+ "prompt_tokens": inputs["input_ids"].shape[1],
161
+ "completion_tokens": completion_tokens,
162
+ "total_tokens": inputs["input_ids"].shape[1] + completion_tokens,
163
+ },
164
+ }
165
+
166
+ # ================= Streaming =================
167
+ stop_event = Event()
168
+
169
+ stopping_criteria = StoppingCriteriaList(
170
+ [CancelledStoppingCriteria(stop_event)]
171
+ )
172
+
173
+ streamer = TextIteratorStreamer(
174
+ tokenizer, skip_prompt=True, skip_special_tokens=True
175
+ )
176
+
177
+ generation_kwargs = dict(
178
+ **generate_kwargs,
179
+ streamer=streamer,
180
+ stopping_criteria=stopping_criteria,
181
+ )
182
+
183
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
184
+ thread.start()
185
+
186
+ async def event_generator():
187
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}"
188
+ firsttime = "<think>"
189
+ cancelled = False
190
+
191
+ try:
192
+ for new_text in streamer:
193
+ if await request.is_disconnected():
194
+ stop_event.set()
195
+ cancelled = True
196
+ break
197
+
198
+ chunk = {
199
+ "id": completion_id,
200
+ "object": "chat.completion.chunk",
201
+ "created": int(time.time()),
202
+ "model": model_name,
203
+ "choices": [
204
+ {
205
+ "index": 0,
206
+ "delta": {"content": firsttime + new_text},
207
+ "finish_reason": None,
208
+ }
209
+ ],
210
+ }
211
+ firsttime = ""
212
+ yield f"data: {json.dumps(chunk)}\n\n"
213
+
214
+ if not cancelled:
215
+ yield "data: [DONE]\n\n"
216
+
217
+ except Exception:
218
+ stop_event.set()
219
+ cancelled = True
220
+ finally:
221
+ if cancelled:
222
+ for _ in streamer:
223
+ pass
224
+ thread.join(timeout=10)
225
+
226
+ return StreamingResponse(
227
+ event_generator(), media_type="text/event-stream"
228
+ )
229
+
230
+
231
+ # ==========================================================
232
+ # Python実行時に自動起動
233
+ # ==========================================================
234
+ if __name__ == "__main__":
235
+ import uvicorn
236
+
237
+ uvicorn.run(
238
+ "test_openai_api:app",
239
+ host=HOST,
240
+ port=PORT,
241
+ reload=False,
242
+ )
tokenizer_config.json ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "154820": {
4
+ "content": "<|endoftext|>",
5
+ "single_word": false,
6
+ "lstrip": false,
7
+ "rstrip": false,
8
+ "normalized": false,
9
+ "special": true
10
+ },
11
+ "154821": {
12
+ "content": "[MASK]",
13
+ "single_word": false,
14
+ "lstrip": false,
15
+ "rstrip": false,
16
+ "normalized": false,
17
+ "special": true
18
+ },
19
+ "154822": {
20
+ "content": "[gMASK]",
21
+ "single_word": false,
22
+ "lstrip": false,
23
+ "rstrip": false,
24
+ "normalized": false,
25
+ "special": true
26
+ },
27
+ "154823": {
28
+ "content": "[sMASK]",
29
+ "single_word": false,
30
+ "lstrip": false,
31
+ "rstrip": false,
32
+ "normalized": false,
33
+ "special": true
34
+ },
35
+ "154824": {
36
+ "content": "<sop>",
37
+ "single_word": false,
38
+ "lstrip": false,
39
+ "rstrip": false,
40
+ "normalized": false,
41
+ "special": true
42
+ },
43
+ "154825": {
44
+ "content": "<eop>",
45
+ "single_word": false,
46
+ "lstrip": false,
47
+ "rstrip": false,
48
+ "normalized": false,
49
+ "special": true
50
+ },
51
+ "154826": {
52
+ "content": "<|system|>",
53
+ "single_word": false,
54
+ "lstrip": false,
55
+ "rstrip": false,
56
+ "normalized": false,
57
+ "special": true
58
+ },
59
+ "154827": {
60
+ "content": "<|user|>",
61
+ "single_word": false,
62
+ "lstrip": false,
63
+ "rstrip": false,
64
+ "normalized": false,
65
+ "special": true
66
+ },
67
+ "154828": {
68
+ "content": "<|assistant|>",
69
+ "single_word": false,
70
+ "lstrip": false,
71
+ "rstrip": false,
72
+ "normalized": false,
73
+ "special": true
74
+ },
75
+ "154829": {
76
+ "content": "<|observation|>",
77
+ "single_word": false,
78
+ "lstrip": false,
79
+ "rstrip": false,
80
+ "normalized": false,
81
+ "special": true
82
+ },
83
+ "154830": {
84
+ "content": "<|begin_of_image|>",
85
+ "single_word": false,
86
+ "lstrip": false,
87
+ "rstrip": false,
88
+ "normalized": false,
89
+ "special": true
90
+ },
91
+ "154831": {
92
+ "content": "<|end_of_image|>",
93
+ "single_word": false,
94
+ "lstrip": false,
95
+ "rstrip": false,
96
+ "normalized": false,
97
+ "special": true
98
+ },
99
+ "154832": {
100
+ "content": "<|begin_of_video|>",
101
+ "single_word": false,
102
+ "lstrip": false,
103
+ "rstrip": false,
104
+ "normalized": false,
105
+ "special": true
106
+ },
107
+ "154833": {
108
+ "content": "<|end_of_video|>",
109
+ "single_word": false,
110
+ "lstrip": false,
111
+ "rstrip": false,
112
+ "normalized": false,
113
+ "special": true
114
+ },
115
+ "154834": {
116
+ "content": "<|begin_of_audio|>",
117
+ "single_word": false,
118
+ "lstrip": false,
119
+ "rstrip": false,
120
+ "normalized": false,
121
+ "special": true
122
+ },
123
+ "154835": {
124
+ "content": "<|end_of_audio|>",
125
+ "single_word": false,
126
+ "lstrip": false,
127
+ "rstrip": false,
128
+ "normalized": false,
129
+ "special": true
130
+ },
131
+ "154836": {
132
+ "content": "<|begin_of_transcription|>",
133
+ "single_word": false,
134
+ "lstrip": false,
135
+ "rstrip": false,
136
+ "normalized": false,
137
+ "special": true
138
+ },
139
+ "154837": {
140
+ "content": "<|end_of_transcription|>",
141
+ "single_word": false,
142
+ "lstrip": false,
143
+ "rstrip": false,
144
+ "normalized": false,
145
+ "special": true
146
+ },
147
+ "154838": {
148
+ "content": "<|code_prefix|>",
149
+ "single_word": false,
150
+ "lstrip": false,
151
+ "rstrip": false,
152
+ "normalized": false,
153
+ "special": false
154
+ },
155
+ "154839": {
156
+ "content": "<|code_middle|>",
157
+ "single_word": false,
158
+ "lstrip": false,
159
+ "rstrip": false,
160
+ "normalized": false,
161
+ "special": false
162
+ },
163
+ "154840": {
164
+ "content": "<|code_suffix|>",
165
+ "single_word": false,
166
+ "lstrip": false,
167
+ "rstrip": false,
168
+ "normalized": false,
169
+ "special": false
170
+ },
171
+ "154841": {
172
+ "content": "<think>",
173
+ "single_word": false,
174
+ "lstrip": false,
175
+ "rstrip": false,
176
+ "normalized": false,
177
+ "special": false
178
+ },
179
+ "154842": {
180
+ "content": "</think>",
181
+ "single_word": false,
182
+ "lstrip": false,
183
+ "rstrip": false,
184
+ "normalized": false,
185
+ "special": false
186
+ },
187
+ "154843": {
188
+ "content": "<tool_call>",
189
+ "single_word": false,
190
+ "lstrip": false,
191
+ "rstrip": false,
192
+ "normalized": false,
193
+ "special": false
194
+ },
195
+ "154844": {
196
+ "content": "</tool_call>",
197
+ "single_word": false,
198
+ "lstrip": false,
199
+ "rstrip": false,
200
+ "normalized": false,
201
+ "special": false
202
+ },
203
+ "154845": {
204
+ "content": "<tool_response>",
205
+ "single_word": false,
206
+ "lstrip": false,
207
+ "rstrip": false,
208
+ "normalized": false,
209
+ "special": false
210
+ },
211
+ "154846": {
212
+ "content": "</tool_response>",
213
+ "single_word": false,
214
+ "lstrip": false,
215
+ "rstrip": false,
216
+ "normalized": false,
217
+ "special": false
218
+ },
219
+ "154847": {
220
+ "content": "<arg_key>",
221
+ "single_word": false,
222
+ "lstrip": false,
223
+ "rstrip": false,
224
+ "normalized": false,
225
+ "special": false
226
+ },
227
+ "154848": {
228
+ "content": "</arg_key>",
229
+ "single_word": false,
230
+ "lstrip": false,
231
+ "rstrip": false,
232
+ "normalized": false,
233
+ "special": false
234
+ },
235
+ "154849": {
236
+ "content": "<arg_value>",
237
+ "single_word": false,
238
+ "lstrip": false,
239
+ "rstrip": false,
240
+ "normalized": false,
241
+ "special": false
242
+ },
243
+ "154850": {
244
+ "content": "</arg_value>",
245
+ "single_word": false,
246
+ "lstrip": false,
247
+ "rstrip": false,
248
+ "normalized": false,
249
+ "special": false
250
+ },
251
+ "154851": {
252
+ "content": "/nothink",
253
+ "single_word": false,
254
+ "lstrip": false,
255
+ "rstrip": false,
256
+ "normalized": false,
257
+ "special": false
258
+ },
259
+ "154852": {
260
+ "content": "<|begin_of_box|>",
261
+ "single_word": false,
262
+ "lstrip": false,
263
+ "rstrip": false,
264
+ "normalized": false,
265
+ "special": false
266
+ },
267
+ "154853": {
268
+ "content": "<|end_of_box|>",
269
+ "single_word": false,
270
+ "lstrip": false,
271
+ "rstrip": false,
272
+ "normalized": false,
273
+ "special": false
274
+ },
275
+ "154854": {
276
+ "content": "<|image|>",
277
+ "single_word": false,
278
+ "lstrip": false,
279
+ "rstrip": false,
280
+ "normalized": false,
281
+ "special": false
282
+ },
283
+ "154855": {
284
+ "content": "<|video|>",
285
+ "single_word": false,
286
+ "lstrip": false,
287
+ "rstrip": false,
288
+ "normalized": false,
289
+ "special": false
290
+ }
291
+ },
292
+ "additional_special_tokens": [
293
+ "<|endoftext|>",
294
+ "[MASK]",
295
+ "[gMASK]",
296
+ "[sMASK]",
297
+ "<sop>",
298
+ "<eop>",
299
+ "<|system|>",
300
+ "<|user|>",
301
+ "<|assistant|>",
302
+ "<|observation|>",
303
+ "<|begin_of_image|>",
304
+ "<|end_of_image|>",
305
+ "<|begin_of_video|>",
306
+ "<|end_of_video|>",
307
+ "<|begin_of_audio|>",
308
+ "<|end_of_audio|>",
309
+ "<|begin_of_transcription|>",
310
+ "<|end_of_transcription|>"
311
+ ],
312
+ "clean_up_tokenization_spaces": false,
313
+ "do_lower_case": false,
314
+ "eos_token": "<|endoftext|>",
315
+ "extra_special_tokens": {},
316
+ "model_max_length": 128000,
317
+ "pad_token": "<|endoftext|>",
318
+ "padding_side": "left",
319
+ "remove_space": false,
320
+ "tokenizer_class": "PreTrainedTokenizer"
321
+ }