knockknock404 commited on
Commit
d44d492
·
1 Parent(s): 79c57dd

Upload 4 files

Browse files
configuration_qwen.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from transformers import PretrainedConfig
7
+
8
+
9
+ class QWenConfig(PretrainedConfig):
10
+ model_type = "qwen"
11
+ keys_to_ignore_at_inference = ["past_key_values"]
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size=151936,
16
+ hidden_size=4096,
17
+ num_hidden_layers=32,
18
+ num_attention_heads=32,
19
+ emb_dropout_prob=0.0,
20
+ attn_dropout_prob=0.0,
21
+ layer_norm_epsilon=1e-6,
22
+ initializer_range=0.02,
23
+ max_position_embeddings=8192,
24
+ scale_attn_weights=True,
25
+ use_cache=True,
26
+ bf16=False,
27
+ fp16=False,
28
+ fp32=False,
29
+ kv_channels=128,
30
+ rotary_pct=1.0,
31
+ rotary_emb_base=10000,
32
+ use_dynamic_ntk=True,
33
+ use_logn_attn=True,
34
+ use_flash_attn="auto",
35
+ intermediate_size=22016,
36
+ no_bias=True,
37
+ tie_word_embeddings=False,
38
+ **kwargs,
39
+ ):
40
+ self.vocab_size = vocab_size
41
+ self.hidden_size = hidden_size
42
+ self.intermediate_size = intermediate_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.num_attention_heads = num_attention_heads
45
+ self.emb_dropout_prob = emb_dropout_prob
46
+ self.attn_dropout_prob = attn_dropout_prob
47
+ self.layer_norm_epsilon = layer_norm_epsilon
48
+ self.initializer_range = initializer_range
49
+ self.scale_attn_weights = scale_attn_weights
50
+ self.use_cache = use_cache
51
+ self.max_position_embeddings = max_position_embeddings
52
+ self.bf16 = bf16
53
+ self.fp16 = fp16
54
+ self.fp32 = fp32
55
+ self.kv_channels = kv_channels
56
+ self.rotary_pct = rotary_pct
57
+ self.rotary_emb_base = rotary_emb_base
58
+ self.use_dynamic_ntk = use_dynamic_ntk
59
+ self.use_logn_attn = use_logn_attn
60
+ self.use_flash_attn = use_flash_attn
61
+ self.no_bias = no_bias
62
+ super().__init__(
63
+ tie_word_embeddings=tie_word_embeddings,
64
+ **kwargs
65
+ )
modeling_qwen.py ADDED
@@ -0,0 +1,1207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import math
8
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from torch.cuda.amp import autocast
14
+
15
+ from torch.nn import CrossEntropyLoss
16
+ from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
17
+ from transformers.generation.logits_process import LogitsProcessorList
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.generation.streamers import BaseStreamer
21
+ from transformers.generation.utils import GenerateOutput
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutputWithPast,
24
+ CausalLMOutputWithPast,
25
+ )
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import logging
28
+
29
+ try:
30
+ from einops import rearrange
31
+ except ImportError:
32
+ rearrange = None
33
+ from torch import nn
34
+
35
+ SUPPORT_CUDA = torch.cuda.is_available()
36
+ SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
37
+ SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
38
+
39
+ from .configuration_qwen import QWenConfig
40
+ from .qwen_generation_utils import (
41
+ HistoryType,
42
+ make_context,
43
+ decode_tokens,
44
+ get_stop_words_ids,
45
+ StopWordsLogitsProcessor,
46
+ )
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CHECKPOINT_FOR_DOC = "qwen"
52
+ _CONFIG_FOR_DOC = "QWenConfig"
53
+
54
+ QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
55
+
56
+ _ERROR_BAD_CHAT_FORMAT = """\
57
+ We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
58
+ If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
59
+ 我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
60
+ 如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
61
+ """
62
+
63
+ _SENTINEL = object()
64
+ _ERROR_STREAM_IN_CHAT = """\
65
+ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
66
+ 向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
67
+ """
68
+
69
+ apply_rotary_emb_func = None
70
+ rms_norm = None
71
+ flash_attn_unpadded_func = None
72
+
73
+
74
+ def _import_flash_attn():
75
+ global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
76
+ try:
77
+ from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
78
+ apply_rotary_emb_func = __apply_rotary_emb_func
79
+ except ImportError:
80
+ logger.warn(
81
+ "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
82
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
83
+ )
84
+
85
+ try:
86
+ from flash_attn.ops.rms_norm import rms_norm as __rms_norm
87
+ rms_norm = __rms_norm
88
+ except ImportError:
89
+ logger.warn(
90
+ "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
91
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
92
+ )
93
+
94
+ try:
95
+ import flash_attn
96
+ if not hasattr(flash_attn, '__version__'):
97
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
98
+ else:
99
+ if int(flash_attn.__version__.split(".")[0]) >= 2:
100
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
101
+ else:
102
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
103
+ flash_attn_unpadded_func = __flash_attn_unpadded_func
104
+ except ImportError:
105
+ logger.warn(
106
+ "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
107
+ "https://github.com/Dao-AILab/flash-attention"
108
+ )
109
+
110
+
111
+ class FlashSelfAttention(torch.nn.Module):
112
+ def __init__(
113
+ self,
114
+ causal=False,
115
+ softmax_scale=None,
116
+ attention_dropout=0.0,
117
+ ):
118
+ super().__init__()
119
+ assert flash_attn_unpadded_func is not None, (
120
+ "Please install FlashAttention first, " "e.g., with pip install flash-attn"
121
+ )
122
+ assert (
123
+ rearrange is not None
124
+ ), "Please install einops first, e.g., with pip install einops"
125
+ self.causal = causal
126
+ self.softmax_scale = softmax_scale
127
+ self.dropout_p = attention_dropout
128
+
129
+ def forward(self, q, k, v):
130
+ assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
131
+ assert all((i.is_cuda for i in (q, k, v)))
132
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
133
+ seqlen_k = k.shape[1]
134
+
135
+ q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
136
+ cu_seqlens_q = torch.arange(
137
+ 0,
138
+ (batch_size + 1) * seqlen_q,
139
+ step=seqlen_q,
140
+ dtype=torch.int32,
141
+ device=q.device,
142
+ )
143
+
144
+ if self.training:
145
+ assert seqlen_k == seqlen_q
146
+
147
+ is_causal = self.causal
148
+ cu_seqlens_k = cu_seqlens_q
149
+ else:
150
+ is_causal = seqlen_q == seqlen_k
151
+ cu_seqlens_k = torch.arange(
152
+ 0,
153
+ (batch_size + 1) * seqlen_k,
154
+ step=seqlen_k,
155
+ dtype=torch.int32,
156
+ device=q.device,
157
+ )
158
+ self.dropout_p = 0
159
+
160
+ output = flash_attn_unpadded_func(
161
+ q,
162
+ k,
163
+ v,
164
+ cu_seqlens_q,
165
+ cu_seqlens_k,
166
+ seqlen_q,
167
+ seqlen_k,
168
+ self.dropout_p,
169
+ softmax_scale=self.softmax_scale,
170
+ causal=is_causal,
171
+ )
172
+
173
+ new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
174
+ output = output.view(new_shape)
175
+ return output
176
+
177
+
178
+ class QWenAttention(nn.Module):
179
+ def __init__(self, config):
180
+ super().__init__()
181
+
182
+ max_positions = config.max_position_embeddings
183
+ self.register_buffer(
184
+ "bias",
185
+ torch.tril(
186
+ torch.ones((max_positions, max_positions), dtype=torch.bool)
187
+ ).view(1, 1, max_positions, max_positions),
188
+ persistent=False,
189
+ )
190
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
191
+ self.seq_length = config.seq_length
192
+
193
+ self.hidden_size = config.hidden_size
194
+ self.split_size = config.hidden_size
195
+ self.num_heads = config.num_attention_heads
196
+ self.head_dim = self.hidden_size // self.num_heads
197
+
198
+ self.use_flash_attn = config.use_flash_attn
199
+ self.scale_attn_weights = True
200
+
201
+ self.projection_size = config.kv_channels * config.num_attention_heads
202
+
203
+ assert self.projection_size % config.num_attention_heads == 0
204
+ self.hidden_size_per_attention_head = (
205
+ self.projection_size // config.num_attention_heads
206
+ )
207
+
208
+ self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
209
+
210
+ self.c_proj = nn.Linear(
211
+ config.hidden_size, self.projection_size, bias=not config.no_bias
212
+ )
213
+
214
+ self.is_fp32 = not (config.bf16 or config.fp16)
215
+ if (
216
+ self.use_flash_attn
217
+ and flash_attn_unpadded_func is not None
218
+ and not self.is_fp32
219
+ ):
220
+ self.core_attention_flash = FlashSelfAttention(
221
+ causal=True, attention_dropout=config.attn_dropout_prob
222
+ )
223
+
224
+ self.bf16 = config.bf16
225
+
226
+
227
+ self.use_dynamic_ntk = config.use_dynamic_ntk
228
+ self.use_logn_attn = config.use_logn_attn
229
+
230
+ logn_list = [
231
+ math.log(i, self.seq_length) if i > self.seq_length else 1
232
+ for i in range(1, 32768)
233
+ ]
234
+ self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
235
+
236
+ self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
237
+
238
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
239
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
240
+
241
+ if self.scale_attn_weights:
242
+ attn_weights = attn_weights / torch.full(
243
+ [],
244
+ value.size(-1) ** 0.5,
245
+ dtype=attn_weights.dtype,
246
+ device=attn_weights.device,
247
+ )
248
+
249
+ query_length, key_length = query.size(-2), key.size(-2)
250
+ causal_mask = self.bias[
251
+ :, :, key_length - query_length : key_length, :key_length
252
+ ]
253
+ mask_value = torch.finfo(attn_weights.dtype).min
254
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
255
+ attn_weights.device
256
+ )
257
+ attn_weights = torch.where(
258
+ causal_mask, attn_weights.to(attn_weights.dtype), mask_value
259
+ )
260
+
261
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
262
+
263
+ attn_weights = attn_weights.type(value.dtype)
264
+ attn_weights = self.attn_dropout(attn_weights)
265
+
266
+ if head_mask is not None:
267
+ attn_weights = attn_weights * head_mask
268
+
269
+ attn_output = torch.matmul(attn_weights, value)
270
+ attn_output = attn_output.transpose(1, 2)
271
+
272
+ return attn_output, attn_weights
273
+
274
+ def _upcast_and_reordered_attn(
275
+ self, query, key, value, attention_mask=None, head_mask=None
276
+ ):
277
+ bsz, num_heads, q_seq_len, dk = query.size()
278
+ _, _, k_seq_len, _ = key.size()
279
+
280
+ attn_weights = torch.empty(
281
+ bsz * num_heads,
282
+ q_seq_len,
283
+ k_seq_len,
284
+ dtype=torch.float32,
285
+ device=query.device,
286
+ )
287
+
288
+ scale_factor = 1.0
289
+ if self.scale_attn_weights:
290
+ scale_factor /= float(value.size(-1)) ** 0.5
291
+
292
+ with autocast(enabled=False):
293
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
294
+ -1, dk, k_seq_len
295
+ )
296
+ attn_weights = torch.baddbmm(
297
+ attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
298
+ )
299
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
300
+
301
+ query_length, key_length = query.size(-2), key.size(-2)
302
+ causal_mask = self.bias[
303
+ :, :, key_length - query_length : key_length, :key_length
304
+ ]
305
+ mask_value = torch.finfo(attn_weights.dtype).min
306
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
307
+ attn_weights.device
308
+ )
309
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
310
+
311
+ if attention_mask is not None:
312
+ attn_weights = attn_weights + attention_mask
313
+
314
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
315
+
316
+ if attn_weights.dtype != torch.float32:
317
+ raise RuntimeError(
318
+ "Error with upcasting, attn_weights does not have dtype torch.float32"
319
+ )
320
+ attn_weights = attn_weights.type(value.dtype)
321
+ attn_weights = self.attn_dropout(attn_weights)
322
+
323
+ if head_mask is not None:
324
+ attn_weights = attn_weights * head_mask
325
+
326
+ attn_output = torch.matmul(attn_weights, value)
327
+
328
+ return attn_output, attn_weights
329
+
330
+ def _split_heads(self, tensor, num_heads, attn_head_size):
331
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
332
+ tensor = tensor.view(new_shape)
333
+ return tensor
334
+
335
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
336
+ tensor = tensor.contiguous()
337
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
338
+ return tensor.view(new_shape)
339
+
340
+ def forward(
341
+ self,
342
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
343
+ rotary_pos_emb: Optional[List[torch.Tensor]] = None,
344
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
345
+ attention_mask: Optional[torch.FloatTensor] = None,
346
+ head_mask: Optional[torch.FloatTensor] = None,
347
+ encoder_hidden_states: Optional[torch.Tensor] = None,
348
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
349
+ output_attentions: Optional[bool] = False,
350
+ use_cache: Optional[bool] = False,
351
+ ):
352
+
353
+ mixed_x_layer = self.c_attn(hidden_states)
354
+
355
+ query, key, value = mixed_x_layer.split(self.split_size, dim=2)
356
+
357
+ query = self._split_heads(query, self.num_heads, self.head_dim)
358
+ key = self._split_heads(key, self.num_heads, self.head_dim)
359
+ value = self._split_heads(value, self.num_heads, self.head_dim)
360
+
361
+ if rotary_pos_emb is not None:
362
+ cur_len = query.shape[1]
363
+ rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
364
+ rotary_pos_emb = (rotary_pos_emb,) * 2
365
+ q_pos_emb, k_pos_emb = rotary_pos_emb
366
+ # Slice the pos emb for current inference
367
+ query = apply_rotary_pos_emb(query, q_pos_emb)
368
+ key = apply_rotary_pos_emb(key, k_pos_emb)
369
+
370
+ if layer_past is not None:
371
+ past_key, past_value = layer_past[0], layer_past[1]
372
+ key = torch.cat((past_key, key), dim=1)
373
+ value = torch.cat((past_value, value), dim=1)
374
+
375
+ if use_cache:
376
+ present = (key, value)
377
+ else:
378
+ present = None
379
+
380
+ if self.use_logn_attn and not self.training:
381
+ if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
382
+ self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
383
+ seq_start = key.size(1) - query.size(1)
384
+ seq_end = key.size(1)
385
+ logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
386
+ query = query * logn_tensor.expand_as(query)
387
+
388
+ if (
389
+ self.use_flash_attn
390
+ and flash_attn_unpadded_func is not None
391
+ and not self.is_fp32
392
+ and query.is_cuda
393
+ ):
394
+ q, k, v = query, key, value
395
+ context_layer = self.core_attention_flash(q, k, v)
396
+
397
+ # b s h d -> b s (h d)
398
+ context_layer = context_layer.flatten(2,3).contiguous()
399
+
400
+ else:
401
+ query = query.permute(0, 2, 1, 3)
402
+ key = key.permute(0, 2, 1, 3)
403
+ value = value.permute(0, 2, 1, 3)
404
+ attn_output, attn_weight = self._attn(
405
+ query, key, value, attention_mask, head_mask
406
+ )
407
+ context_layer = self._merge_heads(
408
+ attn_output, self.num_heads, self.head_dim
409
+ )
410
+
411
+ attn_output = self.c_proj(context_layer)
412
+
413
+ outputs = (attn_output, present)
414
+ if output_attentions:
415
+ if (
416
+ self.use_flash_attn
417
+ and flash_attn_unpadded_func is not None
418
+ and not self.is_fp32
419
+ ):
420
+ raise ValueError("Cannot output attentions while using flash-attn")
421
+ else:
422
+ outputs += (attn_weight,)
423
+
424
+ return outputs
425
+
426
+
427
+ class QWenMLP(nn.Module):
428
+ def __init__(self, config):
429
+ super().__init__()
430
+ self.w1 = nn.Linear(
431
+ config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
432
+ )
433
+ self.w2 = nn.Linear(
434
+ config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
435
+ )
436
+ ff_dim_in = config.intermediate_size // 2
437
+ self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
438
+
439
+ def forward(self, hidden_states):
440
+ a1 = self.w1(hidden_states)
441
+ a2 = self.w2(hidden_states)
442
+ intermediate_parallel = a1 * F.silu(a2)
443
+ output = self.c_proj(intermediate_parallel)
444
+ return output
445
+
446
+ class QWenBlock(nn.Module):
447
+ def __init__(self, config):
448
+ super().__init__()
449
+ hidden_size = config.hidden_size
450
+ self.bf16 = config.bf16
451
+
452
+ self.ln_1 = RMSNorm(
453
+ hidden_size,
454
+ eps=config.layer_norm_epsilon,
455
+ )
456
+ self.attn = QWenAttention(config)
457
+ self.ln_2 = RMSNorm(
458
+ hidden_size,
459
+ eps=config.layer_norm_epsilon,
460
+ )
461
+
462
+ self.mlp = QWenMLP(config)
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
467
+ rotary_pos_emb: Optional[List[torch.Tensor]] = None,
468
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
469
+ attention_mask: Optional[torch.FloatTensor] = None,
470
+ head_mask: Optional[torch.FloatTensor] = None,
471
+ encoder_hidden_states: Optional[torch.Tensor] = None,
472
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
473
+ use_cache: Optional[bool] = False,
474
+ output_attentions: Optional[bool] = False,
475
+ ):
476
+ layernorm_output = self.ln_1(hidden_states)
477
+
478
+ attn_outputs = self.attn(
479
+ layernorm_output,
480
+ rotary_pos_emb,
481
+ layer_past=layer_past,
482
+ attention_mask=attention_mask,
483
+ head_mask=head_mask,
484
+ use_cache=use_cache,
485
+ output_attentions=output_attentions,
486
+ )
487
+ attn_output = attn_outputs[0]
488
+
489
+ outputs = attn_outputs[1:]
490
+
491
+ residual = hidden_states
492
+ layernorm_input = attn_output + residual
493
+
494
+ layernorm_output = self.ln_2(layernorm_input)
495
+
496
+ residual = layernorm_input
497
+ mlp_output = self.mlp(layernorm_output)
498
+ hidden_states = residual + mlp_output
499
+
500
+ if use_cache:
501
+ outputs = (hidden_states,) + outputs
502
+ else:
503
+ outputs = (hidden_states,) + outputs[1:]
504
+
505
+ return outputs
506
+
507
+
508
+ class QWenPreTrainedModel(PreTrainedModel):
509
+ config_class = QWenConfig
510
+ base_model_prefix = "transformer"
511
+ is_parallelizable = False
512
+ supports_gradient_checkpointing = True
513
+ _no_split_modules = ["QWenBlock"]
514
+
515
+ def __init__(self, *inputs, **kwargs):
516
+ super().__init__(*inputs, **kwargs)
517
+
518
+ def _init_weights(self, module):
519
+ """Initialize the weights."""
520
+ if isinstance(module, nn.Linear):
521
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
522
+ if module.bias is not None:
523
+ module.bias.data.zero_()
524
+ elif isinstance(module, nn.Embedding):
525
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
526
+ if module.padding_idx is not None:
527
+ module.weight.data[module.padding_idx].zero_()
528
+ elif isinstance(module, RMSNorm):
529
+ module.weight.data.fill_(1.0)
530
+
531
+ for name, p in module.named_parameters():
532
+ if name == "c_proj.weight":
533
+ p.data.normal_(
534
+ mean=0.0,
535
+ std=(
536
+ self.config.initializer_range
537
+ / math.sqrt(2 * self.config.num_hidden_layers)
538
+ ),
539
+ )
540
+
541
+ def _set_gradient_checkpointing(self, module, value=False):
542
+ if isinstance(module, QWenModel):
543
+ module.gradient_checkpointing = value
544
+
545
+
546
+ class QWenModel(QWenPreTrainedModel):
547
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
548
+
549
+ def __init__(self, config):
550
+ super().__init__(config)
551
+ self.vocab_size = config.vocab_size
552
+ self.num_hidden_layers = config.num_hidden_layers
553
+ self.embed_dim = config.hidden_size
554
+
555
+ self.gradient_checkpointing = False
556
+ self.use_dynamic_ntk = config.use_dynamic_ntk
557
+ self.seq_length = config.seq_length
558
+
559
+ self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
560
+
561
+ self.drop = nn.Dropout(config.emb_dropout_prob)
562
+
563
+
564
+ if config.rotary_pct == 1.0:
565
+ self.rotary_ndims = None
566
+ else:
567
+ assert config.rotary_pct < 1
568
+ self.rotary_ndims = int(
569
+ config.kv_channels * config.rotary_pct
570
+ )
571
+ dim = (
572
+ self.rotary_ndims
573
+ if self.rotary_ndims is not None
574
+ else config.kv_channels
575
+ )
576
+ self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
577
+
578
+ self.h = nn.ModuleList(
579
+ [
580
+ QWenBlock(
581
+ config,
582
+ )
583
+ for i in range(config.num_hidden_layers)
584
+ ]
585
+ )
586
+ self.ln_f = RMSNorm(
587
+ self.embed_dim,
588
+ eps=config.layer_norm_epsilon,
589
+ )
590
+
591
+ self.post_init()
592
+
593
+ def get_input_embeddings(self):
594
+ return self.wte
595
+
596
+ def set_input_embeddings(self, new_embeddings):
597
+ self.wte = new_embeddings
598
+
599
+ def forward(
600
+ self,
601
+ input_ids: Optional[torch.LongTensor] = None,
602
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
603
+ attention_mask: Optional[torch.FloatTensor] = None,
604
+ token_type_ids: Optional[torch.LongTensor] = None,
605
+ position_ids: Optional[torch.LongTensor] = None,
606
+ head_mask: Optional[torch.FloatTensor] = None,
607
+ inputs_embeds: Optional[torch.FloatTensor] = None,
608
+ encoder_hidden_states: Optional[torch.Tensor] = None,
609
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
610
+ use_cache: Optional[bool] = None,
611
+ output_attentions: Optional[bool] = None,
612
+ output_hidden_states: Optional[bool] = None,
613
+ return_dict: Optional[bool] = None,
614
+ ):
615
+ output_attentions = (
616
+ output_attentions
617
+ if output_attentions is not None
618
+ else self.config.output_attentions
619
+ )
620
+ output_hidden_states = (
621
+ output_hidden_states
622
+ if output_hidden_states is not None
623
+ else self.config.output_hidden_states
624
+ )
625
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
626
+ return_dict = (
627
+ return_dict if return_dict is not None else self.config.use_return_dict
628
+ )
629
+
630
+ if input_ids is not None and inputs_embeds is not None:
631
+ raise ValueError(
632
+ "You cannot specify both input_ids and inputs_embeds at the same time"
633
+ )
634
+ elif input_ids is not None:
635
+ input_shape = input_ids.size()
636
+ input_ids = input_ids.view(-1, input_shape[-1])
637
+ batch_size = input_ids.shape[0]
638
+ elif inputs_embeds is not None:
639
+ input_shape = inputs_embeds.size()[:-1]
640
+ batch_size = inputs_embeds.shape[0]
641
+ else:
642
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
643
+
644
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
645
+
646
+ if token_type_ids is not None:
647
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
648
+ if position_ids is not None:
649
+ position_ids = position_ids.view(-1, input_shape[-1])
650
+
651
+ if past_key_values is None:
652
+ past_length = 0
653
+ past_key_values = tuple([None] * len(self.h))
654
+ else:
655
+ past_length = past_key_values[0][0].size(-2)
656
+
657
+ if position_ids is None:
658
+ position_ids = torch.arange(
659
+ past_length,
660
+ input_shape[-1] + past_length,
661
+ dtype=torch.long,
662
+ device=device,
663
+ )
664
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
665
+
666
+ if attention_mask is not None:
667
+ if batch_size <= 0:
668
+ raise ValueError("batch_size has to be defined and > 0")
669
+ attention_mask = attention_mask.view(batch_size, -1)
670
+ attention_mask = attention_mask[:, None, None, :]
671
+ attention_mask = attention_mask.to(dtype=self.dtype)
672
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
673
+
674
+ encoder_attention_mask = None
675
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
676
+
677
+ if inputs_embeds is None:
678
+ inputs_embeds = self.wte(input_ids)
679
+ hidden_states = inputs_embeds
680
+
681
+ kv_seq_len = hidden_states.size()[1]
682
+ if past_key_values[0] is not None:
683
+ # past key values[0][0] shape: bs * seq_len * head_num * dim
684
+ kv_seq_len += past_key_values[0][0].shape[1]
685
+ if (
686
+ self.use_dynamic_ntk
687
+ and kv_seq_len == hidden_states.size()[1]
688
+ and not self.training
689
+ ):
690
+ context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
691
+ ntk_alpha = 2 ** math.ceil(context_value) - 1
692
+ ntk_alpha = max(ntk_alpha, 1)
693
+ else:
694
+ ntk_alpha = self.rotary_emb._ntk_alpha_cached
695
+
696
+ rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
697
+ for idx in range(len(rotary_pos_emb)):
698
+ rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
699
+
700
+ hidden_states = self.drop(hidden_states)
701
+ output_shape = input_shape + (hidden_states.size(-1),)
702
+
703
+ if self.gradient_checkpointing and self.training:
704
+ if use_cache:
705
+ logger.warning_once(
706
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
707
+ )
708
+ use_cache = False
709
+
710
+ presents = () if use_cache else None
711
+ all_self_attentions = () if output_attentions else None
712
+ all_hidden_states = () if output_hidden_states else None
713
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
714
+
715
+ if output_hidden_states:
716
+ all_hidden_states = all_hidden_states + (hidden_states,)
717
+
718
+ if self.gradient_checkpointing and self.training:
719
+
720
+ def create_custom_forward(module):
721
+ def custom_forward(*inputs):
722
+ # None for past_key_value
723
+ return module(*inputs, use_cache, output_attentions)
724
+
725
+ return custom_forward
726
+
727
+ outputs = torch.utils.checkpoint.checkpoint(
728
+ create_custom_forward(block),
729
+ hidden_states,
730
+ rotary_pos_emb,
731
+ None,
732
+ attention_mask,
733
+ head_mask[i],
734
+ encoder_hidden_states,
735
+ encoder_attention_mask,
736
+ )
737
+ else:
738
+ outputs = block(
739
+ hidden_states,
740
+ layer_past=layer_past,
741
+ rotary_pos_emb=rotary_pos_emb,
742
+ attention_mask=attention_mask,
743
+ head_mask=head_mask[i],
744
+ encoder_hidden_states=encoder_hidden_states,
745
+ encoder_attention_mask=encoder_attention_mask,
746
+ use_cache=use_cache,
747
+ output_attentions=output_attentions,
748
+ )
749
+
750
+ hidden_states = outputs[0]
751
+ if use_cache is True:
752
+ presents = presents + (outputs[1],)
753
+
754
+ if output_attentions:
755
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
756
+
757
+ hidden_states = self.ln_f(hidden_states)
758
+ hidden_states = hidden_states.view(output_shape)
759
+ # Add last hidden state
760
+ if output_hidden_states:
761
+ all_hidden_states = all_hidden_states + (hidden_states,)
762
+
763
+ if not return_dict:
764
+ return tuple(
765
+ v for v in [hidden_states, presents, all_hidden_states] if v is not None
766
+ )
767
+
768
+ return BaseModelOutputWithPast(
769
+ last_hidden_state=hidden_states,
770
+ past_key_values=presents,
771
+ hidden_states=all_hidden_states,
772
+ attentions=all_self_attentions,
773
+ )
774
+
775
+
776
+ class QWenLMHeadModel(QWenPreTrainedModel):
777
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
778
+ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
779
+
780
+ def __init__(self, config):
781
+ super().__init__(config)
782
+ assert (
783
+ config.bf16 + config.fp16 + config.fp32 <= 1
784
+ ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
785
+
786
+ autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
787
+
788
+ if autoset_precision:
789
+ if SUPPORT_BF16:
790
+ logger.warn(
791
+ "The model is automatically converting to bf16 for faster inference. "
792
+ "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
793
+ )
794
+ config.bf16 = True
795
+ elif SUPPORT_FP16:
796
+ logger.warn(
797
+ "The model is automatically converting to fp16 for faster inference. "
798
+ "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
799
+ )
800
+ config.fp16 = True
801
+ else:
802
+ config.fp32 = True
803
+
804
+ if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
805
+ logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
806
+ if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
807
+ logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
808
+ if config.fp32:
809
+ if SUPPORT_BF16:
810
+ logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
811
+ elif SUPPORT_FP16:
812
+ logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
813
+
814
+ if config.use_flash_attn == "auto":
815
+ if config.bf16 or config.fp16:
816
+ logger.warn("Try importing flash-attention for faster inference...")
817
+ config.use_flash_attn = True
818
+ else:
819
+ config.use_flash_attn = False
820
+ if config.use_flash_attn and config.fp32:
821
+ logger.warn("Flash attention will be disabled because it does NOT support fp32.")
822
+
823
+ if config.use_flash_attn:
824
+ _import_flash_attn()
825
+
826
+ self.transformer = QWenModel(config)
827
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
828
+
829
+ if config.bf16:
830
+ self.transformer.bfloat16()
831
+ self.lm_head.bfloat16()
832
+ if config.fp16:
833
+ self.transformer.half()
834
+ self.lm_head.half()
835
+ self.post_init()
836
+
837
+ def get_output_embeddings(self):
838
+ return self.lm_head
839
+
840
+ def set_output_embeddings(self, new_embeddings):
841
+ self.lm_head = new_embeddings
842
+
843
+ def prepare_inputs_for_generation(
844
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
845
+ ):
846
+ token_type_ids = kwargs.get("token_type_ids", None)
847
+ if past_key_values:
848
+ input_ids = input_ids[:, -1].unsqueeze(-1)
849
+ if token_type_ids is not None:
850
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
851
+
852
+ attention_mask = kwargs.get("attention_mask", None)
853
+ position_ids = kwargs.get("position_ids", None)
854
+
855
+ if attention_mask is not None and position_ids is None:
856
+ position_ids = attention_mask.long().cumsum(-1) - 1
857
+ position_ids.masked_fill_(attention_mask == 0, 1)
858
+ if past_key_values:
859
+ position_ids = position_ids[:, -1].unsqueeze(-1)
860
+ else:
861
+ position_ids = None
862
+
863
+ if inputs_embeds is not None and past_key_values is None:
864
+ model_inputs = {"inputs_embeds": inputs_embeds}
865
+ else:
866
+ model_inputs = {"input_ids": input_ids}
867
+
868
+ model_inputs.update(
869
+ {
870
+ "past_key_values": past_key_values,
871
+ "use_cache": kwargs.get("use_cache"),
872
+ "position_ids": position_ids,
873
+ "attention_mask": attention_mask,
874
+ "token_type_ids": token_type_ids,
875
+ }
876
+ )
877
+ return model_inputs
878
+
879
+ def forward(
880
+ self,
881
+ input_ids: Optional[torch.LongTensor] = None,
882
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
883
+ attention_mask: Optional[torch.FloatTensor] = None,
884
+ token_type_ids: Optional[torch.LongTensor] = None,
885
+ position_ids: Optional[torch.LongTensor] = None,
886
+ head_mask: Optional[torch.FloatTensor] = None,
887
+ inputs_embeds: Optional[torch.FloatTensor] = None,
888
+ encoder_hidden_states: Optional[torch.Tensor] = None,
889
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
890
+ labels: Optional[torch.LongTensor] = None,
891
+ use_cache: Optional[bool] = None,
892
+ output_attentions: Optional[bool] = None,
893
+ output_hidden_states: Optional[bool] = None,
894
+ return_dict: Optional[bool] = None,
895
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
896
+
897
+ return_dict = (
898
+ return_dict if return_dict is not None else self.config.use_return_dict
899
+ )
900
+
901
+ transformer_outputs = self.transformer(
902
+ input_ids,
903
+ past_key_values=past_key_values,
904
+ attention_mask=attention_mask,
905
+ token_type_ids=token_type_ids,
906
+ position_ids=position_ids,
907
+ head_mask=head_mask,
908
+ inputs_embeds=inputs_embeds,
909
+ encoder_hidden_states=encoder_hidden_states,
910
+ encoder_attention_mask=encoder_attention_mask,
911
+ use_cache=use_cache,
912
+ output_attentions=output_attentions,
913
+ output_hidden_states=output_hidden_states,
914
+ return_dict=return_dict,
915
+ )
916
+ hidden_states = transformer_outputs[0]
917
+
918
+ lm_logits = self.lm_head(hidden_states)
919
+
920
+ loss = None
921
+ if labels is not None:
922
+ labels = labels.to(lm_logits.device)
923
+ shift_logits = lm_logits[..., :-1, :].contiguous()
924
+ shift_labels = labels[..., 1:].contiguous()
925
+ loss_fct = CrossEntropyLoss()
926
+ loss = loss_fct(
927
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
928
+ )
929
+
930
+ if not return_dict:
931
+ output = (lm_logits,) + transformer_outputs[1:]
932
+ return ((loss,) + output) if loss is not None else output
933
+
934
+ return CausalLMOutputWithPast(
935
+ loss=loss,
936
+ logits=lm_logits,
937
+ past_key_values=transformer_outputs.past_key_values,
938
+ hidden_states=transformer_outputs.hidden_states,
939
+ attentions=transformer_outputs.attentions,
940
+ )
941
+
942
+ @staticmethod
943
+ def _reorder_cache(
944
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
945
+ ) -> Tuple[Tuple[torch.Tensor]]:
946
+
947
+ return tuple(
948
+ tuple(
949
+ past_state.index_select(0, beam_idx.to(past_state.device))
950
+ for past_state in layer_past
951
+ )
952
+ for layer_past in past_key_values
953
+ )
954
+
955
+ def chat(
956
+ self,
957
+ tokenizer: PreTrainedTokenizer,
958
+ query: str,
959
+ history: Optional[HistoryType],
960
+ system: str = "You are a helpful assistant.",
961
+ append_history: bool = True,
962
+ stream: Optional[bool] = _SENTINEL,
963
+ stop_words_ids: Optional[List[List[int]]] = None,
964
+ generation_config: Optional[GenerationConfig] = None,
965
+ **kwargs,
966
+ ) -> Tuple[str, HistoryType]:
967
+ generation_config = generation_config if generation_config is not None else self.generation_config
968
+
969
+ assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
970
+ assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
971
+ if history is None:
972
+ history = []
973
+ if stop_words_ids is None:
974
+ stop_words_ids = []
975
+
976
+ max_window_size = kwargs.get('max_window_size', None)
977
+ if max_window_size is None:
978
+ max_window_size = generation_config.max_window_size
979
+ raw_text, context_tokens = make_context(
980
+ tokenizer,
981
+ query,
982
+ history=history,
983
+ system=system,
984
+ max_window_size=max_window_size,
985
+ chat_format=generation_config.chat_format,
986
+ )
987
+
988
+ stop_words_ids.extend(get_stop_words_ids(
989
+ generation_config.chat_format, tokenizer
990
+ ))
991
+ input_ids = torch.tensor([context_tokens]).to(self.device)
992
+ outputs = self.generate(
993
+ input_ids,
994
+ stop_words_ids=stop_words_ids,
995
+ return_dict_in_generate=False,
996
+ generation_config=generation_config,
997
+ **kwargs,
998
+ )
999
+
1000
+ response = decode_tokens(
1001
+ outputs[0],
1002
+ tokenizer,
1003
+ raw_text_len=len(raw_text),
1004
+ context_length=len(context_tokens),
1005
+ chat_format=generation_config.chat_format,
1006
+ verbose=False,
1007
+ errors='replace'
1008
+ )
1009
+
1010
+ if append_history:
1011
+ history.append((query, response))
1012
+
1013
+ return response, history
1014
+
1015
+ def chat_stream(
1016
+ self,
1017
+ tokenizer: PreTrainedTokenizer,
1018
+ query: str,
1019
+ history: Optional[HistoryType],
1020
+ system: str = "You are a helpful assistant.",
1021
+ stop_words_ids: Optional[List[List[int]]] = None,
1022
+ logits_processor: Optional[LogitsProcessorList] = None,
1023
+ generation_config: Optional[GenerationConfig] = None,
1024
+ **kwargs,
1025
+ ) -> Generator[str, Any, None]:
1026
+ generation_config = generation_config if generation_config is not None else self.generation_config
1027
+ assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
1028
+ if history is None:
1029
+ history = []
1030
+ if stop_words_ids is None:
1031
+ stop_words_ids = []
1032
+
1033
+ max_window_size = kwargs.get('max_window_size', None)
1034
+ if max_window_size is None:
1035
+ max_window_size = generation_config.max_window_size
1036
+ raw_text, context_tokens = make_context(
1037
+ tokenizer,
1038
+ query,
1039
+ history=history,
1040
+ system=system,
1041
+ max_window_size=max_window_size,
1042
+ chat_format=generation_config.chat_format,
1043
+ )
1044
+
1045
+ stop_words_ids.extend(get_stop_words_ids(
1046
+ generation_config.chat_format, tokenizer
1047
+ ))
1048
+ if stop_words_ids is not None:
1049
+ stop_words_logits_processor = StopWordsLogitsProcessor(
1050
+ stop_words_ids=stop_words_ids,
1051
+ eos_token_id=generation_config.eos_token_id,
1052
+ )
1053
+ if logits_processor is None:
1054
+ logits_processor = LogitsProcessorList([stop_words_logits_processor])
1055
+ else:
1056
+ logits_processor.append(stop_words_logits_processor)
1057
+ input_ids = torch.tensor([context_tokens]).to(self.device)
1058
+
1059
+ from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
1060
+ self.__class__.generate_stream = NewGenerationMixin.generate
1061
+ self.__class__.sample_stream = NewGenerationMixin.sample_stream
1062
+ stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
1063
+
1064
+ def stream_generator():
1065
+ outputs = []
1066
+ for token in self.generate_stream(
1067
+ input_ids,
1068
+ return_dict_in_generate=False,
1069
+ generation_config=stream_config,
1070
+ logits_processor=logits_processor,
1071
+ seed=-1,
1072
+ **kwargs):
1073
+ outputs.append(token.item())
1074
+ yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')
1075
+
1076
+ return stream_generator()
1077
+
1078
+ def generate(
1079
+ self,
1080
+ inputs: Optional[torch.Tensor] = None,
1081
+ generation_config: Optional[GenerationConfig] = None,
1082
+ logits_processor: Optional[LogitsProcessorList] = None,
1083
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1084
+ prefix_allowed_tokens_fn: Optional[
1085
+ Callable[[int, torch.Tensor], List[int]]
1086
+ ] = None,
1087
+ synced_gpus: Optional[bool] = None,
1088
+ assistant_model: Optional["PreTrainedModel"] = None,
1089
+ streamer: Optional["BaseStreamer"] = None,
1090
+ **kwargs,
1091
+ ) -> Union[GenerateOutput, torch.LongTensor]:
1092
+ generation_config = generation_config if generation_config is not None else self.generation_config
1093
+
1094
+ # Process stop_words_ids.
1095
+ stop_words_ids = kwargs.pop("stop_words_ids", None)
1096
+ if stop_words_ids is None and generation_config is not None:
1097
+ stop_words_ids = getattr(generation_config, "stop_words_ids", None)
1098
+ if stop_words_ids is None:
1099
+ stop_words_ids = getattr(generation_config, "stop_words_ids", None)
1100
+
1101
+ if stop_words_ids is not None:
1102
+ stop_words_logits_processor = StopWordsLogitsProcessor(
1103
+ stop_words_ids=stop_words_ids,
1104
+ eos_token_id=generation_config.eos_token_id,
1105
+ )
1106
+ if logits_processor is None:
1107
+ logits_processor = LogitsProcessorList([stop_words_logits_processor])
1108
+ else:
1109
+ logits_processor.append(stop_words_logits_processor)
1110
+
1111
+ return super().generate(
1112
+ inputs,
1113
+ generation_config=generation_config,
1114
+ logits_processor=logits_processor,
1115
+ stopping_criteria=stopping_criteria,
1116
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1117
+ synced_gpus=synced_gpus,
1118
+ assistant_model=assistant_model,
1119
+ streamer=streamer,
1120
+ **kwargs,
1121
+ )
1122
+
1123
+
1124
+ class RotaryEmbedding(torch.nn.Module):
1125
+ def __init__(self, dim, base=10000):
1126
+ super().__init__()
1127
+ self.dim = dim
1128
+ self.base = base
1129
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
1130
+ if importlib.util.find_spec("einops") is None:
1131
+ raise RuntimeError("einops is required for Rotary Embedding")
1132
+
1133
+ self._rotary_pos_emb_cache = None
1134
+ self._seq_len_cached = 0
1135
+ self._ntk_alpha_cached = 1.0
1136
+
1137
+ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
1138
+ seqlen = max_seq_len + offset
1139
+ if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
1140
+ base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
1141
+ self.inv_freq = 1.0 / (
1142
+ base
1143
+ ** (
1144
+ torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
1145
+ / self.dim
1146
+ )
1147
+ )
1148
+ self._seq_len_cached = max(2 * seqlen, 16)
1149
+ self._ntk_alpha_cached = ntk_alpha
1150
+ seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
1151
+ freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
1152
+
1153
+ emb = torch.cat((freqs, freqs), dim=-1)
1154
+ from einops import rearrange
1155
+
1156
+ emb = rearrange(emb, "n d -> 1 n 1 d")
1157
+
1158
+ cos, sin = emb.cos(), emb.sin()
1159
+ self._rotary_pos_emb_cache = [cos, sin]
1160
+
1161
+ def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
1162
+ self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
1163
+ cos, sin = self._rotary_pos_emb_cache
1164
+ return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]]
1165
+
1166
+
1167
+ def _rotate_half(x):
1168
+ from einops import rearrange
1169
+
1170
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
1171
+ x1, x2 = x.unbind(dim=-2)
1172
+ return torch.cat((-x2, x1), dim=-1)
1173
+
1174
+
1175
+ def apply_rotary_pos_emb(t, freqs):
1176
+ cos, sin = freqs
1177
+ if apply_rotary_emb_func is not None and t.is_cuda:
1178
+ t_ = t.float()
1179
+ cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
1180
+ sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
1181
+ output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
1182
+ return output
1183
+ else:
1184
+ rot_dim = freqs[0].shape[-1]
1185
+ cos, sin = freqs
1186
+ t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
1187
+ t_ = t_.float()
1188
+ t_pass_ = t_pass_.float()
1189
+ t_ = (t_ * cos) + (_rotate_half(t_) * sin)
1190
+ return torch.cat((t_, t_pass_), dim=-1).type_as(t)
1191
+
1192
+
1193
+ class RMSNorm(torch.nn.Module):
1194
+ def __init__(self, dim: int, eps: float = 1e-6):
1195
+ super().__init__()
1196
+ self.eps = eps
1197
+ self.weight = nn.Parameter(torch.ones(dim))
1198
+
1199
+ def _norm(self, x):
1200
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1201
+
1202
+ def forward(self, x):
1203
+ if rms_norm is not None and x.is_cuda:
1204
+ return rms_norm(x, self.weight, self.eps)
1205
+ else:
1206
+ output = self._norm(x.float()).type_as(x)
1207
+ return output * self.weight
qwen_generation_utils.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Generation support."""
7
+
8
+ from typing import Tuple, List, Union, Iterable
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from transformers import PreTrainedTokenizer
14
+ from transformers import logging
15
+ from transformers.generation import LogitsProcessor
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+ # Types.
20
+ HistoryType = List[Tuple[str, str]]
21
+ TokensType = List[int]
22
+ BatchTokensType = List[List[int]]
23
+
24
+
25
+ def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
26
+ for tokens in batch:
27
+ context_length = len(tokens)
28
+ if context_length < seq_length:
29
+ tokens.extend([pad_id] * (seq_length - context_length))
30
+ return batch
31
+
32
+
33
+ def get_ltor_masks_and_position_ids(
34
+ data,
35
+ eod_token,
36
+ reset_position_ids,
37
+ reset_attention_mask,
38
+ eod_mask_loss,
39
+ ):
40
+ """Build masks and position id for left to right model."""
41
+
42
+ # Extract batch size and sequence length.
43
+ micro_batch_size, seq_length = data.size()
44
+
45
+ # Attention mask (lower triangular).
46
+ if reset_attention_mask:
47
+ att_mask_batch = micro_batch_size
48
+ else:
49
+ att_mask_batch = 1
50
+ attention_mask = torch.tril(
51
+ torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
52
+ ).view(att_mask_batch, 1, seq_length, seq_length)
53
+
54
+ # Loss mask.
55
+ loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
56
+ if eod_mask_loss:
57
+ loss_mask[data == eod_token] = 0.0
58
+
59
+ # Position ids.
60
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
61
+ position_ids = position_ids.unsqueeze(0).expand_as(data)
62
+ # We need to clone as the ids will be modifed based on batch index.
63
+ if reset_position_ids:
64
+ position_ids = position_ids.clone()
65
+
66
+ if reset_position_ids or reset_attention_mask:
67
+ # Loop through the batches:
68
+ for b in range(micro_batch_size):
69
+
70
+ # Find indecies where EOD token is.
71
+ eod_index = position_ids[b, data[b] == eod_token]
72
+ # Detach indecies from positions if going to modify positions.
73
+ if reset_position_ids:
74
+ eod_index = eod_index.clone()
75
+
76
+ # Loop through EOD indecies:
77
+ prev_index = 0
78
+ for j in range(eod_index.size()[0]):
79
+ i = eod_index[j]
80
+ # Mask attention loss.
81
+ if reset_attention_mask:
82
+ attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
83
+ # Reset positions.
84
+ if reset_position_ids:
85
+ position_ids[b, (i + 1) :] -= i + 1 - prev_index
86
+ prev_index = i + 1
87
+
88
+ # Convert attention mask to binary:
89
+ attention_mask = attention_mask < 0.5
90
+
91
+ return attention_mask, loss_mask, position_ids
92
+
93
+
94
+ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
95
+ """Generate batch from context tokens."""
96
+ # Move to GPU.
97
+ tokens = context_tokens.contiguous().to(context_tokens.device)
98
+ # Get the attention mask and postition ids.
99
+ attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
100
+ tokens,
101
+ eod_id,
102
+ reset_position_ids=False,
103
+ reset_attention_mask=False,
104
+ eod_mask_loss=False,
105
+ )
106
+ return tokens, attention_mask, position_ids
107
+
108
+
109
+ def get_stop_words_ids(chat_format, tokenizer):
110
+ if chat_format == "raw":
111
+ stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
112
+ elif chat_format == "chatml":
113
+ stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
114
+ else:
115
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
116
+ return stop_words_ids
117
+
118
+
119
+ def make_context(
120
+ tokenizer: PreTrainedTokenizer,
121
+ query: str,
122
+ history: List[Tuple[str, str]] = None,
123
+ system: str = "",
124
+ max_window_size: int = 6144,
125
+ chat_format: str = "chatml",
126
+ ):
127
+ if history is None:
128
+ history = []
129
+
130
+ if chat_format == "chatml":
131
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
132
+ im_start_tokens = [tokenizer.im_start_id]
133
+ im_end_tokens = [tokenizer.im_end_id]
134
+ nl_tokens = tokenizer.encode("\n")
135
+
136
+ def _tokenize_str(role, content):
137
+ return f"{role}\n{content}", tokenizer.encode(
138
+ role, allowed_special=set()
139
+ ) + nl_tokens + tokenizer.encode(content, allowed_special=set())
140
+
141
+ system_text, system_tokens_part = _tokenize_str("system", system)
142
+ system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
143
+
144
+ raw_text = ""
145
+ context_tokens = []
146
+
147
+ for turn_query, turn_response in reversed(history):
148
+ query_text, query_tokens_part = _tokenize_str("user", turn_query)
149
+ query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
150
+ response_text, response_tokens_part = _tokenize_str(
151
+ "assistant", turn_response
152
+ )
153
+ response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
154
+
155
+ next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
156
+ prev_chat = (
157
+ f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
158
+ )
159
+
160
+ current_context_size = (
161
+ len(system_tokens) + len(next_context_tokens) + len(context_tokens)
162
+ )
163
+ if current_context_size < max_window_size:
164
+ context_tokens = next_context_tokens + context_tokens
165
+ raw_text = prev_chat + raw_text
166
+ else:
167
+ break
168
+
169
+ context_tokens = system_tokens + context_tokens
170
+ raw_text = f"{im_start}{system_text}{im_end}" + raw_text
171
+ context_tokens += (
172
+ nl_tokens
173
+ + im_start_tokens
174
+ + _tokenize_str("user", query)[1]
175
+ + im_end_tokens
176
+ + nl_tokens
177
+ + im_start_tokens
178
+ + tokenizer.encode("assistant")
179
+ + nl_tokens
180
+ )
181
+ raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
182
+
183
+ elif chat_format == "raw":
184
+ raw_text = query
185
+ context_tokens = tokenizer.encode(raw_text)
186
+ else:
187
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
188
+
189
+ return raw_text, context_tokens
190
+
191
+
192
+ def _decode_default(
193
+ tokens: List[int],
194
+ *,
195
+ stop_words: List[str],
196
+ eod_words: List[str],
197
+ tokenizer: PreTrainedTokenizer,
198
+ raw_text_len: int,
199
+ verbose: bool = False,
200
+ return_end_reason: bool = False,
201
+ errors: str='replace',
202
+ ):
203
+ trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
204
+ if verbose:
205
+ print("\nRaw Generate: ", trim_decode_tokens)
206
+
207
+ end_reason = f"Gen length {len(tokens)}"
208
+ for stop_word in stop_words:
209
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
210
+ for eod_word in eod_words:
211
+ if eod_word in trim_decode_tokens:
212
+ end_reason = f"Gen {eod_word!r}"
213
+ trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
214
+ trim_decode_tokens = trim_decode_tokens.strip()
215
+ if verbose:
216
+ print("\nEnd Reason:", end_reason)
217
+ print("\nGenerate: ", trim_decode_tokens)
218
+
219
+ if return_end_reason:
220
+ return trim_decode_tokens, end_reason
221
+ else:
222
+ return trim_decode_tokens
223
+
224
+
225
+ def _decode_chatml(
226
+ tokens: List[int],
227
+ *,
228
+ stop_words: List[str],
229
+ eod_token_ids: List[int],
230
+ tokenizer: PreTrainedTokenizer,
231
+ raw_text_len: int,
232
+ context_length: int,
233
+ verbose: bool = False,
234
+ return_end_reason: bool = False,
235
+ errors: str='replace'
236
+ ):
237
+ end_reason = f"Gen length {len(tokens)}"
238
+ eod_token_idx = context_length
239
+ for eod_token_idx in range(context_length, len(tokens)):
240
+ if tokens[eod_token_idx] in eod_token_ids:
241
+ end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
242
+ break
243
+
244
+ trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
245
+ if verbose:
246
+ print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
247
+ print("\nRaw Generate:", trim_decode_tokens)
248
+ print("\nEnd Reason:", end_reason)
249
+ for stop_word in stop_words:
250
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
251
+ trim_decode_tokens = trim_decode_tokens.strip()
252
+ if verbose:
253
+ print("\nGenerate:", trim_decode_tokens)
254
+
255
+ if return_end_reason:
256
+ return trim_decode_tokens, end_reason
257
+ else:
258
+ return trim_decode_tokens
259
+
260
+
261
+ def decode_tokens(
262
+ tokens: Union[torch.LongTensor, TokensType],
263
+ tokenizer: PreTrainedTokenizer,
264
+ raw_text_len: int,
265
+ context_length: int,
266
+ chat_format: str,
267
+ verbose: bool = False,
268
+ return_end_reason: bool = False,
269
+ errors: str="replace",
270
+ ) -> str:
271
+ if torch.is_tensor(tokens):
272
+ tokens = tokens.cpu().numpy().tolist()
273
+
274
+ if chat_format == "chatml":
275
+ return _decode_chatml(
276
+ tokens,
277
+ stop_words=[],
278
+ eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
279
+ tokenizer=tokenizer,
280
+ raw_text_len=raw_text_len,
281
+ context_length=context_length,
282
+ verbose=verbose,
283
+ return_end_reason=return_end_reason,
284
+ errors=errors,
285
+ )
286
+ elif chat_format == "raw":
287
+ return _decode_default(
288
+ tokens,
289
+ stop_words=["<|endoftext|>"],
290
+ eod_words=["<|endoftext|>"],
291
+ tokenizer=tokenizer,
292
+ raw_text_len=raw_text_len,
293
+ verbose=verbose,
294
+ return_end_reason=return_end_reason,
295
+ errors=errors,
296
+ )
297
+ else:
298
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
299
+
300
+
301
+ class StopWordsLogitsProcessor(LogitsProcessor):
302
+ """
303
+ :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
304
+
305
+ Args:
306
+ stop_words_ids (:obj:`List[List[int]]`):
307
+ List of list of token ids of stop ids. In order to get the tokens of the words
308
+ that should not appear in the generated text, use :obj:`tokenizer(bad_word,
309
+ add_prefix_space=True).input_ids`.
310
+ eos_token_id (:obj:`int`):
311
+ The id of the `end-of-sequence` token.
312
+ """
313
+
314
+ def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
315
+
316
+ if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
317
+ raise ValueError(
318
+ f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
319
+ )
320
+ if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
321
+ raise ValueError(
322
+ f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
323
+ )
324
+ if any(
325
+ any(
326
+ (not isinstance(token_id, (int, np.integer)) or token_id < 0)
327
+ for token_id in stop_word_ids
328
+ )
329
+ for stop_word_ids in stop_words_ids
330
+ ):
331
+ raise ValueError(
332
+ f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
333
+ )
334
+
335
+ self.stop_words_ids = list(
336
+ filter(
337
+ lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
338
+ )
339
+ )
340
+ self.eos_token_id = eos_token_id
341
+ for stop_token_seq in self.stop_words_ids:
342
+ assert (
343
+ len(stop_token_seq) > 0
344
+ ), "Stop words token sequences {} cannot have an empty list".format(
345
+ stop_words_ids
346
+ )
347
+
348
+ def __call__(
349
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
350
+ ) -> torch.FloatTensor:
351
+ stopped_samples = self._calc_stopped_samples(input_ids)
352
+ for i, should_stop in enumerate(stopped_samples):
353
+ if should_stop:
354
+ scores[i, self.eos_token_id] = float(2**15)
355
+ return scores
356
+
357
+ def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
358
+ if len(tokens) == 0:
359
+ # if bad word tokens is just one token always ban it
360
+ return True
361
+ elif len(tokens) > len(prev_tokens):
362
+ # if bad word tokens are longer then prev input_ids they can't be equal
363
+ return False
364
+ elif prev_tokens[-len(tokens) :].tolist() == tokens:
365
+ # if tokens match
366
+ return True
367
+ else:
368
+ return False
369
+
370
+ def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
371
+ stopped_samples = []
372
+ for prev_input_ids_slice in prev_input_ids:
373
+ match = False
374
+ for stop_token_seq in self.stop_words_ids:
375
+ if self._tokens_match(prev_input_ids_slice, stop_token_seq):
376
+ # if tokens do not match continue
377
+ match = True
378
+ break
379
+ stopped_samples.append(match)
380
+
381
+ return stopped_samples
382
+
383
+
384
+ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
385
+ """This function has been mostly taken from huggingface conversational
386
+ ai code at
387
+ https://medium.com/huggingface/how-to-build-a-state-of-the-art-
388
+ conversational-ai-with-transfer-learning-2d818ac26313"""
389
+
390
+ if top_k > 0:
391
+ # Remove all tokens with a probability less than the
392
+ # last token of the top-k
393
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
394
+ logits[indices_to_remove] = filter_value
395
+
396
+ if top_p > 0.0:
397
+ # Cconvert to 1D
398
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
399
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
400
+
401
+ # Remove tokens with cumulative probability above the threshold
402
+ sorted_indices_to_remove = cumulative_probs > top_p
403
+ # Shift the indices to the right to keep also the first token
404
+ # above the threshold
405
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
406
+ sorted_indices_to_remove[..., 0] = 0
407
+ for i in range(sorted_indices.size(0)):
408
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
409
+ logits[i][indices_to_remove] = filter_value
410
+
411
+ return logits
412
+
413
+
414
+ def switch(val1, val2, boolean):
415
+ boolean = boolean.type_as(val1)
416
+ return (1 - boolean) * val1 + boolean * val2
tokenization_qwen.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Tokenization classes for QWen."""
7
+
8
+ import base64
9
+ import logging
10
+ import os
11
+ import unicodedata
12
+ from typing import Collection, Dict, List, Set, Tuple, Union
13
+
14
+ import tiktoken
15
+ from transformers import PreTrainedTokenizer, AddedToken
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
21
+
22
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
23
+ ENDOFTEXT = "<|endoftext|>"
24
+ IMSTART = "<|im_start|>"
25
+ IMEND = "<|im_end|>"
26
+ # as the default behavior is changed to allow special tokens in
27
+ # regular texts, the surface forms of special tokens need to be
28
+ # as different as possible to minimize the impact
29
+ EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
30
+ SPECIAL_TOKENS = (
31
+ ENDOFTEXT,
32
+ IMSTART,
33
+ IMEND,
34
+ ) + EXTRAS
35
+
36
+
37
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
38
+ with open(tiktoken_bpe_file, "rb") as f:
39
+ contents = f.read()
40
+ return {
41
+ base64.b64decode(token): int(rank)
42
+ for token, rank in (line.split() for line in contents.splitlines() if line)
43
+ }
44
+
45
+ class QWenTokenizer(PreTrainedTokenizer):
46
+ """QWen tokenizer."""
47
+
48
+ vocab_files_names = VOCAB_FILES_NAMES
49
+
50
+ def __init__(
51
+ self,
52
+ vocab_file,
53
+ errors="replace",
54
+ **kwargs,
55
+ ):
56
+ super().__init__(**kwargs)
57
+
58
+ self.errors = errors # how to handle errors in decoding
59
+
60
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
61
+ self.special_tokens = {
62
+ token: index
63
+ for index, token in enumerate(
64
+ SPECIAL_TOKENS, start=len(self.mergeable_ranks)
65
+ )
66
+ }
67
+
68
+ enc = tiktoken.Encoding(
69
+ "Qwen",
70
+ pat_str=PAT_STR,
71
+ mergeable_ranks=self.mergeable_ranks,
72
+ special_tokens=self.special_tokens,
73
+ )
74
+ assert (
75
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
76
+ ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
77
+
78
+ self.decoder = {
79
+ v: k for k, v in self.mergeable_ranks.items()
80
+ } # type: dict[int, bytes|str]
81
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
82
+
83
+ self.tokenizer = enc # type: tiktoken.Encoding
84
+
85
+ self.eod_id = self.tokenizer.eot_token
86
+ self.im_start_id = self.special_tokens[IMSTART]
87
+ self.im_end_id = self.special_tokens[IMEND]
88
+
89
+ def __len__(self) -> int:
90
+ return self.tokenizer.n_vocab
91
+
92
+ def get_vocab(self) -> Dict[bytes, int]:
93
+ return self.mergeable_ranks
94
+
95
+ def convert_tokens_to_ids(
96
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
97
+ ) -> List[int]:
98
+ ids = []
99
+ if isinstance(tokens, (str, bytes)):
100
+ if tokens in self.special_tokens:
101
+ return self.special_tokens[tokens]
102
+ else:
103
+ return self.mergeable_ranks.get(tokens)
104
+ for token in tokens:
105
+ if token in self.special_tokens:
106
+ ids.append(self.special_tokens[token])
107
+ else:
108
+ ids.append(self.mergeable_ranks.get(token))
109
+ return ids
110
+
111
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
112
+ if not special_tokens and new_tokens:
113
+ raise ValueError('Adding regular tokens is not supported')
114
+ for token in new_tokens:
115
+ surface_form = token.content if isinstance(token, AddedToken) else token
116
+ if surface_form not in SPECIAL_TOKENS:
117
+ raise ValueError('Adding unknown special tokens is not supported')
118
+ return 0
119
+
120
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
121
+ """
122
+ Save only the vocabulary of the tokenizer (vocabulary).
123
+
124
+ Returns:
125
+ `Tuple(str)`: Paths to the files saved.
126
+ """
127
+ file_path = os.path.join(save_directory, "qwen.tiktoken")
128
+ with open(file_path, "w", encoding="utf8") as w:
129
+ for k, v in self.mergeable_ranks.items():
130
+ line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
131
+ w.write(line)
132
+ return (file_path,)
133
+
134
+ def tokenize(
135
+ self,
136
+ text: str,
137
+ allowed_special: Union[Set, str] = "all",
138
+ disallowed_special: Union[Collection, str] = (),
139
+ **kwargs,
140
+ ) -> List[Union[bytes, str]]:
141
+ """
142
+ Converts a string in a sequence of tokens.
143
+
144
+ Args:
145
+ text (`str`):
146
+ The sequence to be encoded.
147
+ allowed_special (`Literal["all"]` or `set`):
148
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
149
+ Default to "all".
150
+ disallowed_special (`Literal["all"]` or `Collection`):
151
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
152
+ Default to an empty tuple.
153
+
154
+ kwargs (additional keyword arguments, *optional*):
155
+ Will be passed to the underlying model specific encode method.
156
+
157
+ Returns:
158
+ `List[bytes|str]`: The list of tokens.
159
+ """
160
+ tokens = []
161
+ text = unicodedata.normalize("NFC", text)
162
+
163
+ # this implementation takes a detour: text -> token id -> token surface forms
164
+ for t in self.tokenizer.encode(
165
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
166
+ ):
167
+ tokens.append(self.decoder[t])
168
+ return tokens
169
+
170
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
171
+ """
172
+ Converts a sequence of tokens in a single string.
173
+ """
174
+ text = ""
175
+ temp = b""
176
+ for t in tokens:
177
+ if isinstance(t, str):
178
+ if temp:
179
+ text += temp.decode("utf-8", errors=self.errors)
180
+ temp = b""
181
+ text += t
182
+ elif isinstance(t, bytes):
183
+ temp += t
184
+ else:
185
+ raise TypeError("token should only be of type types or str")
186
+ if temp:
187
+ text += temp.decode("utf-8", errors=self.errors)
188
+ return text
189
+
190
+ @property
191
+ def vocab_size(self):
192
+ return self.tokenizer.n_vocab
193
+
194
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
195
+ """Converts an id to a token, special tokens included"""
196
+ if index in self.decoder:
197
+ return self.decoder[index]
198
+ raise ValueError("unknown ids")
199
+
200
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
201
+ """Converts a token to an id using the vocab, special tokens included"""
202
+ if token in self.special_tokens:
203
+ return self.special_tokens[token]
204
+ if token in self.mergeable_ranks:
205
+ return self.mergeable_ranks[token]
206
+ raise ValueError("unknown token")
207
+
208
+ def _tokenize(self, text: str, **kwargs):
209
+ """
210
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
211
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
212
+
213
+ Do NOT take care of added tokens.
214
+ """
215
+ raise NotImplementedError
216
+
217
+ def _decode(
218
+ self,
219
+ token_ids: Union[int, List[int]],
220
+ skip_special_tokens: bool = False,
221
+ errors: str = None,
222
+ **kwargs,
223
+ ) -> str:
224
+ if isinstance(token_ids, int):
225
+ token_ids = [token_ids]
226
+ if skip_special_tokens:
227
+ token_ids = [i for i in token_ids if i < self.eod_id]
228
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)