yarkcy commited on
Commit
de4adcc
·
verified ·
1 Parent(s): 35b9bb2

Upload tokenization_bailing.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tokenization_bailing.py +1068 -0
tokenization_bailing.py ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+ # Copyright (c) Ant Group. All rights reserved.
4
+
5
+ import itertools
6
+ from typing import Any, Dict, List, Optional, Union
7
+
8
+ import torch
9
+ from transformers import PreTrainedTokenizerFast
10
+ from transformers.tokenization_utils_base import AddedToken, BatchEncoding
11
+ from transformers.utils import TensorType, logging
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ def is_system(msg):
17
+ return msg['role'].lower() == 'system'
18
+
19
+
20
+ def is_user(msg):
21
+ return msg['role'].lower() in ['human', 'user']
22
+
23
+
24
+ def is_assistant(msg):
25
+ return msg['role'].lower() == 'assistant'
26
+
27
+
28
+ def _convert_to_conversation(query, system=None):
29
+ conversation = []
30
+ if system:
31
+ conversation.append({"role": "SYSTEM", "content": system})
32
+ if isinstance(query, str):
33
+ conversation.append({"role": "HUMAN", "content": query})
34
+ elif isinstance(query, List):
35
+ conversation.extend(query)
36
+ elif isinstance(query, Dict):
37
+ if "messages" in query:
38
+ conversation.extend(query["messages"])
39
+ if "system_message" in query and len(conversation) > 0 and not is_system(conversation[0]):
40
+ conversation.insert(0, {"role": "SYSTEM", "content": query["system_message"]})
41
+ else:
42
+ conversation.append(query)
43
+ return conversation
44
+
45
+
46
+ class BailingTokenizer(PreTrainedTokenizerFast):
47
+ is_bailing_tokenizer = True
48
+ model_input_names = ["input_ids", "attention_mask"]
49
+ slow_tokenizer_class = None
50
+
51
+ # add gmask_token
52
+ SPECIAL_TOKENS_ATTRIBUTES = [
53
+ "bos_token",
54
+ "eos_token",
55
+ "unk_token",
56
+ "sep_token",
57
+ "pad_token",
58
+ "cls_token",
59
+ "mask_token",
60
+ "gmask_token",
61
+ "additional_special_tokens",
62
+ ]
63
+
64
+ def __init__(
65
+ self,
66
+ vocab_file=None,
67
+ merges_file=None,
68
+ tokenizer_file=None,
69
+ clean_up_tokenization_spaces=False,
70
+ bos_token="<|startoftext|>",
71
+ eos_token="<|endoftext|>",
72
+ cls_token="[CLS]",
73
+ pad_token="<|endoftext|>",
74
+ gmask_token="[gMASK]",
75
+ add_bos_token=False,
76
+ add_eos_token=False,
77
+ **kwargs,
78
+ ):
79
+ self.add_bos_token = add_bos_token
80
+
81
+ self._gmask_token = (
82
+ AddedToken(gmask_token, lstrip=False, rstrip=False, normalized=False)
83
+ if isinstance(gmask_token, str)
84
+ else gmask_token
85
+ )
86
+
87
+ self._sop_token = (
88
+ AddedToken(bos_token, lstrip=False, rstrip=False, normalized=False)
89
+ if isinstance(bos_token, str)
90
+ else bos_token
91
+ )
92
+
93
+ self._eop_token = (
94
+ AddedToken(eos_token, lstrip=False, rstrip=False, normalized=False)
95
+ if isinstance(eos_token, str)
96
+ else eos_token
97
+ )
98
+
99
+ super().__init__(
100
+ vocab_file=vocab_file,
101
+ merges_file=merges_file,
102
+ tokenizer_file=tokenizer_file,
103
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
104
+ bos_token=bos_token,
105
+ eos_token=eos_token,
106
+ cls_token=cls_token,
107
+ pad_token=pad_token,
108
+ gmask_token=gmask_token,
109
+ add_bos_token=add_bos_token,
110
+ add_eos_token=add_eos_token,
111
+ **kwargs,
112
+ )
113
+
114
+ self.check_special_tokens()
115
+
116
+ def check_special_tokens(self):
117
+ '''
118
+ eos_token, cls_token, mask_token
119
+ special tokens should init, check special token is not None
120
+ '''
121
+ for name, special_token in zip(
122
+ ['eos', 'bos', 'cls', 'gmask'],
123
+ [self.eos_token, self.bos_token, self.cls_token, self.gmask_token],
124
+ ):
125
+ assert special_token is not None, f'should init special token [{name}] in tokenizer_config.json'
126
+
127
+ @property
128
+ def gmask_token(self) -> Optional[str]:
129
+ if self._gmask_token is None:
130
+ if self.verbose:
131
+ logger.error("Using gmask_token, but it is not set yet.")
132
+ return None
133
+ return str(self._gmask_token)
134
+
135
+ @gmask_token.setter
136
+ def gmask_token(self, value):
137
+ if not isinstance(value, (str, AddedToken)) and value is not None:
138
+ raise ValueError("Cannot set a non-string value as the gmask token")
139
+ self._gmask_token = value
140
+
141
+ @property
142
+ def gmask_token_id(self) -> Optional[int]:
143
+ if self._gmask_token is None:
144
+ return None
145
+ return self.convert_tokens_to_ids(self.gmask_token)
146
+
147
+ @property
148
+ def sop_token(self) -> Optional[str]:
149
+ if self._sop_token is None:
150
+ if self.verbose:
151
+ logger.error("Using sop_token, but it is not set yet.")
152
+ return None
153
+ return str(self._sop_token)
154
+
155
+ @sop_token.setter
156
+ def sop_token(self, value):
157
+ if not isinstance(value, (str, AddedToken)) and value is not None:
158
+ raise ValueError("Cannot set a non-string value as the sop token")
159
+ self._sop_token = value
160
+
161
+ @property
162
+ def sop_token_id(self) -> Optional[int]:
163
+ if self._sop_token is None:
164
+ return None
165
+ return self.convert_tokens_to_ids(self.sop_token)
166
+
167
+ @property
168
+ def eop_token(self) -> Optional[str]:
169
+ if self._eop_token is None:
170
+ if self.verbose:
171
+ logger.error("Using eop_token, but it is not set yet.")
172
+ return None
173
+ return str(self._eop_token)
174
+
175
+ @eop_token.setter
176
+ def eop_token(self, value):
177
+ if not isinstance(value, (str, AddedToken)) and value is not None:
178
+ raise ValueError("Cannot set a non-string value as the eop token")
179
+ self._eop_token = value
180
+
181
+ @property
182
+ def eop_token_id(self) -> Optional[int]:
183
+ if self._eop_token is None:
184
+ return None
185
+ return self.convert_tokens_to_ids(self.eop_token)
186
+
187
+ @property
188
+ def vocab_size(self):
189
+ return len(self.get_vocab())
190
+
191
+ def _chat_from_json(self, chat, chat_format="antglm_chat", system=None):
192
+ msgs = chat if "messages" not in chat else chat["messages"]
193
+ _msgs = []
194
+ sys_msg = None
195
+ for msg in msgs:
196
+ if is_system(msg):
197
+ sys_msg = msg['content']
198
+ else:
199
+ _msgs.append(msg)
200
+ chat = {"messages": _msgs}
201
+ system = system or sys_msg
202
+ if system:
203
+ chat['system_message'] = system
204
+ from .chat_format import Chat
205
+
206
+ return Chat.from_json(chat, name=chat_format)
207
+
208
+ def apply_chat_template(
209
+ self,
210
+ conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
211
+ tools: Optional[List[Dict]] = None,
212
+ documents: Optional[List[Dict[str, str]]] = None,
213
+ chat_template: Optional[str] = None,
214
+ add_generation_prompt: bool = False,
215
+ system: str = None, # only used for legacy chatml
216
+ tokenize=False,
217
+ padding: bool = False,
218
+ truncation: bool = False,
219
+ max_length: Optional[int] = None,
220
+ return_tensors: Optional[Union[str, TensorType]] = None,
221
+ return_dict: bool = False,
222
+ return_assistant_tokens_mask: bool = False,
223
+ tokenizer_kwargs: Optional[Dict[str, Any]] = None,
224
+ **kwargs,
225
+ ):
226
+ if hasattr(self, "chat_template") and self.chat_template:
227
+ if isinstance(conversation, Dict) and "messages" in conversation:
228
+ conversation = conversation["messages"]
229
+ # use transformers built-in method
230
+ return super().apply_chat_template(
231
+ conversation=conversation,
232
+ tools=tools,
233
+ documents=documents,
234
+ chat_template=chat_template,
235
+ add_generation_prompt=add_generation_prompt,
236
+ tokenize=tokenize,
237
+ padding=padding,
238
+ truncation=truncation,
239
+ return_tensors=return_tensors,
240
+ return_dict=return_dict,
241
+ return_assistant_tokens_mask=return_assistant_tokens_mask,
242
+ tokenizer_kwargs=tokenizer_kwargs,
243
+ )
244
+
245
+ # 非chat_template方式后续将不再支持。
246
+ logger.warning("Please set chat_template in tokenizer_config.json!")
247
+
248
+ chat_format = kwargs.get('chat_format', 'antglm_chat')
249
+
250
+ is_batched = False
251
+
252
+ if isinstance(conversation, List) and (
253
+ isinstance(conversation[0], (list, tuple)) or "messages" in conversation[0]
254
+ ):
255
+ conversations = conversation
256
+ is_batched = True
257
+
258
+ if not is_batched:
259
+ conversations = [conversation]
260
+
261
+ rendered = []
262
+ for chat in conversations:
263
+ rendered_chat = self._chat_from_json(chat, chat_format=chat_format, system=system).prompt_str
264
+ rendered.append(rendered_chat)
265
+
266
+ if not is_batched:
267
+ rendered = rendered[0]
268
+
269
+ if tokenize:
270
+ out = self(
271
+ rendered,
272
+ padding=padding,
273
+ truncation=truncation,
274
+ max_length=max_length,
275
+ add_special_tokens=False,
276
+ return_tensors=return_tensors,
277
+ )
278
+ if return_dict:
279
+ return out
280
+ else:
281
+ return out["input_ids"]
282
+ else:
283
+ return rendered
284
+
285
+ def _build_position_ids(
286
+ self,
287
+ mask_pos: int,
288
+ bos_pos: int,
289
+ max_output_length: int,
290
+ rotary_type: Optional[str] = "none",
291
+ **kwargs,
292
+ ) -> List[List[int]]:
293
+ window_size = kwargs.get("window_size", 1024) - 1
294
+ block_position_ids = [0] * bos_pos
295
+
296
+ # 获得mask所在的位置,用于后面output positionid的构造
297
+ if "1d" in rotary_type:
298
+ position_ids = list(range(bos_pos)) + list(range(mask_pos + 1, mask_pos + max_output_length + 2))
299
+ block_position_ids = block_position_ids + list(range(1, max_output_length + 2))
300
+ elif "2d" in rotary_type:
301
+ # 后面input_ids要加一个bos_id
302
+ position_ids = list(range(bos_pos))
303
+ position_ids = position_ids + [mask_pos] * (1 + max_output_length)
304
+ block_position_ids = block_position_ids + list(range(1, max_output_length + 2))
305
+ else:
306
+ # build position ids
307
+ position_ids = []
308
+ repeat_times = bos_pos // window_size
309
+ for _ in range(repeat_times):
310
+ position_ids += list(range(window_size))
311
+ position_ids += list(range(bos_pos - window_size * repeat_times))
312
+ # need consider additional bos_id after input_ids
313
+ mask_pos = position_ids[-1]
314
+ position_ids += [mask_pos] * (max_output_length + 1)
315
+
316
+ block_repeat_times = max_output_length // (window_size - 1)
317
+ additional_block_position_ids = []
318
+ for _ in range(block_repeat_times):
319
+ additional_block_position_ids += list(range(1, window_size))
320
+ additional_block_position_ids += list(
321
+ range(1, max_output_length + 2 - (window_size - 1) * block_repeat_times)
322
+ )
323
+ block_position_ids = block_position_ids + additional_block_position_ids
324
+
325
+ position_ids = [position_ids, block_position_ids]
326
+ return position_ids
327
+
328
+ def _build_inputs_for_generation(
329
+ self,
330
+ input_ids: List[int],
331
+ max_input_length=None,
332
+ left_truncate=True,
333
+ max_output_length=1024,
334
+ rotary_type="none",
335
+ unidirectional_attention: bool = True,
336
+ attention_dtype=None,
337
+ **kwargs,
338
+ ):
339
+ if max_input_length and len(input_ids) > max_input_length:
340
+ if left_truncate:
341
+ input_ids = input_ids[-max_input_length:]
342
+ else:
343
+ input_ids = input_ids[:max_input_length]
344
+
345
+ is_left_padding = input_ids[0] == self.eos_token_id
346
+ if not unidirectional_attention:
347
+ if input_ids[0] != self.cls_token_id:
348
+ input_ids = [self.cls_token_id] + input_ids
349
+
350
+ if self.gmask_token_id not in set(input_ids):
351
+ input_ids = input_ids + [self.gmask_token_id]
352
+
353
+ mask_pos = input_ids.index(self.gmask_token_id)
354
+ sep = len(input_ids)
355
+ else:
356
+ if self.add_bos_token:
357
+ input_ids = input_ids + [self.bos_token_id]
358
+ if self.eos_token_id in input_ids:
359
+ mask_pos = input_ids.index(self.eos_token_id) - 1
360
+ else:
361
+ mask_pos = len(input_ids) - 1
362
+ sep = len(input_ids) - 1
363
+ else:
364
+ sep = len(input_ids)
365
+ if self.eos_token_id in input_ids:
366
+ if is_left_padding:
367
+ ori_input_ids = input_ids
368
+ input_ids = input_ids[::-1]
369
+ mask_pos = input_ids.index(self.eos_token_id) - 1
370
+ mask_pos = max(0, mask_pos) # for empty sequence
371
+ if is_left_padding:
372
+ input_ids = ori_input_ids
373
+ mask_pos = sep - 1 - mask_pos # the first non-eos token
374
+
375
+ else:
376
+ mask_pos = len(input_ids) - 1
377
+
378
+ position_ids = self._build_position_ids(mask_pos, sep, max_output_length, rotary_type, **kwargs)
379
+
380
+ if is_left_padding:
381
+ position_ids[0] = [max(0, i - mask_pos) for i in range(len(position_ids[0]))]
382
+
383
+ # 后面input_ids要加一个bos_id
384
+ total_length = sep + max_output_length
385
+ if self.add_bos_token:
386
+ total_length += 1
387
+
388
+ def build_mask_matrix(seq_length, sep, mask_pos, unidirectional_attention):
389
+ # 长序列使用bool类型节省显存
390
+ if unidirectional_attention:
391
+ attention_mask = torch.ones([seq_length, seq_length], dtype=attention_dtype)
392
+ attention_mask = torch.tril(attention_mask)
393
+ if is_left_padding:
394
+ attention_mask[:, :mask_pos] = 0
395
+ else:
396
+ attention_mask[:, mask_pos + 1 : sep] = 0
397
+ else:
398
+ attention_mask = torch.zeros([seq_length, seq_length], dtype=attention_dtype)
399
+ attention_mask[:, : mask_pos + 1] = 1
400
+ for i in range(sep, total_length):
401
+ attention_mask[i, sep : i + 1] = 1
402
+ return attention_mask
403
+
404
+ if self.add_bos_token:
405
+ attention_mask = build_mask_matrix(total_length, sep + 1, mask_pos, unidirectional_attention)
406
+ else:
407
+ attention_mask = build_mask_matrix(total_length, sep, mask_pos, unidirectional_attention)
408
+ attention_mask = torch.unsqueeze(attention_mask, dim=0)
409
+ attention_mask = torch.unsqueeze(attention_mask, dim=1)
410
+ if attention_dtype is None:
411
+ attention_mask = attention_mask.long()
412
+ inputs = {
413
+ "input_ids": torch.Tensor([input_ids]).long(),
414
+ "position_ids": torch.Tensor([position_ids]).long(),
415
+ "attention_mask": attention_mask,
416
+ }
417
+ return BatchEncoding(inputs)
418
+
419
+ def build_inputs_for_generation(
420
+ self,
421
+ input_ids: Union[List[int], List[List[int]], torch.Tensor],
422
+ max_input_length=None,
423
+ left_truncate=True,
424
+ max_output_length=1024,
425
+ rotary_type="1d",
426
+ unidirectional_attention=True,
427
+ attention_dtype=None,
428
+ **kwargs,
429
+ ):
430
+ if isinstance(input_ids, torch.Tensor):
431
+ input_ids = input_ids.tolist()
432
+
433
+ if isinstance(input_ids[0], list):
434
+ input_ids_list = []
435
+ position_ids_list = []
436
+ attention_mask_list = []
437
+ for _input_ids in input_ids:
438
+ inputs = self._build_inputs_for_generation(
439
+ _input_ids,
440
+ max_input_length=max_input_length,
441
+ left_truncate=left_truncate,
442
+ max_output_length=max_output_length,
443
+ rotary_type=rotary_type,
444
+ unidirectional_attention=unidirectional_attention,
445
+ attention_dtype=attention_dtype,
446
+ **kwargs,
447
+ )
448
+ input_ids_list.append(inputs['input_ids'])
449
+ position_ids_list.append(inputs['position_ids'])
450
+ attention_mask_list.append(inputs["attention_mask"])
451
+
452
+ max_ids_length = max([input.size(1) for input in input_ids_list])
453
+
454
+ for i in range(len(input_ids)):
455
+ cur_ids_length = input_ids_list[i].size(1)
456
+ if cur_ids_length < max_ids_length:
457
+ # pad input ids
458
+ pad_input_ids = input_ids_list[i].new_zeros((1, max_ids_length - cur_ids_length))
459
+ input_ids_list[i] = torch.cat([pad_input_ids, input_ids_list[i]], dim=-1)
460
+
461
+ # pad postition ids with left pad
462
+ # 0, 1, 2, 3, 4 ... -> 0, ..., 0, 1, 2, 3, 4, ...
463
+ pad_position_ids = input_ids_list[i].new_zeros((1, 2, max_ids_length - cur_ids_length))
464
+ position_ids_list[i] = torch.cat([pad_position_ids, position_ids_list[i]], dim=-1)
465
+
466
+ # pad generation attention mask with left and bottom pad
467
+ new_attention_mask = input_ids_list[i].new_zeros(
468
+ 1,
469
+ 1,
470
+ max_ids_length + max_output_length,
471
+ max_ids_length + max_output_length,
472
+ )
473
+ new_attention_mask[
474
+ :,
475
+ :,
476
+ max_ids_length - cur_ids_length :,
477
+ max_ids_length - cur_ids_length :,
478
+ ] = attention_mask_list[i]
479
+ attention_mask_list[i] = new_attention_mask.contiguous()
480
+
481
+ input_ids_list = torch.cat(input_ids_list, dim=0)
482
+ position_ids_list = torch.cat(position_ids_list, dim=0)
483
+ attention_mask_list = torch.cat(attention_mask_list, dim=0)
484
+
485
+ inputs = {
486
+ "input_ids": input_ids_list,
487
+ "position_ids": position_ids_list,
488
+ "attention_mask": attention_mask_list,
489
+ }
490
+
491
+ return BatchEncoding(inputs)
492
+ else:
493
+ return self._build_inputs_for_generation(
494
+ input_ids,
495
+ max_input_length=max_input_length,
496
+ left_truncate=left_truncate,
497
+ max_output_length=max_output_length,
498
+ rotary_type=rotary_type,
499
+ unidirectional_attention=unidirectional_attention,
500
+ **kwargs,
501
+ )
502
+
503
+ def _build_inputs_for_train(
504
+ self,
505
+ inputs: Union[str, List[str]],
506
+ outputs: Union[str, List[str]],
507
+ new_conversation_offset: List[int] = None,
508
+ max_length: int = 2048,
509
+ rotary_type: str = "1d",
510
+ left_truncate: bool = True,
511
+ unidirectional_attention: bool = True,
512
+ isolation_position_ids: bool = False,
513
+ padding: bool = True,
514
+ use_fa2: bool = True,
515
+ use_packed: bool = True,
516
+ use_baichuan_packed: bool = False,
517
+ skip_truncated_turn: bool = False,
518
+ return_attention_mask: bool = True,
519
+ ):
520
+ r"""
521
+ Build tensor input for model training. If inputs and outputs are list, will pack them.
522
+
523
+ Args:
524
+ inputs (str, List[str], List[Dict], List[List[Dict]]): the input prompts.
525
+ outputs (str, List[str]): the output responses.
526
+ max_length (int, Optional): the maximum length of the final input ids for training. Default: 2048
527
+ rotary_type (str, Optional): the rotary type of position embedding. Default: 1d
528
+ left_truncate (bool, Optional): whether truncate the inputs from left. Default: True
529
+ use_fa2 (bool, Optional): whether to build attention mask under flash attention 2.
530
+ new_conversation_offset (List[int], Optional): 第idx条样本是全新的对话,[0, 1]代表:inputs[0]和outputs[0]是一个对话,inputs[1]和outputs[1]是一个对话.
531
+ """
532
+ if use_packed and use_baichuan_packed and unidirectional_attention:
533
+ return self._build_baichuan_inputs_for_train(
534
+ inputs,
535
+ outputs,
536
+ new_conversation_offset,
537
+ max_length,
538
+ rotary_type,
539
+ left_truncate,
540
+ skip_truncated_turn,
541
+ use_fa2,
542
+ padding,
543
+ )
544
+ if isinstance(inputs, str):
545
+ inputs = [inputs]
546
+ if isinstance(outputs, str):
547
+ outputs = [outputs]
548
+
549
+ assert len(inputs) == len(outputs)
550
+
551
+ input_ids = [self(item)['input_ids'] for item in inputs]
552
+ output_ids = [self(item)['input_ids'] for item in outputs]
553
+
554
+ packed_input_ids = []
555
+ packed_output_ids = []
556
+ if new_conversation_offset is None:
557
+ new_conversation_offset = list(range(0, len(inputs)))
558
+ assert 0 in new_conversation_offset, f"没有0,请检查new_conversation_offset: {new_conversation_offset}"
559
+ current_len = 0
560
+
561
+ for idx, (input, output) in enumerate(zip(input_ids, output_ids)):
562
+ num_special_tokens = 0
563
+ if not unidirectional_attention:
564
+ if idx in new_conversation_offset:
565
+ # cls and gmask
566
+ num_special_tokens += 2
567
+ else:
568
+ # only gmask
569
+ num_special_tokens += 1
570
+ else:
571
+ # sop and eos
572
+ if self.add_bos_token:
573
+ num_special_tokens += 2
574
+ else:
575
+ num_special_tokens += 1
576
+
577
+ # truncate
578
+ if len(input) + len(output) + current_len > max_length - num_special_tokens:
579
+ if not use_packed or use_fa2 and unidirectional_attention:
580
+ attention_mask = torch.tensor(0)
581
+ elif use_fa2:
582
+ attention_mask = -1 * torch.ones([2, max_length])
583
+ else:
584
+ attention_mask = torch.tril(torch.ones([max_length, max_length]))
585
+ # 返回一个空的样本,该样本不参与训练
586
+ default_return = {
587
+ 'input_ids': (torch.ones(max_length) * self.eos_token_id).long(),
588
+ 'position_ids': torch.zeros(2, max_length).long(),
589
+ 'attention_mask': (attention_mask.long()),
590
+ 'labels': (torch.ones(max_length) * -100).long(),
591
+ }
592
+ # 如果不截断,直接返回
593
+ if skip_truncated_turn:
594
+ if current_len == 0:
595
+ return default_return
596
+ else:
597
+ break
598
+ left_len = max_length - num_special_tokens - current_len
599
+ # 如果截断,只截断prompt
600
+ if left_len - len(output) > 0:
601
+ if left_truncate:
602
+ input = input[-(left_len - len(output)) :]
603
+ else:
604
+ input = input[: left_len - len(output)]
605
+ else:
606
+ # response超过left_len,直接返回
607
+ if current_len == 0:
608
+ return default_return
609
+ else:
610
+ break
611
+ if unidirectional_attention:
612
+ packed_input_ids.append(list(input))
613
+ else:
614
+ if num_special_tokens == 4:
615
+ packed_input_ids.append([self.cls_token_id] + list(input) + [self.gmask_token_id])
616
+ else:
617
+ packed_input_ids.append(list(input) + [self.gmask_token_id])
618
+
619
+ packed_output_ids.append(list(output) + [self.eos_token_id])
620
+ current_len += len(input) + len(output) + num_special_tokens
621
+
622
+ assert current_len <= max_length
623
+
624
+ if use_packed:
625
+ # pack模式
626
+ def build_mask_matrix(seq_length, sep):
627
+ # https://github.com/pytorch/pytorch/issues/101932, fix triu/tril bf16 support
628
+ m = torch.ones((1, seq_length, seq_length))
629
+ mask = torch.arange(1, m.shape[-1] + 1).reshape(1, -1, 1).to(m.device)
630
+ ids = torch.arange(1, m.shape[-1] + 1).reshape(1, 1, -1).expand(1, m.shape[-1], -1).to(m.device)
631
+ m = (ids <= mask).type_as(m)
632
+
633
+ m[0, :, : int(sep)] = 1
634
+ m = m.squeeze(0)
635
+ return m
636
+
637
+ tokens = []
638
+ attention_mask_list = []
639
+ input_length_list = []
640
+ position_id_list = []
641
+ block_position_id_list = []
642
+ for input, output in zip(packed_input_ids, packed_output_ids):
643
+ if self.add_bos_token:
644
+ data = input + [self.sop_token_id] + output
645
+ mask_pos = len(input) - 1
646
+ else:
647
+ data = input + output
648
+ mask_pos = len(input) - 2
649
+ if return_attention_mask:
650
+ if unidirectional_attention:
651
+ attention_mask = build_mask_matrix(len(data), 0)
652
+ else:
653
+ attention_mask = build_mask_matrix(len(data), len(input))
654
+ attention_mask = attention_mask.squeeze((0, 1))
655
+
656
+ attention_mask_list.append(attention_mask)
657
+ input_length_list.append(len(input))
658
+ tokens += data
659
+
660
+ sop_pos = mask_pos + 1
661
+ position_ids, block_position_ids = self._build_position_ids(
662
+ mask_pos=mask_pos, bos_pos=sop_pos, max_output_length=len(output), rotary_type=rotary_type
663
+ )
664
+
665
+ position_id_list.append(position_ids)
666
+ block_position_id_list.append(block_position_ids)
667
+
668
+ labels = []
669
+ for i in range(len(packed_input_ids)):
670
+ if self.add_bos_token:
671
+ labels += [-100] * len(packed_input_ids[i]) + packed_output_ids[i] + [-100]
672
+ else:
673
+ labels += [-100] * (len(packed_input_ids[i]) - 1) + packed_output_ids[i] + [-100]
674
+
675
+ total_len = 0
676
+ if use_fa2:
677
+ pack_attention_mask = -1 * torch.ones([2, current_len])
678
+ else:
679
+ pack_attention_mask = torch.tril(torch.ones([current_len, current_len]))
680
+
681
+ pack_position_ids = []
682
+ pack_block_position_ids = []
683
+ total_len = 0
684
+ max_index = 0
685
+ for i in range(len(position_id_list)):
686
+
687
+ if use_fa2:
688
+ pack_attention_mask[0][i] = total_len
689
+ pack_attention_mask[1][i] = total_len + input_length_list[i]
690
+ else:
691
+ pack_attention_mask[
692
+ total_len : total_len + attention_mask.shape[0],
693
+ total_len : total_len + attention_mask.shape[0],
694
+ ] = attention_mask
695
+ position_ids = [pid + max_index for pid in position_id_list[i]]
696
+ block_position_ids = block_position_id_list[i]
697
+ pack_position_ids.extend(position_ids)
698
+ pack_block_position_ids.extend(block_position_ids)
699
+ if not isolation_position_ids:
700
+ max_index = pack_position_ids[-1] + 1
701
+ total_len += len(position_id_list[i])
702
+ position_ids = [pack_position_ids, pack_block_position_ids]
703
+ else:
704
+ # 单输入模式
705
+ # 真多轮下,一条样本可能会有好几轮对话,此时需要获取第一条样本的结束位置
706
+ if len(new_conversation_offset) > 1:
707
+ end_idx = new_conversation_offset[1]
708
+ else:
709
+ end_idx = 1
710
+ input, output = list(itertools.chain(*packed_input_ids[:end_idx])), list(
711
+ itertools.chain(*packed_output_ids[:end_idx])
712
+ )
713
+ if self.add_bos_token:
714
+ tokens = input + [self.sop_token_id] + output
715
+ else:
716
+ tokens = input + output
717
+
718
+ if self.add_bos_token:
719
+ labels = [-100] * len(input) + output + [-100]
720
+ position_ids = self._build_position_ids(
721
+ mask_pos=len(input) - 1, bos_pos=len(input), max_output_length=len(output), rotary_type=rotary_type
722
+ )
723
+ else:
724
+ labels = [-100] * (len(input) - 1) + output + [-100]
725
+ position_ids = self._build_position_ids(
726
+ mask_pos=len(input) - 2,
727
+ bos_pos=len(input) - 1,
728
+ max_output_length=len(output),
729
+ rotary_type=rotary_type,
730
+ )
731
+ attention_mask = len(input)
732
+ assert current_len == len(tokens)
733
+
734
+ # 最大长度补全
735
+ if max_length > 0 and len(tokens) < max_length and padding:
736
+ pad_length = max_length - len(tokens)
737
+ tokens += [self.pad_token_id] * pad_length
738
+ labels.extend([-100] * pad_length)
739
+ position_ids[0] += [0] * pad_length
740
+ position_ids[1] += [0] * pad_length
741
+
742
+ if use_packed:
743
+ if use_fa2:
744
+ new_attention_mask = -1 * torch.ones([2, max_length])
745
+ new_attention_mask[:, :current_len] = pack_attention_mask
746
+ else:
747
+ new_attention_mask = torch.tril(torch.ones([max_length, max_length]))
748
+ new_attention_mask[:current_len, :current_len] = pack_attention_mask
749
+ pack_attention_mask = new_attention_mask.contiguous()
750
+
751
+ assert len(tokens) == len(labels)
752
+
753
+ if max_length > 0 and padding:
754
+ assert len(tokens) == max_length
755
+
756
+ if use_fa2 and unidirectional_attention:
757
+ # pack_attention_mask = torch.zeros([1], dtype=torch.long)
758
+ pack_attention_mask = torch.tensor(0)
759
+
760
+ if use_packed:
761
+ if not use_fa2:
762
+ attention_mask = pack_attention_mask.unsqueeze(0).long()
763
+ else:
764
+ attention_mask = pack_attention_mask
765
+ else:
766
+ attention_mask = torch.tensor(attention_mask).long()
767
+ return {
768
+ 'input_ids': torch.tensor(tokens).long(),
769
+ 'position_ids': torch.tensor(position_ids).long(),
770
+ 'attention_mask': attention_mask,
771
+ 'labels': torch.tensor(labels).long(),
772
+ }
773
+
774
+ def _build_baichuan_inputs_for_train(
775
+ self,
776
+ inputs: Union[str, List[str]],
777
+ outputs: Union[str, List[str]],
778
+ new_conversation_offset: List[int] = None,
779
+ max_length: int = 2048,
780
+ rotary_type: str = "1d",
781
+ left_truncate: bool = True,
782
+ skip_truncated_turn: bool = True,
783
+ use_fa2: bool = True,
784
+ padding: bool = True,
785
+ ):
786
+ '''
787
+ input: <role> HUMAN </role> u1 <role> ASSISTANT </role> a11 a12 <role> HUMAN </role> u2 <role> ASSISTANT </role> a21 a22 <|endoftext|> <role> HUMAN </role> u1 <role> ASSISTANT </role> a11 a12 <role> HUMAN </role> u2 <role> ASSISTANT </role> a21 a22 <|endoftext|>
788
+ output: x x x x x x a11 a12 <|endoftext|> x x x x x x a21 a22 <|endoftext|> x x x x x x x a11 a12 <|endoftext|> x x x x x x a21 a22 <|endoftext|> x
789
+ 只适用真多轮+pack数据训练单向模型,需要打开use_true_multiturn
790
+ '''
791
+ if isinstance(inputs, str):
792
+ inputs = [inputs]
793
+ if isinstance(outputs, str):
794
+ outputs = [outputs]
795
+ assert len(inputs) == len(outputs)
796
+
797
+ input_ids = [self(item)['input_ids'] for item in inputs]
798
+ output_ids = [self(item)['input_ids'] for item in outputs]
799
+
800
+ packed_input_ids = []
801
+ packed_output_ids = []
802
+
803
+ if new_conversation_offset is None:
804
+ new_conversation_offset = list(range(0, len(inputs)))
805
+ assert 0 in new_conversation_offset, f"没有0,请检查new_conversation_offset: {new_conversation_offset}"
806
+ current_len = 0
807
+
808
+ for idx, (input, output) in enumerate(zip(input_ids, output_ids)):
809
+ num_special_tokens = 0
810
+ if idx != 0 and idx in new_conversation_offset:
811
+ # 在input_ids加入eos,只有第0条样本不加
812
+ num_special_tokens += 1
813
+
814
+ # truncate
815
+ if len(input) + len(output) + current_len > max_length - num_special_tokens:
816
+ if use_fa2:
817
+ attention_mask = torch.tensor(0)
818
+ else:
819
+ attention_mask = torch.tril(torch.ones([max_length, max_length]))
820
+ # 返回一个空的样本,该样本不参与训练
821
+ default_return = {
822
+ 'input_ids': (torch.ones(max_length) * self.eos_token_id).long(),
823
+ 'position_ids': torch.zeros(2, max_length).long(),
824
+ 'attention_mask': (attention_mask.long()),
825
+ 'labels': (torch.ones(max_length) * -100).long(),
826
+ }
827
+
828
+ # 如果不截断,直接返回
829
+ if skip_truncated_turn:
830
+ if current_len == 0:
831
+ return default_return
832
+ else:
833
+ break
834
+ left_len = max_length - num_special_tokens - current_len
835
+ # 如果截断,只截断prompt
836
+ if left_len - len(output) > 0:
837
+ if left_truncate:
838
+ input = input[-(left_len - len(output)) :]
839
+ else:
840
+ input = input[: left_len - len(output)]
841
+ else:
842
+ # response超过left_len,直接返回
843
+ if current_len == 0:
844
+ return default_return
845
+ else:
846
+ break
847
+ # 这里拼的是input_ids
848
+ if num_special_tokens == 1:
849
+ packed_input_ids.append([self.eos_token_id] + list(input))
850
+ else:
851
+ packed_input_ids.append(list(input))
852
+ packed_output_ids.append(list(output))
853
+ current_len += len(input) + len(output) + num_special_tokens
854
+ assert current_len <= max_length
855
+
856
+ def build_mask_matrix(seq_length, sep):
857
+ # https://github.com/pytorch/pytorch/issues/101932, fix triu/tril bf16 support
858
+ m = torch.ones((1, seq_length, seq_length))
859
+ mask = torch.arange(1, m.shape[-1] + 1).reshape(1, -1, 1).to(m.device)
860
+ ids = torch.arange(1, m.shape[-1] + 1).reshape(1, 1, -1).expand(1, m.shape[-1], -1).to(m.device)
861
+ m = (ids <= mask).type_as(m)
862
+
863
+ m[0, :, : int(sep)] = 1
864
+ m = m.squeeze(0)
865
+ return m
866
+
867
+ tokens = []
868
+ attention_mask_list = []
869
+ position_id_list = []
870
+ block_position_id_list = []
871
+ token_lens = []
872
+ for input, output in zip(packed_input_ids, packed_output_ids):
873
+ data = input + output
874
+ if not use_fa2:
875
+ attention_mask = build_mask_matrix(len(data), 0)
876
+ attention_mask_list.append(attention_mask)
877
+ tokens += data
878
+ token_lens.append(len(data))
879
+
880
+ position_ids, block_position_ids = self._build_position_ids(
881
+ mask_pos=len(input) - 2, bos_pos=len(input) - 1, max_output_length=len(output), rotary_type=rotary_type
882
+ )
883
+
884
+ position_id_list.append(position_ids)
885
+ block_position_id_list.append(block_position_ids)
886
+
887
+ labels = []
888
+ for i in range(len(packed_input_ids)):
889
+ labels += [-100] * (len(packed_input_ids[i]) - 1) + packed_output_ids[i] + [self.eos_token_id]
890
+
891
+ total_len = 0
892
+ if use_fa2:
893
+ pack_attention_mask = torch.Tensor([[0], [1]])
894
+ else:
895
+ pack_attention_mask = torch.tril(torch.ones([max_length, max_length]))
896
+
897
+ pack_position_ids = []
898
+ pack_block_position_ids = []
899
+ total_len = 0
900
+ max_index = 0
901
+ for i in range(len(token_lens)):
902
+ if not use_fa2:
903
+ attention_mask = attention_mask_list[i]
904
+ pack_attention_mask[
905
+ total_len : total_len + attention_mask.shape[0], total_len : total_len + attention_mask.shape[0]
906
+ ] = attention_mask
907
+ position_ids = [pid + max_index for pid in position_id_list[i]]
908
+ block_position_ids = block_position_id_list[i]
909
+ pack_position_ids.extend(position_ids)
910
+ pack_block_position_ids.extend(block_position_ids)
911
+ max_index = pack_position_ids[-1] + 1
912
+ total_len += token_lens[i]
913
+ position_ids = [pack_position_ids, pack_block_position_ids]
914
+
915
+ if max_length > 0 and len(tokens) < max_length and padding:
916
+ pad_length = max_length - len(tokens)
917
+ tokens += [self.pad_token_id] * pad_length
918
+ labels.extend([-100] * pad_length)
919
+ position_ids[0] += [0] * pad_length
920
+ position_ids[1] += [0] * pad_length
921
+
922
+ assert len(tokens) == len(labels)
923
+
924
+ if not use_fa2:
925
+ attention_mask = pack_attention_mask.unsqueeze(0).long()
926
+ else:
927
+ attention_mask = torch.tensor(0)
928
+ return {
929
+ 'input_ids': torch.tensor(tokens).long(),
930
+ 'position_ids': torch.tensor(position_ids).long(),
931
+ 'attention_mask': attention_mask,
932
+ 'labels': torch.tensor(labels).long(),
933
+ }
934
+
935
+ def build_inputs_for_train(
936
+ self,
937
+ data: Union[Dict, List[Dict]],
938
+ new_conversation_offset: List[int] = None,
939
+ chat_format="antglm_chat",
940
+ is_chat_format=True, # 如果传入的是字符串,用于说明是否已经是
941
+ use_true_multiturn=False,
942
+ max_length: int = 2048,
943
+ rotary_type: str = "1d",
944
+ left_truncate: bool = True,
945
+ unidirectional_attention: bool = True,
946
+ isolation_position_ids: bool = False,
947
+ padding: bool = True,
948
+ use_fa2: bool = True,
949
+ use_packed: bool = True,
950
+ use_baichuan_packed: bool = False,
951
+ skip_truncated_turn: bool = False,
952
+ return_attention_mask: bool = True,
953
+ ):
954
+ r"""
955
+ Build tensor input for model training. If inputs and outputs are list, will pack them.
956
+
957
+ Args:
958
+ inputs (str, List[str], List[Dict], List[List[Dict]]): the input prompts.
959
+ outputs (str, List[str]): the output responses.
960
+ new_conversation_offset (List[int]): the offset index of the new conversation turn.
961
+ is_chat_format (bool): whether the input is already chatml format
962
+ max_length (int, Optional): the maximum length of the final input ids for training. Default: 2048
963
+ rotary_type (str, Optional): the rotary type of position embedding. Default: 1d
964
+ left_truncate (bool, Optional): whether truncate the inputs from left. Default: True
965
+ use_fa2 (bool, Optional): whether to build attention mask under flash attention 2.
966
+ """
967
+ if isinstance(data, List):
968
+ # chatml list
969
+ _inputs = []
970
+ _outputs = []
971
+ new_conversation_offset = []
972
+ for _input in data:
973
+ if use_true_multiturn:
974
+ chat = self._chat_from_json(_input, chat_format=chat_format)
975
+ chat_data = chat.prompt_pack
976
+ new_conversation_offset.append(len(_inputs))
977
+ _inputs.extend(chat_data['input'])
978
+ _outputs.extend(chat_data['output'])
979
+ else:
980
+ _conversation = _convert_to_conversation(_input)
981
+ assert is_assistant(_conversation[-1])
982
+
983
+ _inputs.append(
984
+ self.apply_chat_template(_conversation[:-1], tokenize=False, add_generation_prompt=True)
985
+ )
986
+ _outputs.append(_conversation[-1]['content'])
987
+
988
+ return self._build_inputs_for_train(
989
+ inputs=_inputs,
990
+ outputs=_outputs,
991
+ new_conversation_offset=new_conversation_offset,
992
+ max_length=max_length,
993
+ rotary_type=rotary_type,
994
+ left_truncate=left_truncate,
995
+ unidirectional_attention=unidirectional_attention,
996
+ isolation_position_ids=isolation_position_ids,
997
+ padding=padding,
998
+ use_fa2=use_fa2,
999
+ use_packed=use_packed,
1000
+ use_baichuan_packed=use_baichuan_packed,
1001
+ skip_truncated_turn=skip_truncated_turn,
1002
+ return_attention_mask=return_attention_mask,
1003
+ )
1004
+ elif isinstance(data, Dict):
1005
+ if 'messages' in data:
1006
+ # chatml format
1007
+ if use_true_multiturn:
1008
+ chat = self._chat_from_json(data, chat_format=chat_format)
1009
+ chat_data = chat.prompt_pack
1010
+ else:
1011
+ _conversation = _convert_to_conversation(data)
1012
+ assert is_assistant(_conversation[-1])
1013
+
1014
+ chat_data = {
1015
+ "input": self.apply_chat_template(
1016
+ _conversation[:-1], tokenize=False, add_generation_prompt=True
1017
+ ),
1018
+ "output": _conversation[-1]['content'],
1019
+ }
1020
+
1021
+ return self._build_inputs_for_train(
1022
+ inputs=chat_data['input'],
1023
+ outputs=chat_data['output'],
1024
+ max_length=max_length,
1025
+ rotary_type=rotary_type,
1026
+ left_truncate=left_truncate,
1027
+ unidirectional_attention=unidirectional_attention,
1028
+ isolation_position_ids=isolation_position_ids,
1029
+ padding=padding,
1030
+ use_fa2=use_fa2,
1031
+ use_packed=use_packed,
1032
+ use_baichuan_packed=use_baichuan_packed,
1033
+ skip_truncated_turn=skip_truncated_turn,
1034
+ return_attention_mask=return_attention_mask,
1035
+ )
1036
+ else:
1037
+ inputs = data['input']
1038
+ outputs = data['output']
1039
+
1040
+ if isinstance(inputs, str):
1041
+ inputs = [inputs]
1042
+ if isinstance(outputs, str):
1043
+ outputs = [outputs]
1044
+
1045
+ if not is_chat_format and chat_format:
1046
+ inputs = [
1047
+ self.apply_chat_template(
1048
+ [{"role": "HUMAN", "content": item}], tokenize=False, chat_format=chat_format
1049
+ )
1050
+ for item in inputs
1051
+ ]
1052
+
1053
+ return self._build_inputs_for_train(
1054
+ inputs=inputs,
1055
+ outputs=outputs,
1056
+ new_conversation_offset=new_conversation_offset,
1057
+ max_length=max_length,
1058
+ rotary_type=rotary_type,
1059
+ left_truncate=left_truncate,
1060
+ unidirectional_attention=unidirectional_attention,
1061
+ isolation_position_ids=isolation_position_ids,
1062
+ padding=padding,
1063
+ use_fa2=use_fa2,
1064
+ use_packed=use_packed,
1065
+ use_baichuan_packed=use_baichuan_packed,
1066
+ skip_truncated_turn=skip_truncated_turn,
1067
+ return_attention_mask=return_attention_mask,
1068
+ )