Adorg commited on
Commit
81146ca
·
1 Parent(s): 8337280

Upload modelling_longitudinal.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modelling_longitudinal.py +515 -0
modelling_longitudinal.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from typing import Any, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import transformers
7
+ from peft import LoraConfig, TaskType, get_peft_config, get_peft_model
8
+ from torch.nn import CrossEntropyLoss
9
+ from transformers import (AutoModel, PreTrainedTokenizerFast,
10
+ VisionEncoderDecoderModel)
11
+ from transformers.configuration_utils import PretrainedConfig
12
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
13
+ from transformers.modeling_utils import PreTrainedModel
14
+ from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import \
15
+ VisionEncoderDecoderConfig
16
+ from transformers.utils import logging
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class CvtWithProjectionHeadConfig(transformers.CvtConfig):
22
+ def __init__(self, projection_size: int = None, **kwargs: Any) -> None:
23
+ super().__init__(**kwargs)
24
+ self.projection_size = projection_size
25
+
26
+
27
+ class ModelOutputWithProjectionEmbedding(transformers.modeling_outputs.ModelOutput):
28
+ last_hidden_state: torch.FloatTensor
29
+ attention_mask: torch.FloatTensor
30
+
31
+
32
+ class CvtProjectionHead(torch.nn.Module):
33
+
34
+ def __init__(self, config) -> None:
35
+ super().__init__()
36
+
37
+ # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/models/cvt/modeling_cvt.py#L657
38
+ self.layer_norm = torch.nn.LayerNorm(config.embed_dim[-1], eps=config.layer_norm_eps)
39
+
40
+ # No bias as following layer normalisation with bias:
41
+ self.projection = torch.nn.Linear(config.embed_dim[-1], config.projection_size, bias=False)
42
+
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ x = self.layer_norm(x)
46
+ x = self.projection(x)
47
+ return x
48
+
49
+
50
+ class MultiCvtWithProjectionHead(transformers.CvtPreTrainedModel):
51
+ def __init__(self, config):
52
+ super().__init__(config)
53
+
54
+ self.cvt = transformers.CvtModel(config, add_pooling_layer=False)
55
+ self.projection_head = CvtProjectionHead(config)
56
+
57
+ # Initialize weights and apply final processing:
58
+ self.post_init()
59
+
60
+ def forward(
61
+ self,
62
+ pixel_values: Optional[torch.Tensor] = None,
63
+ output_hidden_states: Optional[bool] = None,
64
+ return_dict: Optional[bool] = None,
65
+ ) -> Union[Tuple, ModelOutputWithProjectionEmbedding]:
66
+
67
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
68
+
69
+ # Flatten the batch and study_id dimensions:
70
+ outputs = self.cvt(
71
+ pixel_values.view(-1, *pixel_values.shape[2:]),
72
+ output_hidden_states=output_hidden_states,
73
+ return_dict=return_dict,
74
+ )
75
+
76
+ # Flatten h x w:
77
+ last_hidden_state = torch.flatten(outputs.last_hidden_state, 2)
78
+
79
+ # Project the features for each spatial position to the decoder's hidden size:
80
+ projection = self.projection_head(torch.permute(last_hidden_state, [0, 2, 1]))
81
+
82
+ # Concatenate the features for each chest X-ray:
83
+ projection = projection.view(pixel_values.shape[0], -1, projection.shape[-1])
84
+
85
+ # Derive the attention mask from the pixel values:
86
+ attention_mask = (pixel_values[:, :, 0, 0, 0] != 0.0).repeat_interleave(last_hidden_state.shape[-1], dim=1)
87
+
88
+ if not return_dict:
89
+ return projection
90
+
91
+ return ModelOutputWithProjectionEmbedding(
92
+ last_hidden_state=projection, attention_mask=attention_mask,
93
+ )
94
+
95
+
96
+ class LongitudinalPromptMultiCXREncoderDecoderModel(VisionEncoderDecoderModel):
97
+
98
+ config_class = VisionEncoderDecoderConfig
99
+ base_model_prefix = "vision_encoder_decoder"
100
+ main_input_name = "pixel_values"
101
+ supports_gradient_checkpointing = True
102
+
103
+ def __init__(
104
+ self,
105
+ config: Optional[PretrainedConfig] = None,
106
+ encoder: Optional[PreTrainedModel] = None,
107
+ decoder: Optional[PreTrainedModel] = None,
108
+ encoder_decoder_ckpt_name: Optional[str] = None,
109
+ ):
110
+
111
+ if decoder:
112
+ assert decoder.config.add_cross_attention, '"add_cross_attention" must be True for the given decoder'
113
+ assert decoder.config.is_decoder, '"is_decoder" must be True for the given decoder'
114
+
115
+ if config is None and (encoder is None or decoder is None):
116
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
117
+ if config is None:
118
+ config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
119
+ else:
120
+ if not isinstance(config, self.config_class):
121
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
122
+
123
+ config.tie_word_embeddings = False
124
+
125
+ # initialize with config
126
+ PreTrainedModel.__init__(self, config)
127
+
128
+ # Encoder:
129
+ if encoder is None:
130
+ encoder = MultiCvtWithProjectionHead(config=config.encoder)
131
+
132
+ # Decoder:
133
+ if decoder is None:
134
+ decoder = transformers.BertLMHeadModel(config=config.decoder)
135
+
136
+ self.encoder = encoder
137
+ self.decoder = decoder
138
+
139
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
140
+ logger.warning(
141
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
142
+ f" {self.config.encoder}"
143
+ )
144
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
145
+ logger.warning(
146
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
147
+ f" {self.config.decoder}"
148
+ )
149
+
150
+ self.encoder.config = self.config.encoder
151
+ self.decoder.config = self.config.decoder
152
+
153
+ # Load multi checkpoint:
154
+ if encoder_decoder_ckpt_name:
155
+ encoder_decoder = AutoModel.from_pretrained(encoder_decoder_ckpt_name, trust_remote_code=True)
156
+ self.load_state_dict(encoder_decoder.state_dict())
157
+ else:
158
+ warnings.warn('The encoder-to-decoder model was not warm-started before applying low-rank approximation.')
159
+
160
+ # Freeze the encoder:
161
+ for p in self.encoder.parameters():
162
+ p.requires_grad = False
163
+
164
+ # Freeze the decoder and add LoRA:
165
+ peft_config = LoraConfig(
166
+ inference_mode=False,
167
+ r=8,
168
+ lora_alpha=32,
169
+ lora_dropout=0.1,
170
+ target_modules='bert.encoder.layer.[0-9]+.attention.self.(query|key)',
171
+ )
172
+ self.decoder = get_peft_model(self.decoder, peft_config)
173
+ self.decoder.print_trainable_parameters()
174
+
175
+ def forward(
176
+ self,
177
+ pixel_values: Optional[torch.FloatTensor] = None,
178
+ decoder_input_ids: Optional[torch.LongTensor] = None,
179
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
180
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
181
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
182
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
183
+ labels: Optional[torch.LongTensor] = None,
184
+ use_cache: Optional[bool] = None,
185
+ output_attentions: Optional[bool] = None,
186
+ output_hidden_states: Optional[bool] = None,
187
+ return_dict: Optional[bool] = None,
188
+ **kwargs,
189
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
190
+
191
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
192
+
193
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
194
+
195
+ kwargs_decoder = {
196
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
197
+ }
198
+
199
+ if encoder_outputs is None:
200
+ if pixel_values is None:
201
+ raise ValueError("You have to specify pixel_values")
202
+
203
+ encoder_outputs = self.encoder(
204
+ pixel_values,
205
+ output_hidden_states=output_hidden_states,
206
+ return_dict=return_dict,
207
+ **kwargs_encoder,
208
+ ) # CvT does not support output_attentions.
209
+ elif isinstance(encoder_outputs, tuple):
210
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
211
+
212
+ encoder_hidden_states = encoder_outputs[0]
213
+
214
+ decoder_outputs = self.decoder(
215
+ input_ids=decoder_input_ids,
216
+ attention_mask=decoder_attention_mask,
217
+ encoder_hidden_states=encoder_hidden_states,
218
+ encoder_attention_mask=encoder_outputs.attention_mask,
219
+ inputs_embeds=decoder_inputs_embeds,
220
+ output_attentions=output_attentions,
221
+ output_hidden_states=output_hidden_states,
222
+ use_cache=use_cache,
223
+ past_key_values=past_key_values,
224
+ return_dict=return_dict,
225
+ **kwargs_decoder,
226
+ )
227
+
228
+ # Loss:
229
+ loss = None
230
+ if labels is not None:
231
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
232
+ loss_fct = CrossEntropyLoss()
233
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
234
+
235
+ if not return_dict:
236
+ if loss is not None:
237
+ return (loss,) + decoder_outputs + encoder_outputs
238
+ else:
239
+ return decoder_outputs + encoder_outputs
240
+
241
+ return Seq2SeqLMOutput(
242
+ loss=loss,
243
+ logits=decoder_outputs.logits,
244
+ past_key_values=decoder_outputs.past_key_values,
245
+ decoder_hidden_states=decoder_outputs.hidden_states,
246
+ decoder_attentions=decoder_outputs.attentions,
247
+ cross_attentions=decoder_outputs.cross_attentions,
248
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
249
+ # encoder_hidden_states=encoder_outputs.hidden_states,
250
+ # encoder_attentions=encoder_outputs.attentions,
251
+ )
252
+
253
+ def prepare_inputs_for_generation(
254
+ self,
255
+ input_ids,
256
+ special_token_ids,
257
+ mask_token_id,
258
+ past_key_values=None,
259
+ attention_mask=None,
260
+ use_cache=None,
261
+ encoder_outputs=None,
262
+ **kwargs,
263
+ ):
264
+ """
265
+ Modification of:
266
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
267
+ """
268
+
269
+ # An update to generate() now prepends bos_token_id to each sequence if it does not exist at the start of the input:
270
+ # https://github.com/huggingface/transformers/blob/d533465150532b0c5de167b574e59f64c68b1154/src/transformers/generation/utils.py#L699C13-L699C30
271
+ # Hence, we remove the prepended bos_token_id from each sequence if it is there:
272
+ if torch.all(input_ids[:, 0] == 1):
273
+ input_ids = input_ids[:, 1:]
274
+
275
+ decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
276
+ decoder_attention_mask = (input_ids != mask_token_id).int()
277
+ decoder_position_ids = torch.nn.functional.relu(
278
+ torch.cumsum(decoder_attention_mask, dim=1, dtype=torch.int64) - 1
279
+ )
280
+
281
+ if not past_key_values:
282
+ token_type_ids = self.token_ids_to_token_type_ids(input_ids, special_token_ids, [0, 1, 0, 1])
283
+ else:
284
+ token_type_ids = self.token_ids_to_token_type_ids_past(input_ids, special_token_ids, [0, 1, 0, 1])
285
+ decoder_position_ids = decoder_position_ids[:, -1:]
286
+
287
+ input_dict = {
288
+ 'attention_mask': attention_mask,
289
+ 'decoder_attention_mask': decoder_attention_mask,
290
+ 'decoder_input_ids': decoder_inputs['input_ids'],
291
+ 'decoder_token_type_ids': token_type_ids,
292
+ 'decoder_position_ids': decoder_position_ids,
293
+ 'encoder_outputs': encoder_outputs,
294
+ 'past_key_values': decoder_inputs['past_key_values'],
295
+ 'use_cache': use_cache,
296
+ }
297
+ return input_dict
298
+
299
+ def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections=None):
300
+ """
301
+ Extract token type identifiers from the token identifiers.
302
+
303
+ Argument/s:
304
+ token_ids - token identifiers.
305
+ special_token_ids - special token identifiers that indicate the separation between sections.
306
+ token_type_id_section - token type identifier for each section.
307
+
308
+ Returns:
309
+ token_type_ids - token type identifiers.
310
+ """
311
+
312
+ token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
313
+
314
+ mbatch_size, seq_len = token_ids.shape
315
+ token_type_ids = torch.full_like(token_ids, token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
316
+
317
+ for i, j in enumerate(special_token_ids):
318
+ # Find first occurrence of special tokens that indicate the boundary between sections:
319
+ cols = (token_ids == j).int().argmax(dim=1)
320
+ rows = torch.arange(mbatch_size, device=token_ids.device)
321
+
322
+ # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
323
+ cols += 1
324
+
325
+ # Ensure that the column index is not out of bounds. If 0, then token_id not present.
326
+ # This is safe as index 0 is always a special token (now equal to 1 due to +1):
327
+ rows = rows[torch.logical_and(cols != 1, cols < seq_len)]
328
+ cols = cols[torch.logical_and(cols != 1, cols < seq_len)]
329
+
330
+ # Indices to that correspond to the second sequence:
331
+ if rows.nelement() != 0:
332
+ ids = torch.stack([
333
+ torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange(
334
+ y, seq_len, device=token_ids.device,
335
+ )
336
+ ])
337
+
338
+ token_type_ids[ids[:, 0], ids[:, 1]] = token_type_id_sections[i + 1]
339
+
340
+ return token_type_ids
341
+
342
+ def token_ids_to_token_type_ids_past(self, token_ids, special_token_ids, token_type_id_sections=None):
343
+ """
344
+ Extract token type identifiers from the token identifiers if past != None.
345
+
346
+ Argument/s:
347
+ token_ids - token identifiers.
348
+ special_token_ids - special token identifiers that indicate the separation between sections.
349
+
350
+ Returns:
351
+ token_type_ids - token type identifiers.
352
+ """
353
+
354
+ token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
355
+ token_type_ids = torch.full([token_ids.shape[0], 1], token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
356
+
357
+ # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
358
+ token_ids = token_ids[:, :-1]
359
+
360
+ for i, j in enumerate(special_token_ids):
361
+
362
+ # Find first occurrence of special token, which indicates the boundary between sections:
363
+ exists = torch.any(token_ids == j, dim=1, keepdim=True)
364
+ token_type_ids[exists] = token_type_id_sections[i + 1]
365
+
366
+ return token_type_ids
367
+
368
+ def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int):
369
+ """
370
+ Tokenize the reports and creates the inputs and targets for teacher forcing.
371
+
372
+ Argument/s:
373
+ findings - findings section.
374
+ impression - impression section.
375
+ return_token_type_ids - return the token type identifiers.
376
+ tokenizer - Hugging Face tokenizer.
377
+ max_len - maximum number of tokens.
378
+
379
+ Returns:
380
+ decoder_input_ids - the token identifiers for the input of the decoder.
381
+ decoder_attention_mask - the attention mask for the decoder_input_ids.
382
+ label_ids - the label token identifiers for the decoder.
383
+ """
384
+
385
+ # Prepare the sections for the tokenizer by placing special tokens between each section:
386
+ report = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
387
+ zip(findings, impression)]
388
+
389
+ # Tokenize the report:
390
+ tokenized = tokenizer(
391
+ report,
392
+ padding='longest',
393
+ truncation=True,
394
+ max_length=max_len + 1, # +1 to account for the bias between input and target.
395
+ return_tensors='pt',
396
+ return_token_type_ids=False,
397
+ add_special_tokens=False,
398
+ ).to(self.device)
399
+
400
+ # Modify for language modelling:
401
+ batch_dict = {
402
+
403
+ # Labels for the decoder (shifted right by one for autoregression):
404
+ 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
405
+
406
+ # Remove last token identifier to match the sequence length of the labels:
407
+ 'decoder_input_ids': tokenized['input_ids'][:, :-1],
408
+
409
+ # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
410
+ 'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
411
+ }
412
+
413
+ return batch_dict
414
+
415
+ def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast):
416
+ """
417
+ Split the token identifiers into sections, then convert the token identifiers into strings.
418
+
419
+ Argument/s:
420
+ token_ids - token identifiers.
421
+ special_token_ids - special token identifiers that indicate the end of each section.
422
+ tokenizer - Hugging Face tokenizer.
423
+
424
+ Returns:
425
+ token_type_ids - token type identifiers.
426
+ """
427
+
428
+ _, seq_len = token_ids.shape
429
+
430
+ # The number of sections is the same as the number of special_token_ids:
431
+ num_sections = len(special_token_ids)
432
+
433
+ sections = {k: [] for k in range(num_sections)}
434
+
435
+ for i in token_ids:
436
+ prev_col = 0
437
+ for j, k in enumerate(special_token_ids):
438
+
439
+ # The maximum sequence length was exceeded, thus no more tokens:
440
+ if prev_col >= seq_len:
441
+ sections[j].append('')
442
+ continue
443
+
444
+ # Find first occurrence of special tokens that indicate the boundary between sections:
445
+ col = (i == k).int().argmax().item()
446
+
447
+ # If equal to 0, token was not found, set the column to the sequence length (as the decoder exceeded
448
+ # the maximum sequence length):
449
+ if col == 0:
450
+ col = seq_len
451
+
452
+ # Extract section token identifiers:
453
+ section_token_ids = i[prev_col:col]
454
+ prev_col = col
455
+ section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True)
456
+
457
+ sections[j].append(section_string)
458
+
459
+ return tuple(sections.values())
460
+
461
+ def tokenize_prompt(
462
+ self,
463
+ previous_findings: str,
464
+ previous_impression: str,
465
+ tokenizer: PreTrainedTokenizerFast,
466
+ max_len: int,
467
+ add_bos_token_id: bool = False,
468
+ ):
469
+ """
470
+ Tokenize the sections of the previous report to be used as a prompt.
471
+
472
+ Argument/s:
473
+ previous_findings - previous findings section.
474
+ previous_impression - previous impression section.
475
+ tokenizer - Hugging Face tokenizer.
476
+ max_len - maximum number of tokens.
477
+ add_bos_token_id - whether to add the BOS token identifier to the prompt.
478
+
479
+ Returns:
480
+ input_ids - the input identifiers for the previous impression.
481
+ attention_mask - the attention mask for the previous impression
482
+ """
483
+
484
+ # Use [NPF]/[NPI] special token if no previous findings/impression:
485
+ previous_findings = ['[NPF]' if not i else i for i in previous_findings]
486
+ previous_impression = ['[NPI]' if not i else i for i in previous_impression]
487
+
488
+ # Prepare the sections for the tokenizer by placing special tokens:
489
+ previous_sections = [
490
+ f'[PMT]{i}[PMT-SEP]{j}{tokenizer.bos_token}' if add_bos_token_id else f'[PMT]{i}[PMT-SEP]{j}' \
491
+ for i, j in zip(previous_findings, previous_impression)
492
+ ]
493
+
494
+ # Tokenize:
495
+ previous_sections = tokenizer(
496
+ previous_sections,
497
+ padding='longest',
498
+ truncation=True,
499
+ max_length=max_len,
500
+ return_tensors='pt',
501
+ return_token_type_ids=False,
502
+ add_special_tokens=False,
503
+ ).to(self.device)
504
+
505
+ # Ensure BOS token identifier is at the end of the input_ids:
506
+ if previous_sections.input_ids.shape[1] == max_len:
507
+ previous_sections.input_ids[:, -1] = torch.where(
508
+ previous_sections.attention_mask[:, -1] == 1,
509
+ tokenizer.bos_token_id,
510
+ previous_sections.input_ids[:, -1],
511
+ )
512
+
513
+ assert previous_sections.input_ids.shape[1] <= max_len
514
+
515
+ return {'input_ids': previous_sections.input_ids, 'attention_mask': previous_sections.attention_mask}