koichi12 commited on
Commit
9d4bc92
·
verified ·
1 Parent(s): 7754566

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__pycache__/more.cpython-311.pyc +3 -0
  3. .venv/lib/python3.11/site-packages/transformers/data/__init__.py +45 -0
  4. .venv/lib/python3.11/site-packages/transformers/data/__pycache__/__init__.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/transformers/data/__pycache__/data_collator.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/transformers/data/data_collator.py +1656 -0
  7. .venv/lib/python3.11/site-packages/transformers/data/datasets/__init__.py +23 -0
  8. .venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/glue.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/language_modeling.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/squad.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/transformers/data/datasets/glue.py +161 -0
  13. .venv/lib/python3.11/site-packages/transformers/data/datasets/language_modeling.py +530 -0
  14. .venv/lib/python3.11/site-packages/transformers/data/datasets/squad.py +229 -0
  15. .venv/lib/python3.11/site-packages/transformers/data/metrics/__init__.py +98 -0
  16. .venv/lib/python3.11/site-packages/transformers/data/metrics/__pycache__/__init__.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/transformers/data/metrics/__pycache__/squad_metrics.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/transformers/data/metrics/squad_metrics.py +779 -0
  19. .venv/lib/python3.11/site-packages/transformers/data/processors/__init__.py +18 -0
  20. .venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/__init__.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/glue.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/squad.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/utils.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/xnli.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/transformers/data/processors/glue.py +643 -0
  26. .venv/lib/python3.11/site-packages/transformers/data/processors/squad.py +845 -0
  27. .venv/lib/python3.11/site-packages/transformers/data/processors/utils.py +349 -0
  28. .venv/lib/python3.11/site-packages/transformers/data/processors/xnli.py +96 -0
  29. .venv/lib/python3.11/site-packages/transformers/generation/__init__.py +352 -0
  30. .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/flax_logits_process.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/flax_utils.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/streamers.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/tf_logits_process.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/watermarking.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/transformers/generation/beam_constraints.py +524 -0
  39. .venv/lib/python3.11/site-packages/transformers/generation/beam_search.py +1013 -0
  40. .venv/lib/python3.11/site-packages/transformers/generation/candidate_generator.py +871 -0
  41. .venv/lib/python3.11/site-packages/transformers/generation/configuration_utils.py +1628 -0
  42. .venv/lib/python3.11/site-packages/transformers/generation/flax_logits_process.py +544 -0
  43. .venv/lib/python3.11/site-packages/transformers/generation/flax_utils.py +1027 -0
  44. .venv/lib/python3.11/site-packages/transformers/generation/logits_process.py +0 -0
  45. .venv/lib/python3.11/site-packages/transformers/generation/stopping_criteria.py +514 -0
  46. .venv/lib/python3.11/site-packages/transformers/generation/streamers.py +318 -0
  47. .venv/lib/python3.11/site-packages/transformers/generation/tf_logits_process.py +603 -0
  48. .venv/lib/python3.11/site-packages/transformers/generation/tf_utils.py +0 -0
  49. .venv/lib/python3.11/site-packages/transformers/generation/utils.py +0 -0
  50. .venv/lib/python3.11/site-packages/transformers/generation/watermarking.py +549 -0
.gitattributes CHANGED
@@ -421,3 +421,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
421
  .venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
422
  .venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
423
  .venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
421
  .venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
422
  .venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
423
  .venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
424
+ .venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__pycache__/more.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__pycache__/more.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2915b1a6eb03bcc7f50087391a236782ed9422d5e5eb9ed0a937ca85ce5e73a
3
+ size 167956
.venv/lib/python3.11/site-packages/transformers/data/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .data_collator import (
16
+ DataCollatorForLanguageModeling,
17
+ DataCollatorForPermutationLanguageModeling,
18
+ DataCollatorForSeq2Seq,
19
+ DataCollatorForSOP,
20
+ DataCollatorForTokenClassification,
21
+ DataCollatorForWholeWordMask,
22
+ DataCollatorWithFlattening,
23
+ DataCollatorWithPadding,
24
+ DefaultDataCollator,
25
+ default_data_collator,
26
+ )
27
+ from .metrics import glue_compute_metrics, xnli_compute_metrics
28
+ from .processors import (
29
+ DataProcessor,
30
+ InputExample,
31
+ InputFeatures,
32
+ SingleSentenceClassificationProcessor,
33
+ SquadExample,
34
+ SquadFeatures,
35
+ SquadV1Processor,
36
+ SquadV2Processor,
37
+ glue_convert_examples_to_features,
38
+ glue_output_modes,
39
+ glue_processors,
40
+ glue_tasks_num_labels,
41
+ squad_convert_examples_to_features,
42
+ xnli_output_modes,
43
+ xnli_processors,
44
+ xnli_tasks_num_labels,
45
+ )
.venv/lib/python3.11/site-packages/transformers/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.5 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/data/__pycache__/data_collator.cpython-311.pyc ADDED
Binary file (94.5 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/data/data_collator.py ADDED
@@ -0,0 +1,1656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ import warnings
17
+ from collections.abc import Mapping
18
+ from dataclasses import dataclass
19
+ from random import randint
20
+ from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+
24
+ from ..models.bert import BertTokenizer, BertTokenizerFast
25
+ from ..tokenization_utils_base import PreTrainedTokenizerBase
26
+ from ..utils import PaddingStrategy
27
+
28
+
29
+ InputDataClass = NewType("InputDataClass", Any)
30
+
31
+ """
32
+ A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary
33
+ of PyTorch/TensorFlow tensors or NumPy arrays.
34
+ """
35
+ DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, Any]])
36
+
37
+
38
+ class DataCollatorMixin:
39
+ def __call__(self, features, return_tensors=None):
40
+ if return_tensors is None:
41
+ return_tensors = self.return_tensors
42
+ if return_tensors == "tf":
43
+ return self.tf_call(features)
44
+ elif return_tensors == "pt":
45
+ return self.torch_call(features)
46
+ elif return_tensors == "np":
47
+ return self.numpy_call(features)
48
+ else:
49
+ raise ValueError(f"Framework '{return_tensors}' not recognized!")
50
+
51
+
52
+ def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
53
+ """
54
+ Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
55
+ """
56
+
57
+ # To avoid errors when using Feature extractors
58
+ if not hasattr(tokenizer, "deprecation_warnings"):
59
+ return tokenizer.pad(*pad_args, **pad_kwargs)
60
+
61
+ # Save the state of the warning, then disable it
62
+ warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
63
+ tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
64
+
65
+ try:
66
+ padded = tokenizer.pad(*pad_args, **pad_kwargs)
67
+ finally:
68
+ # Restore the state of the warning.
69
+ tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state
70
+
71
+ return padded
72
+
73
+
74
+ def default_data_collator(features: List[InputDataClass], return_tensors="pt") -> Dict[str, Any]:
75
+ """
76
+ Very simple data collator that simply collates batches of dict-like objects and performs special handling for
77
+ potential keys named:
78
+
79
+ - `label`: handles a single value (int or float) per object
80
+ - `label_ids`: handles a list of values per object
81
+
82
+ Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
83
+ to the model. See glue and ner for example of how it's useful.
84
+ """
85
+
86
+ # In this function we'll make the assumption that all `features` in the batch
87
+ # have the same attributes.
88
+ # So we will look at the first element as a proxy for what attributes exist
89
+ # on the whole batch.
90
+
91
+ if return_tensors == "pt":
92
+ return torch_default_data_collator(features)
93
+ elif return_tensors == "tf":
94
+ return tf_default_data_collator(features)
95
+ elif return_tensors == "np":
96
+ return numpy_default_data_collator(features)
97
+
98
+
99
+ @dataclass
100
+ class DefaultDataCollator(DataCollatorMixin):
101
+ """
102
+ Very simple data collator that simply collates batches of dict-like objects and performs special handling for
103
+ potential keys named:
104
+
105
+ - `label`: handles a single value (int or float) per object
106
+ - `label_ids`: handles a list of values per object
107
+
108
+ Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
109
+ to the model. See glue and ner for example of how it's useful.
110
+
111
+ This is an object (like other data collators) rather than a pure function like default_data_collator. This can be
112
+ helpful if you need to set a return_tensors value at initialization.
113
+
114
+ Args:
115
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
116
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
117
+ """
118
+
119
+ return_tensors: str = "pt"
120
+
121
+ def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, Any]:
122
+ if return_tensors is None:
123
+ return_tensors = self.return_tensors
124
+ return default_data_collator(features, return_tensors)
125
+
126
+
127
+ def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
128
+ import torch
129
+
130
+ if not isinstance(features[0], Mapping):
131
+ features = [vars(f) for f in features]
132
+ first = features[0]
133
+ batch = {}
134
+
135
+ # Special handling for labels.
136
+ # Ensure that tensor is created with the correct type
137
+ # (it should be automatically the case, but let's make sure of it.)
138
+ if "label" in first and first["label"] is not None:
139
+ label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
140
+ dtype = torch.long if isinstance(label, int) else torch.float
141
+ batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
142
+ elif "label_ids" in first and first["label_ids"] is not None:
143
+ if isinstance(first["label_ids"], torch.Tensor):
144
+ batch["labels"] = torch.stack([f["label_ids"] for f in features])
145
+ else:
146
+ dtype = torch.long if isinstance(first["label_ids"][0], int) else torch.float
147
+ batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
148
+
149
+ # Handling of all other possible keys.
150
+ # Again, we will use the first element to figure out which key/values are not None for this model.
151
+ for k, v in first.items():
152
+ if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
153
+ if isinstance(v, torch.Tensor):
154
+ batch[k] = torch.stack([f[k] for f in features])
155
+ elif isinstance(v, np.ndarray):
156
+ batch[k] = torch.from_numpy(np.stack([f[k] for f in features]))
157
+ else:
158
+ batch[k] = torch.tensor([f[k] for f in features])
159
+
160
+ return batch
161
+
162
+
163
+ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
164
+ import tensorflow as tf
165
+
166
+ if not isinstance(features[0], Mapping):
167
+ features = [vars(f) for f in features]
168
+ first = features[0]
169
+ batch = {}
170
+
171
+ # Special handling for labels.
172
+ # Ensure that tensor is created with the correct type
173
+ # (it should be automatically the case, but let's make sure of it.)
174
+ if "label" in first and first["label"] is not None:
175
+ label_col_name = "label"
176
+ elif "label_ids" in first and first["label_ids"] is not None:
177
+ label_col_name = "label_ids"
178
+ elif "labels" in first and first["labels"] is not None:
179
+ label_col_name = "labels"
180
+ else:
181
+ label_col_name = None
182
+ if label_col_name is not None:
183
+ if isinstance(first[label_col_name], tf.Tensor):
184
+ dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32
185
+ elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
186
+ dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
187
+ elif isinstance(first[label_col_name], (tuple, list)):
188
+ dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32
189
+ else:
190
+ dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32
191
+ batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype)
192
+ # Handling of all other possible keys.
193
+ # Again, we will use the first element to figure out which key/values are not None for this model.
194
+ for k, v in first.items():
195
+ if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str):
196
+ if isinstance(v, (tf.Tensor, np.ndarray)):
197
+ batch[k] = tf.stack([f[k] for f in features])
198
+ else:
199
+ batch[k] = tf.convert_to_tensor([f[k] for f in features])
200
+
201
+ return batch
202
+
203
+
204
+ def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
205
+ if not isinstance(features[0], Mapping):
206
+ features = [vars(f) for f in features]
207
+ first = features[0]
208
+ batch = {}
209
+
210
+ # Special handling for labels.
211
+ # Ensure that tensor is created with the correct type
212
+ # (it should be automatically the case, but let's make sure of it.)
213
+ if "label" in first and first["label"] is not None:
214
+ label = first["label"].item() if isinstance(first["label"], np.ndarray) else first["label"]
215
+ dtype = np.int64 if isinstance(label, int) else np.float32
216
+ batch["labels"] = np.array([f["label"] for f in features], dtype=dtype)
217
+ elif "label_ids" in first and first["label_ids"] is not None:
218
+ if isinstance(first["label_ids"], np.ndarray):
219
+ batch["labels"] = np.stack([f["label_ids"] for f in features])
220
+ else:
221
+ dtype = np.int64 if isinstance(first["label_ids"][0], int) else np.float32
222
+ batch["labels"] = np.array([f["label_ids"] for f in features], dtype=dtype)
223
+
224
+ # Handling of all other possible keys.
225
+ # Again, we will use the first element to figure out which key/values are not None for this model.
226
+ for k, v in first.items():
227
+ if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
228
+ if isinstance(v, np.ndarray):
229
+ batch[k] = np.stack([f[k] for f in features])
230
+ else:
231
+ batch[k] = np.array([f[k] for f in features])
232
+
233
+ return batch
234
+
235
+
236
+ @dataclass
237
+ class DataCollatorWithPadding:
238
+ """
239
+ Data collator that will dynamically pad the inputs received.
240
+
241
+ Args:
242
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
243
+ The tokenizer used for encoding the data.
244
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
245
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
246
+ among:
247
+
248
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
249
+ sequence is provided).
250
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
251
+ acceptable input length for the model if that argument is not provided.
252
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
253
+ max_length (`int`, *optional*):
254
+ Maximum length of the returned list and optionally padding length (see above).
255
+ pad_to_multiple_of (`int`, *optional*):
256
+ If set will pad the sequence to a multiple of the provided value.
257
+
258
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
259
+ 7.0 (Volta).
260
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
261
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
262
+ """
263
+
264
+ tokenizer: PreTrainedTokenizerBase
265
+ padding: Union[bool, str, PaddingStrategy] = True
266
+ max_length: Optional[int] = None
267
+ pad_to_multiple_of: Optional[int] = None
268
+ return_tensors: str = "pt"
269
+
270
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
271
+ batch = pad_without_fast_tokenizer_warning(
272
+ self.tokenizer,
273
+ features,
274
+ padding=self.padding,
275
+ max_length=self.max_length,
276
+ pad_to_multiple_of=self.pad_to_multiple_of,
277
+ return_tensors=self.return_tensors,
278
+ )
279
+ if "label" in batch:
280
+ batch["labels"] = batch["label"]
281
+ del batch["label"]
282
+ if "label_ids" in batch:
283
+ batch["labels"] = batch["label_ids"]
284
+ del batch["label_ids"]
285
+ return batch
286
+
287
+
288
+ @dataclass
289
+ class DataCollatorForTokenClassification(DataCollatorMixin):
290
+ """
291
+ Data collator that will dynamically pad the inputs received, as well as the labels.
292
+
293
+ Args:
294
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
295
+ The tokenizer used for encoding the data.
296
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
297
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
298
+ among:
299
+
300
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
301
+ sequence is provided).
302
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
303
+ acceptable input length for the model if that argument is not provided.
304
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
305
+ max_length (`int`, *optional*):
306
+ Maximum length of the returned list and optionally padding length (see above).
307
+ pad_to_multiple_of (`int`, *optional*):
308
+ If set will pad the sequence to a multiple of the provided value.
309
+
310
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
311
+ 7.0 (Volta).
312
+ label_pad_token_id (`int`, *optional*, defaults to -100):
313
+ The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
314
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
315
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
316
+ """
317
+
318
+ tokenizer: PreTrainedTokenizerBase
319
+ padding: Union[bool, str, PaddingStrategy] = True
320
+ max_length: Optional[int] = None
321
+ pad_to_multiple_of: Optional[int] = None
322
+ label_pad_token_id: int = -100
323
+ return_tensors: str = "pt"
324
+
325
+ def torch_call(self, features):
326
+ import torch
327
+
328
+ label_name = "label" if "label" in features[0].keys() else "labels"
329
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
330
+
331
+ no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
332
+
333
+ batch = pad_without_fast_tokenizer_warning(
334
+ self.tokenizer,
335
+ no_labels_features,
336
+ padding=self.padding,
337
+ max_length=self.max_length,
338
+ pad_to_multiple_of=self.pad_to_multiple_of,
339
+ return_tensors="pt",
340
+ )
341
+
342
+ if labels is None:
343
+ return batch
344
+
345
+ sequence_length = batch["input_ids"].shape[1]
346
+ padding_side = self.tokenizer.padding_side
347
+
348
+ def to_list(tensor_or_iterable):
349
+ if isinstance(tensor_or_iterable, torch.Tensor):
350
+ return tensor_or_iterable.tolist()
351
+ return list(tensor_or_iterable)
352
+
353
+ if padding_side == "right":
354
+ batch[label_name] = [
355
+ to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
356
+ ]
357
+ else:
358
+ batch[label_name] = [
359
+ [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
360
+ ]
361
+
362
+ batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
363
+ return batch
364
+
365
+ def tf_call(self, features):
366
+ import tensorflow as tf
367
+
368
+ label_name = "label" if "label" in features[0].keys() else "labels"
369
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
370
+ batch = pad_without_fast_tokenizer_warning(
371
+ self.tokenizer,
372
+ features,
373
+ padding=self.padding,
374
+ max_length=self.max_length,
375
+ pad_to_multiple_of=self.pad_to_multiple_of,
376
+ # Conversion to tensors will fail if we have labels as they are not of the same length yet.
377
+ return_tensors="tf" if labels is None else None,
378
+ )
379
+
380
+ if labels is None:
381
+ return batch
382
+
383
+ sequence_length = tf.convert_to_tensor(batch["input_ids"]).shape[1]
384
+ padding_side = self.tokenizer.padding_side
385
+ if padding_side == "right":
386
+ batch["labels"] = [
387
+ list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
388
+ ]
389
+ else:
390
+ batch["labels"] = [
391
+ [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
392
+ ]
393
+
394
+ batch = {k: tf.convert_to_tensor(v, dtype=tf.int64) for k, v in batch.items()}
395
+ return batch
396
+
397
+ def numpy_call(self, features):
398
+ label_name = "label" if "label" in features[0].keys() else "labels"
399
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
400
+ batch = pad_without_fast_tokenizer_warning(
401
+ self.tokenizer,
402
+ features,
403
+ padding=self.padding,
404
+ max_length=self.max_length,
405
+ pad_to_multiple_of=self.pad_to_multiple_of,
406
+ # Conversion to tensors will fail if we have labels as they are not of the same length yet.
407
+ return_tensors="np" if labels is None else None,
408
+ )
409
+
410
+ if labels is None:
411
+ return batch
412
+
413
+ sequence_length = np.array(batch["input_ids"]).shape[1]
414
+ padding_side = self.tokenizer.padding_side
415
+ if padding_side == "right":
416
+ batch["labels"] = [
417
+ list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
418
+ ]
419
+ else:
420
+ batch["labels"] = [
421
+ [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
422
+ ]
423
+
424
+ batch = {k: np.array(v, dtype=np.int64) for k, v in batch.items()}
425
+ return batch
426
+
427
+
428
+ def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
429
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
430
+ import torch
431
+
432
+ # Tensorize if necessary.
433
+ if isinstance(examples[0], (list, tuple, np.ndarray)):
434
+ examples = [torch.tensor(e, dtype=torch.long) for e in examples]
435
+
436
+ length_of_first = examples[0].size(0)
437
+
438
+ # Check if padding is necessary.
439
+
440
+ are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
441
+ if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
442
+ if not isinstance(examples, torch.Tensor):
443
+ return torch.stack(examples, dim=0)
444
+
445
+ # If yes, check if we have a `pad_token`.
446
+ if tokenizer.pad_token is None:
447
+ raise ValueError(
448
+ "You are attempting to pad samples but the tokenizer you are using"
449
+ f" ({tokenizer.__class__.__name__}) does not have a pad token."
450
+ )
451
+
452
+ # Creating the full tensor and filling it with our data.
453
+ max_length = max(x.size(0) for x in examples)
454
+ if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
455
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
456
+ result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
457
+ for i, example in enumerate(examples):
458
+ if tokenizer.padding_side == "right":
459
+ result[i, : example.shape[0]] = example
460
+ else:
461
+ result[i, -example.shape[0] :] = example
462
+ return result
463
+
464
+
465
+ def _tf_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
466
+ import tensorflow as tf
467
+
468
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
469
+ # Tensorize if necessary.
470
+ if isinstance(examples[0], (list, tuple)):
471
+ examples = [tf.convert_to_tensor(e, dtype=tf.int64) for e in examples]
472
+
473
+ # Check if padding is necessary.
474
+ length_of_first = len(examples[0])
475
+ are_tensors_same_length = all(len(x) == length_of_first for x in examples)
476
+ if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
477
+ return tf.stack(examples, axis=0)
478
+
479
+ # If yes, check if we have a `pad_token`.
480
+ if tokenizer.pad_token is None:
481
+ raise ValueError(
482
+ "You are attempting to pad samples but the tokenizer you are using"
483
+ f" ({tokenizer.__class__.__name__}) does not have a pad token."
484
+ )
485
+
486
+ # Creating the full tensor and filling it with our data.
487
+ max_length = max(len(x) for x in examples)
488
+ if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
489
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
490
+ # result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
491
+ result = []
492
+ rank = tf.rank(examples[0])
493
+ paddings = np.zeros((rank, 2), dtype=np.int32)
494
+ for example in examples:
495
+ if tokenizer.padding_side == "right":
496
+ paddings[0, 1] = max_length - len(example)
497
+ else:
498
+ paddings[0, 0] = max_length - len(example)
499
+ result.append(tf.pad(example, paddings, constant_values=tokenizer.pad_token_id))
500
+ return tf.stack(result, axis=0)
501
+
502
+
503
+ def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
504
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
505
+ # Tensorize if necessary.
506
+ if isinstance(examples[0], (list, tuple)):
507
+ examples = [np.array(e, dtype=np.int64) for e in examples]
508
+
509
+ # Check if padding is necessary.
510
+ length_of_first = len(examples[0])
511
+ are_tensors_same_length = all(len(x) == length_of_first for x in examples)
512
+ if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
513
+ return np.stack(examples, axis=0)
514
+
515
+ # If yes, check if we have a `pad_token`.
516
+ if tokenizer.pad_token is None:
517
+ raise ValueError(
518
+ "You are attempting to pad samples but the tokenizer you are using"
519
+ f" ({tokenizer.__class__.__name__}) does not have a pad token."
520
+ )
521
+
522
+ # Creating the full tensor and filling it with our data.
523
+ max_length = max(len(x) for x in examples)
524
+ if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
525
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
526
+ result = np.full(shape=(len(examples), max_length), fill_value=tokenizer.pad_token_id, dtype=examples[0].dtype)
527
+ for i, example in enumerate(examples):
528
+ if tokenizer.padding_side == "right":
529
+ result[i, : example.shape[0]] = example
530
+ else:
531
+ result[i, -example.shape[0] :] = example
532
+ return result
533
+
534
+
535
+ def tolist(x):
536
+ if isinstance(x, list):
537
+ return x
538
+ elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
539
+ x = x.numpy()
540
+ return x.tolist()
541
+
542
+
543
+ @dataclass
544
+ class DataCollatorForSeq2Seq:
545
+ """
546
+ Data collator that will dynamically pad the inputs received, as well as the labels.
547
+
548
+ Args:
549
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
550
+ The tokenizer used for encoding the data.
551
+ model ([`PreTrainedModel`], *optional*):
552
+ The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
553
+ prepare the *decoder_input_ids*
554
+
555
+ This is useful when using *label_smoothing* to avoid calculating loss twice.
556
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
557
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
558
+ among:
559
+
560
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
561
+ sequence is provided).
562
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
563
+ acceptable input length for the model if that argument is not provided.
564
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
565
+ max_length (`int`, *optional*):
566
+ Maximum length of the returned list and optionally padding length (see above).
567
+ pad_to_multiple_of (`int`, *optional*):
568
+ If set will pad the sequence to a multiple of the provided value.
569
+
570
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
571
+ 7.0 (Volta).
572
+ label_pad_token_id (`int`, *optional*, defaults to -100):
573
+ The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
574
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
575
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
576
+ """
577
+
578
+ tokenizer: PreTrainedTokenizerBase
579
+ model: Optional[Any] = None
580
+ padding: Union[bool, str, PaddingStrategy] = True
581
+ max_length: Optional[int] = None
582
+ pad_to_multiple_of: Optional[int] = None
583
+ label_pad_token_id: int = -100
584
+ return_tensors: str = "pt"
585
+
586
+ def __call__(self, features, return_tensors=None):
587
+ if return_tensors is None:
588
+ return_tensors = self.return_tensors
589
+
590
+ label_name = "label" if "label" in features[0].keys() else "labels"
591
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
592
+ # reconvert list[None] to None if necessary
593
+ # this might occur when we pass {..., "labels": None}
594
+ if labels is not None and all(label is None for label in labels):
595
+ labels = None
596
+ non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
597
+
598
+ # run through tokenizer without labels to ensure no side effects
599
+ batch = pad_without_fast_tokenizer_warning(
600
+ self.tokenizer,
601
+ non_labels_features,
602
+ padding=self.padding,
603
+ max_length=self.max_length,
604
+ pad_to_multiple_of=self.pad_to_multiple_of,
605
+ return_tensors=return_tensors,
606
+ )
607
+
608
+ # we have to pad the labels manually as we cannot rely on `tokenizer.pad` and we need them to be of the same length to return tensors
609
+ no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD
610
+ if labels is not None:
611
+ if no_padding:
612
+ if isinstance(features[0][label_name], list):
613
+ batch["labels"] = list(labels)
614
+ else:
615
+ batch["labels"] = [np.concatenate([label, []]) for label in labels]
616
+ else:
617
+ max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
618
+ max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
619
+ if self.pad_to_multiple_of is not None:
620
+ max_label_length = (
621
+ (max_label_length + self.pad_to_multiple_of - 1)
622
+ // self.pad_to_multiple_of
623
+ * self.pad_to_multiple_of
624
+ )
625
+
626
+ padding_side = self.tokenizer.padding_side
627
+ if isinstance(features[0][label_name], list):
628
+ batch["labels"] = [
629
+ label + [self.label_pad_token_id] * (max_label_length - len(label))
630
+ if padding_side == "right"
631
+ else [self.label_pad_token_id] * (max_label_length - len(label)) + label
632
+ for label in labels
633
+ ]
634
+ else:
635
+ batch["labels"] = [
636
+ np.concatenate(
637
+ [
638
+ label,
639
+ np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
640
+ ]
641
+ )
642
+ if padding_side == "right"
643
+ else np.concatenate(
644
+ [
645
+ np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
646
+ label,
647
+ ]
648
+ )
649
+ for label in labels
650
+ ]
651
+
652
+ # reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument
653
+ if batch.get("labels", None) is not None:
654
+ if return_tensors == "pt":
655
+ import torch
656
+
657
+ batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
658
+ elif return_tensors == "tf":
659
+ import tensorflow as tf
660
+
661
+ batch["labels"] = tf.constant(batch["labels"], dtype=tf.int64)
662
+ else:
663
+ batch["labels"] = np.array(batch["labels"], dtype=np.int64)
664
+ else:
665
+ batch["labels"] = None
666
+
667
+ # prepare decoder_input_ids
668
+ if (
669
+ labels is not None
670
+ and self.model is not None
671
+ and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
672
+ ):
673
+ decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
674
+ batch["decoder_input_ids"] = decoder_input_ids
675
+
676
+ return batch
677
+
678
+
679
+ @dataclass
680
+ class DataCollatorForLanguageModeling(DataCollatorMixin):
681
+ """
682
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
683
+ are not all of the same length.
684
+
685
+ Args:
686
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
687
+ The tokenizer used for encoding the data.
688
+ mlm (`bool`, *optional*, defaults to `True`):
689
+ Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
690
+ with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
691
+ tokens and the value to predict for the masked token.
692
+ mlm_probability (`float`, *optional*, defaults to 0.15):
693
+ The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
694
+ pad_to_multiple_of (`int`, *optional*):
695
+ If set will pad the sequence to a multiple of the provided value.
696
+
697
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
698
+ 7.0 (Volta).
699
+ return_tensors (`str`):
700
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
701
+
702
+ <Tip>
703
+
704
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
705
+ BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
706
+ [`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
707
+
708
+ </Tip>"""
709
+
710
+ tokenizer: PreTrainedTokenizerBase
711
+ mlm: bool = True
712
+ mlm_probability: float = 0.15
713
+ pad_to_multiple_of: Optional[int] = None
714
+ tf_experimental_compile: bool = False
715
+ return_tensors: str = "pt"
716
+
717
+ def __post_init__(self):
718
+ if self.mlm and self.tokenizer.mask_token is None:
719
+ raise ValueError(
720
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
721
+ "You should pass `mlm=False` to train on causal language modeling instead."
722
+ )
723
+ if self.tf_experimental_compile:
724
+ import tensorflow as tf
725
+
726
+ self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True)
727
+
728
+ @staticmethod
729
+ def tf_bernoulli(shape, probability):
730
+ import tensorflow as tf
731
+
732
+ prob_matrix = tf.fill(shape, probability)
733
+ return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
734
+
735
+ def tf_mask_tokens(
736
+ self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None
737
+ ) -> Tuple[Any, Any]:
738
+ """
739
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
740
+ """
741
+ import tensorflow as tf
742
+
743
+ mask_token_id = tf.cast(mask_token_id, inputs.dtype)
744
+
745
+ input_shape = tf.shape(inputs)
746
+ # 1 for a special token, 0 for a normal token in the special tokens mask
747
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
748
+ masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability) & ~special_tokens_mask
749
+ # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
750
+ labels = tf.where(masked_indices, inputs, -100)
751
+
752
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
753
+ indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
754
+
755
+ inputs = tf.where(indices_replaced, mask_token_id, inputs)
756
+
757
+ # 10% of the time, we replace masked input tokens with random word
758
+ indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
759
+ random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
760
+
761
+ inputs = tf.where(indices_random, random_words, inputs)
762
+
763
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
764
+ return inputs, labels
765
+
766
+ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
767
+ import tensorflow as tf
768
+
769
+ # Handle dict or lists with proper padding and conversion to tensor.
770
+ if isinstance(examples[0], Mapping):
771
+ batch = pad_without_fast_tokenizer_warning(
772
+ self.tokenizer, examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of
773
+ )
774
+ else:
775
+ batch = {
776
+ "input_ids": _tf_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
777
+ }
778
+
779
+ # If special token mask has been preprocessed, pop it from the dict.
780
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
781
+ if self.mlm:
782
+ if special_tokens_mask is None:
783
+ special_tokens_mask = [
784
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
785
+ for val in batch["input_ids"].numpy().tolist()
786
+ ]
787
+ # Cannot directly create as bool
788
+ special_tokens_mask = tf.cast(tf.convert_to_tensor(special_tokens_mask, dtype=tf.int64), tf.bool)
789
+ else:
790
+ special_tokens_mask = tf.cast(special_tokens_mask, tf.bool)
791
+ batch["input_ids"], batch["labels"] = self.tf_mask_tokens(
792
+ tf.cast(batch["input_ids"], tf.int64),
793
+ special_tokens_mask=special_tokens_mask,
794
+ mask_token_id=self.tokenizer.mask_token_id,
795
+ vocab_size=len(self.tokenizer),
796
+ )
797
+ else:
798
+ labels = batch["input_ids"]
799
+ if self.tokenizer.pad_token_id is not None:
800
+ # Replace self.tokenizer.pad_token_id with -100
801
+ labels = tf.where(labels == self.tokenizer.pad_token_id, -100, labels)
802
+ else:
803
+ labels = tf.identity(labels) # Makes a copy, just in case
804
+ batch["labels"] = labels
805
+ return batch
806
+
807
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
808
+ # Handle dict or lists with proper padding and conversion to tensor.
809
+ if isinstance(examples[0], Mapping):
810
+ batch = pad_without_fast_tokenizer_warning(
811
+ self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
812
+ )
813
+ else:
814
+ batch = {
815
+ "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
816
+ }
817
+
818
+ # If special token mask has been preprocessed, pop it from the dict.
819
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
820
+ if self.mlm:
821
+ batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
822
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
823
+ )
824
+ else:
825
+ labels = batch["input_ids"].clone()
826
+ if self.tokenizer.pad_token_id is not None:
827
+ labels[labels == self.tokenizer.pad_token_id] = -100
828
+ batch["labels"] = labels
829
+ return batch
830
+
831
+ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
832
+ """
833
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
834
+ """
835
+ import torch
836
+
837
+ labels = inputs.clone()
838
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
839
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
840
+ if special_tokens_mask is None:
841
+ special_tokens_mask = [
842
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
843
+ ]
844
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
845
+ else:
846
+ special_tokens_mask = special_tokens_mask.bool()
847
+
848
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
849
+ masked_indices = torch.bernoulli(probability_matrix).bool()
850
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
851
+
852
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
853
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
854
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
855
+
856
+ # 10% of the time, we replace masked input tokens with random word
857
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
858
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
859
+ inputs[indices_random] = random_words[indices_random]
860
+
861
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
862
+ return inputs, labels
863
+
864
+ def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
865
+ # Handle dict or lists with proper padding and conversion to tensor.
866
+ if isinstance(examples[0], Mapping):
867
+ batch = pad_without_fast_tokenizer_warning(
868
+ self.tokenizer, examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of
869
+ )
870
+ else:
871
+ batch = {
872
+ "input_ids": _numpy_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
873
+ }
874
+
875
+ # If special token mask has been preprocessed, pop it from the dict.
876
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
877
+ if self.mlm:
878
+ batch["input_ids"], batch["labels"] = self.numpy_mask_tokens(
879
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
880
+ )
881
+ else:
882
+ labels = np.copy(batch["input_ids"])
883
+ if self.tokenizer.pad_token_id is not None:
884
+ labels[labels == self.tokenizer.pad_token_id] = -100
885
+ batch["labels"] = labels
886
+ return batch
887
+
888
+ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
889
+ """
890
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
891
+ """
892
+ labels = np.copy(inputs)
893
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
894
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
895
+ if special_tokens_mask is None:
896
+ special_tokens_mask = [
897
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
898
+ ]
899
+ special_tokens_mask = np.array(special_tokens_mask, dtype=bool)
900
+ else:
901
+ special_tokens_mask = special_tokens_mask.astype(bool)
902
+
903
+ probability_matrix[special_tokens_mask] = 0
904
+ # Numpy doesn't have bernoulli, so we use a binomial with 1 trial
905
+ masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
906
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
907
+
908
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
909
+ indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
910
+ inputs[indices_replaced] = self.tokenizer.mask_token_id
911
+
912
+ # 10% of the time, we replace masked input tokens with random word
913
+ # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
914
+ indices_random = (
915
+ np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
916
+ )
917
+ random_words = np.random.randint(
918
+ low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
919
+ )
920
+ inputs[indices_random] = random_words
921
+
922
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
923
+ return inputs, labels
924
+
925
+
926
+ @dataclass
927
+ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
928
+ """
929
+ Data collator used for language modeling that masks entire words.
930
+
931
+ - collates batches of tensors, honoring their tokenizer's pad_token
932
+ - preprocesses batches for masked language modeling
933
+
934
+ <Tip>
935
+
936
+ This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically
937
+ that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will
938
+ produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`].
939
+
940
+ </Tip>"""
941
+
942
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
943
+ if isinstance(examples[0], Mapping):
944
+ input_ids = [e["input_ids"] for e in examples]
945
+ else:
946
+ input_ids = examples
947
+ examples = [{"input_ids": e} for e in examples]
948
+
949
+ batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
950
+
951
+ mask_labels = []
952
+ for e in examples:
953
+ ref_tokens = []
954
+ for id in tolist(e["input_ids"]):
955
+ token = self.tokenizer._convert_id_to_token(id)
956
+ ref_tokens.append(token)
957
+
958
+ # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
959
+ if "chinese_ref" in e:
960
+ ref_pos = tolist(e["chinese_ref"])
961
+ len_seq = len(e["input_ids"])
962
+ for i in range(len_seq):
963
+ if i in ref_pos:
964
+ ref_tokens[i] = "##" + ref_tokens[i]
965
+ mask_labels.append(self._whole_word_mask(ref_tokens))
966
+ batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
967
+ inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
968
+ return {"input_ids": inputs, "labels": labels}
969
+
970
+ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
971
+ import tensorflow as tf
972
+
973
+ if isinstance(examples[0], Mapping):
974
+ input_ids = [e["input_ids"] for e in examples]
975
+ else:
976
+ input_ids = examples
977
+ examples = [{"input_ids": e} for e in examples]
978
+
979
+ batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
980
+
981
+ mask_labels = []
982
+ for e in examples:
983
+ ref_tokens = []
984
+ for id in tolist(e["input_ids"]):
985
+ token = self.tokenizer._convert_id_to_token(id)
986
+ ref_tokens.append(token)
987
+
988
+ # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
989
+ if "chinese_ref" in e:
990
+ ref_pos = tolist(e["chinese_ref"])
991
+ len_seq = len(e["input_ids"])
992
+ for i in range(len_seq):
993
+ if i in ref_pos:
994
+ ref_tokens[i] = "##" + ref_tokens[i]
995
+ mask_labels.append(self._whole_word_mask(ref_tokens))
996
+ batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
997
+ inputs, labels = self.tf_mask_tokens(tf.cast(batch_input, tf.int64), batch_mask)
998
+ return {"input_ids": inputs, "labels": labels}
999
+
1000
+ def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
1001
+ if isinstance(examples[0], Mapping):
1002
+ input_ids = [e["input_ids"] for e in examples]
1003
+ else:
1004
+ input_ids = examples
1005
+ examples = [{"input_ids": e} for e in examples]
1006
+
1007
+ batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
1008
+
1009
+ mask_labels = []
1010
+ for e in examples:
1011
+ ref_tokens = []
1012
+ for id in tolist(e["input_ids"]):
1013
+ token = self.tokenizer._convert_id_to_token(id)
1014
+ ref_tokens.append(token)
1015
+
1016
+ # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
1017
+ if "chinese_ref" in e:
1018
+ ref_pos = tolist(e["chinese_ref"])
1019
+ len_seq = len(e["input_ids"])
1020
+ for i in range(len_seq):
1021
+ if i in ref_pos:
1022
+ ref_tokens[i] = "##" + ref_tokens[i]
1023
+ mask_labels.append(self._whole_word_mask(ref_tokens))
1024
+ batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
1025
+ inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
1026
+ return {"input_ids": inputs, "labels": labels}
1027
+
1028
+ def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
1029
+ """
1030
+ Get 0/1 labels for masked tokens with whole word mask proxy
1031
+ """
1032
+ if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
1033
+ warnings.warn(
1034
+ "DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
1035
+ "Please refer to the documentation for more information."
1036
+ )
1037
+
1038
+ cand_indexes = []
1039
+ for i, token in enumerate(input_tokens):
1040
+ if token == "[CLS]" or token == "[SEP]":
1041
+ continue
1042
+
1043
+ if len(cand_indexes) >= 1 and token.startswith("##"):
1044
+ cand_indexes[-1].append(i)
1045
+ else:
1046
+ cand_indexes.append([i])
1047
+
1048
+ random.shuffle(cand_indexes)
1049
+ num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
1050
+ masked_lms = []
1051
+ covered_indexes = set()
1052
+ for index_set in cand_indexes:
1053
+ if len(masked_lms) >= num_to_predict:
1054
+ break
1055
+ # If adding a whole-word mask would exceed the maximum number of
1056
+ # predictions, then just skip this candidate.
1057
+ if len(masked_lms) + len(index_set) > num_to_predict:
1058
+ continue
1059
+ is_any_index_covered = False
1060
+ for index in index_set:
1061
+ if index in covered_indexes:
1062
+ is_any_index_covered = True
1063
+ break
1064
+ if is_any_index_covered:
1065
+ continue
1066
+ for index in index_set:
1067
+ covered_indexes.add(index)
1068
+ masked_lms.append(index)
1069
+
1070
+ if len(covered_indexes) != len(masked_lms):
1071
+ raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
1072
+ mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
1073
+ return mask_labels
1074
+
1075
+ def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
1076
+ """
1077
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
1078
+ 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
1079
+ """
1080
+ import torch
1081
+
1082
+ if self.tokenizer.mask_token is None:
1083
+ raise ValueError(
1084
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
1085
+ " --mlm flag if you want to use this tokenizer."
1086
+ )
1087
+ labels = inputs.clone()
1088
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
1089
+
1090
+ probability_matrix = mask_labels
1091
+
1092
+ special_tokens_mask = [
1093
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
1094
+ ]
1095
+ probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
1096
+ if self.tokenizer.pad_token is not None:
1097
+ padding_mask = labels.eq(self.tokenizer.pad_token_id)
1098
+ probability_matrix.masked_fill_(padding_mask, value=0.0)
1099
+
1100
+ masked_indices = probability_matrix.bool()
1101
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
1102
+
1103
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1104
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
1105
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
1106
+
1107
+ # 10% of the time, we replace masked input tokens with random word
1108
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
1109
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
1110
+ inputs[indices_random] = random_words[indices_random]
1111
+
1112
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
1113
+ return inputs, labels
1114
+
1115
+ def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
1116
+ """
1117
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
1118
+ 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
1119
+ """
1120
+ import tensorflow as tf
1121
+
1122
+ input_shape = tf.shape(inputs)
1123
+ if self.tokenizer.mask_token is None:
1124
+ raise ValueError(
1125
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
1126
+ " --mlm flag if you want to use this tokenizer."
1127
+ )
1128
+ labels = tf.identity(inputs)
1129
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
1130
+
1131
+ masked_indices = tf.cast(mask_labels, tf.bool)
1132
+
1133
+ special_tokens_mask = [
1134
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
1135
+ ]
1136
+ masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
1137
+ if self.tokenizer.pad_token is not None:
1138
+ padding_mask = inputs == self.tokenizer.pad_token_id
1139
+ masked_indices = masked_indices & ~padding_mask
1140
+
1141
+ # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
1142
+ labels = tf.where(masked_indices, inputs, -100)
1143
+
1144
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1145
+ indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
1146
+
1147
+ inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)
1148
+
1149
+ # 10% of the time, we replace masked input tokens with random word
1150
+ indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
1151
+ random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
1152
+ inputs = tf.where(indices_random, random_words, inputs)
1153
+
1154
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
1155
+ return inputs, labels
1156
+
1157
+ def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
1158
+ """
1159
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
1160
+ 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
1161
+ """
1162
+ if self.tokenizer.mask_token is None:
1163
+ raise ValueError(
1164
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
1165
+ " --mlm flag if you want to use this tokenizer."
1166
+ )
1167
+ labels = np.copy(inputs)
1168
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
1169
+
1170
+ masked_indices = mask_labels.astype(bool)
1171
+
1172
+ special_tokens_mask = [
1173
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
1174
+ ]
1175
+ masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
1176
+ if self.tokenizer.pad_token is not None:
1177
+ padding_mask = labels == self.tokenizer.pad_token_id
1178
+ masked_indices[padding_mask] = 0
1179
+
1180
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
1181
+
1182
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1183
+ indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
1184
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
1185
+
1186
+ # 10% of the time, we replace masked input tokens with random word
1187
+ # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
1188
+ indices_random = (
1189
+ np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
1190
+ )
1191
+ random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
1192
+ inputs[indices_random] = random_words[indices_random]
1193
+
1194
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
1195
+ return inputs, labels
1196
+
1197
+
1198
+ @dataclass
1199
+ class DataCollatorForSOP(DataCollatorForLanguageModeling):
1200
+ """
1201
+ Data collator used for sentence order prediction task.
1202
+
1203
+ - collates batches of tensors, honoring their tokenizer's pad_token
1204
+ - preprocesses batches for both masked language modeling and sentence order prediction
1205
+ """
1206
+
1207
+ def __init__(self, *args, **kwargs):
1208
+ warnings.warn(
1209
+ "DataCollatorForSOP is deprecated and will be removed in a future version, you can now use "
1210
+ "DataCollatorForLanguageModeling instead.",
1211
+ FutureWarning,
1212
+ )
1213
+
1214
+ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
1215
+ import torch
1216
+ from torch.nn.utils.rnn import pad_sequence
1217
+
1218
+ input_ids = [example["input_ids"] for example in examples]
1219
+ input_ids = _torch_collate_batch(input_ids, self.tokenizer)
1220
+ input_ids, labels, attention_mask = self.mask_tokens(input_ids)
1221
+
1222
+ token_type_ids = [example["token_type_ids"] for example in examples]
1223
+ # size of segment_ids varied because randomness, padding zero to the end as the original implementation
1224
+ token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
1225
+
1226
+ sop_label_list = [example["sentence_order_label"] for example in examples]
1227
+ sentence_order_label = torch.stack(sop_label_list)
1228
+
1229
+ return {
1230
+ "input_ids": input_ids,
1231
+ "labels": labels,
1232
+ "attention_mask": attention_mask,
1233
+ "token_type_ids": token_type_ids,
1234
+ "sentence_order_label": sentence_order_label,
1235
+ }
1236
+
1237
+ def mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any]:
1238
+ """
1239
+ Prepare masked tokens inputs/labels/attention_mask for masked language modeling: 80% MASK, 10% random, 10%
1240
+ original. N-gram not applied yet.
1241
+ """
1242
+ import torch
1243
+
1244
+ if self.tokenizer.mask_token is None:
1245
+ raise ValueError(
1246
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
1247
+ " --mlm flag if you want to use this tokenizer."
1248
+ )
1249
+
1250
+ labels = inputs.clone()
1251
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
1252
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
1253
+ special_tokens_mask = [
1254
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
1255
+ ]
1256
+ probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
1257
+ if self.tokenizer.pad_token is not None:
1258
+ padding_mask = labels.eq(self.tokenizer.pad_token_id)
1259
+ probability_matrix.masked_fill_(padding_mask, value=0.0)
1260
+ masked_indices = torch.bernoulli(probability_matrix).bool()
1261
+ # probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
1262
+ attention_mask = (~masked_indices).float()
1263
+ if self.tokenizer.pad_token is not None:
1264
+ attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
1265
+ attention_mask.masked_fill_(attention_padding_mask, value=1.0)
1266
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
1267
+
1268
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1269
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
1270
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
1271
+
1272
+ # 10% of the time, we replace masked input tokens with random word
1273
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
1274
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
1275
+ inputs[indices_random] = random_words[indices_random]
1276
+
1277
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
1278
+ return inputs, labels, attention_mask
1279
+
1280
+
1281
+ @dataclass
1282
+ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
1283
+ """
1284
+ Data collator used for permutation language modeling.
1285
+
1286
+ - collates batches of tensors, honoring their tokenizer's pad_token
1287
+ - preprocesses batches for permutation language modeling with procedures specific to XLNet
1288
+ """
1289
+
1290
+ tokenizer: PreTrainedTokenizerBase
1291
+ plm_probability: float = 1 / 6
1292
+ max_span_length: int = 5 # maximum length of a span of masked tokens
1293
+ return_tensors: str = "pt"
1294
+
1295
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
1296
+ if isinstance(examples[0], Mapping):
1297
+ examples = [e["input_ids"] for e in examples]
1298
+ batch = _torch_collate_batch(examples, self.tokenizer)
1299
+ inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
1300
+ return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
1301
+
1302
+ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
1303
+ if isinstance(examples[0], Mapping):
1304
+ examples = [e["input_ids"] for e in examples]
1305
+ batch = _tf_collate_batch(examples, self.tokenizer)
1306
+ inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch)
1307
+ return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
1308
+
1309
+ def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
1310
+ if isinstance(examples[0], Mapping):
1311
+ examples = [e["input_ids"] for e in examples]
1312
+ batch = _numpy_collate_batch(examples, self.tokenizer)
1313
+ inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
1314
+ return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
1315
+
1316
+ def torch_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
1317
+ """
1318
+ The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
1319
+
1320
+ 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1321
+ 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1322
+ 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
1323
+ masked
1324
+ 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
1325
+ span_length]` and mask tokens `start_index:start_index + span_length`
1326
+ 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
1327
+ sequence to be processed), repeat from Step 1.
1328
+ """
1329
+ import torch
1330
+
1331
+ if self.tokenizer.mask_token is None:
1332
+ raise ValueError(
1333
+ "This tokenizer does not have a mask token which is necessary for permutation language modeling."
1334
+ " Please add a mask token if you want to use this tokenizer."
1335
+ )
1336
+
1337
+ if inputs.size(1) % 2 != 0:
1338
+ raise ValueError(
1339
+ "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
1340
+ " relevant comments in source code for details."
1341
+ )
1342
+
1343
+ labels = inputs.clone()
1344
+ # Creating the mask and target_mapping tensors
1345
+ masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
1346
+ target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
1347
+
1348
+ for i in range(labels.size(0)):
1349
+ # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1350
+ cur_len = 0
1351
+ max_len = labels.size(1)
1352
+
1353
+ while cur_len < max_len:
1354
+ # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1355
+ span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
1356
+ # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
1357
+ context_length = int(span_length / self.plm_probability)
1358
+ # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
1359
+ start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
1360
+ masked_indices[i, start_index : start_index + span_length] = 1
1361
+ # Set `cur_len = cur_len + context_length`
1362
+ cur_len += context_length
1363
+
1364
+ # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
1365
+ # the i-th predict corresponds to the i-th token.
1366
+ target_mapping[i] = torch.eye(labels.size(1))
1367
+
1368
+ special_tokens_mask = torch.tensor(
1369
+ [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
1370
+ dtype=torch.bool,
1371
+ )
1372
+ masked_indices.masked_fill_(special_tokens_mask, value=0.0)
1373
+ if self.tokenizer.pad_token is not None:
1374
+ padding_mask = labels.eq(self.tokenizer.pad_token_id)
1375
+ masked_indices.masked_fill_(padding_mask, value=0.0)
1376
+
1377
+ # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
1378
+ non_func_mask = ~(padding_mask | special_tokens_mask)
1379
+
1380
+ inputs[masked_indices] = self.tokenizer.mask_token_id
1381
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
1382
+
1383
+ perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
1384
+
1385
+ for i in range(labels.size(0)):
1386
+ # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
1387
+ # determine which tokens a given token can attend to (encoded in `perm_mask`).
1388
+ # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
1389
+ # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
1390
+ # we assume that reused length is half of sequence length and permutation length is equal to reused length.
1391
+ # This requires that the sequence length be even.
1392
+
1393
+ # Create a linear factorisation order
1394
+ perm_index = torch.arange(labels.size(1))
1395
+ # Split this into two halves, assuming that half the sequence is reused each time
1396
+ perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
1397
+ # Permute the two halves such that they do not cross over
1398
+ perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
1399
+ # Flatten this out into the desired permuted factorisation order
1400
+ perm_index = torch.flatten(perm_index.transpose(0, 1))
1401
+ # Set the permutation indices of non-masked (non-functional) tokens to the
1402
+ # smallest index (-1) so that:
1403
+ # (1) They can be seen by all other positions
1404
+ # (2) They cannot see masked positions, so there won't be information leak
1405
+ perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
1406
+ # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
1407
+ # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
1408
+ # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
1409
+ perm_mask[i] = (
1410
+ perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
1411
+ ) & masked_indices[i]
1412
+
1413
+ return inputs.long(), perm_mask, target_mapping, labels.long()
1414
+
1415
+ def tf_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
1416
+ """
1417
+ The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
1418
+
1419
+ 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1420
+ 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1421
+ 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
1422
+ masked
1423
+ 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
1424
+ span_length]` and mask tokens `start_index:start_index + span_length`
1425
+ 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
1426
+ sequence to be processed), repeat from Step 1.
1427
+ """
1428
+ import tensorflow as tf
1429
+
1430
+ if self.tokenizer.mask_token is None:
1431
+ raise ValueError(
1432
+ "This tokenizer does not have a mask token which is necessary for permutation language modeling."
1433
+ " Please add a mask token if you want to use this tokenizer."
1434
+ )
1435
+
1436
+ if tf.shape(inputs)[1] % 2 != 0:
1437
+ raise ValueError(
1438
+ "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
1439
+ " relevant comments in source code for details."
1440
+ )
1441
+
1442
+ labels = tf.identity(inputs)
1443
+ # Creating the mask and target_mapping tensors
1444
+ masked_indices = np.full(labels.shape.as_list(), 0, dtype=bool)
1445
+ labels_shape = tf.shape(labels)
1446
+ target_mapping = np.zeros((labels_shape[0], labels_shape[1], labels_shape[1]), dtype=np.float32)
1447
+
1448
+ for i in range(len(labels)):
1449
+ # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1450
+ cur_len = 0
1451
+ max_len = tf.shape(labels)[1]
1452
+
1453
+ while cur_len < max_len:
1454
+ # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1455
+ span_length = randint(1, self.max_span_length + 1)
1456
+ # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
1457
+ context_length = int(span_length / self.plm_probability)
1458
+ # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
1459
+ start_index = cur_len + randint(0, context_length - span_length + 1)
1460
+ masked_indices[i, start_index : start_index + span_length] = 1
1461
+ # Set `cur_len = cur_len + context_length`
1462
+ cur_len += context_length
1463
+
1464
+ # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
1465
+ # the i-th predict corresponds to the i-th token.
1466
+ target_mapping[i] = np.eye(labels_shape[1])
1467
+ masked_indices = tf.cast(tf.convert_to_tensor(masked_indices), dtype=tf.bool)
1468
+ target_mapping = tf.convert_to_tensor(target_mapping)
1469
+ special_tokens_mask = tf.convert_to_tensor(
1470
+ [
1471
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
1472
+ for val in labels.numpy().tolist()
1473
+ ],
1474
+ )
1475
+ special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool)
1476
+ masked_indices = masked_indices & ~special_tokens_mask
1477
+ if self.tokenizer.pad_token is not None:
1478
+ padding_mask = labels == self.tokenizer.pad_token_id
1479
+ masked_indices = masked_indices & ~padding_mask
1480
+
1481
+ # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
1482
+ non_func_mask = ~(padding_mask | special_tokens_mask)
1483
+
1484
+ inputs = tf.where(masked_indices, self.tokenizer.mask_token_id, inputs)
1485
+ labels = tf.where(masked_indices, labels, -100) # We only compute loss on masked tokens
1486
+
1487
+ perm_mask = []
1488
+
1489
+ for i in range(len(labels)):
1490
+ # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
1491
+ # determine which tokens a given token can attend to (encoded in `perm_mask`).
1492
+ # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
1493
+ # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
1494
+ # we assume that reused length is half of sequence length and permutation length is equal to reused length.
1495
+ # This requires that the sequence length be even.
1496
+
1497
+ # Create a linear factorisation order
1498
+ # tf.range is the equivalent of torch.arange
1499
+ perm_index = tf.range(labels_shape[1])
1500
+ # Split this into two halves, assuming that half the sequence is reused each time
1501
+ perm_index = tf.transpose(tf.reshape(perm_index, (-1, labels_shape[1] // 2)))
1502
+ # Permute the two halves such that they do not cross over
1503
+ perm_index = tf.random.shuffle(perm_index) # Shuffles along the first dimension
1504
+ # Flatten this out into the desired permuted factorisation order
1505
+ perm_index = tf.reshape(tf.transpose(perm_index), (-1,))
1506
+ # Set the permutation indices of non-masked (non-functional) tokens to the
1507
+ # smallest index (-1) so that:
1508
+ # (1) They can be seen by all other positions
1509
+ # (2) They cannot see masked positions, so there won't be information leak
1510
+ perm_index = tf.where(~masked_indices[i] & non_func_mask[i], -1, perm_index)
1511
+ # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
1512
+ # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
1513
+ # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
1514
+ perm_mask.append(
1515
+ (tf.reshape(perm_index, (labels_shape[1], 1)) <= tf.reshape(perm_index, (1, labels_shape[1])))
1516
+ & masked_indices[i]
1517
+ )
1518
+ perm_mask = tf.stack(perm_mask, axis=0)
1519
+
1520
+ return tf.cast(inputs, tf.int64), tf.cast(perm_mask, tf.float32), target_mapping, tf.cast(labels, tf.int64)
1521
+
1522
+ def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
1523
+ """
1524
+ The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
1525
+
1526
+ 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1527
+ 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1528
+ 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
1529
+ masked
1530
+ 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
1531
+ span_length]` and mask tokens `start_index:start_index + span_length`
1532
+ 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
1533
+ sequence to be processed), repeat from Step 1.
1534
+ """
1535
+ if self.tokenizer.mask_token is None:
1536
+ raise ValueError(
1537
+ "This tokenizer does not have a mask token which is necessary for permutation language modeling."
1538
+ " Please add a mask token if you want to use this tokenizer."
1539
+ )
1540
+
1541
+ if inputs.shape[1] % 2 != 0:
1542
+ raise ValueError(
1543
+ "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
1544
+ " relevant comments in source code for details."
1545
+ )
1546
+
1547
+ labels = np.copy(inputs)
1548
+ # Creating the mask and target_mapping tensors
1549
+ masked_indices = np.full(labels.shape, 0, dtype=bool)
1550
+ target_mapping = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
1551
+
1552
+ for i in range(labels.shape[0]):
1553
+ # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1554
+ cur_len = 0
1555
+ max_len = labels.shape[1]
1556
+
1557
+ while cur_len < max_len:
1558
+ # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1559
+ span_length = randint(1, self.max_span_length + 1)
1560
+ # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
1561
+ context_length = int(span_length / self.plm_probability)
1562
+ # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
1563
+ start_index = cur_len + randint(0, context_length - span_length + 1)
1564
+ masked_indices[i, start_index : start_index + span_length] = 1
1565
+ # Set `cur_len = cur_len + context_length`
1566
+ cur_len += context_length
1567
+
1568
+ # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
1569
+ # the i-th predict corresponds to the i-th token.
1570
+ target_mapping[i] = np.eye(labels.shape[1])
1571
+
1572
+ special_tokens_mask = np.array(
1573
+ [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
1574
+ dtype=bool,
1575
+ )
1576
+ masked_indices[special_tokens_mask] = 0
1577
+ if self.tokenizer.pad_token is not None:
1578
+ padding_mask = labels == self.tokenizer.pad_token_id
1579
+ masked_indices[padding_mask] = 0.0
1580
+
1581
+ # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
1582
+ non_func_mask = ~(padding_mask | special_tokens_mask)
1583
+
1584
+ inputs[masked_indices] = self.tokenizer.mask_token_id
1585
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
1586
+
1587
+ perm_mask = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
1588
+
1589
+ for i in range(labels.shape[0]):
1590
+ # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
1591
+ # determine which tokens a given token can attend to (encoded in `perm_mask`).
1592
+ # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
1593
+ # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
1594
+ # we assume that reused length is half of sequence length and permutation length is equal to reused length.
1595
+ # This requires that the sequence length be even.
1596
+
1597
+ # Create a linear factorisation order
1598
+ perm_index = np.arange(labels.shape[1])
1599
+ # Split this into two halves, assuming that half the sequence is reused each time
1600
+ perm_index = perm_index.reshape((-1, labels.shape[1] // 2)).T
1601
+ # Permute the two halves such that they do not cross over
1602
+ np.random.shuffle(perm_index)
1603
+ # Flatten this out into the desired permuted factorisation order
1604
+ perm_index = perm_index.T.flatten()
1605
+ # Set the permutation indices of non-masked (non-functional) tokens to the
1606
+ # smallest index (-1) so that:
1607
+ # (1) They can be seen by all other positions
1608
+ # (2) They cannot see masked positions, so there won't be information leak
1609
+ perm_index[~masked_indices[i] & non_func_mask[i]] = -1
1610
+ # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
1611
+ # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
1612
+ # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
1613
+ perm_mask[i] = (
1614
+ perm_index.reshape((labels.shape[1], 1)) <= perm_index.reshape((1, labels.shape[1]))
1615
+ ) & masked_indices[i]
1616
+
1617
+ return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)
1618
+
1619
+
1620
+ @dataclass
1621
+ class DataCollatorWithFlattening(DefaultDataCollator):
1622
+ """
1623
+ Data collator used for padding free approach. Does the following:
1624
+
1625
+ - concatate the entire mini batch into single long sequence [1, total_tokens]
1626
+ - uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
1627
+ - no padding will be added, returns `input_ids`, `labels` and `position_ids`
1628
+ """
1629
+
1630
+ def __init__(self, *args, return_position_ids=True, separator_id=-100, **kwargs):
1631
+ super().__init__(*args, **kwargs)
1632
+ self.return_position_ids = return_position_ids
1633
+ self.separator_id = separator_id
1634
+ warnings.warn(
1635
+ "Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence."
1636
+ "Make sure your attention computation is able to handle it!"
1637
+ )
1638
+
1639
+ def __call__(self, features, return_tensors=None, separator_id=None):
1640
+ if return_tensors is None:
1641
+ return_tensors = self.return_tensors
1642
+ if separator_id is None:
1643
+ separator_id = self.separator_id
1644
+ is_labels_provided = "labels" in features[0]
1645
+ ret = {"input_ids": [], "labels": []}
1646
+ if self.return_position_ids:
1647
+ ret.update({"position_ids": []})
1648
+ for idx in range(0, len(features)):
1649
+ ret["input_ids"] += features[idx]["input_ids"]
1650
+ if is_labels_provided:
1651
+ ret["labels"] += [separator_id] + features[idx]["labels"][1:]
1652
+ else:
1653
+ ret["labels"] += [separator_id] + features[idx]["input_ids"][1:]
1654
+ if self.return_position_ids:
1655
+ ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
1656
+ return default_data_collator([ret], return_tensors)
.venv/lib/python3.11/site-packages/transformers/data/datasets/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .glue import GlueDataset, GlueDataTrainingArguments
16
+ from .language_modeling import (
17
+ LineByLineTextDataset,
18
+ LineByLineWithRefDataset,
19
+ LineByLineWithSOPTextDataset,
20
+ TextDataset,
21
+ TextDatasetForNextSentencePrediction,
22
+ )
23
+ from .squad import SquadDataset, SquadDataTrainingArguments
.venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (674 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/glue.cpython-311.pyc ADDED
Binary file (7.97 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/language_modeling.cpython-311.pyc ADDED
Binary file (27.9 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/squad.cpython-311.pyc ADDED
Binary file (11.2 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/data/datasets/glue.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ import warnings
18
+ from dataclasses import dataclass, field
19
+ from enum import Enum
20
+ from typing import List, Optional, Union
21
+
22
+ import torch
23
+ from filelock import FileLock
24
+ from torch.utils.data import Dataset
25
+
26
+ from ...tokenization_utils_base import PreTrainedTokenizerBase
27
+ from ...utils import logging
28
+ from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
29
+ from ..processors.utils import InputFeatures
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ @dataclass
36
+ class GlueDataTrainingArguments:
37
+ """
38
+ Arguments pertaining to what data we are going to input our model for training and eval.
39
+
40
+ Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
41
+ line.
42
+ """
43
+
44
+ task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
45
+ data_dir: str = field(
46
+ metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
47
+ )
48
+ max_seq_length: int = field(
49
+ default=128,
50
+ metadata={
51
+ "help": (
52
+ "The maximum total input sequence length after tokenization. Sequences longer "
53
+ "than this will be truncated, sequences shorter will be padded."
54
+ )
55
+ },
56
+ )
57
+ overwrite_cache: bool = field(
58
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
59
+ )
60
+
61
+ def __post_init__(self):
62
+ self.task_name = self.task_name.lower()
63
+
64
+
65
+ class Split(Enum):
66
+ train = "train"
67
+ dev = "dev"
68
+ test = "test"
69
+
70
+
71
+ class GlueDataset(Dataset):
72
+ """
73
+ This will be superseded by a framework-agnostic approach soon.
74
+ """
75
+
76
+ args: GlueDataTrainingArguments
77
+ output_mode: str
78
+ features: List[InputFeatures]
79
+
80
+ def __init__(
81
+ self,
82
+ args: GlueDataTrainingArguments,
83
+ tokenizer: PreTrainedTokenizerBase,
84
+ limit_length: Optional[int] = None,
85
+ mode: Union[str, Split] = Split.train,
86
+ cache_dir: Optional[str] = None,
87
+ ):
88
+ warnings.warn(
89
+ "This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
90
+ "library. You can have a look at this example script for pointers: "
91
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py",
92
+ FutureWarning,
93
+ )
94
+ self.args = args
95
+ self.processor = glue_processors[args.task_name]()
96
+ self.output_mode = glue_output_modes[args.task_name]
97
+ if isinstance(mode, str):
98
+ try:
99
+ mode = Split[mode]
100
+ except KeyError:
101
+ raise KeyError("mode is not a valid split name")
102
+ # Load data features from cache or dataset file
103
+ cached_features_file = os.path.join(
104
+ cache_dir if cache_dir is not None else args.data_dir,
105
+ f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{args.task_name}",
106
+ )
107
+ label_list = self.processor.get_labels()
108
+ if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in (
109
+ "RobertaTokenizer",
110
+ "RobertaTokenizerFast",
111
+ "XLMRobertaTokenizer",
112
+ "BartTokenizer",
113
+ "BartTokenizerFast",
114
+ ):
115
+ # HACK(label indices are swapped in RoBERTa pretrained model)
116
+ label_list[1], label_list[2] = label_list[2], label_list[1]
117
+ self.label_list = label_list
118
+
119
+ # Make sure only the first process in distributed training processes the dataset,
120
+ # and the others will use the cache.
121
+ lock_path = cached_features_file + ".lock"
122
+ with FileLock(lock_path):
123
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
124
+ start = time.time()
125
+ self.features = torch.load(cached_features_file)
126
+ logger.info(
127
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
128
+ )
129
+ else:
130
+ logger.info(f"Creating features from dataset file at {args.data_dir}")
131
+
132
+ if mode == Split.dev:
133
+ examples = self.processor.get_dev_examples(args.data_dir)
134
+ elif mode == Split.test:
135
+ examples = self.processor.get_test_examples(args.data_dir)
136
+ else:
137
+ examples = self.processor.get_train_examples(args.data_dir)
138
+ if limit_length is not None:
139
+ examples = examples[:limit_length]
140
+ self.features = glue_convert_examples_to_features(
141
+ examples,
142
+ tokenizer,
143
+ max_length=args.max_seq_length,
144
+ label_list=label_list,
145
+ output_mode=self.output_mode,
146
+ )
147
+ start = time.time()
148
+ torch.save(self.features, cached_features_file)
149
+ # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
150
+ logger.info(
151
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
152
+ )
153
+
154
+ def __len__(self):
155
+ return len(self.features)
156
+
157
+ def __getitem__(self, i) -> InputFeatures:
158
+ return self.features[i]
159
+
160
+ def get_labels(self):
161
+ return self.label_list
.venv/lib/python3.11/site-packages/transformers/data/datasets/language_modeling.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ import pickle
18
+ import random
19
+ import time
20
+ import warnings
21
+ from typing import Dict, List, Optional
22
+
23
+ import torch
24
+ from filelock import FileLock
25
+ from torch.utils.data import Dataset
26
+
27
+ from ...tokenization_utils import PreTrainedTokenizer
28
+ from ...utils import logging
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ DEPRECATION_WARNING = (
35
+ "This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
36
+ "library. You can have a look at this example script for pointers: {0}"
37
+ )
38
+
39
+
40
+ class TextDataset(Dataset):
41
+ """
42
+ This will be superseded by a framework-agnostic approach soon.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ tokenizer: PreTrainedTokenizer,
48
+ file_path: str,
49
+ block_size: int,
50
+ overwrite_cache=False,
51
+ cache_dir: Optional[str] = None,
52
+ ):
53
+ warnings.warn(
54
+ DEPRECATION_WARNING.format(
55
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
56
+ ),
57
+ FutureWarning,
58
+ )
59
+ if os.path.isfile(file_path) is False:
60
+ raise ValueError(f"Input file path {file_path} not found")
61
+
62
+ block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
63
+
64
+ directory, filename = os.path.split(file_path)
65
+ cached_features_file = os.path.join(
66
+ cache_dir if cache_dir is not None else directory,
67
+ f"cached_lm_{tokenizer.__class__.__name__}_{block_size}_{filename}",
68
+ )
69
+
70
+ # Make sure only the first process in distributed training processes the dataset,
71
+ # and the others will use the cache.
72
+ lock_path = cached_features_file + ".lock"
73
+ with FileLock(lock_path):
74
+ if os.path.exists(cached_features_file) and not overwrite_cache:
75
+ start = time.time()
76
+ with open(cached_features_file, "rb") as handle:
77
+ self.examples = pickle.load(handle)
78
+ logger.info(
79
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
80
+ )
81
+
82
+ else:
83
+ logger.info(f"Creating features from dataset file at {directory}")
84
+
85
+ self.examples = []
86
+ with open(file_path, encoding="utf-8") as f:
87
+ text = f.read()
88
+
89
+ tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
90
+
91
+ for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size
92
+ self.examples.append(
93
+ tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size])
94
+ )
95
+ # Note that we are losing the last truncated example here for the sake of simplicity (no padding)
96
+ # If your dataset is small, first you should look for a bigger one :-) and second you
97
+ # can change this behavior by adding (model specific) padding.
98
+
99
+ start = time.time()
100
+ with open(cached_features_file, "wb") as handle:
101
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
102
+ logger.info(
103
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
104
+ )
105
+
106
+ def __len__(self):
107
+ return len(self.examples)
108
+
109
+ def __getitem__(self, i) -> torch.Tensor:
110
+ return torch.tensor(self.examples[i], dtype=torch.long)
111
+
112
+
113
+ class LineByLineTextDataset(Dataset):
114
+ """
115
+ This will be superseded by a framework-agnostic approach soon.
116
+ """
117
+
118
+ def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
119
+ warnings.warn(
120
+ DEPRECATION_WARNING.format(
121
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
122
+ ),
123
+ FutureWarning,
124
+ )
125
+ if os.path.isfile(file_path) is False:
126
+ raise ValueError(f"Input file path {file_path} not found")
127
+ # Here, we do not cache the features, operating under the assumption
128
+ # that we will soon use fast multithreaded tokenizers from the
129
+ # `tokenizers` repo everywhere =)
130
+ logger.info(f"Creating features from dataset file at {file_path}")
131
+
132
+ with open(file_path, encoding="utf-8") as f:
133
+ lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
134
+
135
+ batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)
136
+ self.examples = batch_encoding["input_ids"]
137
+ self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
138
+
139
+ def __len__(self):
140
+ return len(self.examples)
141
+
142
+ def __getitem__(self, i) -> Dict[str, torch.tensor]:
143
+ return self.examples[i]
144
+
145
+
146
+ class LineByLineWithRefDataset(Dataset):
147
+ """
148
+ This will be superseded by a framework-agnostic approach soon.
149
+ """
150
+
151
+ def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
152
+ warnings.warn(
153
+ DEPRECATION_WARNING.format(
154
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm_wwm.py"
155
+ ),
156
+ FutureWarning,
157
+ )
158
+ if os.path.isfile(file_path) is False:
159
+ raise ValueError(f"Input file path {file_path} not found")
160
+ if os.path.isfile(ref_path) is False:
161
+ raise ValueError(f"Ref file path {file_path} not found")
162
+ # Here, we do not cache the features, operating under the assumption
163
+ # that we will soon use fast multithreaded tokenizers from the
164
+ # `tokenizers` repo everywhere =)
165
+ logger.info(f"Creating features from dataset file at {file_path}")
166
+ logger.info(f"Use ref segment results at {ref_path}")
167
+ with open(file_path, encoding="utf-8") as f:
168
+ data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line
169
+ data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
170
+ # Get ref inf from file
171
+ with open(ref_path, encoding="utf-8") as f:
172
+ ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
173
+ if len(data) != len(ref):
174
+ raise ValueError(
175
+ f"Length of Input file should be equal to Ref file. But the length of {file_path} is {len(data)} "
176
+ f"while length of {ref_path} is {len(ref)}"
177
+ )
178
+
179
+ batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
180
+ self.examples = batch_encoding["input_ids"]
181
+ self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
182
+
183
+ n = len(self.examples)
184
+ for i in range(n):
185
+ self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
186
+
187
+ def __len__(self):
188
+ return len(self.examples)
189
+
190
+ def __getitem__(self, i) -> Dict[str, torch.tensor]:
191
+ return self.examples[i]
192
+
193
+
194
+ class LineByLineWithSOPTextDataset(Dataset):
195
+ """
196
+ Dataset for sentence order prediction task, prepare sentence pairs for SOP task
197
+ """
198
+
199
+ def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
200
+ warnings.warn(
201
+ DEPRECATION_WARNING.format(
202
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
203
+ ),
204
+ FutureWarning,
205
+ )
206
+ if os.path.isdir(file_dir) is False:
207
+ raise ValueError(f"{file_dir} is not a directory")
208
+ logger.info(f"Creating features from dataset file folder at {file_dir}")
209
+ self.examples = []
210
+ # TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
211
+ # file path looks like ./dataset/wiki_1, ./dataset/wiki_2
212
+ for file_name in os.listdir(file_dir):
213
+ file_path = os.path.join(file_dir, file_name)
214
+ if os.path.isfile(file_path) is False:
215
+ raise ValueError(f"{file_path} is not a file")
216
+ article_open = False
217
+ with open(file_path, encoding="utf-8") as f:
218
+ original_lines = f.readlines()
219
+ article_lines = []
220
+ for line in original_lines:
221
+ if "<doc id=" in line:
222
+ article_open = True
223
+ elif "</doc>" in line:
224
+ article_open = False
225
+ document = [
226
+ tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))
227
+ for line in article_lines[1:]
228
+ if (len(line) > 0 and not line.isspace())
229
+ ]
230
+
231
+ examples = self.create_examples_from_document(document, block_size, tokenizer)
232
+ self.examples.extend(examples)
233
+ article_lines = []
234
+ else:
235
+ if article_open:
236
+ article_lines.append(line)
237
+
238
+ logger.info("Dataset parse finished.")
239
+
240
+ def create_examples_from_document(self, document, block_size, tokenizer, short_seq_prob=0.1):
241
+ """Creates examples for a single document."""
242
+
243
+ # Account for special tokens
244
+ max_num_tokens = block_size - tokenizer.num_special_tokens_to_add(pair=True)
245
+
246
+ # We *usually* want to fill up the entire sequence since we are padding
247
+ # to `block_size` anyways, so short sequences are generally wasted
248
+ # computation. However, we *sometimes*
249
+ # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
250
+ # sequences to minimize the mismatch between pretraining and fine-tuning.
251
+ # The `target_seq_length` is just a rough target however, whereas
252
+ # `block_size` is a hard limit.
253
+ target_seq_length = max_num_tokens
254
+ if random.random() < short_seq_prob:
255
+ target_seq_length = random.randint(2, max_num_tokens)
256
+
257
+ # We DON'T just concatenate all of the tokens from a document into a long
258
+ # sequence and choose an arbitrary split point because this would make the
259
+ # next sentence prediction task too easy. Instead, we split the input into
260
+ # segments "A" and "B" based on the actual "sentences" provided by the user
261
+ # input.
262
+ examples = []
263
+ current_chunk = [] # a buffer stored current working segments
264
+ current_length = 0
265
+ i = 0
266
+ while i < len(document):
267
+ segment = document[i] # get a segment
268
+ if not segment:
269
+ i += 1
270
+ continue
271
+ current_chunk.append(segment) # add a segment to current chunk
272
+ current_length += len(segment) # overall token length
273
+ # if current length goes to the target length or reaches the end of file, start building token a and b
274
+ if i == len(document) - 1 or current_length >= target_seq_length:
275
+ if current_chunk:
276
+ # `a_end` is how many segments from `current_chunk` go into the `A` (first) sentence.
277
+ a_end = 1
278
+ # if current chunk has more than 2 sentences, pick part of it `A` (first) sentence
279
+ if len(current_chunk) >= 2:
280
+ a_end = random.randint(1, len(current_chunk) - 1)
281
+ # token a
282
+ tokens_a = []
283
+ for j in range(a_end):
284
+ tokens_a.extend(current_chunk[j])
285
+
286
+ # token b
287
+ tokens_b = []
288
+ for j in range(a_end, len(current_chunk)):
289
+ tokens_b.extend(current_chunk[j])
290
+
291
+ if len(tokens_a) == 0 or len(tokens_b) == 0:
292
+ continue
293
+
294
+ # switch tokens_a and tokens_b randomly
295
+ if random.random() < 0.5:
296
+ is_next = False
297
+ tokens_a, tokens_b = tokens_b, tokens_a
298
+ else:
299
+ is_next = True
300
+
301
+ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
302
+ """Truncates a pair of sequences to a maximum sequence length."""
303
+ while True:
304
+ total_length = len(tokens_a) + len(tokens_b)
305
+ if total_length <= max_num_tokens:
306
+ break
307
+ trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
308
+ if not (len(trunc_tokens) >= 1):
309
+ raise ValueError("Sequence length to be truncated must be no less than one")
310
+ # We want to sometimes truncate from the front and sometimes from the
311
+ # back to add more randomness and avoid biases.
312
+ if random.random() < 0.5:
313
+ del trunc_tokens[0]
314
+ else:
315
+ trunc_tokens.pop()
316
+
317
+ truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
318
+ if not (len(tokens_a) >= 1):
319
+ raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
320
+ if not (len(tokens_b) >= 1):
321
+ raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
322
+
323
+ # add special tokens
324
+ input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
325
+ # add token type ids, 0 for sentence a, 1 for sentence b
326
+ token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
327
+
328
+ example = {
329
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
330
+ "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
331
+ "sentence_order_label": torch.tensor(0 if is_next else 1, dtype=torch.long),
332
+ }
333
+ examples.append(example)
334
+ current_chunk = [] # clear current chunk
335
+ current_length = 0 # reset current text length
336
+ i += 1 # go to next line
337
+ return examples
338
+
339
+ def __len__(self):
340
+ return len(self.examples)
341
+
342
+ def __getitem__(self, i) -> Dict[str, torch.tensor]:
343
+ return self.examples[i]
344
+
345
+
346
+ class TextDatasetForNextSentencePrediction(Dataset):
347
+ """
348
+ This will be superseded by a framework-agnostic approach soon.
349
+ """
350
+
351
+ def __init__(
352
+ self,
353
+ tokenizer: PreTrainedTokenizer,
354
+ file_path: str,
355
+ block_size: int,
356
+ overwrite_cache=False,
357
+ short_seq_probability=0.1,
358
+ nsp_probability=0.5,
359
+ ):
360
+ warnings.warn(
361
+ DEPRECATION_WARNING.format(
362
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
363
+ ),
364
+ FutureWarning,
365
+ )
366
+ if not os.path.isfile(file_path):
367
+ raise ValueError(f"Input file path {file_path} not found")
368
+
369
+ self.short_seq_probability = short_seq_probability
370
+ self.nsp_probability = nsp_probability
371
+
372
+ directory, filename = os.path.split(file_path)
373
+ cached_features_file = os.path.join(
374
+ directory,
375
+ f"cached_nsp_{tokenizer.__class__.__name__}_{block_size}_{filename}",
376
+ )
377
+
378
+ self.tokenizer = tokenizer
379
+
380
+ # Make sure only the first process in distributed training processes the dataset,
381
+ # and the others will use the cache.
382
+ lock_path = cached_features_file + ".lock"
383
+
384
+ # Input file format:
385
+ # (1) One sentence per line. These should ideally be actual sentences, not
386
+ # entire paragraphs or arbitrary spans of text. (Because we use the
387
+ # sentence boundaries for the "next sentence prediction" task).
388
+ # (2) Blank lines between documents. Document boundaries are needed so
389
+ # that the "next sentence prediction" task doesn't span between documents.
390
+ #
391
+ # Example:
392
+ # I am very happy.
393
+ # Here is the second sentence.
394
+ #
395
+ # A new document.
396
+
397
+ with FileLock(lock_path):
398
+ if os.path.exists(cached_features_file) and not overwrite_cache:
399
+ start = time.time()
400
+ with open(cached_features_file, "rb") as handle:
401
+ self.examples = pickle.load(handle)
402
+ logger.info(
403
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
404
+ )
405
+ else:
406
+ logger.info(f"Creating features from dataset file at {directory}")
407
+
408
+ self.documents = [[]]
409
+ with open(file_path, encoding="utf-8") as f:
410
+ while True:
411
+ line = f.readline()
412
+ if not line:
413
+ break
414
+ line = line.strip()
415
+
416
+ # Empty lines are used as document delimiters
417
+ if not line and len(self.documents[-1]) != 0:
418
+ self.documents.append([])
419
+ tokens = tokenizer.tokenize(line)
420
+ tokens = tokenizer.convert_tokens_to_ids(tokens)
421
+ if tokens:
422
+ self.documents[-1].append(tokens)
423
+
424
+ logger.info(f"Creating examples from {len(self.documents)} documents.")
425
+ self.examples = []
426
+ for doc_index, document in enumerate(self.documents):
427
+ self.create_examples_from_document(document, doc_index, block_size)
428
+
429
+ start = time.time()
430
+ with open(cached_features_file, "wb") as handle:
431
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
432
+ logger.info(
433
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
434
+ )
435
+
436
+ def create_examples_from_document(self, document: List[List[int]], doc_index: int, block_size: int):
437
+ """Creates examples for a single document."""
438
+
439
+ max_num_tokens = block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
440
+
441
+ # We *usually* want to fill up the entire sequence since we are padding
442
+ # to `block_size` anyways, so short sequences are generally wasted
443
+ # computation. However, we *sometimes*
444
+ # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
445
+ # sequences to minimize the mismatch between pretraining and fine-tuning.
446
+ # The `target_seq_length` is just a rough target however, whereas
447
+ # `block_size` is a hard limit.
448
+ target_seq_length = max_num_tokens
449
+ if random.random() < self.short_seq_probability:
450
+ target_seq_length = random.randint(2, max_num_tokens)
451
+
452
+ current_chunk = [] # a buffer stored current working segments
453
+ current_length = 0
454
+ i = 0
455
+
456
+ while i < len(document):
457
+ segment = document[i]
458
+ current_chunk.append(segment)
459
+ current_length += len(segment)
460
+ if i == len(document) - 1 or current_length >= target_seq_length:
461
+ if current_chunk:
462
+ # `a_end` is how many segments from `current_chunk` go into the `A`
463
+ # (first) sentence.
464
+ a_end = 1
465
+ if len(current_chunk) >= 2:
466
+ a_end = random.randint(1, len(current_chunk) - 1)
467
+
468
+ tokens_a = []
469
+ for j in range(a_end):
470
+ tokens_a.extend(current_chunk[j])
471
+
472
+ tokens_b = []
473
+
474
+ if len(current_chunk) == 1 or random.random() < self.nsp_probability:
475
+ is_random_next = True
476
+ target_b_length = target_seq_length - len(tokens_a)
477
+
478
+ # This should rarely go for more than one iteration for large
479
+ # corpora. However, just to be careful, we try to make sure that
480
+ # the random document is not the same as the document
481
+ # we're processing.
482
+ for _ in range(10):
483
+ random_document_index = random.randint(0, len(self.documents) - 1)
484
+ if random_document_index != doc_index:
485
+ break
486
+
487
+ random_document = self.documents[random_document_index]
488
+ random_start = random.randint(0, len(random_document) - 1)
489
+ for j in range(random_start, len(random_document)):
490
+ tokens_b.extend(random_document[j])
491
+ if len(tokens_b) >= target_b_length:
492
+ break
493
+ # We didn't actually use these segments so we "put them back" so
494
+ # they don't go to waste.
495
+ num_unused_segments = len(current_chunk) - a_end
496
+ i -= num_unused_segments
497
+ # Actual next
498
+ else:
499
+ is_random_next = False
500
+ for j in range(a_end, len(current_chunk)):
501
+ tokens_b.extend(current_chunk[j])
502
+
503
+ if not (len(tokens_a) >= 1):
504
+ raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
505
+ if not (len(tokens_b) >= 1):
506
+ raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
507
+
508
+ # add special tokens
509
+ input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
510
+ # add token type ids, 0 for sentence a, 1 for sentence b
511
+ token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
512
+
513
+ example = {
514
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
515
+ "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
516
+ "next_sentence_label": torch.tensor(1 if is_random_next else 0, dtype=torch.long),
517
+ }
518
+
519
+ self.examples.append(example)
520
+
521
+ current_chunk = []
522
+ current_length = 0
523
+
524
+ i += 1
525
+
526
+ def __len__(self):
527
+ return len(self.examples)
528
+
529
+ def __getitem__(self, i):
530
+ return self.examples[i]
.venv/lib/python3.11/site-packages/transformers/data/datasets/squad.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ from dataclasses import dataclass, field
18
+ from enum import Enum
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import torch
22
+ from filelock import FileLock
23
+ from torch.utils.data import Dataset
24
+
25
+ from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
26
+ from ...tokenization_utils import PreTrainedTokenizer
27
+ from ...utils import logging
28
+ from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
34
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
35
+
36
+
37
+ @dataclass
38
+ class SquadDataTrainingArguments:
39
+ """
40
+ Arguments pertaining to what data we are going to input our model for training and eval.
41
+ """
42
+
43
+ model_type: str = field(
44
+ default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)}
45
+ )
46
+ data_dir: str = field(
47
+ default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
48
+ )
49
+ max_seq_length: int = field(
50
+ default=128,
51
+ metadata={
52
+ "help": (
53
+ "The maximum total input sequence length after tokenization. Sequences longer "
54
+ "than this will be truncated, sequences shorter will be padded."
55
+ )
56
+ },
57
+ )
58
+ doc_stride: int = field(
59
+ default=128,
60
+ metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
61
+ )
62
+ max_query_length: int = field(
63
+ default=64,
64
+ metadata={
65
+ "help": (
66
+ "The maximum number of tokens for the question. Questions longer than this will "
67
+ "be truncated to this length."
68
+ )
69
+ },
70
+ )
71
+ max_answer_length: int = field(
72
+ default=30,
73
+ metadata={
74
+ "help": (
75
+ "The maximum length of an answer that can be generated. This is needed because the start "
76
+ "and end predictions are not conditioned on one another."
77
+ )
78
+ },
79
+ )
80
+ overwrite_cache: bool = field(
81
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
82
+ )
83
+ version_2_with_negative: bool = field(
84
+ default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."}
85
+ )
86
+ null_score_diff_threshold: float = field(
87
+ default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
88
+ )
89
+ n_best_size: int = field(
90
+ default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
91
+ )
92
+ lang_id: int = field(
93
+ default=0,
94
+ metadata={
95
+ "help": (
96
+ "language id of input for language-specific xlm models (see"
97
+ " tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
98
+ )
99
+ },
100
+ )
101
+ threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"})
102
+
103
+
104
+ class Split(Enum):
105
+ train = "train"
106
+ dev = "dev"
107
+
108
+
109
+ class SquadDataset(Dataset):
110
+ """
111
+ This will be superseded by a framework-agnostic approach soon.
112
+ """
113
+
114
+ args: SquadDataTrainingArguments
115
+ features: List[SquadFeatures]
116
+ mode: Split
117
+ is_language_sensitive: bool
118
+
119
+ def __init__(
120
+ self,
121
+ args: SquadDataTrainingArguments,
122
+ tokenizer: PreTrainedTokenizer,
123
+ limit_length: Optional[int] = None,
124
+ mode: Union[str, Split] = Split.train,
125
+ is_language_sensitive: Optional[bool] = False,
126
+ cache_dir: Optional[str] = None,
127
+ dataset_format: Optional[str] = "pt",
128
+ ):
129
+ self.args = args
130
+ self.is_language_sensitive = is_language_sensitive
131
+ self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
132
+ if isinstance(mode, str):
133
+ try:
134
+ mode = Split[mode]
135
+ except KeyError:
136
+ raise KeyError("mode is not a valid split name")
137
+ self.mode = mode
138
+ # Load data features from cache or dataset file
139
+ version_tag = "v2" if args.version_2_with_negative else "v1"
140
+ cached_features_file = os.path.join(
141
+ cache_dir if cache_dir is not None else args.data_dir,
142
+ f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{version_tag}",
143
+ )
144
+
145
+ # Make sure only the first process in distributed training processes the dataset,
146
+ # and the others will use the cache.
147
+ lock_path = cached_features_file + ".lock"
148
+ with FileLock(lock_path):
149
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
150
+ start = time.time()
151
+ self.old_features = torch.load(cached_features_file)
152
+
153
+ # Legacy cache files have only features, while new cache files
154
+ # will have dataset and examples also.
155
+ self.features = self.old_features["features"]
156
+ self.dataset = self.old_features.get("dataset", None)
157
+ self.examples = self.old_features.get("examples", None)
158
+ logger.info(
159
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
160
+ )
161
+
162
+ if self.dataset is None or self.examples is None:
163
+ logger.warning(
164
+ f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in"
165
+ " future run"
166
+ )
167
+ else:
168
+ if mode == Split.dev:
169
+ self.examples = self.processor.get_dev_examples(args.data_dir)
170
+ else:
171
+ self.examples = self.processor.get_train_examples(args.data_dir)
172
+
173
+ self.features, self.dataset = squad_convert_examples_to_features(
174
+ examples=self.examples,
175
+ tokenizer=tokenizer,
176
+ max_seq_length=args.max_seq_length,
177
+ doc_stride=args.doc_stride,
178
+ max_query_length=args.max_query_length,
179
+ is_training=mode == Split.train,
180
+ threads=args.threads,
181
+ return_dataset=dataset_format,
182
+ )
183
+
184
+ start = time.time()
185
+ torch.save(
186
+ {"features": self.features, "dataset": self.dataset, "examples": self.examples},
187
+ cached_features_file,
188
+ )
189
+ # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
190
+ logger.info(
191
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
192
+ )
193
+
194
+ def __len__(self):
195
+ return len(self.features)
196
+
197
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
198
+ # Convert to Tensors and build dataset
199
+ feature = self.features[i]
200
+
201
+ input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
202
+ attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long)
203
+ token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long)
204
+ cls_index = torch.tensor(feature.cls_index, dtype=torch.long)
205
+ p_mask = torch.tensor(feature.p_mask, dtype=torch.float)
206
+ is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float)
207
+
208
+ inputs = {
209
+ "input_ids": input_ids,
210
+ "attention_mask": attention_mask,
211
+ "token_type_ids": token_type_ids,
212
+ }
213
+
214
+ if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
215
+ del inputs["token_type_ids"]
216
+
217
+ if self.args.model_type in ["xlnet", "xlm"]:
218
+ inputs.update({"cls_index": cls_index, "p_mask": p_mask})
219
+ if self.args.version_2_with_negative:
220
+ inputs.update({"is_impossible": is_impossible})
221
+ if self.is_language_sensitive:
222
+ inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)})
223
+
224
+ if self.mode == Split.train:
225
+ start_positions = torch.tensor(feature.start_position, dtype=torch.long)
226
+ end_positions = torch.tensor(feature.end_position, dtype=torch.long)
227
+ inputs.update({"start_positions": start_positions, "end_positions": end_positions})
228
+
229
+ return inputs
.venv/lib/python3.11/site-packages/transformers/data/metrics/__init__.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ import warnings
14
+
15
+ from ...utils import is_sklearn_available, requires_backends
16
+
17
+
18
+ if is_sklearn_available():
19
+ from scipy.stats import pearsonr, spearmanr
20
+ from sklearn.metrics import f1_score, matthews_corrcoef
21
+
22
+
23
+ DEPRECATION_WARNING = (
24
+ "This metric will be removed from the library soon, metrics should be handled with the 🤗 Evaluate "
25
+ "library. You can have a look at this example script for pointers: "
26
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
27
+ )
28
+
29
+
30
+ def simple_accuracy(preds, labels):
31
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
32
+ requires_backends(simple_accuracy, "sklearn")
33
+ return (preds == labels).mean()
34
+
35
+
36
+ def acc_and_f1(preds, labels):
37
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
38
+ requires_backends(acc_and_f1, "sklearn")
39
+ acc = simple_accuracy(preds, labels)
40
+ f1 = f1_score(y_true=labels, y_pred=preds)
41
+ return {
42
+ "acc": acc,
43
+ "f1": f1,
44
+ "acc_and_f1": (acc + f1) / 2,
45
+ }
46
+
47
+
48
+ def pearson_and_spearman(preds, labels):
49
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
50
+ requires_backends(pearson_and_spearman, "sklearn")
51
+ pearson_corr = pearsonr(preds, labels)[0]
52
+ spearman_corr = spearmanr(preds, labels)[0]
53
+ return {
54
+ "pearson": pearson_corr,
55
+ "spearmanr": spearman_corr,
56
+ "corr": (pearson_corr + spearman_corr) / 2,
57
+ }
58
+
59
+
60
+ def glue_compute_metrics(task_name, preds, labels):
61
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
62
+ requires_backends(glue_compute_metrics, "sklearn")
63
+ assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
64
+ if task_name == "cola":
65
+ return {"mcc": matthews_corrcoef(labels, preds)}
66
+ elif task_name == "sst-2":
67
+ return {"acc": simple_accuracy(preds, labels)}
68
+ elif task_name == "mrpc":
69
+ return acc_and_f1(preds, labels)
70
+ elif task_name == "sts-b":
71
+ return pearson_and_spearman(preds, labels)
72
+ elif task_name == "qqp":
73
+ return acc_and_f1(preds, labels)
74
+ elif task_name == "mnli":
75
+ return {"mnli/acc": simple_accuracy(preds, labels)}
76
+ elif task_name == "mnli-mm":
77
+ return {"mnli-mm/acc": simple_accuracy(preds, labels)}
78
+ elif task_name == "qnli":
79
+ return {"acc": simple_accuracy(preds, labels)}
80
+ elif task_name == "rte":
81
+ return {"acc": simple_accuracy(preds, labels)}
82
+ elif task_name == "wnli":
83
+ return {"acc": simple_accuracy(preds, labels)}
84
+ elif task_name == "hans":
85
+ return {"acc": simple_accuracy(preds, labels)}
86
+ else:
87
+ raise KeyError(task_name)
88
+
89
+
90
+ def xnli_compute_metrics(task_name, preds, labels):
91
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
92
+ requires_backends(xnli_compute_metrics, "sklearn")
93
+ if len(preds) != len(labels):
94
+ raise ValueError(f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}")
95
+ if task_name == "xnli":
96
+ return {"acc": simple_accuracy(preds, labels)}
97
+ else:
98
+ raise KeyError(task_name)
.venv/lib/python3.11/site-packages/transformers/data/metrics/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.64 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/data/metrics/__pycache__/squad_metrics.cpython-311.pyc ADDED
Binary file (31 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/data/metrics/squad_metrics.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was modified by XLNet authors to
16
+ update `find_best_threshold` scripts for SQuAD V2.0
17
+
18
+ In addition to basic functionality, we also compute additional statistics and plot precision-recall curves if an
19
+ additional na_prob.json file is provided. This file is expected to map question ID's to the model's predicted
20
+ probability that a question is unanswerable.
21
+ """
22
+
23
+ import collections
24
+ import json
25
+ import math
26
+ import re
27
+ import string
28
+
29
+ from ...models.bert import BasicTokenizer
30
+ from ...utils import logging
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ def normalize_answer(s):
37
+ """Lower text and remove punctuation, articles and extra whitespace."""
38
+
39
+ def remove_articles(text):
40
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
41
+ return re.sub(regex, " ", text)
42
+
43
+ def white_space_fix(text):
44
+ return " ".join(text.split())
45
+
46
+ def remove_punc(text):
47
+ exclude = set(string.punctuation)
48
+ return "".join(ch for ch in text if ch not in exclude)
49
+
50
+ def lower(text):
51
+ return text.lower()
52
+
53
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
54
+
55
+
56
+ def get_tokens(s):
57
+ if not s:
58
+ return []
59
+ return normalize_answer(s).split()
60
+
61
+
62
+ def compute_exact(a_gold, a_pred):
63
+ return int(normalize_answer(a_gold) == normalize_answer(a_pred))
64
+
65
+
66
+ def compute_f1(a_gold, a_pred):
67
+ gold_toks = get_tokens(a_gold)
68
+ pred_toks = get_tokens(a_pred)
69
+ common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
70
+ num_same = sum(common.values())
71
+ if len(gold_toks) == 0 or len(pred_toks) == 0:
72
+ # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
73
+ return int(gold_toks == pred_toks)
74
+ if num_same == 0:
75
+ return 0
76
+ precision = 1.0 * num_same / len(pred_toks)
77
+ recall = 1.0 * num_same / len(gold_toks)
78
+ f1 = (2 * precision * recall) / (precision + recall)
79
+ return f1
80
+
81
+
82
+ def get_raw_scores(examples, preds):
83
+ """
84
+ Computes the exact and f1 scores from the examples and the model predictions
85
+ """
86
+ exact_scores = {}
87
+ f1_scores = {}
88
+
89
+ for example in examples:
90
+ qas_id = example.qas_id
91
+ gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
92
+
93
+ if not gold_answers:
94
+ # For unanswerable questions, only correct answer is empty string
95
+ gold_answers = [""]
96
+
97
+ if qas_id not in preds:
98
+ print(f"Missing prediction for {qas_id}")
99
+ continue
100
+
101
+ prediction = preds[qas_id]
102
+ exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers)
103
+ f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers)
104
+
105
+ return exact_scores, f1_scores
106
+
107
+
108
+ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
109
+ new_scores = {}
110
+ for qid, s in scores.items():
111
+ pred_na = na_probs[qid] > na_prob_thresh
112
+ if pred_na:
113
+ new_scores[qid] = float(not qid_to_has_ans[qid])
114
+ else:
115
+ new_scores[qid] = s
116
+ return new_scores
117
+
118
+
119
+ def make_eval_dict(exact_scores, f1_scores, qid_list=None):
120
+ if not qid_list:
121
+ total = len(exact_scores)
122
+ return collections.OrderedDict(
123
+ [
124
+ ("exact", 100.0 * sum(exact_scores.values()) / total),
125
+ ("f1", 100.0 * sum(f1_scores.values()) / total),
126
+ ("total", total),
127
+ ]
128
+ )
129
+ else:
130
+ total = len(qid_list)
131
+ return collections.OrderedDict(
132
+ [
133
+ ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
134
+ ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
135
+ ("total", total),
136
+ ]
137
+ )
138
+
139
+
140
+ def merge_eval(main_eval, new_eval, prefix):
141
+ for k in new_eval:
142
+ main_eval[f"{prefix}_{k}"] = new_eval[k]
143
+
144
+
145
+ def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
146
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
147
+ cur_score = num_no_ans
148
+ best_score = cur_score
149
+ best_thresh = 0.0
150
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
151
+ for i, qid in enumerate(qid_list):
152
+ if qid not in scores:
153
+ continue
154
+ if qid_to_has_ans[qid]:
155
+ diff = scores[qid]
156
+ else:
157
+ if preds[qid]:
158
+ diff = -1
159
+ else:
160
+ diff = 0
161
+ cur_score += diff
162
+ if cur_score > best_score:
163
+ best_score = cur_score
164
+ best_thresh = na_probs[qid]
165
+
166
+ has_ans_score, has_ans_cnt = 0, 0
167
+ for qid in qid_list:
168
+ if not qid_to_has_ans[qid]:
169
+ continue
170
+ has_ans_cnt += 1
171
+
172
+ if qid not in scores:
173
+ continue
174
+ has_ans_score += scores[qid]
175
+
176
+ return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
177
+
178
+
179
+ def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
180
+ best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
181
+ best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
182
+ main_eval["best_exact"] = best_exact
183
+ main_eval["best_exact_thresh"] = exact_thresh
184
+ main_eval["best_f1"] = best_f1
185
+ main_eval["best_f1_thresh"] = f1_thresh
186
+ main_eval["has_ans_exact"] = has_ans_exact
187
+ main_eval["has_ans_f1"] = has_ans_f1
188
+
189
+
190
+ def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
191
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
192
+ cur_score = num_no_ans
193
+ best_score = cur_score
194
+ best_thresh = 0.0
195
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
196
+ for _, qid in enumerate(qid_list):
197
+ if qid not in scores:
198
+ continue
199
+ if qid_to_has_ans[qid]:
200
+ diff = scores[qid]
201
+ else:
202
+ if preds[qid]:
203
+ diff = -1
204
+ else:
205
+ diff = 0
206
+ cur_score += diff
207
+ if cur_score > best_score:
208
+ best_score = cur_score
209
+ best_thresh = na_probs[qid]
210
+ return 100.0 * best_score / len(scores), best_thresh
211
+
212
+
213
+ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
214
+ best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
215
+ best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
216
+
217
+ main_eval["best_exact"] = best_exact
218
+ main_eval["best_exact_thresh"] = exact_thresh
219
+ main_eval["best_f1"] = best_f1
220
+ main_eval["best_f1_thresh"] = f1_thresh
221
+
222
+
223
+ def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
224
+ qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}
225
+ has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer]
226
+ no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer]
227
+
228
+ if no_answer_probs is None:
229
+ no_answer_probs = {k: 0.0 for k in preds}
230
+
231
+ exact, f1 = get_raw_scores(examples, preds)
232
+
233
+ exact_threshold = apply_no_ans_threshold(
234
+ exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
235
+ )
236
+ f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
237
+
238
+ evaluation = make_eval_dict(exact_threshold, f1_threshold)
239
+
240
+ if has_answer_qids:
241
+ has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
242
+ merge_eval(evaluation, has_ans_eval, "HasAns")
243
+
244
+ if no_answer_qids:
245
+ no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
246
+ merge_eval(evaluation, no_ans_eval, "NoAns")
247
+
248
+ if no_answer_probs:
249
+ find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
250
+
251
+ return evaluation
252
+
253
+
254
+ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
255
+ """Project the tokenized prediction back to the original text."""
256
+
257
+ # When we created the data, we kept track of the alignment between original
258
+ # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
259
+ # now `orig_text` contains the span of our original text corresponding to the
260
+ # span that we predicted.
261
+ #
262
+ # However, `orig_text` may contain extra characters that we don't want in
263
+ # our prediction.
264
+ #
265
+ # For example, let's say:
266
+ # pred_text = steve smith
267
+ # orig_text = Steve Smith's
268
+ #
269
+ # We don't want to return `orig_text` because it contains the extra "'s".
270
+ #
271
+ # We don't want to return `pred_text` because it's already been normalized
272
+ # (the SQuAD eval script also does punctuation stripping/lower casing but
273
+ # our tokenizer does additional normalization like stripping accent
274
+ # characters).
275
+ #
276
+ # What we really want to return is "Steve Smith".
277
+ #
278
+ # Therefore, we have to apply a semi-complicated alignment heuristic between
279
+ # `pred_text` and `orig_text` to get a character-to-character alignment. This
280
+ # can fail in certain cases in which case we just return `orig_text`.
281
+
282
+ def _strip_spaces(text):
283
+ ns_chars = []
284
+ ns_to_s_map = collections.OrderedDict()
285
+ for i, c in enumerate(text):
286
+ if c == " ":
287
+ continue
288
+ ns_to_s_map[len(ns_chars)] = i
289
+ ns_chars.append(c)
290
+ ns_text = "".join(ns_chars)
291
+ return (ns_text, ns_to_s_map)
292
+
293
+ # We first tokenize `orig_text`, strip whitespace from the result
294
+ # and `pred_text`, and check if they are the same length. If they are
295
+ # NOT the same length, the heuristic has failed. If they are the same
296
+ # length, we assume the characters are one-to-one aligned.
297
+ tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
298
+
299
+ tok_text = " ".join(tokenizer.tokenize(orig_text))
300
+
301
+ start_position = tok_text.find(pred_text)
302
+ if start_position == -1:
303
+ if verbose_logging:
304
+ logger.info(f"Unable to find text: '{pred_text}' in '{orig_text}'")
305
+ return orig_text
306
+ end_position = start_position + len(pred_text) - 1
307
+
308
+ (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
309
+ (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
310
+
311
+ if len(orig_ns_text) != len(tok_ns_text):
312
+ if verbose_logging:
313
+ logger.info(f"Length not equal after stripping spaces: '{orig_ns_text}' vs '{tok_ns_text}'")
314
+ return orig_text
315
+
316
+ # We then project the characters in `pred_text` back to `orig_text` using
317
+ # the character-to-character alignment.
318
+ tok_s_to_ns_map = {}
319
+ for i, tok_index in tok_ns_to_s_map.items():
320
+ tok_s_to_ns_map[tok_index] = i
321
+
322
+ orig_start_position = None
323
+ if start_position in tok_s_to_ns_map:
324
+ ns_start_position = tok_s_to_ns_map[start_position]
325
+ if ns_start_position in orig_ns_to_s_map:
326
+ orig_start_position = orig_ns_to_s_map[ns_start_position]
327
+
328
+ if orig_start_position is None:
329
+ if verbose_logging:
330
+ logger.info("Couldn't map start position")
331
+ return orig_text
332
+
333
+ orig_end_position = None
334
+ if end_position in tok_s_to_ns_map:
335
+ ns_end_position = tok_s_to_ns_map[end_position]
336
+ if ns_end_position in orig_ns_to_s_map:
337
+ orig_end_position = orig_ns_to_s_map[ns_end_position]
338
+
339
+ if orig_end_position is None:
340
+ if verbose_logging:
341
+ logger.info("Couldn't map end position")
342
+ return orig_text
343
+
344
+ output_text = orig_text[orig_start_position : (orig_end_position + 1)]
345
+ return output_text
346
+
347
+
348
+ def _get_best_indexes(logits, n_best_size):
349
+ """Get the n-best logits from a list."""
350
+ index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
351
+
352
+ best_indexes = []
353
+ for i in range(len(index_and_score)):
354
+ if i >= n_best_size:
355
+ break
356
+ best_indexes.append(index_and_score[i][0])
357
+ return best_indexes
358
+
359
+
360
+ def _compute_softmax(scores):
361
+ """Compute softmax probability over raw logits."""
362
+ if not scores:
363
+ return []
364
+
365
+ max_score = None
366
+ for score in scores:
367
+ if max_score is None or score > max_score:
368
+ max_score = score
369
+
370
+ exp_scores = []
371
+ total_sum = 0.0
372
+ for score in scores:
373
+ x = math.exp(score - max_score)
374
+ exp_scores.append(x)
375
+ total_sum += x
376
+
377
+ probs = []
378
+ for score in exp_scores:
379
+ probs.append(score / total_sum)
380
+ return probs
381
+
382
+
383
+ def compute_predictions_logits(
384
+ all_examples,
385
+ all_features,
386
+ all_results,
387
+ n_best_size,
388
+ max_answer_length,
389
+ do_lower_case,
390
+ output_prediction_file,
391
+ output_nbest_file,
392
+ output_null_log_odds_file,
393
+ verbose_logging,
394
+ version_2_with_negative,
395
+ null_score_diff_threshold,
396
+ tokenizer,
397
+ ):
398
+ """Write final predictions to the json file and log-odds of null if needed."""
399
+ if output_prediction_file:
400
+ logger.info(f"Writing predictions to: {output_prediction_file}")
401
+ if output_nbest_file:
402
+ logger.info(f"Writing nbest to: {output_nbest_file}")
403
+ if output_null_log_odds_file and version_2_with_negative:
404
+ logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}")
405
+
406
+ example_index_to_features = collections.defaultdict(list)
407
+ for feature in all_features:
408
+ example_index_to_features[feature.example_index].append(feature)
409
+
410
+ unique_id_to_result = {}
411
+ for result in all_results:
412
+ unique_id_to_result[result.unique_id] = result
413
+
414
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
415
+ "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
416
+ )
417
+
418
+ all_predictions = collections.OrderedDict()
419
+ all_nbest_json = collections.OrderedDict()
420
+ scores_diff_json = collections.OrderedDict()
421
+
422
+ for example_index, example in enumerate(all_examples):
423
+ features = example_index_to_features[example_index]
424
+
425
+ prelim_predictions = []
426
+ # keep track of the minimum score of null start+end of position 0
427
+ score_null = 1000000 # large and positive
428
+ min_null_feature_index = 0 # the paragraph slice with min null score
429
+ null_start_logit = 0 # the start logit at the slice with min null score
430
+ null_end_logit = 0 # the end logit at the slice with min null score
431
+ for feature_index, feature in enumerate(features):
432
+ result = unique_id_to_result[feature.unique_id]
433
+ start_indexes = _get_best_indexes(result.start_logits, n_best_size)
434
+ end_indexes = _get_best_indexes(result.end_logits, n_best_size)
435
+ # if we could have irrelevant answers, get the min score of irrelevant
436
+ if version_2_with_negative:
437
+ feature_null_score = result.start_logits[0] + result.end_logits[0]
438
+ if feature_null_score < score_null:
439
+ score_null = feature_null_score
440
+ min_null_feature_index = feature_index
441
+ null_start_logit = result.start_logits[0]
442
+ null_end_logit = result.end_logits[0]
443
+ for start_index in start_indexes:
444
+ for end_index in end_indexes:
445
+ # We could hypothetically create invalid predictions, e.g., predict
446
+ # that the start of the span is in the question. We throw out all
447
+ # invalid predictions.
448
+ if start_index >= len(feature.tokens):
449
+ continue
450
+ if end_index >= len(feature.tokens):
451
+ continue
452
+ if start_index not in feature.token_to_orig_map:
453
+ continue
454
+ if end_index not in feature.token_to_orig_map:
455
+ continue
456
+ if not feature.token_is_max_context.get(start_index, False):
457
+ continue
458
+ if end_index < start_index:
459
+ continue
460
+ length = end_index - start_index + 1
461
+ if length > max_answer_length:
462
+ continue
463
+ prelim_predictions.append(
464
+ _PrelimPrediction(
465
+ feature_index=feature_index,
466
+ start_index=start_index,
467
+ end_index=end_index,
468
+ start_logit=result.start_logits[start_index],
469
+ end_logit=result.end_logits[end_index],
470
+ )
471
+ )
472
+ if version_2_with_negative:
473
+ prelim_predictions.append(
474
+ _PrelimPrediction(
475
+ feature_index=min_null_feature_index,
476
+ start_index=0,
477
+ end_index=0,
478
+ start_logit=null_start_logit,
479
+ end_logit=null_end_logit,
480
+ )
481
+ )
482
+ prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
483
+
484
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
485
+ "NbestPrediction", ["text", "start_logit", "end_logit"]
486
+ )
487
+
488
+ seen_predictions = {}
489
+ nbest = []
490
+ for pred in prelim_predictions:
491
+ if len(nbest) >= n_best_size:
492
+ break
493
+ feature = features[pred.feature_index]
494
+ if pred.start_index > 0: # this is a non-null prediction
495
+ tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
496
+ orig_doc_start = feature.token_to_orig_map[pred.start_index]
497
+ orig_doc_end = feature.token_to_orig_map[pred.end_index]
498
+ orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
499
+
500
+ tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
501
+
502
+ # tok_text = " ".join(tok_tokens)
503
+ #
504
+ # # De-tokenize WordPieces that have been split off.
505
+ # tok_text = tok_text.replace(" ##", "")
506
+ # tok_text = tok_text.replace("##", "")
507
+
508
+ # Clean whitespace
509
+ tok_text = tok_text.strip()
510
+ tok_text = " ".join(tok_text.split())
511
+ orig_text = " ".join(orig_tokens)
512
+
513
+ final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
514
+ if final_text in seen_predictions:
515
+ continue
516
+
517
+ seen_predictions[final_text] = True
518
+ else:
519
+ final_text = ""
520
+ seen_predictions[final_text] = True
521
+
522
+ nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
523
+ # if we didn't include the empty option in the n-best, include it
524
+ if version_2_with_negative:
525
+ if "" not in seen_predictions:
526
+ nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
527
+
528
+ # In very rare edge cases we could only have single null prediction.
529
+ # So we just create a nonce prediction in this case to avoid failure.
530
+ if len(nbest) == 1:
531
+ nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
532
+
533
+ # In very rare edge cases we could have no valid predictions. So we
534
+ # just create a nonce prediction in this case to avoid failure.
535
+ if not nbest:
536
+ nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
537
+
538
+ if len(nbest) < 1:
539
+ raise ValueError("No valid predictions")
540
+
541
+ total_scores = []
542
+ best_non_null_entry = None
543
+ for entry in nbest:
544
+ total_scores.append(entry.start_logit + entry.end_logit)
545
+ if not best_non_null_entry:
546
+ if entry.text:
547
+ best_non_null_entry = entry
548
+
549
+ probs = _compute_softmax(total_scores)
550
+
551
+ nbest_json = []
552
+ for i, entry in enumerate(nbest):
553
+ output = collections.OrderedDict()
554
+ output["text"] = entry.text
555
+ output["probability"] = probs[i]
556
+ output["start_logit"] = entry.start_logit
557
+ output["end_logit"] = entry.end_logit
558
+ nbest_json.append(output)
559
+
560
+ if len(nbest_json) < 1:
561
+ raise ValueError("No valid predictions")
562
+
563
+ if not version_2_with_negative:
564
+ all_predictions[example.qas_id] = nbest_json[0]["text"]
565
+ else:
566
+ # predict "" iff the null score - the score of best non-null > threshold
567
+ score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
568
+ scores_diff_json[example.qas_id] = score_diff
569
+ if score_diff > null_score_diff_threshold:
570
+ all_predictions[example.qas_id] = ""
571
+ else:
572
+ all_predictions[example.qas_id] = best_non_null_entry.text
573
+ all_nbest_json[example.qas_id] = nbest_json
574
+
575
+ if output_prediction_file:
576
+ with open(output_prediction_file, "w") as writer:
577
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
578
+
579
+ if output_nbest_file:
580
+ with open(output_nbest_file, "w") as writer:
581
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
582
+
583
+ if output_null_log_odds_file and version_2_with_negative:
584
+ with open(output_null_log_odds_file, "w") as writer:
585
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
586
+
587
+ return all_predictions
588
+
589
+
590
+ def compute_predictions_log_probs(
591
+ all_examples,
592
+ all_features,
593
+ all_results,
594
+ n_best_size,
595
+ max_answer_length,
596
+ output_prediction_file,
597
+ output_nbest_file,
598
+ output_null_log_odds_file,
599
+ start_n_top,
600
+ end_n_top,
601
+ version_2_with_negative,
602
+ tokenizer,
603
+ verbose_logging,
604
+ ):
605
+ """
606
+ XLNet write prediction logic (more complex than Bert's). Write final predictions to the json file and log-odds of
607
+ null if needed.
608
+
609
+ Requires utils_squad_evaluate.py
610
+ """
611
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
612
+ "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
613
+ )
614
+
615
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
616
+ "NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
617
+ )
618
+
619
+ logger.info(f"Writing predictions to: {output_prediction_file}")
620
+
621
+ example_index_to_features = collections.defaultdict(list)
622
+ for feature in all_features:
623
+ example_index_to_features[feature.example_index].append(feature)
624
+
625
+ unique_id_to_result = {}
626
+ for result in all_results:
627
+ unique_id_to_result[result.unique_id] = result
628
+
629
+ all_predictions = collections.OrderedDict()
630
+ all_nbest_json = collections.OrderedDict()
631
+ scores_diff_json = collections.OrderedDict()
632
+
633
+ for example_index, example in enumerate(all_examples):
634
+ features = example_index_to_features[example_index]
635
+
636
+ prelim_predictions = []
637
+ # keep track of the minimum score of null start+end of position 0
638
+ score_null = 1000000 # large and positive
639
+
640
+ for feature_index, feature in enumerate(features):
641
+ result = unique_id_to_result[feature.unique_id]
642
+
643
+ cur_null_score = result.cls_logits
644
+
645
+ # if we could have irrelevant answers, get the min score of irrelevant
646
+ score_null = min(score_null, cur_null_score)
647
+
648
+ for i in range(start_n_top):
649
+ for j in range(end_n_top):
650
+ start_log_prob = result.start_logits[i]
651
+ start_index = result.start_top_index[i]
652
+
653
+ j_index = i * end_n_top + j
654
+
655
+ end_log_prob = result.end_logits[j_index]
656
+ end_index = result.end_top_index[j_index]
657
+
658
+ # We could hypothetically create invalid predictions, e.g., predict
659
+ # that the start of the span is in the question. We throw out all
660
+ # invalid predictions.
661
+ if start_index >= feature.paragraph_len - 1:
662
+ continue
663
+ if end_index >= feature.paragraph_len - 1:
664
+ continue
665
+
666
+ if not feature.token_is_max_context.get(start_index, False):
667
+ continue
668
+ if end_index < start_index:
669
+ continue
670
+ length = end_index - start_index + 1
671
+ if length > max_answer_length:
672
+ continue
673
+
674
+ prelim_predictions.append(
675
+ _PrelimPrediction(
676
+ feature_index=feature_index,
677
+ start_index=start_index,
678
+ end_index=end_index,
679
+ start_log_prob=start_log_prob,
680
+ end_log_prob=end_log_prob,
681
+ )
682
+ )
683
+
684
+ prelim_predictions = sorted(
685
+ prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
686
+ )
687
+
688
+ seen_predictions = {}
689
+ nbest = []
690
+ for pred in prelim_predictions:
691
+ if len(nbest) >= n_best_size:
692
+ break
693
+ feature = features[pred.feature_index]
694
+
695
+ # XLNet un-tokenizer
696
+ # Let's keep it simple for now and see if we need all this later.
697
+ #
698
+ # tok_start_to_orig_index = feature.tok_start_to_orig_index
699
+ # tok_end_to_orig_index = feature.tok_end_to_orig_index
700
+ # start_orig_pos = tok_start_to_orig_index[pred.start_index]
701
+ # end_orig_pos = tok_end_to_orig_index[pred.end_index]
702
+ # paragraph_text = example.paragraph_text
703
+ # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
704
+
705
+ # Previously used Bert untokenizer
706
+ tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
707
+ orig_doc_start = feature.token_to_orig_map[pred.start_index]
708
+ orig_doc_end = feature.token_to_orig_map[pred.end_index]
709
+ orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
710
+ tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
711
+
712
+ # Clean whitespace
713
+ tok_text = tok_text.strip()
714
+ tok_text = " ".join(tok_text.split())
715
+ orig_text = " ".join(orig_tokens)
716
+
717
+ if hasattr(tokenizer, "do_lower_case"):
718
+ do_lower_case = tokenizer.do_lower_case
719
+ else:
720
+ do_lower_case = tokenizer.do_lowercase_and_remove_accent
721
+
722
+ final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
723
+
724
+ if final_text in seen_predictions:
725
+ continue
726
+
727
+ seen_predictions[final_text] = True
728
+
729
+ nbest.append(
730
+ _NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
731
+ )
732
+
733
+ # In very rare edge cases we could have no valid predictions. So we
734
+ # just create a nonce prediction in this case to avoid failure.
735
+ if not nbest:
736
+ nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
737
+
738
+ total_scores = []
739
+ best_non_null_entry = None
740
+ for entry in nbest:
741
+ total_scores.append(entry.start_log_prob + entry.end_log_prob)
742
+ if not best_non_null_entry:
743
+ best_non_null_entry = entry
744
+
745
+ probs = _compute_softmax(total_scores)
746
+
747
+ nbest_json = []
748
+ for i, entry in enumerate(nbest):
749
+ output = collections.OrderedDict()
750
+ output["text"] = entry.text
751
+ output["probability"] = probs[i]
752
+ output["start_log_prob"] = entry.start_log_prob
753
+ output["end_log_prob"] = entry.end_log_prob
754
+ nbest_json.append(output)
755
+
756
+ if len(nbest_json) < 1:
757
+ raise ValueError("No valid predictions")
758
+ if best_non_null_entry is None:
759
+ raise ValueError("No valid predictions")
760
+
761
+ score_diff = score_null
762
+ scores_diff_json[example.qas_id] = score_diff
763
+ # note(zhiliny): always predict best_non_null_entry
764
+ # and the evaluation script will search for the best threshold
765
+ all_predictions[example.qas_id] = best_non_null_entry.text
766
+
767
+ all_nbest_json[example.qas_id] = nbest_json
768
+
769
+ with open(output_prediction_file, "w") as writer:
770
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
771
+
772
+ with open(output_nbest_file, "w") as writer:
773
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
774
+
775
+ if version_2_with_negative:
776
+ with open(output_null_log_odds_file, "w") as writer:
777
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
778
+
779
+ return all_predictions
.venv/lib/python3.11/site-packages/transformers/data/processors/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels
16
+ from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
17
+ from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor
18
+ from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
.venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (893 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/glue.cpython-311.pyc ADDED
Binary file (35.8 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/squad.cpython-311.pyc ADDED
Binary file (34.7 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/utils.cpython-311.pyc ADDED
Binary file (19.9 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/xnli.cpython-311.pyc ADDED
Binary file (4.2 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/data/processors/glue.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """GLUE processors and helpers"""
17
+
18
+ import os
19
+ import warnings
20
+ from dataclasses import asdict
21
+ from enum import Enum
22
+ from typing import List, Optional, Union
23
+
24
+ from ...tokenization_utils import PreTrainedTokenizer
25
+ from ...utils import is_tf_available, logging
26
+ from .utils import DataProcessor, InputExample, InputFeatures
27
+
28
+
29
+ if is_tf_available():
30
+ import tensorflow as tf
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ DEPRECATION_WARNING = (
35
+ "This {0} will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
36
+ "library. You can have a look at this example script for pointers: "
37
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
38
+ )
39
+
40
+
41
+ def glue_convert_examples_to_features(
42
+ examples: Union[List[InputExample], "tf.data.Dataset"],
43
+ tokenizer: PreTrainedTokenizer,
44
+ max_length: Optional[int] = None,
45
+ task=None,
46
+ label_list=None,
47
+ output_mode=None,
48
+ ):
49
+ """
50
+ Loads a data file into a list of `InputFeatures`
51
+
52
+ Args:
53
+ examples: List of `InputExamples` or `tf.data.Dataset` containing the examples.
54
+ tokenizer: Instance of a tokenizer that will tokenize the examples
55
+ max_length: Maximum example length. Defaults to the tokenizer's max_len
56
+ task: GLUE task
57
+ label_list: List of labels. Can be obtained from the processor using the `processor.get_labels()` method
58
+ output_mode: String indicating the output mode. Either `regression` or `classification`
59
+
60
+ Returns:
61
+ If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the task-specific
62
+ features. If the input is a list of `InputExamples`, will return a list of task-specific `InputFeatures` which
63
+ can be fed to the model.
64
+
65
+ """
66
+ warnings.warn(DEPRECATION_WARNING.format("function"), FutureWarning)
67
+ if is_tf_available() and isinstance(examples, tf.data.Dataset):
68
+ if task is None:
69
+ raise ValueError("When calling glue_convert_examples_to_features from TF, the task parameter is required.")
70
+ return _tf_glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
71
+ return _glue_convert_examples_to_features(
72
+ examples, tokenizer, max_length=max_length, task=task, label_list=label_list, output_mode=output_mode
73
+ )
74
+
75
+
76
+ if is_tf_available():
77
+
78
+ def _tf_glue_convert_examples_to_features(
79
+ examples: tf.data.Dataset,
80
+ tokenizer: PreTrainedTokenizer,
81
+ task=str,
82
+ max_length: Optional[int] = None,
83
+ ) -> tf.data.Dataset:
84
+ """
85
+ Returns:
86
+ A `tf.data.Dataset` containing the task-specific features.
87
+
88
+ """
89
+ processor = glue_processors[task]()
90
+ examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
91
+ features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
92
+ label_type = tf.float32 if task == "sts-b" else tf.int64
93
+
94
+ def gen():
95
+ for ex in features:
96
+ d = {k: v for k, v in asdict(ex).items() if v is not None}
97
+ label = d.pop("label")
98
+ yield (d, label)
99
+
100
+ input_names = tokenizer.model_input_names
101
+
102
+ return tf.data.Dataset.from_generator(
103
+ gen,
104
+ ({k: tf.int32 for k in input_names}, label_type),
105
+ ({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
106
+ )
107
+
108
+
109
+ def _glue_convert_examples_to_features(
110
+ examples: List[InputExample],
111
+ tokenizer: PreTrainedTokenizer,
112
+ max_length: Optional[int] = None,
113
+ task=None,
114
+ label_list=None,
115
+ output_mode=None,
116
+ ):
117
+ if max_length is None:
118
+ max_length = tokenizer.model_max_length
119
+
120
+ if task is not None:
121
+ processor = glue_processors[task]()
122
+ if label_list is None:
123
+ label_list = processor.get_labels()
124
+ logger.info(f"Using label list {label_list} for task {task}")
125
+ if output_mode is None:
126
+ output_mode = glue_output_modes[task]
127
+ logger.info(f"Using output mode {output_mode} for task {task}")
128
+
129
+ label_map = {label: i for i, label in enumerate(label_list)}
130
+
131
+ def label_from_example(example: InputExample) -> Union[int, float, None]:
132
+ if example.label is None:
133
+ return None
134
+ if output_mode == "classification":
135
+ return label_map[example.label]
136
+ elif output_mode == "regression":
137
+ return float(example.label)
138
+ raise KeyError(output_mode)
139
+
140
+ labels = [label_from_example(example) for example in examples]
141
+
142
+ batch_encoding = tokenizer(
143
+ [(example.text_a, example.text_b) for example in examples],
144
+ max_length=max_length,
145
+ padding="max_length",
146
+ truncation=True,
147
+ )
148
+
149
+ features = []
150
+ for i in range(len(examples)):
151
+ inputs = {k: batch_encoding[k][i] for k in batch_encoding}
152
+
153
+ feature = InputFeatures(**inputs, label=labels[i])
154
+ features.append(feature)
155
+
156
+ for i, example in enumerate(examples[:5]):
157
+ logger.info("*** Example ***")
158
+ logger.info(f"guid: {example.guid}")
159
+ logger.info(f"features: {features[i]}")
160
+
161
+ return features
162
+
163
+
164
+ class OutputMode(Enum):
165
+ classification = "classification"
166
+ regression = "regression"
167
+
168
+
169
+ class MrpcProcessor(DataProcessor):
170
+ """Processor for the MRPC data set (GLUE version)."""
171
+
172
+ def __init__(self, *args, **kwargs):
173
+ super().__init__(*args, **kwargs)
174
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
175
+
176
+ def get_example_from_tensor_dict(self, tensor_dict):
177
+ """See base class."""
178
+ return InputExample(
179
+ tensor_dict["idx"].numpy(),
180
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
181
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
182
+ str(tensor_dict["label"].numpy()),
183
+ )
184
+
185
+ def get_train_examples(self, data_dir):
186
+ """See base class."""
187
+ logger.info(f"LOOKING AT {os.path.join(data_dir, 'train.tsv')}")
188
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
189
+
190
+ def get_dev_examples(self, data_dir):
191
+ """See base class."""
192
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
193
+
194
+ def get_test_examples(self, data_dir):
195
+ """See base class."""
196
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
197
+
198
+ def get_labels(self):
199
+ """See base class."""
200
+ return ["0", "1"]
201
+
202
+ def _create_examples(self, lines, set_type):
203
+ """Creates examples for the training, dev and test sets."""
204
+ examples = []
205
+ for i, line in enumerate(lines):
206
+ if i == 0:
207
+ continue
208
+ guid = f"{set_type}-{i}"
209
+ text_a = line[3]
210
+ text_b = line[4]
211
+ label = None if set_type == "test" else line[0]
212
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
213
+ return examples
214
+
215
+
216
+ class MnliProcessor(DataProcessor):
217
+ """Processor for the MultiNLI data set (GLUE version)."""
218
+
219
+ def __init__(self, *args, **kwargs):
220
+ super().__init__(*args, **kwargs)
221
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
222
+
223
+ def get_example_from_tensor_dict(self, tensor_dict):
224
+ """See base class."""
225
+ return InputExample(
226
+ tensor_dict["idx"].numpy(),
227
+ tensor_dict["premise"].numpy().decode("utf-8"),
228
+ tensor_dict["hypothesis"].numpy().decode("utf-8"),
229
+ str(tensor_dict["label"].numpy()),
230
+ )
231
+
232
+ def get_train_examples(self, data_dir):
233
+ """See base class."""
234
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
235
+
236
+ def get_dev_examples(self, data_dir):
237
+ """See base class."""
238
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
239
+
240
+ def get_test_examples(self, data_dir):
241
+ """See base class."""
242
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
243
+
244
+ def get_labels(self):
245
+ """See base class."""
246
+ return ["contradiction", "entailment", "neutral"]
247
+
248
+ def _create_examples(self, lines, set_type):
249
+ """Creates examples for the training, dev and test sets."""
250
+ examples = []
251
+ for i, line in enumerate(lines):
252
+ if i == 0:
253
+ continue
254
+ guid = f"{set_type}-{line[0]}"
255
+ text_a = line[8]
256
+ text_b = line[9]
257
+ label = None if set_type.startswith("test") else line[-1]
258
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
259
+ return examples
260
+
261
+
262
+ class MnliMismatchedProcessor(MnliProcessor):
263
+ """Processor for the MultiNLI Mismatched data set (GLUE version)."""
264
+
265
+ def __init__(self, *args, **kwargs):
266
+ super().__init__(*args, **kwargs)
267
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
268
+
269
+ def get_dev_examples(self, data_dir):
270
+ """See base class."""
271
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
272
+
273
+ def get_test_examples(self, data_dir):
274
+ """See base class."""
275
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
276
+
277
+
278
+ class ColaProcessor(DataProcessor):
279
+ """Processor for the CoLA data set (GLUE version)."""
280
+
281
+ def __init__(self, *args, **kwargs):
282
+ super().__init__(*args, **kwargs)
283
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
284
+
285
+ def get_example_from_tensor_dict(self, tensor_dict):
286
+ """See base class."""
287
+ return InputExample(
288
+ tensor_dict["idx"].numpy(),
289
+ tensor_dict["sentence"].numpy().decode("utf-8"),
290
+ None,
291
+ str(tensor_dict["label"].numpy()),
292
+ )
293
+
294
+ def get_train_examples(self, data_dir):
295
+ """See base class."""
296
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
297
+
298
+ def get_dev_examples(self, data_dir):
299
+ """See base class."""
300
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
301
+
302
+ def get_test_examples(self, data_dir):
303
+ """See base class."""
304
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
305
+
306
+ def get_labels(self):
307
+ """See base class."""
308
+ return ["0", "1"]
309
+
310
+ def _create_examples(self, lines, set_type):
311
+ """Creates examples for the training, dev and test sets."""
312
+ test_mode = set_type == "test"
313
+ if test_mode:
314
+ lines = lines[1:]
315
+ text_index = 1 if test_mode else 3
316
+ examples = []
317
+ for i, line in enumerate(lines):
318
+ guid = f"{set_type}-{i}"
319
+ text_a = line[text_index]
320
+ label = None if test_mode else line[1]
321
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
322
+ return examples
323
+
324
+
325
+ class Sst2Processor(DataProcessor):
326
+ """Processor for the SST-2 data set (GLUE version)."""
327
+
328
+ def __init__(self, *args, **kwargs):
329
+ super().__init__(*args, **kwargs)
330
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
331
+
332
+ def get_example_from_tensor_dict(self, tensor_dict):
333
+ """See base class."""
334
+ return InputExample(
335
+ tensor_dict["idx"].numpy(),
336
+ tensor_dict["sentence"].numpy().decode("utf-8"),
337
+ None,
338
+ str(tensor_dict["label"].numpy()),
339
+ )
340
+
341
+ def get_train_examples(self, data_dir):
342
+ """See base class."""
343
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
344
+
345
+ def get_dev_examples(self, data_dir):
346
+ """See base class."""
347
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
348
+
349
+ def get_test_examples(self, data_dir):
350
+ """See base class."""
351
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
352
+
353
+ def get_labels(self):
354
+ """See base class."""
355
+ return ["0", "1"]
356
+
357
+ def _create_examples(self, lines, set_type):
358
+ """Creates examples for the training, dev and test sets."""
359
+ examples = []
360
+ text_index = 1 if set_type == "test" else 0
361
+ for i, line in enumerate(lines):
362
+ if i == 0:
363
+ continue
364
+ guid = f"{set_type}-{i}"
365
+ text_a = line[text_index]
366
+ label = None if set_type == "test" else line[1]
367
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
368
+ return examples
369
+
370
+
371
+ class StsbProcessor(DataProcessor):
372
+ """Processor for the STS-B data set (GLUE version)."""
373
+
374
+ def __init__(self, *args, **kwargs):
375
+ super().__init__(*args, **kwargs)
376
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
377
+
378
+ def get_example_from_tensor_dict(self, tensor_dict):
379
+ """See base class."""
380
+ return InputExample(
381
+ tensor_dict["idx"].numpy(),
382
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
383
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
384
+ str(tensor_dict["label"].numpy()),
385
+ )
386
+
387
+ def get_train_examples(self, data_dir):
388
+ """See base class."""
389
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
390
+
391
+ def get_dev_examples(self, data_dir):
392
+ """See base class."""
393
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
394
+
395
+ def get_test_examples(self, data_dir):
396
+ """See base class."""
397
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
398
+
399
+ def get_labels(self):
400
+ """See base class."""
401
+ return [None]
402
+
403
+ def _create_examples(self, lines, set_type):
404
+ """Creates examples for the training, dev and test sets."""
405
+ examples = []
406
+ for i, line in enumerate(lines):
407
+ if i == 0:
408
+ continue
409
+ guid = f"{set_type}-{line[0]}"
410
+ text_a = line[7]
411
+ text_b = line[8]
412
+ label = None if set_type == "test" else line[-1]
413
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
414
+ return examples
415
+
416
+
417
+ class QqpProcessor(DataProcessor):
418
+ """Processor for the QQP data set (GLUE version)."""
419
+
420
+ def __init__(self, *args, **kwargs):
421
+ super().__init__(*args, **kwargs)
422
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
423
+
424
+ def get_example_from_tensor_dict(self, tensor_dict):
425
+ """See base class."""
426
+ return InputExample(
427
+ tensor_dict["idx"].numpy(),
428
+ tensor_dict["question1"].numpy().decode("utf-8"),
429
+ tensor_dict["question2"].numpy().decode("utf-8"),
430
+ str(tensor_dict["label"].numpy()),
431
+ )
432
+
433
+ def get_train_examples(self, data_dir):
434
+ """See base class."""
435
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
436
+
437
+ def get_dev_examples(self, data_dir):
438
+ """See base class."""
439
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
440
+
441
+ def get_test_examples(self, data_dir):
442
+ """See base class."""
443
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
444
+
445
+ def get_labels(self):
446
+ """See base class."""
447
+ return ["0", "1"]
448
+
449
+ def _create_examples(self, lines, set_type):
450
+ """Creates examples for the training, dev and test sets."""
451
+ test_mode = set_type == "test"
452
+ q1_index = 1 if test_mode else 3
453
+ q2_index = 2 if test_mode else 4
454
+ examples = []
455
+ for i, line in enumerate(lines):
456
+ if i == 0:
457
+ continue
458
+ guid = f"{set_type}-{line[0]}"
459
+ try:
460
+ text_a = line[q1_index]
461
+ text_b = line[q2_index]
462
+ label = None if test_mode else line[5]
463
+ except IndexError:
464
+ continue
465
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
466
+ return examples
467
+
468
+
469
+ class QnliProcessor(DataProcessor):
470
+ """Processor for the QNLI data set (GLUE version)."""
471
+
472
+ def __init__(self, *args, **kwargs):
473
+ super().__init__(*args, **kwargs)
474
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
475
+
476
+ def get_example_from_tensor_dict(self, tensor_dict):
477
+ """See base class."""
478
+ return InputExample(
479
+ tensor_dict["idx"].numpy(),
480
+ tensor_dict["question"].numpy().decode("utf-8"),
481
+ tensor_dict["sentence"].numpy().decode("utf-8"),
482
+ str(tensor_dict["label"].numpy()),
483
+ )
484
+
485
+ def get_train_examples(self, data_dir):
486
+ """See base class."""
487
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
488
+
489
+ def get_dev_examples(self, data_dir):
490
+ """See base class."""
491
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
492
+
493
+ def get_test_examples(self, data_dir):
494
+ """See base class."""
495
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
496
+
497
+ def get_labels(self):
498
+ """See base class."""
499
+ return ["entailment", "not_entailment"]
500
+
501
+ def _create_examples(self, lines, set_type):
502
+ """Creates examples for the training, dev and test sets."""
503
+ examples = []
504
+ for i, line in enumerate(lines):
505
+ if i == 0:
506
+ continue
507
+ guid = f"{set_type}-{line[0]}"
508
+ text_a = line[1]
509
+ text_b = line[2]
510
+ label = None if set_type == "test" else line[-1]
511
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
512
+ return examples
513
+
514
+
515
+ class RteProcessor(DataProcessor):
516
+ """Processor for the RTE data set (GLUE version)."""
517
+
518
+ def __init__(self, *args, **kwargs):
519
+ super().__init__(*args, **kwargs)
520
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
521
+
522
+ def get_example_from_tensor_dict(self, tensor_dict):
523
+ """See base class."""
524
+ return InputExample(
525
+ tensor_dict["idx"].numpy(),
526
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
527
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
528
+ str(tensor_dict["label"].numpy()),
529
+ )
530
+
531
+ def get_train_examples(self, data_dir):
532
+ """See base class."""
533
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
534
+
535
+ def get_dev_examples(self, data_dir):
536
+ """See base class."""
537
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
538
+
539
+ def get_test_examples(self, data_dir):
540
+ """See base class."""
541
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
542
+
543
+ def get_labels(self):
544
+ """See base class."""
545
+ return ["entailment", "not_entailment"]
546
+
547
+ def _create_examples(self, lines, set_type):
548
+ """Creates examples for the training, dev and test sets."""
549
+ examples = []
550
+ for i, line in enumerate(lines):
551
+ if i == 0:
552
+ continue
553
+ guid = f"{set_type}-{line[0]}"
554
+ text_a = line[1]
555
+ text_b = line[2]
556
+ label = None if set_type == "test" else line[-1]
557
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
558
+ return examples
559
+
560
+
561
+ class WnliProcessor(DataProcessor):
562
+ """Processor for the WNLI data set (GLUE version)."""
563
+
564
+ def __init__(self, *args, **kwargs):
565
+ super().__init__(*args, **kwargs)
566
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
567
+
568
+ def get_example_from_tensor_dict(self, tensor_dict):
569
+ """See base class."""
570
+ return InputExample(
571
+ tensor_dict["idx"].numpy(),
572
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
573
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
574
+ str(tensor_dict["label"].numpy()),
575
+ )
576
+
577
+ def get_train_examples(self, data_dir):
578
+ """See base class."""
579
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
580
+
581
+ def get_dev_examples(self, data_dir):
582
+ """See base class."""
583
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
584
+
585
+ def get_test_examples(self, data_dir):
586
+ """See base class."""
587
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
588
+
589
+ def get_labels(self):
590
+ """See base class."""
591
+ return ["0", "1"]
592
+
593
+ def _create_examples(self, lines, set_type):
594
+ """Creates examples for the training, dev and test sets."""
595
+ examples = []
596
+ for i, line in enumerate(lines):
597
+ if i == 0:
598
+ continue
599
+ guid = f"{set_type}-{line[0]}"
600
+ text_a = line[1]
601
+ text_b = line[2]
602
+ label = None if set_type == "test" else line[-1]
603
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
604
+ return examples
605
+
606
+
607
+ glue_tasks_num_labels = {
608
+ "cola": 2,
609
+ "mnli": 3,
610
+ "mrpc": 2,
611
+ "sst-2": 2,
612
+ "sts-b": 1,
613
+ "qqp": 2,
614
+ "qnli": 2,
615
+ "rte": 2,
616
+ "wnli": 2,
617
+ }
618
+
619
+ glue_processors = {
620
+ "cola": ColaProcessor,
621
+ "mnli": MnliProcessor,
622
+ "mnli-mm": MnliMismatchedProcessor,
623
+ "mrpc": MrpcProcessor,
624
+ "sst-2": Sst2Processor,
625
+ "sts-b": StsbProcessor,
626
+ "qqp": QqpProcessor,
627
+ "qnli": QnliProcessor,
628
+ "rte": RteProcessor,
629
+ "wnli": WnliProcessor,
630
+ }
631
+
632
+ glue_output_modes = {
633
+ "cola": "classification",
634
+ "mnli": "classification",
635
+ "mnli-mm": "classification",
636
+ "mrpc": "classification",
637
+ "sst-2": "classification",
638
+ "sts-b": "regression",
639
+ "qqp": "classification",
640
+ "qnli": "classification",
641
+ "rte": "classification",
642
+ "wnli": "classification",
643
+ }
.venv/lib/python3.11/site-packages/transformers/data/processors/squad.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ from functools import partial
18
+ from multiprocessing import Pool, cpu_count
19
+
20
+ import numpy as np
21
+ from tqdm import tqdm
22
+
23
+ from ...models.bert.tokenization_bert import whitespace_tokenize
24
+ from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy
25
+ from ...utils import is_tf_available, is_torch_available, logging
26
+ from .utils import DataProcessor
27
+
28
+
29
+ # Store the tokenizers which insert 2 separators tokens
30
+ MULTI_SEP_TOKENS_TOKENIZERS_SET = {"roberta", "camembert", "bart", "mpnet"}
31
+
32
+
33
+ if is_torch_available():
34
+ import torch
35
+ from torch.utils.data import TensorDataset
36
+
37
+ if is_tf_available():
38
+ import tensorflow as tf
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
44
+ """Returns tokenized answer spans that better match the annotated answer."""
45
+ tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
46
+
47
+ for new_start in range(input_start, input_end + 1):
48
+ for new_end in range(input_end, new_start - 1, -1):
49
+ text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
50
+ if text_span == tok_answer_text:
51
+ return (new_start, new_end)
52
+
53
+ return (input_start, input_end)
54
+
55
+
56
+ def _check_is_max_context(doc_spans, cur_span_index, position):
57
+ """Check if this is the 'max context' doc span for the token."""
58
+ best_score = None
59
+ best_span_index = None
60
+ for span_index, doc_span in enumerate(doc_spans):
61
+ end = doc_span.start + doc_span.length - 1
62
+ if position < doc_span.start:
63
+ continue
64
+ if position > end:
65
+ continue
66
+ num_left_context = position - doc_span.start
67
+ num_right_context = end - position
68
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
69
+ if best_score is None or score > best_score:
70
+ best_score = score
71
+ best_span_index = span_index
72
+
73
+ return cur_span_index == best_span_index
74
+
75
+
76
+ def _new_check_is_max_context(doc_spans, cur_span_index, position):
77
+ """Check if this is the 'max context' doc span for the token."""
78
+ # if len(doc_spans) == 1:
79
+ # return True
80
+ best_score = None
81
+ best_span_index = None
82
+ for span_index, doc_span in enumerate(doc_spans):
83
+ end = doc_span["start"] + doc_span["length"] - 1
84
+ if position < doc_span["start"]:
85
+ continue
86
+ if position > end:
87
+ continue
88
+ num_left_context = position - doc_span["start"]
89
+ num_right_context = end - position
90
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"]
91
+ if best_score is None or score > best_score:
92
+ best_score = score
93
+ best_span_index = span_index
94
+
95
+ return cur_span_index == best_span_index
96
+
97
+
98
+ def _is_whitespace(c):
99
+ if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
100
+ return True
101
+ return False
102
+
103
+
104
+ def squad_convert_example_to_features(
105
+ example, max_seq_length, doc_stride, max_query_length, padding_strategy, is_training
106
+ ):
107
+ features = []
108
+ if is_training and not example.is_impossible:
109
+ # Get start and end position
110
+ start_position = example.start_position
111
+ end_position = example.end_position
112
+
113
+ # If the answer cannot be found in the text, then skip this example.
114
+ actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
115
+ cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
116
+ if actual_text.find(cleaned_answer_text) == -1:
117
+ logger.warning(f"Could not find answer: '{actual_text}' vs. '{cleaned_answer_text}'")
118
+ return []
119
+
120
+ tok_to_orig_index = []
121
+ orig_to_tok_index = []
122
+ all_doc_tokens = []
123
+ for i, token in enumerate(example.doc_tokens):
124
+ orig_to_tok_index.append(len(all_doc_tokens))
125
+ if tokenizer.__class__.__name__ in [
126
+ "RobertaTokenizer",
127
+ "LongformerTokenizer",
128
+ "BartTokenizer",
129
+ "RobertaTokenizerFast",
130
+ "LongformerTokenizerFast",
131
+ "BartTokenizerFast",
132
+ ]:
133
+ sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
134
+ else:
135
+ sub_tokens = tokenizer.tokenize(token)
136
+ for sub_token in sub_tokens:
137
+ tok_to_orig_index.append(i)
138
+ all_doc_tokens.append(sub_token)
139
+
140
+ if is_training and not example.is_impossible:
141
+ tok_start_position = orig_to_tok_index[example.start_position]
142
+ if example.end_position < len(example.doc_tokens) - 1:
143
+ tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
144
+ else:
145
+ tok_end_position = len(all_doc_tokens) - 1
146
+
147
+ (tok_start_position, tok_end_position) = _improve_answer_span(
148
+ all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
149
+ )
150
+
151
+ spans = []
152
+
153
+ truncated_query = tokenizer.encode(
154
+ example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length
155
+ )
156
+
157
+ # Tokenizers who insert 2 SEP tokens in-between <context> & <question> need to have special handling
158
+ # in the way they compute mask of added tokens.
159
+ tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
160
+ sequence_added_tokens = (
161
+ tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1
162
+ if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
163
+ else tokenizer.model_max_length - tokenizer.max_len_single_sentence
164
+ )
165
+ sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair
166
+
167
+ span_doc_tokens = all_doc_tokens
168
+ while len(spans) * doc_stride < len(all_doc_tokens):
169
+ # Define the side we want to truncate / pad and the text/pair sorting
170
+ if tokenizer.padding_side == "right":
171
+ texts = truncated_query
172
+ pairs = span_doc_tokens
173
+ truncation = TruncationStrategy.ONLY_SECOND.value
174
+ else:
175
+ texts = span_doc_tokens
176
+ pairs = truncated_query
177
+ truncation = TruncationStrategy.ONLY_FIRST.value
178
+
179
+ encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
180
+ texts,
181
+ pairs,
182
+ truncation=truncation,
183
+ padding=padding_strategy,
184
+ max_length=max_seq_length,
185
+ return_overflowing_tokens=True,
186
+ stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
187
+ return_token_type_ids=True,
188
+ )
189
+
190
+ paragraph_len = min(
191
+ len(all_doc_tokens) - len(spans) * doc_stride,
192
+ max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
193
+ )
194
+
195
+ if tokenizer.pad_token_id in encoded_dict["input_ids"]:
196
+ if tokenizer.padding_side == "right":
197
+ non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
198
+ else:
199
+ last_padding_id_position = (
200
+ len(encoded_dict["input_ids"]) - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id)
201
+ )
202
+ non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :]
203
+
204
+ else:
205
+ non_padded_ids = encoded_dict["input_ids"]
206
+
207
+ tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
208
+
209
+ token_to_orig_map = {}
210
+ for i in range(paragraph_len):
211
+ index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
212
+ token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
213
+
214
+ encoded_dict["paragraph_len"] = paragraph_len
215
+ encoded_dict["tokens"] = tokens
216
+ encoded_dict["token_to_orig_map"] = token_to_orig_map
217
+ encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens
218
+ encoded_dict["token_is_max_context"] = {}
219
+ encoded_dict["start"] = len(spans) * doc_stride
220
+ encoded_dict["length"] = paragraph_len
221
+
222
+ spans.append(encoded_dict)
223
+
224
+ if "overflowing_tokens" not in encoded_dict or (
225
+ "overflowing_tokens" in encoded_dict and len(encoded_dict["overflowing_tokens"]) == 0
226
+ ):
227
+ break
228
+ span_doc_tokens = encoded_dict["overflowing_tokens"]
229
+
230
+ for doc_span_index in range(len(spans)):
231
+ for j in range(spans[doc_span_index]["paragraph_len"]):
232
+ is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
233
+ index = (
234
+ j
235
+ if tokenizer.padding_side == "left"
236
+ else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
237
+ )
238
+ spans[doc_span_index]["token_is_max_context"][index] = is_max_context
239
+
240
+ for span in spans:
241
+ # Identify the position of the CLS token
242
+ cls_index = span["input_ids"].index(tokenizer.cls_token_id)
243
+
244
+ # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
245
+ # Original TF implementation also keep the classification token (set to 0)
246
+ p_mask = np.ones_like(span["token_type_ids"])
247
+ if tokenizer.padding_side == "right":
248
+ p_mask[len(truncated_query) + sequence_added_tokens :] = 0
249
+ else:
250
+ p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
251
+
252
+ pad_token_indices = np.where(np.atleast_1d(span["input_ids"] == tokenizer.pad_token_id))
253
+ special_token_indices = np.asarray(
254
+ tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
255
+ ).nonzero()
256
+
257
+ p_mask[pad_token_indices] = 1
258
+ p_mask[special_token_indices] = 1
259
+
260
+ # Set the cls index to 0: the CLS index can be used for impossible answers
261
+ p_mask[cls_index] = 0
262
+
263
+ span_is_impossible = example.is_impossible
264
+ start_position = 0
265
+ end_position = 0
266
+ if is_training and not span_is_impossible:
267
+ # For training, if our document chunk does not contain an annotation
268
+ # we throw it out, since there is nothing to predict.
269
+ doc_start = span["start"]
270
+ doc_end = span["start"] + span["length"] - 1
271
+ out_of_span = False
272
+
273
+ if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
274
+ out_of_span = True
275
+
276
+ if out_of_span:
277
+ start_position = cls_index
278
+ end_position = cls_index
279
+ span_is_impossible = True
280
+ else:
281
+ if tokenizer.padding_side == "left":
282
+ doc_offset = 0
283
+ else:
284
+ doc_offset = len(truncated_query) + sequence_added_tokens
285
+
286
+ start_position = tok_start_position - doc_start + doc_offset
287
+ end_position = tok_end_position - doc_start + doc_offset
288
+
289
+ features.append(
290
+ SquadFeatures(
291
+ span["input_ids"],
292
+ span["attention_mask"],
293
+ span["token_type_ids"],
294
+ cls_index,
295
+ p_mask.tolist(),
296
+ example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
297
+ unique_id=0,
298
+ paragraph_len=span["paragraph_len"],
299
+ token_is_max_context=span["token_is_max_context"],
300
+ tokens=span["tokens"],
301
+ token_to_orig_map=span["token_to_orig_map"],
302
+ start_position=start_position,
303
+ end_position=end_position,
304
+ is_impossible=span_is_impossible,
305
+ qas_id=example.qas_id,
306
+ )
307
+ )
308
+ return features
309
+
310
+
311
+ def squad_convert_example_to_features_init(tokenizer_for_convert: PreTrainedTokenizerBase):
312
+ global tokenizer
313
+ tokenizer = tokenizer_for_convert
314
+
315
+
316
+ def squad_convert_examples_to_features(
317
+ examples,
318
+ tokenizer,
319
+ max_seq_length,
320
+ doc_stride,
321
+ max_query_length,
322
+ is_training,
323
+ padding_strategy="max_length",
324
+ return_dataset=False,
325
+ threads=1,
326
+ tqdm_enabled=True,
327
+ ):
328
+ """
329
+ Converts a list of examples into a list of features that can be directly given as input to a model. It is
330
+ model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
331
+
332
+ Args:
333
+ examples: list of [`~data.processors.squad.SquadExample`]
334
+ tokenizer: an instance of a child of [`PreTrainedTokenizer`]
335
+ max_seq_length: The maximum sequence length of the inputs.
336
+ doc_stride: The stride used when the context is too large and is split across several features.
337
+ max_query_length: The maximum length of the query.
338
+ is_training: whether to create features for model evaluation or model training.
339
+ padding_strategy: Default to "max_length". Which padding strategy to use
340
+ return_dataset: Default False. Either 'pt' or 'tf'.
341
+ if 'pt': returns a torch.data.TensorDataset, if 'tf': returns a tf.data.Dataset
342
+ threads: multiple processing threads.
343
+
344
+
345
+ Returns:
346
+ list of [`~data.processors.squad.SquadFeatures`]
347
+
348
+ Example:
349
+
350
+ ```python
351
+ processor = SquadV2Processor()
352
+ examples = processor.get_dev_examples(data_dir)
353
+
354
+ features = squad_convert_examples_to_features(
355
+ examples=examples,
356
+ tokenizer=tokenizer,
357
+ max_seq_length=args.max_seq_length,
358
+ doc_stride=args.doc_stride,
359
+ max_query_length=args.max_query_length,
360
+ is_training=not evaluate,
361
+ )
362
+ ```"""
363
+ # Defining helper methods
364
+ features = []
365
+
366
+ threads = min(threads, cpu_count())
367
+ with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
368
+ annotate_ = partial(
369
+ squad_convert_example_to_features,
370
+ max_seq_length=max_seq_length,
371
+ doc_stride=doc_stride,
372
+ max_query_length=max_query_length,
373
+ padding_strategy=padding_strategy,
374
+ is_training=is_training,
375
+ )
376
+ features = list(
377
+ tqdm(
378
+ p.imap(annotate_, examples, chunksize=32),
379
+ total=len(examples),
380
+ desc="convert squad examples to features",
381
+ disable=not tqdm_enabled,
382
+ )
383
+ )
384
+
385
+ new_features = []
386
+ unique_id = 1000000000
387
+ example_index = 0
388
+ for example_features in tqdm(
389
+ features, total=len(features), desc="add example index and unique id", disable=not tqdm_enabled
390
+ ):
391
+ if not example_features:
392
+ continue
393
+ for example_feature in example_features:
394
+ example_feature.example_index = example_index
395
+ example_feature.unique_id = unique_id
396
+ new_features.append(example_feature)
397
+ unique_id += 1
398
+ example_index += 1
399
+ features = new_features
400
+ del new_features
401
+ if return_dataset == "pt":
402
+ if not is_torch_available():
403
+ raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
404
+
405
+ # Convert to Tensors and build dataset
406
+ all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
407
+ all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
408
+ all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
409
+ all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
410
+ all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
411
+ all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
412
+
413
+ if not is_training:
414
+ all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
415
+ dataset = TensorDataset(
416
+ all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask
417
+ )
418
+ else:
419
+ all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
420
+ all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
421
+ dataset = TensorDataset(
422
+ all_input_ids,
423
+ all_attention_masks,
424
+ all_token_type_ids,
425
+ all_start_positions,
426
+ all_end_positions,
427
+ all_cls_index,
428
+ all_p_mask,
429
+ all_is_impossible,
430
+ )
431
+
432
+ return features, dataset
433
+ elif return_dataset == "tf":
434
+ if not is_tf_available():
435
+ raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
436
+
437
+ def gen():
438
+ for i, ex in enumerate(features):
439
+ if ex.token_type_ids is None:
440
+ yield (
441
+ {
442
+ "input_ids": ex.input_ids,
443
+ "attention_mask": ex.attention_mask,
444
+ "feature_index": i,
445
+ "qas_id": ex.qas_id,
446
+ },
447
+ {
448
+ "start_positions": ex.start_position,
449
+ "end_positions": ex.end_position,
450
+ "cls_index": ex.cls_index,
451
+ "p_mask": ex.p_mask,
452
+ "is_impossible": ex.is_impossible,
453
+ },
454
+ )
455
+ else:
456
+ yield (
457
+ {
458
+ "input_ids": ex.input_ids,
459
+ "attention_mask": ex.attention_mask,
460
+ "token_type_ids": ex.token_type_ids,
461
+ "feature_index": i,
462
+ "qas_id": ex.qas_id,
463
+ },
464
+ {
465
+ "start_positions": ex.start_position,
466
+ "end_positions": ex.end_position,
467
+ "cls_index": ex.cls_index,
468
+ "p_mask": ex.p_mask,
469
+ "is_impossible": ex.is_impossible,
470
+ },
471
+ )
472
+
473
+ # Why have we split the batch into a tuple? PyTorch just has a list of tensors.
474
+ if "token_type_ids" in tokenizer.model_input_names:
475
+ train_types = (
476
+ {
477
+ "input_ids": tf.int32,
478
+ "attention_mask": tf.int32,
479
+ "token_type_ids": tf.int32,
480
+ "feature_index": tf.int64,
481
+ "qas_id": tf.string,
482
+ },
483
+ {
484
+ "start_positions": tf.int64,
485
+ "end_positions": tf.int64,
486
+ "cls_index": tf.int64,
487
+ "p_mask": tf.int32,
488
+ "is_impossible": tf.int32,
489
+ },
490
+ )
491
+
492
+ train_shapes = (
493
+ {
494
+ "input_ids": tf.TensorShape([None]),
495
+ "attention_mask": tf.TensorShape([None]),
496
+ "token_type_ids": tf.TensorShape([None]),
497
+ "feature_index": tf.TensorShape([]),
498
+ "qas_id": tf.TensorShape([]),
499
+ },
500
+ {
501
+ "start_positions": tf.TensorShape([]),
502
+ "end_positions": tf.TensorShape([]),
503
+ "cls_index": tf.TensorShape([]),
504
+ "p_mask": tf.TensorShape([None]),
505
+ "is_impossible": tf.TensorShape([]),
506
+ },
507
+ )
508
+ else:
509
+ train_types = (
510
+ {"input_ids": tf.int32, "attention_mask": tf.int32, "feature_index": tf.int64, "qas_id": tf.string},
511
+ {
512
+ "start_positions": tf.int64,
513
+ "end_positions": tf.int64,
514
+ "cls_index": tf.int64,
515
+ "p_mask": tf.int32,
516
+ "is_impossible": tf.int32,
517
+ },
518
+ )
519
+
520
+ train_shapes = (
521
+ {
522
+ "input_ids": tf.TensorShape([None]),
523
+ "attention_mask": tf.TensorShape([None]),
524
+ "feature_index": tf.TensorShape([]),
525
+ "qas_id": tf.TensorShape([]),
526
+ },
527
+ {
528
+ "start_positions": tf.TensorShape([]),
529
+ "end_positions": tf.TensorShape([]),
530
+ "cls_index": tf.TensorShape([]),
531
+ "p_mask": tf.TensorShape([None]),
532
+ "is_impossible": tf.TensorShape([]),
533
+ },
534
+ )
535
+
536
+ return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
537
+ else:
538
+ return features
539
+
540
+
541
+ class SquadProcessor(DataProcessor):
542
+ """
543
+ Processor for the SQuAD data set. overridden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and
544
+ version 2.0 of SQuAD, respectively.
545
+ """
546
+
547
+ train_file = None
548
+ dev_file = None
549
+
550
+ def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
551
+ if not evaluate:
552
+ answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8")
553
+ answer_start = tensor_dict["answers"]["answer_start"][0].numpy()
554
+ answers = []
555
+ else:
556
+ answers = [
557
+ {"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")}
558
+ for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"])
559
+ ]
560
+
561
+ answer = None
562
+ answer_start = None
563
+
564
+ return SquadExample(
565
+ qas_id=tensor_dict["id"].numpy().decode("utf-8"),
566
+ question_text=tensor_dict["question"].numpy().decode("utf-8"),
567
+ context_text=tensor_dict["context"].numpy().decode("utf-8"),
568
+ answer_text=answer,
569
+ start_position_character=answer_start,
570
+ title=tensor_dict["title"].numpy().decode("utf-8"),
571
+ answers=answers,
572
+ )
573
+
574
+ def get_examples_from_dataset(self, dataset, evaluate=False):
575
+ """
576
+ Creates a list of [`~data.processors.squad.SquadExample`] using a TFDS dataset.
577
+
578
+ Args:
579
+ dataset: The tfds dataset loaded from *tensorflow_datasets.load("squad")*
580
+ evaluate: Boolean specifying if in evaluation mode or in training mode
581
+
582
+ Returns:
583
+ List of SquadExample
584
+
585
+ Examples:
586
+
587
+ ```python
588
+ >>> import tensorflow_datasets as tfds
589
+
590
+ >>> dataset = tfds.load("squad")
591
+
592
+ >>> training_examples = get_examples_from_dataset(dataset, evaluate=False)
593
+ >>> evaluation_examples = get_examples_from_dataset(dataset, evaluate=True)
594
+ ```"""
595
+
596
+ if evaluate:
597
+ dataset = dataset["validation"]
598
+ else:
599
+ dataset = dataset["train"]
600
+
601
+ examples = []
602
+ for tensor_dict in tqdm(dataset):
603
+ examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
604
+
605
+ return examples
606
+
607
+ def get_train_examples(self, data_dir, filename=None):
608
+ """
609
+ Returns the training examples from the data directory.
610
+
611
+ Args:
612
+ data_dir: Directory containing the data files used for training and evaluating.
613
+ filename: None by default, specify this if the training file has a different name than the original one
614
+ which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
615
+
616
+ """
617
+ if data_dir is None:
618
+ data_dir = ""
619
+
620
+ if self.train_file is None:
621
+ raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
622
+
623
+ with open(
624
+ os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
625
+ ) as reader:
626
+ input_data = json.load(reader)["data"]
627
+ return self._create_examples(input_data, "train")
628
+
629
+ def get_dev_examples(self, data_dir, filename=None):
630
+ """
631
+ Returns the evaluation example from the data directory.
632
+
633
+ Args:
634
+ data_dir: Directory containing the data files used for training and evaluating.
635
+ filename: None by default, specify this if the evaluation file has a different name than the original one
636
+ which is `dev-v1.1.json` and `dev-v2.0.json` for squad versions 1.1 and 2.0 respectively.
637
+ """
638
+ if data_dir is None:
639
+ data_dir = ""
640
+
641
+ if self.dev_file is None:
642
+ raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
643
+
644
+ with open(
645
+ os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
646
+ ) as reader:
647
+ input_data = json.load(reader)["data"]
648
+ return self._create_examples(input_data, "dev")
649
+
650
+ def _create_examples(self, input_data, set_type):
651
+ is_training = set_type == "train"
652
+ examples = []
653
+ for entry in tqdm(input_data):
654
+ title = entry["title"]
655
+ for paragraph in entry["paragraphs"]:
656
+ context_text = paragraph["context"]
657
+ for qa in paragraph["qas"]:
658
+ qas_id = qa["id"]
659
+ question_text = qa["question"]
660
+ start_position_character = None
661
+ answer_text = None
662
+ answers = []
663
+
664
+ is_impossible = qa.get("is_impossible", False)
665
+ if not is_impossible:
666
+ if is_training:
667
+ answer = qa["answers"][0]
668
+ answer_text = answer["text"]
669
+ start_position_character = answer["answer_start"]
670
+ else:
671
+ answers = qa["answers"]
672
+
673
+ example = SquadExample(
674
+ qas_id=qas_id,
675
+ question_text=question_text,
676
+ context_text=context_text,
677
+ answer_text=answer_text,
678
+ start_position_character=start_position_character,
679
+ title=title,
680
+ is_impossible=is_impossible,
681
+ answers=answers,
682
+ )
683
+ examples.append(example)
684
+ return examples
685
+
686
+
687
+ class SquadV1Processor(SquadProcessor):
688
+ train_file = "train-v1.1.json"
689
+ dev_file = "dev-v1.1.json"
690
+
691
+
692
+ class SquadV2Processor(SquadProcessor):
693
+ train_file = "train-v2.0.json"
694
+ dev_file = "dev-v2.0.json"
695
+
696
+
697
+ class SquadExample:
698
+ """
699
+ A single training/test example for the Squad dataset, as loaded from disk.
700
+
701
+ Args:
702
+ qas_id: The example's unique identifier
703
+ question_text: The question string
704
+ context_text: The context string
705
+ answer_text: The answer string
706
+ start_position_character: The character position of the start of the answer
707
+ title: The title of the example
708
+ answers: None by default, this is used during evaluation. Holds answers as well as their start positions.
709
+ is_impossible: False by default, set to True if the example has no possible answer.
710
+ """
711
+
712
+ def __init__(
713
+ self,
714
+ qas_id,
715
+ question_text,
716
+ context_text,
717
+ answer_text,
718
+ start_position_character,
719
+ title,
720
+ answers=[],
721
+ is_impossible=False,
722
+ ):
723
+ self.qas_id = qas_id
724
+ self.question_text = question_text
725
+ self.context_text = context_text
726
+ self.answer_text = answer_text
727
+ self.title = title
728
+ self.is_impossible = is_impossible
729
+ self.answers = answers
730
+
731
+ self.start_position, self.end_position = 0, 0
732
+
733
+ doc_tokens = []
734
+ char_to_word_offset = []
735
+ prev_is_whitespace = True
736
+
737
+ # Split on whitespace so that different tokens may be attributed to their original position.
738
+ for c in self.context_text:
739
+ if _is_whitespace(c):
740
+ prev_is_whitespace = True
741
+ else:
742
+ if prev_is_whitespace:
743
+ doc_tokens.append(c)
744
+ else:
745
+ doc_tokens[-1] += c
746
+ prev_is_whitespace = False
747
+ char_to_word_offset.append(len(doc_tokens) - 1)
748
+
749
+ self.doc_tokens = doc_tokens
750
+ self.char_to_word_offset = char_to_word_offset
751
+
752
+ # Start and end positions only has a value during evaluation.
753
+ if start_position_character is not None and not is_impossible:
754
+ self.start_position = char_to_word_offset[start_position_character]
755
+ self.end_position = char_to_word_offset[
756
+ min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1)
757
+ ]
758
+
759
+
760
+ class SquadFeatures:
761
+ """
762
+ Single squad example features to be fed to a model. Those features are model-specific and can be crafted from
763
+ [`~data.processors.squad.SquadExample`] using the
764
+ :method:*~transformers.data.processors.squad.squad_convert_examples_to_features* method.
765
+
766
+ Args:
767
+ input_ids: Indices of input sequence tokens in the vocabulary.
768
+ attention_mask: Mask to avoid performing attention on padding token indices.
769
+ token_type_ids: Segment token indices to indicate first and second portions of the inputs.
770
+ cls_index: the index of the CLS token.
771
+ p_mask: Mask identifying tokens that can be answers vs. tokens that cannot.
772
+ Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer
773
+ example_index: the index of the example
774
+ unique_id: The unique Feature identifier
775
+ paragraph_len: The length of the context
776
+ token_is_max_context:
777
+ List of booleans identifying which tokens have their maximum context in this feature object. If a token
778
+ does not have their maximum context in this feature object, it means that another feature object has more
779
+ information related to that token and should be prioritized over this feature for that token.
780
+ tokens: list of tokens corresponding to the input ids
781
+ token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer.
782
+ start_position: start of the answer token index
783
+ end_position: end of the answer token index
784
+ encoding: optionally store the BatchEncoding with the fast-tokenizer alignment methods.
785
+ """
786
+
787
+ def __init__(
788
+ self,
789
+ input_ids,
790
+ attention_mask,
791
+ token_type_ids,
792
+ cls_index,
793
+ p_mask,
794
+ example_index,
795
+ unique_id,
796
+ paragraph_len,
797
+ token_is_max_context,
798
+ tokens,
799
+ token_to_orig_map,
800
+ start_position,
801
+ end_position,
802
+ is_impossible,
803
+ qas_id: str = None,
804
+ encoding: BatchEncoding = None,
805
+ ):
806
+ self.input_ids = input_ids
807
+ self.attention_mask = attention_mask
808
+ self.token_type_ids = token_type_ids
809
+ self.cls_index = cls_index
810
+ self.p_mask = p_mask
811
+
812
+ self.example_index = example_index
813
+ self.unique_id = unique_id
814
+ self.paragraph_len = paragraph_len
815
+ self.token_is_max_context = token_is_max_context
816
+ self.tokens = tokens
817
+ self.token_to_orig_map = token_to_orig_map
818
+
819
+ self.start_position = start_position
820
+ self.end_position = end_position
821
+ self.is_impossible = is_impossible
822
+ self.qas_id = qas_id
823
+
824
+ self.encoding = encoding
825
+
826
+
827
+ class SquadResult:
828
+ """
829
+ Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
830
+
831
+ Args:
832
+ unique_id: The unique identifier corresponding to that example.
833
+ start_logits: The logits corresponding to the start of the answer
834
+ end_logits: The logits corresponding to the end of the answer
835
+ """
836
+
837
+ def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
838
+ self.start_logits = start_logits
839
+ self.end_logits = end_logits
840
+ self.unique_id = unique_id
841
+
842
+ if start_top_index:
843
+ self.start_top_index = start_top_index
844
+ self.end_top_index = end_top_index
845
+ self.cls_logits = cls_logits
.venv/lib/python3.11/site-packages/transformers/data/processors/utils.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import csv
18
+ import dataclasses
19
+ import json
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Union
22
+
23
+ from ...utils import is_tf_available, is_torch_available, logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class InputExample:
31
+ """
32
+ A single training/test example for simple sequence classification.
33
+
34
+ Args:
35
+ guid: Unique id for the example.
36
+ text_a: string. The untokenized text of the first sequence. For single
37
+ sequence tasks, only this sequence must be specified.
38
+ text_b: (Optional) string. The untokenized text of the second sequence.
39
+ Only must be specified for sequence pair tasks.
40
+ label: (Optional) string. The label of the example. This should be
41
+ specified for train and dev examples, but not for test examples.
42
+ """
43
+
44
+ guid: str
45
+ text_a: str
46
+ text_b: Optional[str] = None
47
+ label: Optional[str] = None
48
+
49
+ def to_json_string(self):
50
+ """Serializes this instance to a JSON string."""
51
+ return json.dumps(dataclasses.asdict(self), indent=2) + "\n"
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class InputFeatures:
56
+ """
57
+ A single set of features of data. Property names are the same names as the corresponding inputs to a model.
58
+
59
+ Args:
60
+ input_ids: Indices of input sequence tokens in the vocabulary.
61
+ attention_mask: Mask to avoid performing attention on padding token indices.
62
+ Mask values selected in `[0, 1]`: Usually `1` for tokens that are NOT MASKED, `0` for MASKED (padded)
63
+ tokens.
64
+ token_type_ids: (Optional) Segment token indices to indicate first and second
65
+ portions of the inputs. Only some models use them.
66
+ label: (Optional) Label corresponding to the input. Int for classification problems,
67
+ float for regression problems.
68
+ """
69
+
70
+ input_ids: List[int]
71
+ attention_mask: Optional[List[int]] = None
72
+ token_type_ids: Optional[List[int]] = None
73
+ label: Optional[Union[int, float]] = None
74
+
75
+ def to_json_string(self):
76
+ """Serializes this instance to a JSON string."""
77
+ return json.dumps(dataclasses.asdict(self)) + "\n"
78
+
79
+
80
+ class DataProcessor:
81
+ """Base class for data converters for sequence classification data sets."""
82
+
83
+ def get_example_from_tensor_dict(self, tensor_dict):
84
+ """
85
+ Gets an example from a dict with tensorflow tensors.
86
+
87
+ Args:
88
+ tensor_dict: Keys and values should match the corresponding Glue
89
+ tensorflow_dataset examples.
90
+ """
91
+ raise NotImplementedError()
92
+
93
+ def get_train_examples(self, data_dir):
94
+ """Gets a collection of [`InputExample`] for the train set."""
95
+ raise NotImplementedError()
96
+
97
+ def get_dev_examples(self, data_dir):
98
+ """Gets a collection of [`InputExample`] for the dev set."""
99
+ raise NotImplementedError()
100
+
101
+ def get_test_examples(self, data_dir):
102
+ """Gets a collection of [`InputExample`] for the test set."""
103
+ raise NotImplementedError()
104
+
105
+ def get_labels(self):
106
+ """Gets the list of labels for this data set."""
107
+ raise NotImplementedError()
108
+
109
+ def tfds_map(self, example):
110
+ """
111
+ Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. This method converts
112
+ examples to the correct format.
113
+ """
114
+ if len(self.get_labels()) > 1:
115
+ example.label = self.get_labels()[int(example.label)]
116
+ return example
117
+
118
+ @classmethod
119
+ def _read_tsv(cls, input_file, quotechar=None):
120
+ """Reads a tab separated value file."""
121
+ with open(input_file, "r", encoding="utf-8-sig") as f:
122
+ return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
123
+
124
+
125
+ class SingleSentenceClassificationProcessor(DataProcessor):
126
+ """Generic processor for a single sentence classification data set."""
127
+
128
+ def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
129
+ self.labels = [] if labels is None else labels
130
+ self.examples = [] if examples is None else examples
131
+ self.mode = mode
132
+ self.verbose = verbose
133
+
134
+ def __len__(self):
135
+ return len(self.examples)
136
+
137
+ def __getitem__(self, idx):
138
+ if isinstance(idx, slice):
139
+ return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
140
+ return self.examples[idx]
141
+
142
+ @classmethod
143
+ def create_from_csv(
144
+ cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
145
+ ):
146
+ processor = cls(**kwargs)
147
+ processor.add_examples_from_csv(
148
+ file_name,
149
+ split_name=split_name,
150
+ column_label=column_label,
151
+ column_text=column_text,
152
+ column_id=column_id,
153
+ skip_first_row=skip_first_row,
154
+ overwrite_labels=True,
155
+ overwrite_examples=True,
156
+ )
157
+ return processor
158
+
159
+ @classmethod
160
+ def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
161
+ processor = cls(**kwargs)
162
+ processor.add_examples(texts_or_text_and_labels, labels=labels)
163
+ return processor
164
+
165
+ def add_examples_from_csv(
166
+ self,
167
+ file_name,
168
+ split_name="",
169
+ column_label=0,
170
+ column_text=1,
171
+ column_id=None,
172
+ skip_first_row=False,
173
+ overwrite_labels=False,
174
+ overwrite_examples=False,
175
+ ):
176
+ lines = self._read_tsv(file_name)
177
+ if skip_first_row:
178
+ lines = lines[1:]
179
+ texts = []
180
+ labels = []
181
+ ids = []
182
+ for i, line in enumerate(lines):
183
+ texts.append(line[column_text])
184
+ labels.append(line[column_label])
185
+ if column_id is not None:
186
+ ids.append(line[column_id])
187
+ else:
188
+ guid = f"{split_name}-{i}" if split_name else str(i)
189
+ ids.append(guid)
190
+
191
+ return self.add_examples(
192
+ texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
193
+ )
194
+
195
+ def add_examples(
196
+ self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
197
+ ):
198
+ if labels is not None and len(texts_or_text_and_labels) != len(labels):
199
+ raise ValueError(
200
+ f"Text and labels have mismatched lengths {len(texts_or_text_and_labels)} and {len(labels)}"
201
+ )
202
+ if ids is not None and len(texts_or_text_and_labels) != len(ids):
203
+ raise ValueError(f"Text and ids have mismatched lengths {len(texts_or_text_and_labels)} and {len(ids)}")
204
+ if ids is None:
205
+ ids = [None] * len(texts_or_text_and_labels)
206
+ if labels is None:
207
+ labels = [None] * len(texts_or_text_and_labels)
208
+ examples = []
209
+ added_labels = set()
210
+ for text_or_text_and_label, label, guid in zip(texts_or_text_and_labels, labels, ids):
211
+ if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
212
+ text, label = text_or_text_and_label
213
+ else:
214
+ text = text_or_text_and_label
215
+ added_labels.add(label)
216
+ examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
217
+
218
+ # Update examples
219
+ if overwrite_examples:
220
+ self.examples = examples
221
+ else:
222
+ self.examples.extend(examples)
223
+
224
+ # Update labels
225
+ if overwrite_labels:
226
+ self.labels = list(added_labels)
227
+ else:
228
+ self.labels = list(set(self.labels).union(added_labels))
229
+
230
+ return self.examples
231
+
232
+ def get_features(
233
+ self,
234
+ tokenizer,
235
+ max_length=None,
236
+ pad_on_left=False,
237
+ pad_token=0,
238
+ mask_padding_with_zero=True,
239
+ return_tensors=None,
240
+ ):
241
+ """
242
+ Convert examples in a list of `InputFeatures`
243
+
244
+ Args:
245
+ tokenizer: Instance of a tokenizer that will tokenize the examples
246
+ max_length: Maximum example length
247
+ pad_on_left: If set to `True`, the examples will be padded on the left rather than on the right (default)
248
+ pad_token: Padding token
249
+ mask_padding_with_zero: If set to `True`, the attention mask will be filled by `1` for actual values
250
+ and by `0` for padded values. If set to `False`, inverts it (`1` for padded values, `0` for actual
251
+ values)
252
+
253
+ Returns:
254
+ If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the
255
+ task-specific features. If the input is a list of `InputExamples`, will return a list of task-specific
256
+ `InputFeatures` which can be fed to the model.
257
+
258
+ """
259
+ if max_length is None:
260
+ max_length = tokenizer.max_len
261
+
262
+ label_map = {label: i for i, label in enumerate(self.labels)}
263
+
264
+ all_input_ids = []
265
+ for ex_index, example in enumerate(self.examples):
266
+ if ex_index % 10000 == 0:
267
+ logger.info(f"Tokenizing example {ex_index}")
268
+
269
+ input_ids = tokenizer.encode(
270
+ example.text_a,
271
+ add_special_tokens=True,
272
+ max_length=min(max_length, tokenizer.max_len),
273
+ )
274
+ all_input_ids.append(input_ids)
275
+
276
+ batch_length = max(len(input_ids) for input_ids in all_input_ids)
277
+
278
+ features = []
279
+ for ex_index, (input_ids, example) in enumerate(zip(all_input_ids, self.examples)):
280
+ if ex_index % 10000 == 0:
281
+ logger.info(f"Writing example {ex_index}/{len(self.examples)}")
282
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
283
+ # tokens are attended to.
284
+ attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
285
+
286
+ # Zero-pad up to the sequence length.
287
+ padding_length = batch_length - len(input_ids)
288
+ if pad_on_left:
289
+ input_ids = ([pad_token] * padding_length) + input_ids
290
+ attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
291
+ else:
292
+ input_ids = input_ids + ([pad_token] * padding_length)
293
+ attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
294
+
295
+ if len(input_ids) != batch_length:
296
+ raise ValueError(f"Error with input length {len(input_ids)} vs {batch_length}")
297
+ if len(attention_mask) != batch_length:
298
+ raise ValueError(f"Error with input length {len(attention_mask)} vs {batch_length}")
299
+
300
+ if self.mode == "classification":
301
+ label = label_map[example.label]
302
+ elif self.mode == "regression":
303
+ label = float(example.label)
304
+ else:
305
+ raise ValueError(self.mode)
306
+
307
+ if ex_index < 5 and self.verbose:
308
+ logger.info("*** Example ***")
309
+ logger.info(f"guid: {example.guid}")
310
+ logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}")
311
+ logger.info(f"attention_mask: {' '.join([str(x) for x in attention_mask])}")
312
+ logger.info(f"label: {example.label} (id = {label})")
313
+
314
+ features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
315
+
316
+ if return_tensors is None:
317
+ return features
318
+ elif return_tensors == "tf":
319
+ if not is_tf_available():
320
+ raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
321
+ import tensorflow as tf
322
+
323
+ def gen():
324
+ for ex in features:
325
+ yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
326
+
327
+ dataset = tf.data.Dataset.from_generator(
328
+ gen,
329
+ ({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
330
+ ({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
331
+ )
332
+ return dataset
333
+ elif return_tensors == "pt":
334
+ if not is_torch_available():
335
+ raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
336
+ import torch
337
+ from torch.utils.data import TensorDataset
338
+
339
+ all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
340
+ all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
341
+ if self.mode == "classification":
342
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
343
+ elif self.mode == "regression":
344
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
345
+
346
+ dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels)
347
+ return dataset
348
+ else:
349
+ raise ValueError("return_tensors should be one of 'tf' or 'pt'")
.venv/lib/python3.11/site-packages/transformers/data/processors/xnli.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """XNLI utils (dataset loading and evaluation)"""
17
+
18
+ import os
19
+
20
+ from ...utils import logging
21
+ from .utils import DataProcessor, InputExample
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class XnliProcessor(DataProcessor):
28
+ """
29
+ Processor for the XNLI dataset. Adapted from
30
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207
31
+ """
32
+
33
+ def __init__(self, language, train_language=None):
34
+ self.language = language
35
+ self.train_language = train_language
36
+
37
+ def get_train_examples(self, data_dir):
38
+ """See base class."""
39
+ lg = self.language if self.train_language is None else self.train_language
40
+ lines = self._read_tsv(os.path.join(data_dir, f"XNLI-MT-1.0/multinli/multinli.train.{lg}.tsv"))
41
+ examples = []
42
+ for i, line in enumerate(lines):
43
+ if i == 0:
44
+ continue
45
+ guid = f"train-{i}"
46
+ text_a = line[0]
47
+ text_b = line[1]
48
+ label = "contradiction" if line[2] == "contradictory" else line[2]
49
+ if not isinstance(text_a, str):
50
+ raise TypeError(f"Training input {text_a} is not a string")
51
+ if not isinstance(text_b, str):
52
+ raise TypeError(f"Training input {text_b} is not a string")
53
+ if not isinstance(label, str):
54
+ raise TypeError(f"Training label {label} is not a string")
55
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
56
+ return examples
57
+
58
+ def get_test_examples(self, data_dir):
59
+ """See base class."""
60
+ lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv"))
61
+ examples = []
62
+ for i, line in enumerate(lines):
63
+ if i == 0:
64
+ continue
65
+ language = line[0]
66
+ if language != self.language:
67
+ continue
68
+ guid = f"test-{i}"
69
+ text_a = line[6]
70
+ text_b = line[7]
71
+ label = line[1]
72
+ if not isinstance(text_a, str):
73
+ raise TypeError(f"Training input {text_a} is not a string")
74
+ if not isinstance(text_b, str):
75
+ raise TypeError(f"Training input {text_b} is not a string")
76
+ if not isinstance(label, str):
77
+ raise TypeError(f"Training label {label} is not a string")
78
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
79
+ return examples
80
+
81
+ def get_labels(self):
82
+ """See base class."""
83
+ return ["contradiction", "entailment", "neutral"]
84
+
85
+
86
+ xnli_processors = {
87
+ "xnli": XnliProcessor,
88
+ }
89
+
90
+ xnli_output_modes = {
91
+ "xnli": "classification",
92
+ }
93
+
94
+ xnli_tasks_num_labels = {
95
+ "xnli": 3,
96
+ }
.venv/lib/python3.11/site-packages/transformers/generation/__init__.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
18
+
19
+
20
+ _import_structure = {
21
+ "configuration_utils": [
22
+ "BaseWatermarkingConfig",
23
+ "CompileConfig",
24
+ "GenerationConfig",
25
+ "GenerationMode",
26
+ "SynthIDTextWatermarkingConfig",
27
+ "WatermarkingConfig",
28
+ ],
29
+ "streamers": ["AsyncTextIteratorStreamer", "TextIteratorStreamer", "TextStreamer"],
30
+ }
31
+
32
+ try:
33
+ if not is_torch_available():
34
+ raise OptionalDependencyNotAvailable()
35
+ except OptionalDependencyNotAvailable:
36
+ pass
37
+ else:
38
+ _import_structure["beam_constraints"] = [
39
+ "Constraint",
40
+ "ConstraintListState",
41
+ "DisjunctiveConstraint",
42
+ "PhrasalConstraint",
43
+ ]
44
+ _import_structure["beam_search"] = [
45
+ "BeamHypotheses",
46
+ "BeamScorer",
47
+ "BeamSearchScorer",
48
+ "ConstrainedBeamSearchScorer",
49
+ ]
50
+ _import_structure["candidate_generator"] = [
51
+ "AssistedCandidateGenerator",
52
+ "CandidateGenerator",
53
+ "EarlyExitCandidateGenerator",
54
+ "PromptLookupCandidateGenerator",
55
+ ]
56
+ _import_structure["logits_process"] = [
57
+ "AlternatingCodebooksLogitsProcessor",
58
+ "ClassifierFreeGuidanceLogitsProcessor",
59
+ "EncoderNoRepeatNGramLogitsProcessor",
60
+ "EncoderRepetitionPenaltyLogitsProcessor",
61
+ "EpsilonLogitsWarper",
62
+ "EtaLogitsWarper",
63
+ "ExponentialDecayLengthPenalty",
64
+ "ForcedBOSTokenLogitsProcessor",
65
+ "ForcedEOSTokenLogitsProcessor",
66
+ "HammingDiversityLogitsProcessor",
67
+ "InfNanRemoveLogitsProcessor",
68
+ "LogitNormalization",
69
+ "LogitsProcessor",
70
+ "LogitsProcessorList",
71
+ "LogitsWarper",
72
+ "MinLengthLogitsProcessor",
73
+ "MinNewTokensLengthLogitsProcessor",
74
+ "MinPLogitsWarper",
75
+ "NoBadWordsLogitsProcessor",
76
+ "NoRepeatNGramLogitsProcessor",
77
+ "PrefixConstrainedLogitsProcessor",
78
+ "RepetitionPenaltyLogitsProcessor",
79
+ "SequenceBiasLogitsProcessor",
80
+ "SuppressTokensLogitsProcessor",
81
+ "SuppressTokensAtBeginLogitsProcessor",
82
+ "SynthIDTextWatermarkLogitsProcessor",
83
+ "TemperatureLogitsWarper",
84
+ "TopKLogitsWarper",
85
+ "TopPLogitsWarper",
86
+ "TypicalLogitsWarper",
87
+ "UnbatchedClassifierFreeGuidanceLogitsProcessor",
88
+ "WhisperTimeStampLogitsProcessor",
89
+ "WatermarkLogitsProcessor",
90
+ ]
91
+ _import_structure["stopping_criteria"] = [
92
+ "MaxNewTokensCriteria",
93
+ "MaxLengthCriteria",
94
+ "MaxTimeCriteria",
95
+ "ConfidenceCriteria",
96
+ "EosTokenCriteria",
97
+ "StoppingCriteria",
98
+ "StoppingCriteriaList",
99
+ "validate_stopping_criteria",
100
+ "StopStringCriteria",
101
+ ]
102
+ _import_structure["utils"] = [
103
+ "GenerationMixin",
104
+ "GreedySearchEncoderDecoderOutput",
105
+ "GreedySearchDecoderOnlyOutput",
106
+ "SampleEncoderDecoderOutput",
107
+ "SampleDecoderOnlyOutput",
108
+ "BeamSearchEncoderDecoderOutput",
109
+ "BeamSearchDecoderOnlyOutput",
110
+ "BeamSampleEncoderDecoderOutput",
111
+ "BeamSampleDecoderOnlyOutput",
112
+ "ContrastiveSearchEncoderDecoderOutput",
113
+ "ContrastiveSearchDecoderOnlyOutput",
114
+ "GenerateBeamDecoderOnlyOutput",
115
+ "GenerateBeamEncoderDecoderOutput",
116
+ "GenerateDecoderOnlyOutput",
117
+ "GenerateEncoderDecoderOutput",
118
+ ]
119
+ _import_structure["watermarking"] = [
120
+ "WatermarkDetector",
121
+ "WatermarkDetectorOutput",
122
+ "BayesianDetectorModel",
123
+ "BayesianDetectorConfig",
124
+ "SynthIDTextWatermarkDetector",
125
+ ]
126
+
127
+ try:
128
+ if not is_tf_available():
129
+ raise OptionalDependencyNotAvailable()
130
+ except OptionalDependencyNotAvailable:
131
+ pass
132
+ else:
133
+ _import_structure["tf_logits_process"] = [
134
+ "TFForcedBOSTokenLogitsProcessor",
135
+ "TFForcedEOSTokenLogitsProcessor",
136
+ "TFForceTokensLogitsProcessor",
137
+ "TFLogitsProcessor",
138
+ "TFLogitsProcessorList",
139
+ "TFLogitsWarper",
140
+ "TFMinLengthLogitsProcessor",
141
+ "TFNoBadWordsLogitsProcessor",
142
+ "TFNoRepeatNGramLogitsProcessor",
143
+ "TFRepetitionPenaltyLogitsProcessor",
144
+ "TFSuppressTokensAtBeginLogitsProcessor",
145
+ "TFSuppressTokensLogitsProcessor",
146
+ "TFTemperatureLogitsWarper",
147
+ "TFTopKLogitsWarper",
148
+ "TFTopPLogitsWarper",
149
+ ]
150
+ _import_structure["tf_utils"] = [
151
+ "TFGenerationMixin",
152
+ "TFGreedySearchDecoderOnlyOutput",
153
+ "TFGreedySearchEncoderDecoderOutput",
154
+ "TFSampleEncoderDecoderOutput",
155
+ "TFSampleDecoderOnlyOutput",
156
+ "TFBeamSearchEncoderDecoderOutput",
157
+ "TFBeamSearchDecoderOnlyOutput",
158
+ "TFBeamSampleEncoderDecoderOutput",
159
+ "TFBeamSampleDecoderOnlyOutput",
160
+ "TFContrastiveSearchEncoderDecoderOutput",
161
+ "TFContrastiveSearchDecoderOnlyOutput",
162
+ ]
163
+
164
+ try:
165
+ if not is_flax_available():
166
+ raise OptionalDependencyNotAvailable()
167
+ except OptionalDependencyNotAvailable:
168
+ pass
169
+ else:
170
+ _import_structure["flax_logits_process"] = [
171
+ "FlaxForcedBOSTokenLogitsProcessor",
172
+ "FlaxForcedEOSTokenLogitsProcessor",
173
+ "FlaxForceTokensLogitsProcessor",
174
+ "FlaxLogitsProcessor",
175
+ "FlaxLogitsProcessorList",
176
+ "FlaxLogitsWarper",
177
+ "FlaxMinLengthLogitsProcessor",
178
+ "FlaxSuppressTokensAtBeginLogitsProcessor",
179
+ "FlaxSuppressTokensLogitsProcessor",
180
+ "FlaxTemperatureLogitsWarper",
181
+ "FlaxTopKLogitsWarper",
182
+ "FlaxTopPLogitsWarper",
183
+ "FlaxWhisperTimeStampLogitsProcessor",
184
+ "FlaxNoRepeatNGramLogitsProcessor",
185
+ ]
186
+ _import_structure["flax_utils"] = [
187
+ "FlaxGenerationMixin",
188
+ "FlaxGreedySearchOutput",
189
+ "FlaxSampleOutput",
190
+ "FlaxBeamSearchOutput",
191
+ ]
192
+
193
+ if TYPE_CHECKING:
194
+ from .configuration_utils import (
195
+ BaseWatermarkingConfig,
196
+ CompileConfig,
197
+ GenerationConfig,
198
+ GenerationMode,
199
+ SynthIDTextWatermarkingConfig,
200
+ WatermarkingConfig,
201
+ )
202
+ from .streamers import AsyncTextIteratorStreamer, TextIteratorStreamer, TextStreamer
203
+
204
+ try:
205
+ if not is_torch_available():
206
+ raise OptionalDependencyNotAvailable()
207
+ except OptionalDependencyNotAvailable:
208
+ pass
209
+ else:
210
+ from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
211
+ from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
212
+ from .candidate_generator import (
213
+ AssistedCandidateGenerator,
214
+ CandidateGenerator,
215
+ EarlyExitCandidateGenerator,
216
+ PromptLookupCandidateGenerator,
217
+ )
218
+ from .logits_process import (
219
+ AlternatingCodebooksLogitsProcessor,
220
+ ClassifierFreeGuidanceLogitsProcessor,
221
+ EncoderNoRepeatNGramLogitsProcessor,
222
+ EncoderRepetitionPenaltyLogitsProcessor,
223
+ EpsilonLogitsWarper,
224
+ EtaLogitsWarper,
225
+ ExponentialDecayLengthPenalty,
226
+ ForcedBOSTokenLogitsProcessor,
227
+ ForcedEOSTokenLogitsProcessor,
228
+ HammingDiversityLogitsProcessor,
229
+ InfNanRemoveLogitsProcessor,
230
+ LogitNormalization,
231
+ LogitsProcessor,
232
+ LogitsProcessorList,
233
+ LogitsWarper,
234
+ MinLengthLogitsProcessor,
235
+ MinNewTokensLengthLogitsProcessor,
236
+ MinPLogitsWarper,
237
+ NoBadWordsLogitsProcessor,
238
+ NoRepeatNGramLogitsProcessor,
239
+ PrefixConstrainedLogitsProcessor,
240
+ RepetitionPenaltyLogitsProcessor,
241
+ SequenceBiasLogitsProcessor,
242
+ SuppressTokensAtBeginLogitsProcessor,
243
+ SuppressTokensLogitsProcessor,
244
+ SynthIDTextWatermarkLogitsProcessor,
245
+ TemperatureLogitsWarper,
246
+ TopKLogitsWarper,
247
+ TopPLogitsWarper,
248
+ TypicalLogitsWarper,
249
+ UnbatchedClassifierFreeGuidanceLogitsProcessor,
250
+ WatermarkLogitsProcessor,
251
+ WhisperTimeStampLogitsProcessor,
252
+ )
253
+ from .stopping_criteria import (
254
+ ConfidenceCriteria,
255
+ EosTokenCriteria,
256
+ MaxLengthCriteria,
257
+ MaxNewTokensCriteria,
258
+ MaxTimeCriteria,
259
+ StoppingCriteria,
260
+ StoppingCriteriaList,
261
+ StopStringCriteria,
262
+ validate_stopping_criteria,
263
+ )
264
+ from .utils import (
265
+ BeamSampleDecoderOnlyOutput,
266
+ BeamSampleEncoderDecoderOutput,
267
+ BeamSearchDecoderOnlyOutput,
268
+ BeamSearchEncoderDecoderOutput,
269
+ ContrastiveSearchDecoderOnlyOutput,
270
+ ContrastiveSearchEncoderDecoderOutput,
271
+ GenerateBeamDecoderOnlyOutput,
272
+ GenerateBeamEncoderDecoderOutput,
273
+ GenerateDecoderOnlyOutput,
274
+ GenerateEncoderDecoderOutput,
275
+ GenerationMixin,
276
+ GreedySearchDecoderOnlyOutput,
277
+ GreedySearchEncoderDecoderOutput,
278
+ SampleDecoderOnlyOutput,
279
+ SampleEncoderDecoderOutput,
280
+ )
281
+ from .watermarking import (
282
+ BayesianDetectorConfig,
283
+ BayesianDetectorModel,
284
+ SynthIDTextWatermarkDetector,
285
+ WatermarkDetector,
286
+ WatermarkDetectorOutput,
287
+ )
288
+
289
+ try:
290
+ if not is_tf_available():
291
+ raise OptionalDependencyNotAvailable()
292
+ except OptionalDependencyNotAvailable:
293
+ pass
294
+ else:
295
+ from .tf_logits_process import (
296
+ TFForcedBOSTokenLogitsProcessor,
297
+ TFForcedEOSTokenLogitsProcessor,
298
+ TFForceTokensLogitsProcessor,
299
+ TFLogitsProcessor,
300
+ TFLogitsProcessorList,
301
+ TFLogitsWarper,
302
+ TFMinLengthLogitsProcessor,
303
+ TFNoBadWordsLogitsProcessor,
304
+ TFNoRepeatNGramLogitsProcessor,
305
+ TFRepetitionPenaltyLogitsProcessor,
306
+ TFSuppressTokensAtBeginLogitsProcessor,
307
+ TFSuppressTokensLogitsProcessor,
308
+ TFTemperatureLogitsWarper,
309
+ TFTopKLogitsWarper,
310
+ TFTopPLogitsWarper,
311
+ )
312
+ from .tf_utils import (
313
+ TFBeamSampleDecoderOnlyOutput,
314
+ TFBeamSampleEncoderDecoderOutput,
315
+ TFBeamSearchDecoderOnlyOutput,
316
+ TFBeamSearchEncoderDecoderOutput,
317
+ TFContrastiveSearchDecoderOnlyOutput,
318
+ TFContrastiveSearchEncoderDecoderOutput,
319
+ TFGenerationMixin,
320
+ TFGreedySearchDecoderOnlyOutput,
321
+ TFGreedySearchEncoderDecoderOutput,
322
+ TFSampleDecoderOnlyOutput,
323
+ TFSampleEncoderDecoderOutput,
324
+ )
325
+
326
+ try:
327
+ if not is_flax_available():
328
+ raise OptionalDependencyNotAvailable()
329
+ except OptionalDependencyNotAvailable:
330
+ pass
331
+ else:
332
+ from .flax_logits_process import (
333
+ FlaxForcedBOSTokenLogitsProcessor,
334
+ FlaxForcedEOSTokenLogitsProcessor,
335
+ FlaxForceTokensLogitsProcessor,
336
+ FlaxLogitsProcessor,
337
+ FlaxLogitsProcessorList,
338
+ FlaxLogitsWarper,
339
+ FlaxMinLengthLogitsProcessor,
340
+ FlaxNoRepeatNGramLogitsProcessor,
341
+ FlaxSuppressTokensAtBeginLogitsProcessor,
342
+ FlaxSuppressTokensLogitsProcessor,
343
+ FlaxTemperatureLogitsWarper,
344
+ FlaxTopKLogitsWarper,
345
+ FlaxTopPLogitsWarper,
346
+ FlaxWhisperTimeStampLogitsProcessor,
347
+ )
348
+ from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
349
+ else:
350
+ import sys
351
+
352
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-311.pyc ADDED
Binary file (43.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-311.pyc ADDED
Binary file (89.4 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/flax_logits_process.cpython-311.pyc ADDED
Binary file (34.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/flax_utils.cpython-311.pyc ADDED
Binary file (47.2 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-311.pyc ADDED
Binary file (33.2 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/streamers.cpython-311.pyc ADDED
Binary file (15.6 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/tf_logits_process.cpython-311.pyc ADDED
Binary file (44.5 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/watermarking.cpython-311.pyc ADDED
Binary file (29.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/generation/beam_constraints.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional
3
+
4
+
5
+ class Constraint(ABC):
6
+ r"""Abstract base class for all constraints that can be applied during generation.
7
+ It must define how the constraint can be satisfied.
8
+
9
+ All classes that inherit Constraint must follow the requirement that
10
+
11
+ ```py
12
+ completed = False
13
+ while not completed:
14
+ _, completed = constraint.update(constraint.advance())
15
+ ```
16
+
17
+ will always terminate (halt).
18
+ """
19
+
20
+ def __init__(self):
21
+ # test for the above condition
22
+ self.test()
23
+
24
+ def test(self):
25
+ """
26
+ Tests whether this constraint has been properly defined.
27
+ """
28
+ counter = 0
29
+ completed = False
30
+ while not completed:
31
+ if counter == 1:
32
+ self.reset()
33
+ advance = self.advance()
34
+ if not self.does_advance(advance):
35
+ raise Exception(
36
+ "Custom Constraint is not defined correctly. self.does_advance(self.advance()) must be true."
37
+ )
38
+
39
+ stepped, completed, reset = self.update(advance)
40
+ counter += 1
41
+
42
+ if counter > 10000:
43
+ raise Exception("update() does not fulfill the constraint.")
44
+
45
+ if self.remaining() != 0:
46
+ raise Exception("Custom Constraint is not defined correctly.")
47
+
48
+ @abstractmethod
49
+ def advance(self):
50
+ """
51
+ When called, returns the token(s) that would take this constraint one step closer to being fulfilled.
52
+
53
+ Return:
54
+ token_ids (Union[int, List[int], None]):
55
+ - A single token ID (int) that advances the constraint, or
56
+ - A list of token IDs that could advance the constraint
57
+ - None if the constraint is completed or cannot be advanced
58
+ """
59
+ raise NotImplementedError(
60
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
61
+ )
62
+
63
+ @abstractmethod
64
+ def does_advance(self, token_id: int):
65
+ """
66
+ Reads in a token and returns whether it creates progress.
67
+ """
68
+ raise NotImplementedError(
69
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
70
+ )
71
+
72
+ @abstractmethod
73
+ def update(self, token_id: int):
74
+ """
75
+ Reads in a token and returns booleans that indicate the progress made by it. This function will update the
76
+ state of this object unlikes `does_advance(self, token_id: int)`.
77
+
78
+ This isn't to test whether a certain token will advance the progress; it's to update its state as if it has
79
+ been generated. This becomes important if token_id != desired token (refer to else statement in
80
+ PhrasalConstraint)
81
+
82
+ Args:
83
+ token_id(`int`):
84
+ The id of a newly generated token in the beam search.
85
+ Return:
86
+ stepped(`bool`):
87
+ Whether this constraint has become one step closer to being fulfuilled.
88
+ completed(`bool`):
89
+ Whether this constraint has been completely fulfilled by this token being generated.
90
+ reset (`bool`):
91
+ Whether this constraint has reset its progress by this token being generated.
92
+ """
93
+ raise NotImplementedError(
94
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
95
+ )
96
+
97
+ @abstractmethod
98
+ def reset(self):
99
+ """
100
+ Resets the state of this constraint to its initialization. We would call this in cases where the fulfillment of
101
+ a constraint is abrupted by an unwanted token.
102
+ """
103
+ raise NotImplementedError(
104
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
105
+ )
106
+
107
+ @abstractmethod
108
+ def remaining(self):
109
+ """
110
+ Returns the number of remaining steps of `advance()` in order to complete this constraint.
111
+ """
112
+ raise NotImplementedError(
113
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
114
+ )
115
+
116
+ @abstractmethod
117
+ def copy(self, stateful=False):
118
+ """
119
+ Creates a new instance of this constraint.
120
+
121
+ Args:
122
+ stateful(`bool`): Whether to not only copy the constraint for new instance, but also its state.
123
+
124
+ Return:
125
+ constraint(`Constraint`): The same constraint as the one being called from.
126
+ """
127
+ raise NotImplementedError(
128
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
129
+ )
130
+
131
+
132
+ class PhrasalConstraint(Constraint):
133
+ r"""
134
+ [`Constraint`] enforcing that an ordered sequence of tokens is included in the output.
135
+
136
+ Args:
137
+ token_ids (`List[int]`):
138
+ The id of the token that must be generated by the output.
139
+ """
140
+
141
+ def __init__(self, token_ids: List[int]):
142
+ super(Constraint, self).__init__()
143
+
144
+ if not isinstance(token_ids, list) or len(token_ids) == 0:
145
+ raise ValueError(f"`token_ids` has to be a non-empty list, but is {token_ids}.")
146
+ if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids):
147
+ raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.")
148
+
149
+ self.token_ids = token_ids
150
+
151
+ self.seqlen = len(self.token_ids)
152
+ self.fulfilled_idx = -1 # the index of the currently fulfilled step
153
+ self.completed = False
154
+
155
+ def advance(self):
156
+ if self.completed:
157
+ return None
158
+ return self.token_ids[self.fulfilled_idx + 1]
159
+
160
+ def does_advance(self, token_id: int):
161
+ if not isinstance(token_id, int):
162
+ raise TypeError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
163
+
164
+ if self.completed:
165
+ return False
166
+
167
+ return token_id == self.token_ids[self.fulfilled_idx + 1]
168
+
169
+ def update(self, token_id: int):
170
+ if not isinstance(token_id, int):
171
+ raise TypeError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
172
+
173
+ stepped = False
174
+ completed = False
175
+ reset = False
176
+
177
+ if self.does_advance(token_id):
178
+ self.fulfilled_idx += 1
179
+ stepped = True
180
+ if self.fulfilled_idx == (self.seqlen - 1):
181
+ completed = True
182
+ self.completed = completed
183
+ else:
184
+ # failed to make progress.
185
+ reset = True
186
+ self.reset()
187
+ return stepped, completed, reset
188
+
189
+ def reset(self):
190
+ self.completed = False
191
+ self.fulfilled_idx = 0
192
+
193
+ def remaining(self):
194
+ return self.seqlen - (self.fulfilled_idx + 1)
195
+
196
+ def copy(self, stateful=False):
197
+ new_constraint = PhrasalConstraint(self.token_ids)
198
+
199
+ if stateful:
200
+ new_constraint.seq_len = self.seqlen
201
+ new_constraint.fulfilled_idx = self.fulfilled_idx
202
+ new_constraint.completed = self.completed
203
+
204
+ return new_constraint
205
+
206
+
207
+ class DisjunctiveTrie:
208
+ def __init__(self, nested_token_ids: List[List[int]], no_subsets=True):
209
+ r"""
210
+ A helper class that builds a trie with the words represented in `nested_token_ids`.
211
+ """
212
+ self.max_height = max([len(one) for one in nested_token_ids])
213
+
214
+ root = {}
215
+ for token_ids in nested_token_ids:
216
+ level = root
217
+ for tidx, token_id in enumerate(token_ids):
218
+ if token_id not in level:
219
+ level[token_id] = {}
220
+
221
+ level = level[token_id]
222
+
223
+ if no_subsets and self.has_subsets(root, nested_token_ids):
224
+ raise ValueError(
225
+ "Each list in `nested_token_ids` can't be a complete subset of another list, but is"
226
+ f" {nested_token_ids}."
227
+ )
228
+
229
+ self.trie = root
230
+
231
+ def next_tokens(self, current_seq):
232
+ """
233
+ The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`.
234
+ """
235
+ start = self.trie
236
+
237
+ for current_token in current_seq:
238
+ start = start[current_token]
239
+
240
+ next_tokens = list(start.keys())
241
+
242
+ return next_tokens
243
+
244
+ def reached_leaf(self, current_seq):
245
+ next_tokens = self.next_tokens(current_seq)
246
+
247
+ return len(next_tokens) == 0
248
+
249
+ def count_leaves(self, root):
250
+ next_nodes = list(root.values())
251
+ if len(next_nodes) == 0:
252
+ return 1
253
+ else:
254
+ return sum([self.count_leaves(nn) for nn in next_nodes])
255
+
256
+ def has_subsets(self, trie, nested_token_ids):
257
+ """
258
+ Returns whether # of leaves == # of words. Otherwise some word is a subset of another.
259
+ """
260
+ leaf_count = self.count_leaves(trie)
261
+ return len(nested_token_ids) != leaf_count
262
+
263
+
264
+ class DisjunctiveConstraint(Constraint):
265
+ r"""
266
+ A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints.
267
+
268
+ Args:
269
+ nested_token_ids (`List[List[int]]`):
270
+ A list of words, where each word is a list of ids. This constraint is fulfilled by generating just one from
271
+ the list of words.
272
+ """
273
+
274
+ def __init__(self, nested_token_ids: List[List[int]]):
275
+ super(Constraint, self).__init__()
276
+
277
+ if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0:
278
+ raise ValueError(f"`nested_token_ids` has to be a non-empty list, but is {nested_token_ids}.")
279
+ if any(not isinstance(token_ids, list) for token_ids in nested_token_ids):
280
+ raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.")
281
+ if any(
282
+ any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
283
+ for token_ids in nested_token_ids
284
+ ):
285
+ raise ValueError(
286
+ f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}."
287
+ )
288
+
289
+ self.trie = DisjunctiveTrie(nested_token_ids)
290
+ self.token_ids = nested_token_ids
291
+
292
+ self.seqlen = self.trie.max_height
293
+ self.current_seq = []
294
+ self.completed = False
295
+
296
+ def advance(self):
297
+ token_list = self.trie.next_tokens(self.current_seq)
298
+
299
+ if len(token_list) == 0:
300
+ return None
301
+ else:
302
+ return token_list
303
+
304
+ def does_advance(self, token_id: int):
305
+ if not isinstance(token_id, int):
306
+ raise TypeError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
307
+
308
+ next_tokens = self.trie.next_tokens(self.current_seq)
309
+
310
+ return token_id in next_tokens
311
+
312
+ def update(self, token_id: int):
313
+ if not isinstance(token_id, int):
314
+ raise TypeError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
315
+
316
+ stepped = False
317
+ completed = False
318
+ reset = False
319
+
320
+ if self.does_advance(token_id):
321
+ self.current_seq.append(token_id)
322
+ stepped = True
323
+ else:
324
+ reset = True
325
+ self.reset()
326
+
327
+ completed = self.trie.reached_leaf(self.current_seq)
328
+ self.completed = completed
329
+
330
+ return stepped, completed, reset
331
+
332
+ def reset(self):
333
+ self.completed = False
334
+ self.current_seq = []
335
+
336
+ def remaining(self):
337
+ if self.completed:
338
+ # since this can be completed without reaching max height
339
+ return 0
340
+ else:
341
+ return self.seqlen - len(self.current_seq)
342
+
343
+ def copy(self, stateful=False):
344
+ new_constraint = DisjunctiveConstraint(self.token_ids)
345
+
346
+ if stateful:
347
+ new_constraint.seq_len = self.seqlen
348
+ new_constraint.current_seq = self.current_seq
349
+ new_constraint.completed = self.completed
350
+
351
+ return new_constraint
352
+
353
+
354
+ class ConstraintListState:
355
+ r"""
356
+ A class for beam scorers to track its progress through a list of constraints.
357
+
358
+ Args:
359
+ constraints (`List[Constraint]`):
360
+ A list of [`Constraint`] objects that must be fulfilled by the beam scorer.
361
+ """
362
+
363
+ def __init__(self, constraints: List[Constraint]):
364
+ self.constraints = constraints
365
+
366
+ # max # of steps required to fulfill a given constraint
367
+ self.max_seqlen = max([c.seqlen for c in constraints])
368
+ self.n_constraints = len(constraints)
369
+ self.completed = False
370
+
371
+ self.init_state()
372
+
373
+ def init_state(self):
374
+ self.complete_constraints = []
375
+ self.inprogress_constraint = None
376
+ self.pending_constraints = [constraint.copy(stateful=False) for constraint in self.constraints]
377
+
378
+ def get_bank(self):
379
+ add = 0
380
+ if self.inprogress_constraint:
381
+ # extra points for having a constraint mid-fulfilled
382
+ add += self.max_seqlen - self.inprogress_constraint.remaining()
383
+
384
+ return (len(self.complete_constraints) * self.max_seqlen) + add
385
+
386
+ def advance(self):
387
+ """The list of tokens to generate such that we can make progress.
388
+ By "list" we don't mean the list of token that will fully fulfill a constraint.
389
+
390
+ Given constraints `c_i = {t_ij | j == # of tokens}`, If we're not in the middle of progressing through a
391
+ specific constraint `c_i`, we return:
392
+
393
+ `[t_k1 for k in indices of unfulfilled constraints]`
394
+
395
+ If we are in the middle of a constraint, then we return:
396
+ `[t_ij]`, where `i` is the index of the inprogress constraint, `j` is the next step for the constraint.
397
+
398
+ Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
399
+ that's the only one we'll return.
400
+ """
401
+ token_list = []
402
+ if self.inprogress_constraint is None:
403
+ for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
404
+ advance = constraint.advance()
405
+ if isinstance(advance, int):
406
+ token_list.append(advance)
407
+ elif isinstance(advance, list):
408
+ token_list.extend(advance)
409
+ else:
410
+ advance = self.inprogress_constraint.advance()
411
+ if isinstance(advance, int):
412
+ token_list.append(advance)
413
+ elif isinstance(advance, list):
414
+ token_list.extend(advance)
415
+
416
+ if len(token_list) == 0:
417
+ return None
418
+ else:
419
+ return token_list
420
+
421
+ def reset(self, token_ids: Optional[List[int]]):
422
+ """
423
+ token_ids: the tokens generated thus far to reset the state of the progress through constraints.
424
+ """
425
+ self.init_state()
426
+
427
+ if token_ids is not None:
428
+ for token in token_ids:
429
+ # completes or steps **one** constraint
430
+ complete, stepped = self.add(token)
431
+
432
+ # the entire list of constraints are fulfilled
433
+ if self.completed:
434
+ break
435
+
436
+ def add(self, token_id: int):
437
+ if not isinstance(token_id, int):
438
+ raise TypeError(f"`token_id` should be an `int`, but is `{token_id}`.")
439
+
440
+ complete, stepped = False, False
441
+
442
+ if self.completed:
443
+ complete = True
444
+ stepped = False
445
+ return complete, stepped
446
+
447
+ if self.inprogress_constraint is not None:
448
+ # In the middle of fulfilling a constraint. If the `token_id` *does* makes an incremental progress to current
449
+ # job, simply update the state
450
+
451
+ stepped, complete, reset = self.inprogress_constraint.update(token_id)
452
+ if reset:
453
+ # 1. If the next token breaks the progress, then we must restart.
454
+ # e.g. constraint = "I love pies" and sequence so far is "I love" but `token_id` == "books".
455
+
456
+ # But that doesn't mean we self.init_state(), since we only reset the state for this particular
457
+ # constraint, not the full list of constraints.
458
+
459
+ self.pending_constraints.append(self.inprogress_constraint.copy(stateful=False))
460
+ self.inprogress_constraint = None
461
+
462
+ if complete:
463
+ # 2. If the next token completes the constraint, move it to completed list, set
464
+ # inprogress to None. If there are no pending constraints either, then this full list of constraints
465
+ # is complete.
466
+
467
+ self.complete_constraints.append(self.inprogress_constraint)
468
+ self.inprogress_constraint = None
469
+
470
+ if len(self.pending_constraints) == 0:
471
+ # we're done!
472
+ self.completed = True
473
+
474
+ else:
475
+ # Not in the middle of fulfilling a constraint. So does this `token_id` helps us step towards any of our list
476
+ # of constraints?
477
+
478
+ for cidx, pending_constraint in enumerate(self.pending_constraints):
479
+ if pending_constraint.does_advance(token_id):
480
+ stepped, complete, reset = pending_constraint.update(token_id)
481
+
482
+ if not stepped:
483
+ raise Exception(
484
+ "`constraint.update(token_id)` is not yielding incremental progress, "
485
+ "even though `constraint.does_advance(token_id)` is true."
486
+ )
487
+
488
+ if complete:
489
+ self.complete_constraints.append(pending_constraint)
490
+ self.inprogress_constraint = None
491
+
492
+ if not complete and stepped:
493
+ self.inprogress_constraint = pending_constraint
494
+
495
+ if complete or stepped:
496
+ # If we made any progress at all, then it's at least not a "pending constraint".
497
+
498
+ self.pending_constraints = (
499
+ self.pending_constraints[:cidx] + self.pending_constraints[cidx + 1 :]
500
+ )
501
+
502
+ if len(self.pending_constraints) == 0 and self.inprogress_constraint is None:
503
+ # If there's no longer any pending after this and no inprogress either, then we must be
504
+ # complete.
505
+
506
+ self.completed = True
507
+
508
+ break # prevent accidentally stepping through multiple constraints with just one token.
509
+
510
+ return complete, stepped
511
+
512
+ def copy(self, stateful=True):
513
+ new_state = ConstraintListState(self.constraints) # we actually never though self.constraints objects
514
+ # throughout this process. So it's at initialization state.
515
+
516
+ if stateful:
517
+ new_state.complete_constraints = [
518
+ constraint.copy(stateful=True) for constraint in self.complete_constraints
519
+ ]
520
+ if self.inprogress_constraint is not None:
521
+ new_state.inprogress_constraint = self.inprogress_constraint.copy(stateful=True)
522
+ new_state.pending_constraints = [constraint.copy() for constraint in self.pending_constraints]
523
+
524
+ return new_state
.venv/lib/python3.11/site-packages/transformers/generation/beam_search.py ADDED
@@ -0,0 +1,1013 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC, abstractmethod
17
+ from collections import UserDict
18
+ from typing import Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from ..utils import add_start_docstrings
24
+ from .beam_constraints import Constraint, ConstraintListState
25
+
26
+
27
+ PROCESS_INPUTS_DOCSTRING = r"""
28
+ Args:
29
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
30
+ Indices of input sequence tokens in the vocabulary.
31
+
32
+ Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
33
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
34
+
35
+ [What are input IDs?](../glossary#input-ids)
36
+ next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
37
+ Current scores of the top `2 * num_beams` non-finished beam hypotheses.
38
+ next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
39
+ `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
40
+ next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
41
+ Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
42
+ pad_token_id (`int`, *optional*):
43
+ The id of the *padding* token.
44
+ eos_token_id (`Union[int, List[int]]`, *optional*):
45
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
46
+ beam_indices (`torch.LongTensor`, *optional*):
47
+ Beam indices indicating to which beam hypothesis each token correspond.
48
+ group_index (`int`, *optional*):
49
+ The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
50
+
51
+ Return:
52
+ `UserDict`: A dictionary composed of the fields as defined above:
53
+
54
+ - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all
55
+ non-finished beams.
56
+ - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added
57
+ to the non-finished beam_hypotheses.
58
+ - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
59
+ indicating to which beam the next tokens shall be added.
60
+
61
+ """
62
+
63
+ FINALIZE_INPUTS_DOCSTRING = r"""
64
+ Args:
65
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
66
+ Indices of input sequence tokens in the vocabulary.
67
+
68
+ Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
69
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
70
+
71
+ [What are input IDs?](../glossary#input-ids)
72
+ final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
73
+ The final scores of all non-finished beams.
74
+ final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
75
+ The last tokens to be added to the non-finished beam_hypotheses.
76
+ final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
77
+ The beam indices indicating to which beam the `final_beam_tokens` shall be added.
78
+ pad_token_id (`int`, *optional*):
79
+ The id of the *padding* token.
80
+ eos_token_id (`Union[int, List[int]]`, *optional*):
81
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
82
+
83
+ Return:
84
+ `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
85
+ The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
86
+ due to the `eos_token_id`.
87
+
88
+ """
89
+
90
+
91
+ class BeamScorer(ABC):
92
+ """
93
+ Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
94
+ [`~PreTrainedModel.beam_sample`].
95
+ """
96
+
97
+ @abstractmethod
98
+ @add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
99
+ def process(
100
+ self,
101
+ input_ids: torch.LongTensor,
102
+ next_scores: torch.FloatTensor,
103
+ next_tokens: torch.LongTensor,
104
+ next_indices: torch.LongTensor,
105
+ **kwargs,
106
+ ) -> Tuple[torch.Tensor]:
107
+ raise NotImplementedError("This is an abstract method.")
108
+
109
+ @abstractmethod
110
+ @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
111
+ def finalize(
112
+ self,
113
+ input_ids: torch.LongTensor,
114
+ next_scores: torch.FloatTensor,
115
+ next_tokens: torch.LongTensor,
116
+ next_indices: torch.LongTensor,
117
+ max_length: int,
118
+ **kwargs,
119
+ ) -> torch.LongTensor:
120
+ raise NotImplementedError("This is an abstract method.")
121
+
122
+
123
+ class BeamSearchScorer(BeamScorer):
124
+ r"""
125
+ [`BeamScorer`] implementing standard beam search decoding.
126
+
127
+ Adapted in part from [Facebook's XLM beam search
128
+ code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
129
+
130
+ Reference for the diverse beam search algorithm and implementation [Ashwin Kalyan's DBS
131
+ implementation](https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua)
132
+
133
+ Args:
134
+ batch_size (`int`):
135
+ Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
136
+ num_beams (`int`):
137
+ Number of beams for beam search.
138
+ device (`torch.device`):
139
+ Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
140
+ allocated.
141
+ length_penalty (`float`, *optional*, defaults to 1.0):
142
+ Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
143
+ the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
144
+ likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
145
+ `length_penalty` < 0.0 encourages shorter sequences.
146
+ do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
147
+ Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
148
+ `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
149
+ heuristic is applied and the generation stops when is it very unlikely to find better candidates;
150
+ `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
151
+ beam search algorithm).
152
+ num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
153
+ The number of beam hypotheses that shall be returned upon calling
154
+ [`~transformers.BeamSearchScorer.finalize`].
155
+ num_beam_groups (`int`, *optional*, defaults to 1):
156
+ Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
157
+ See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
158
+ max_length (`int`, *optional*):
159
+ The maximum length of the sequence to be generated.
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ batch_size: int,
165
+ num_beams: int,
166
+ device: torch.device,
167
+ length_penalty: Optional[float] = 1.0,
168
+ do_early_stopping: Optional[Union[bool, str]] = False,
169
+ num_beam_hyps_to_keep: Optional[int] = 1,
170
+ num_beam_groups: Optional[int] = 1,
171
+ max_length: Optional[int] = None,
172
+ ):
173
+ self.num_beams = num_beams
174
+ self.device = device
175
+ self.length_penalty = length_penalty
176
+ self.do_early_stopping = do_early_stopping
177
+ self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
178
+ self.num_beam_groups = num_beam_groups
179
+ self.group_size = self.num_beams // self.num_beam_groups
180
+
181
+ self._is_init = False
182
+ # self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
183
+ # If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
184
+ self._beam_hyps = [
185
+ BeamHypotheses(
186
+ num_beams=self.group_size,
187
+ length_penalty=self.length_penalty,
188
+ early_stopping=self.do_early_stopping,
189
+ max_length=max_length,
190
+ )
191
+ for _ in range(batch_size * self.num_beam_groups)
192
+ ]
193
+ # self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
194
+ # in the i-th mini-batch is complete.
195
+ self._done = torch.tensor(
196
+ [False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
197
+ )
198
+
199
+ if not isinstance(num_beams, int) or num_beams <= 1:
200
+ raise ValueError(
201
+ f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
202
+ " one should make use of `greedy_search` instead."
203
+ )
204
+
205
+ if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
206
+ raise ValueError(
207
+ "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
208
+ f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
209
+ )
210
+
211
+ @property
212
+ def is_done(self) -> bool:
213
+ return self._done.all()
214
+
215
+ def process(
216
+ self,
217
+ input_ids: torch.LongTensor,
218
+ next_scores: torch.FloatTensor,
219
+ next_tokens: torch.LongTensor,
220
+ next_indices: torch.LongTensor,
221
+ pad_token_id: Optional[Union[int, torch.Tensor]] = None,
222
+ eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
223
+ beam_indices: Optional[torch.LongTensor] = None,
224
+ group_index: Optional[int] = 0,
225
+ decoder_prompt_len: Optional[int] = 0,
226
+ ) -> Dict[str, torch.Tensor]:
227
+ # add up to the length which the next_scores is calculated on (including decoder prompt)
228
+ cur_len = input_ids.shape[-1] + 1
229
+ batch_size = len(self._beam_hyps) // self.num_beam_groups
230
+
231
+ if not (batch_size == (input_ids.shape[0] // self.group_size)):
232
+ if self.num_beam_groups > 1:
233
+ raise ValueError(
234
+ f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
235
+ f"size of {self.group_size} is expected by the beam scorer."
236
+ )
237
+ else:
238
+ raise ValueError(
239
+ f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
240
+ f"{self.group_size} is expected by the beam scorer."
241
+ )
242
+
243
+ device = input_ids.device
244
+ next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
245
+ next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
246
+ next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
247
+
248
+ if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
249
+ if isinstance(eos_token_id, int):
250
+ eos_token_id = [eos_token_id]
251
+ eos_token_id = torch.tensor(eos_token_id)
252
+
253
+ for batch_idx in range(batch_size):
254
+ batch_group_idx = batch_idx * self.num_beam_groups + group_index
255
+ if self._done[batch_group_idx]:
256
+ if self.num_beams < len(self._beam_hyps[batch_group_idx]):
257
+ raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
258
+ if eos_token_id is None or pad_token_id is None:
259
+ raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
260
+ # pad the batch
261
+ next_beam_scores[batch_idx, :] = 0
262
+ next_beam_tokens[batch_idx, :] = pad_token_id
263
+ next_beam_indices[batch_idx, :] = 0
264
+ continue
265
+
266
+ # next tokens for this sentence
267
+ beam_idx = 0
268
+ for beam_token_rank, (next_token, next_score, next_index) in enumerate(
269
+ zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
270
+ ):
271
+ batch_beam_idx = batch_idx * self.group_size + next_index
272
+ # add to generated hypotheses if end of sentence
273
+ if (eos_token_id is not None) and (next_token.item() in eos_token_id):
274
+ # if beam_token does not belong to top num_beams tokens, it should not be added
275
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
276
+ if is_beam_token_worse_than_top_num_beams:
277
+ continue
278
+ if beam_indices is not None:
279
+ beam_index = beam_indices[batch_beam_idx]
280
+ beam_index = beam_index + (batch_beam_idx,)
281
+ else:
282
+ beam_index = None
283
+
284
+ self._beam_hyps[batch_group_idx].add(
285
+ input_ids[batch_beam_idx].clone(),
286
+ next_score.item(),
287
+ beam_indices=beam_index,
288
+ generated_len=cur_len - decoder_prompt_len,
289
+ )
290
+ else:
291
+ # add next predicted token since it is not eos_token
292
+ next_beam_scores[batch_idx, beam_idx] = next_score
293
+ next_beam_tokens[batch_idx, beam_idx] = next_token
294
+ next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
295
+ beam_idx += 1
296
+
297
+ # once the beam for next step is full, don't add more tokens to it.
298
+ if beam_idx == self.group_size:
299
+ break
300
+
301
+ if beam_idx < self.group_size:
302
+ raise ValueError(
303
+ f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
304
+ f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
305
+ )
306
+
307
+ # Check if we are done so that we can save a pad step if all(done)
308
+ self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
309
+ next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
310
+ )
311
+
312
+ return UserDict(
313
+ {
314
+ "next_beam_scores": next_beam_scores.view(-1),
315
+ "next_beam_tokens": next_beam_tokens.view(-1),
316
+ "next_beam_indices": next_beam_indices.view(-1),
317
+ }
318
+ )
319
+
320
+ def finalize(
321
+ self,
322
+ input_ids: torch.LongTensor,
323
+ final_beam_scores: torch.FloatTensor,
324
+ final_beam_tokens: torch.LongTensor,
325
+ final_beam_indices: torch.LongTensor,
326
+ max_length: int,
327
+ pad_token_id: Optional[Union[int, torch.Tensor]] = None,
328
+ eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
329
+ beam_indices: Optional[torch.LongTensor] = None,
330
+ decoder_prompt_len: Optional[int] = 0,
331
+ ) -> Tuple[torch.LongTensor]:
332
+ batch_size = len(self._beam_hyps) // self.num_beam_groups
333
+
334
+ if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
335
+ if isinstance(eos_token_id, int):
336
+ eos_token_id = [eos_token_id]
337
+ eos_token_id = torch.tensor(eos_token_id)
338
+
339
+ # finalize all open beam hypotheses and add to generated hypotheses
340
+ for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
341
+ if self._done[batch_group_idx]:
342
+ continue
343
+
344
+ # all open beam hypotheses are added to the beam hypothesis
345
+ # beam hypothesis class automatically keeps the best beams
346
+ for index_per_group in range(self.group_size):
347
+ batch_beam_idx = batch_group_idx * self.group_size + index_per_group
348
+ final_score = final_beam_scores[batch_beam_idx].item()
349
+ final_tokens = input_ids[batch_beam_idx]
350
+ beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
351
+ generated_len = final_tokens.shape[-1] - decoder_prompt_len
352
+ beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
353
+
354
+ # select the best hypotheses
355
+ sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
356
+ best = []
357
+ best_indices = []
358
+ best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
359
+
360
+ # retrieve best hypotheses
361
+ for i in range(batch_size):
362
+ beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
363
+ candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
364
+ sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
365
+ for j in range(self.num_beam_hyps_to_keep):
366
+ best_hyp_tuple = sorted_hyps.pop()
367
+ best_score = best_hyp_tuple[0]
368
+ best_hyp = best_hyp_tuple[1]
369
+ best_index = best_hyp_tuple[2]
370
+ sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
371
+
372
+ # append hyp to lists
373
+ best.append(best_hyp)
374
+
375
+ # append indices to list
376
+ best_indices.append(best_index)
377
+
378
+ best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
379
+
380
+ # prepare for adding eos
381
+ sent_lengths_max = sent_lengths.max().item() + 1
382
+ sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
383
+ decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
384
+
385
+ if len(best_indices) > 0 and best_indices[0] is not None:
386
+ indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
387
+ else:
388
+ indices = None
389
+
390
+ # shorter batches are padded if needed
391
+ if sent_lengths.min().item() != sent_lengths.max().item():
392
+ if pad_token_id is None:
393
+ raise ValueError("`pad_token_id` has to be defined")
394
+ decoded.fill_(pad_token_id)
395
+
396
+ if indices is not None:
397
+ indices.fill_(-1)
398
+
399
+ # fill with hypotheses and eos_token_id if the latter fits in
400
+ for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
401
+ decoded[i, : sent_lengths[i]] = hypo
402
+
403
+ if indices is not None:
404
+ indices[i, : len(best_idx)] = torch.tensor(best_idx)
405
+
406
+ if sent_lengths[i] < sent_max_len:
407
+ # inserting only the first eos_token_id
408
+ decoded[i, sent_lengths[i]] = eos_token_id[0]
409
+
410
+ return UserDict(
411
+ {
412
+ "sequences": decoded,
413
+ "sequence_scores": best_scores,
414
+ "beam_indices": indices,
415
+ }
416
+ )
417
+
418
+
419
+ class ConstrainedBeamSearchScorer(BeamScorer):
420
+ r"""
421
+ [`BeamScorer`] implementing constrained beam search decoding.
422
+
423
+
424
+ Args:
425
+ batch_size (`int`):
426
+ Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
427
+ num_beams (`int`):
428
+ Number of beams for beam search.
429
+ constraints (`List[Constraint]`):
430
+ A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
431
+ output. For more information, the documentation of [`Constraint`] should be read.
432
+ device (`torch.device`):
433
+ Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
434
+ allocated.
435
+ length_penalty (`float`, *optional*, defaults to 1.0):
436
+ Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
437
+ the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
438
+ likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
439
+ `length_penalty` < 0.0 encourages shorter sequences.
440
+ do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
441
+ Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
442
+ `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
443
+ heuristic is applied and the generation stops when is it very unlikely to find better candidates;
444
+ `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
445
+ beam search algorithm).
446
+ num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
447
+ The number of beam hypotheses that shall be returned upon calling
448
+ [`~transformers.BeamSearchScorer.finalize`].
449
+ num_beam_groups (`int`, *optional*, defaults to 1):
450
+ Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
451
+ See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
452
+ max_length (`int`, *optional*):
453
+ The maximum length of the sequence to be generated.
454
+ """
455
+
456
+ def __init__(
457
+ self,
458
+ batch_size: int,
459
+ num_beams: int,
460
+ constraints: List[Constraint],
461
+ device: torch.device,
462
+ length_penalty: Optional[float] = 1.0,
463
+ do_early_stopping: Optional[Union[bool, str]] = False,
464
+ num_beam_hyps_to_keep: Optional[int] = 1,
465
+ num_beam_groups: Optional[int] = 1,
466
+ max_length: Optional[int] = None,
467
+ ):
468
+ self.num_beams = num_beams
469
+ self.device = device
470
+ self.length_penalty = length_penalty
471
+ self.do_early_stopping = do_early_stopping
472
+ self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
473
+ self.num_beam_groups = num_beam_groups
474
+ self.group_size = self.num_beams // self.num_beam_groups
475
+ self.constraints = constraints
476
+
477
+ self._is_init = False
478
+ self._beam_hyps = [
479
+ BeamHypotheses(
480
+ num_beams=self.num_beams,
481
+ length_penalty=self.length_penalty,
482
+ early_stopping=self.do_early_stopping,
483
+ max_length=max_length,
484
+ )
485
+ for _ in range(batch_size)
486
+ ]
487
+ self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
488
+
489
+ if not isinstance(num_beams, int) or num_beams <= 1:
490
+ raise ValueError(
491
+ f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
492
+ " one should make use of `greedy_search` instead."
493
+ )
494
+
495
+ if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
496
+ raise ValueError(
497
+ "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
498
+ f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
499
+ )
500
+
501
+ @property
502
+ def is_done(self) -> bool:
503
+ return self._done.all()
504
+
505
+ def make_constraint_states(self, n):
506
+ return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)]
507
+
508
+ def check_completes_constraints(self, sequence):
509
+ new_state = self.make_constraint_states(1)[0]
510
+ new_state.reset(sequence)
511
+ return new_state.completed
512
+
513
+ def process(
514
+ self,
515
+ input_ids: torch.LongTensor,
516
+ next_scores: torch.FloatTensor,
517
+ next_tokens: torch.LongTensor,
518
+ next_indices: torch.LongTensor,
519
+ scores_for_all_vocab: torch.FloatTensor,
520
+ pad_token_id: Optional[Union[int, torch.Tensor]] = None,
521
+ eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
522
+ beam_indices: Optional[torch.LongTensor] = None,
523
+ decoder_prompt_len: Optional[int] = 0,
524
+ ) -> Tuple[torch.Tensor]:
525
+ r"""
526
+ Args:
527
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
528
+ Indices of input sequence tokens in the vocabulary.
529
+
530
+ Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
531
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
532
+
533
+ [What are input IDs?](../glossary#input-ids)
534
+ next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
535
+ Current scores of the top `2 * num_beams` non-finished beam hypotheses.
536
+ next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
537
+ `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
538
+ next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
539
+ Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
540
+ scores_for_all_vocab (`torch.FloatTensor` of shape `(batch_size * num_beams, sequence_length)`):
541
+ The scores of all tokens in the vocabulary for each of the beam hypotheses.
542
+ pad_token_id (`int`, *optional*):
543
+ The id of the *padding* token.
544
+ eos_token_id (`Union[int, List[int]]`, *optional*):
545
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
546
+ beam_indices (`torch.LongTensor`, *optional*):
547
+ Beam indices indicating to which beam hypothesis each token correspond.
548
+ decoder_prompt_len (`int`, *optional*):
549
+ The length of prompt that is included in the input to decoder.
550
+ Return:
551
+ `UserDict`: A dictionary composed of the fields as defined above:
552
+
553
+ - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
554
+ all
555
+ non-finished beams.
556
+
557
+ - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
558
+ added
559
+ to the non-finished beam_hypotheses.
560
+ - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
561
+ indicating to which beam the next tokens shall be added.
562
+ """
563
+
564
+ # add up to the length which the next_scores is calculated on (including decoder prompt)
565
+ cur_len = input_ids.shape[-1] + 1
566
+ batch_size = len(self._beam_hyps)
567
+ if not (batch_size == (input_ids.shape[0] // self.group_size)):
568
+ if self.num_beam_groups > 1:
569
+ raise ValueError(
570
+ f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
571
+ f"size of {self.group_size} is expected by the beam scorer."
572
+ )
573
+ else:
574
+ raise ValueError(
575
+ f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
576
+ f"{self.group_size} is expected by the beam scorer."
577
+ )
578
+
579
+ device = input_ids.device
580
+
581
+ next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
582
+ next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
583
+ next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
584
+
585
+ if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
586
+ if isinstance(eos_token_id, int):
587
+ eos_token_id = [eos_token_id]
588
+ eos_token_id = torch.tensor(eos_token_id)
589
+
590
+ for batch_idx, beam_hyp in enumerate(self._beam_hyps):
591
+ if self._done[batch_idx]:
592
+ if self.num_beams < len(beam_hyp):
593
+ raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
594
+ if eos_token_id is None or pad_token_id is None:
595
+ raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
596
+ # pad the batch
597
+ next_beam_scores[batch_idx, :] = 0
598
+ next_beam_tokens[batch_idx, :] = pad_token_id
599
+ next_beam_indices[batch_idx, :] = 0
600
+ continue
601
+
602
+ # next tokens for this sentence.
603
+ beam_idx = 0
604
+ for beam_token_rank, (next_token, next_score, next_index) in enumerate(
605
+ zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
606
+ ):
607
+ batch_beam_idx = batch_idx * self.group_size + next_index
608
+ # add to generated hypotheses if end of sentence
609
+ if (eos_token_id is not None) and (next_token.item() in eos_token_id):
610
+ # if beam_token does not belong to top num_beams tokens, it should not be added
611
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
612
+ if is_beam_token_worse_than_top_num_beams:
613
+ continue
614
+
615
+ completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].cpu().tolist())
616
+ if completes_constraint:
617
+ if beam_indices is not None:
618
+ beam_index = beam_indices[batch_beam_idx]
619
+ beam_index = beam_index + (batch_beam_idx,)
620
+ else:
621
+ beam_index = None
622
+
623
+ beam_hyp.add(
624
+ input_ids[batch_beam_idx].clone(),
625
+ next_score.item(),
626
+ beam_indices=beam_index,
627
+ generated_len=cur_len - decoder_prompt_len,
628
+ )
629
+ else:
630
+ # add next predicted token since it is not eos_token
631
+ next_beam_scores[batch_idx, beam_idx] = next_score
632
+ next_beam_tokens[batch_idx, beam_idx] = next_token
633
+ next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
634
+ beam_idx += 1
635
+
636
+ # once the beam for next step is full, don't add more tokens to it.
637
+ if beam_idx == self.group_size:
638
+ break
639
+
640
+ new_scores, new_tokens, new_indices = self.step_sentence_constraint(
641
+ batch_idx,
642
+ input_ids,
643
+ scores_for_all_vocab,
644
+ next_beam_scores[batch_idx],
645
+ next_beam_tokens[batch_idx],
646
+ next_beam_indices[batch_idx],
647
+ )
648
+
649
+ next_beam_scores[batch_idx] = new_scores
650
+ next_beam_tokens[batch_idx] = new_tokens
651
+ next_beam_indices[batch_idx] = new_indices
652
+
653
+ if beam_idx < self.group_size:
654
+ raise ValueError(
655
+ f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
656
+ f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
657
+ )
658
+
659
+ # Check if we are done so that we can save a pad step if all(done)
660
+ self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
661
+ next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
662
+ )
663
+
664
+ return UserDict(
665
+ {
666
+ "next_beam_scores": next_beam_scores.view(-1),
667
+ "next_beam_tokens": next_beam_tokens.view(-1),
668
+ "next_beam_indices": next_beam_indices.view(-1),
669
+ }
670
+ )
671
+
672
+ def step_sentence_constraint(
673
+ self,
674
+ batch_idx: int,
675
+ input_ids: torch.LongTensor,
676
+ vocab_scores: torch.FloatTensor,
677
+ sent_beam_scores: torch.FloatTensor,
678
+ sent_beam_tokens: torch.LongTensor,
679
+ sent_beam_indices: torch.LongTensor,
680
+ push_progress: bool = False,
681
+ ):
682
+ # sent_beam_tokens are the next {num_beams} number of tokens that are under consideration for this beam
683
+ # (candidate next tokens)
684
+
685
+ # 1. Adding "advance_tokens"
686
+ # using ConstraintStateList.advance(), we propose new tokens to be added into this "candidate list" that will
687
+ # advance us in fulfilling the constraints.
688
+
689
+ # 2. Selecting best candidates such that we end up with highest probable candidates
690
+ # that fulfill our constraints.
691
+
692
+ orig_len = sent_beam_indices.size(0)
693
+ device = sent_beam_indices.device
694
+
695
+ # initialize states
696
+ topk_contraint_states = self.make_constraint_states(orig_len)
697
+ advance_constraint_states = self.make_constraint_states(orig_len)
698
+
699
+ sidx, eidx = batch_idx * orig_len, (batch_idx + 1) * orig_len
700
+ this_batch_input_ids = input_ids[sidx:eidx]
701
+ this_batch_token_scores = vocab_scores[sidx:eidx]
702
+ full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1)
703
+
704
+ # need to make new hypothesis that advance the constraints
705
+ track_new = {
706
+ "new_seqs": full_hypotheses.tolist(),
707
+ "new_states": [],
708
+ "new_indices": [],
709
+ "new_tokens": [],
710
+ "new_scores": [],
711
+ }
712
+ for seq_idx, pre_seq in enumerate(this_batch_input_ids):
713
+ # pre_seq = ith sequence generated before this step.
714
+
715
+ # input_ids -> (topk) generic beam search best model next tokens
716
+ # -> (advance) constraints forcing the next token
717
+ # either way, we need to sort them into "banks" later, so store a "ConstraintListState" for all types of
718
+ # hypotheses.
719
+
720
+ topk_state = topk_contraint_states[seq_idx]
721
+ topk_state.reset(full_hypotheses[seq_idx].cpu().tolist())
722
+
723
+ advance_state = advance_constraint_states[seq_idx]
724
+ advance_state.reset(pre_seq.cpu().tolist())
725
+
726
+ if not advance_state.completed:
727
+ advance_tokens = torch.LongTensor(advance_state.advance()).to(device)
728
+ for advance_token in advance_tokens:
729
+ # since adding each `advance_token` leads to a different hypothesis, create new state instance.
730
+ new_state = advance_state.copy(stateful=True)
731
+ new_state.add(advance_token.cpu().tolist())
732
+
733
+ advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist()
734
+ if advance_seq not in track_new["new_seqs"]:
735
+ # prevent duplicates, which are basically bound to happen in this process.
736
+ track_new["new_seqs"].append(advance_seq)
737
+ track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches
738
+ track_new["new_tokens"].append(advance_token)
739
+ track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
740
+ track_new["new_states"].append(new_state)
741
+ elif push_progress:
742
+ # Basically, `sent_beam_indices` often chooses very little among `input_ids` the generated sequences that
743
+ # actually fulfill our constraints. For example, let constraints == ["loves pies"] and
744
+
745
+ # pre_seq_1 = "The child loves pies and" pre_seq_2 = "The child plays in the playground and"
746
+
747
+ # Without this step, if `sent_beam_indices` is something like [1,1], then
748
+ # 1. `pre_seq_1` won't be added to the list of (topk) hypothesis since it's not in the indices and
749
+ # 2. it won't be added to the list of (advance) hypothesis since it's completed already. (this is
750
+ # the else part of `if constraints_completed[seq_idx]`)
751
+ # 3. it ends up simply getting removed from consideration.
752
+
753
+ # #3 might be fine and actually desired, since it's likely that it's a low-probability output anyways,
754
+ # especially if it's not in the list of `sent_beam_indices`. But this often leads to lengthened beam
755
+ # search times, since completed sequences keep getting removed after all this effort for constrained
756
+ # generation.
757
+
758
+ # Here, we basically take `pre_seq_1` and to "push" it into the considered list of hypotheses, by simply
759
+ # appending the next likely token in the vocabulary and adding it to the list of hypotheses.
760
+
761
+ new_score, new_token = torch.max(this_batch_token_scores[seq_idx], 0) # some next probable token
762
+ advance_seq = torch.cat((pre_seq, new_token.unsqueeze(0)), -1)
763
+
764
+ advance_state = advance_constraint_states[seq_idx]
765
+
766
+ advance_seq = advance_seq.cpu().tolist()
767
+
768
+ advance_state.reset(advance_seq)
769
+ if advance_seq not in track_new["new_seqs"]:
770
+ # but still don't want to have duplicates
771
+ track_new["new_seqs"].append(advance_seq)
772
+ track_new["new_indices"].append(seq_idx)
773
+ track_new["new_tokens"].append(new_token)
774
+ track_new["new_scores"].append(new_score)
775
+ track_new["new_states"].append(advance_state)
776
+
777
+ if len(track_new["new_indices"]) > 0:
778
+ new_indices = torch.tensor(track_new["new_indices"]).to(device)
779
+ new_tokens = torch.stack(track_new["new_tokens"]).to(device)
780
+ new_scores = torch.stack(track_new["new_scores"]).to(device)
781
+
782
+ all_states = topk_contraint_states + track_new["new_states"]
783
+ all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1)
784
+ all_scores = torch.cat((sent_beam_scores, new_scores), -1)
785
+ all_banks = torch.tensor([one.get_bank() for one in all_states]).to(device)
786
+
787
+ zipped = all_banks * 100 + all_scores
788
+ indices = zipped.sort(descending=True).indices
789
+ sorted_banks = all_banks[indices]
790
+
791
+ # Then we end up with {sorted among bank C}, {sorted among bank C-1}, ..., {sorted among bank 0}
792
+
793
+ counter = -1
794
+ cur_bank = sorted_banks[0]
795
+ increments = []
796
+ for bank in sorted_banks:
797
+ if bank == cur_bank:
798
+ counter += 1
799
+ else:
800
+ counter = 0
801
+ cur_bank = bank
802
+ increments.append(counter)
803
+ rearrangers = torch.tensor(np.argsort(increments, kind="mergesort"))
804
+
805
+ indices = indices[rearrangers][:orig_len]
806
+
807
+ sent_beam_scores = all_scores[indices]
808
+ sent_beam_tokens = all_tokens[indices]
809
+ sent_beam_indices = torch.cat((sent_beam_indices, new_indices))[indices]
810
+
811
+ return sent_beam_scores, sent_beam_tokens, sent_beam_indices
812
+
813
+ def finalize(
814
+ self,
815
+ input_ids: torch.LongTensor,
816
+ final_beam_scores: torch.FloatTensor,
817
+ final_beam_tokens: torch.LongTensor,
818
+ final_beam_indices: torch.LongTensor,
819
+ max_length: int,
820
+ pad_token_id: Optional[Union[int, torch.Tensor]] = None,
821
+ eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
822
+ beam_indices: Optional[torch.LongTensor] = None,
823
+ decoder_prompt_len: Optional[int] = 0,
824
+ ) -> Tuple[torch.LongTensor]:
825
+ batch_size = len(self._beam_hyps)
826
+
827
+ if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
828
+ if isinstance(eos_token_id, int):
829
+ eos_token_id = [eos_token_id]
830
+ eos_token_id = torch.tensor(eos_token_id)
831
+
832
+ # finalize all open beam hypotheses and add to generated hypotheses
833
+ for batch_idx, beam_hyp in enumerate(self._beam_hyps):
834
+ if self._done[batch_idx]:
835
+ continue
836
+
837
+ # all open beam hypotheses are added to the beam hypothesis
838
+ # beam hypothesis class automatically keeps the best beams
839
+
840
+ ids_collect = []
841
+ for beam_id in range(self.num_beams):
842
+ batch_beam_idx = batch_idx * self.num_beams + beam_id
843
+ final_score = final_beam_scores[batch_beam_idx].item()
844
+ final_tokens = input_ids[batch_beam_idx]
845
+
846
+ completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
847
+ if completes_constraint:
848
+ beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
849
+ generated_len = final_tokens.shape[-1] - decoder_prompt_len
850
+ beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
851
+ ids_collect.append(beam_id)
852
+
853
+ # due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
854
+ # generation. In these cases we simply return the highest scoring outputs.
855
+ if len(ids_collect) < self.num_beam_hyps_to_keep:
856
+ for beam_id in range(self.num_beams):
857
+ if beam_id not in ids_collect:
858
+ batch_beam_idx = batch_idx * self.num_beams + beam_id
859
+ final_score = final_beam_scores[batch_beam_idx].item()
860
+ final_tokens = input_ids[batch_beam_idx]
861
+ generated_len = final_tokens.shape[-1] - decoder_prompt_len
862
+ beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
863
+ if len(ids_collect) >= self.num_beam_hyps_to_keep:
864
+ break
865
+
866
+ # select the best hypotheses
867
+ sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
868
+ best = []
869
+ best_indices = []
870
+ best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
871
+
872
+ # retrieve best hypotheses
873
+ for i, beam_hyp in enumerate(self._beam_hyps):
874
+ sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
875
+ for j in range(self.num_beam_hyps_to_keep):
876
+ best_hyp_tuple = sorted_hyps.pop()
877
+ best_score = best_hyp_tuple[0]
878
+ best_hyp = best_hyp_tuple[1]
879
+ best_index = best_hyp_tuple[2]
880
+ sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
881
+
882
+ # append to lists
883
+ best.append(best_hyp)
884
+
885
+ # append indices to list
886
+ best_indices.append(best_index)
887
+
888
+ best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
889
+
890
+ # prepare for adding eos
891
+ sent_lengths_max = sent_lengths.max().item() + 1
892
+
893
+ sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
894
+ decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
895
+
896
+ if len(best_indices) > 0 and best_indices[0] is not None:
897
+ indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
898
+ else:
899
+ indices = None
900
+
901
+ # shorter batches are padded if needed
902
+ if sent_lengths.min().item() != sent_lengths.max().item():
903
+ if pad_token_id is None:
904
+ raise ValueError("`pad_token_id` has to be defined")
905
+ decoded.fill_(pad_token_id)
906
+
907
+ if indices is not None:
908
+ indices.fill_(-1)
909
+
910
+ # fill with hypotheses and eos_token_id if the latter fits in
911
+ for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
912
+ decoded[i, : sent_lengths[i]] = hypo
913
+
914
+ if indices is not None:
915
+ indices[i, : len(best_idx)] = torch.tensor(best_idx)
916
+
917
+ if sent_lengths[i] < sent_max_len:
918
+ # inserting only the first eos_token_id
919
+ decoded[i, sent_lengths[i]] = eos_token_id[0]
920
+
921
+ return UserDict(
922
+ {
923
+ "sequences": decoded,
924
+ "sequence_scores": best_scores,
925
+ "beam_indices": indices,
926
+ }
927
+ )
928
+
929
+
930
+ class BeamHypotheses:
931
+ def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
932
+ """
933
+ Initialize n-best list of hypotheses.
934
+ """
935
+ self.length_penalty = length_penalty
936
+ self.early_stopping = early_stopping
937
+ self.max_length = max_length
938
+ self.num_beams = num_beams
939
+ self.beams = []
940
+ self.worst_score = 1e9
941
+
942
+ if not isinstance(self.early_stopping, bool) and self.max_length is None:
943
+ raise ValueError(
944
+ "When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
945
+ " BeamScorer class instance at initialization time."
946
+ )
947
+
948
+ def __len__(self):
949
+ """
950
+ Number of hypotheses in the list.
951
+ """
952
+ return len(self.beams)
953
+
954
+ def add(
955
+ self,
956
+ hyp: torch.LongTensor,
957
+ sum_logprobs: float,
958
+ beam_indices: Optional[torch.LongTensor] = None,
959
+ generated_len: Optional[int] = None,
960
+ ):
961
+ """
962
+ Add a new hypothesis to the list.
963
+ """
964
+ if generated_len is not None:
965
+ score = sum_logprobs / (generated_len**self.length_penalty)
966
+ # This 'else' case exists for retrocompatibility
967
+ else:
968
+ score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
969
+
970
+ if len(self) < self.num_beams or score > self.worst_score:
971
+ self.beams.append((score, hyp, beam_indices))
972
+ if len(self) > self.num_beams:
973
+ sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
974
+ del self.beams[sorted_next_scores[0][1]]
975
+ self.worst_score = sorted_next_scores[1][0]
976
+ else:
977
+ self.worst_score = min(score, self.worst_score)
978
+
979
+ def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
980
+ """
981
+ If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
982
+ one in the heap, then we are done with this sentence.
983
+ """
984
+
985
+ if len(self) < self.num_beams:
986
+ return False
987
+
988
+ # `True`: stop as soon as at least `num_beams` hypotheses are finished
989
+ if self.early_stopping is True:
990
+ return True
991
+ # `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
992
+ # when `length_penalty` is positive. See the discussion below for more details.
993
+ # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
994
+ elif self.early_stopping is False:
995
+ highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
996
+ ret = self.worst_score >= highest_attainable_score
997
+ return ret
998
+ # `"never"`: compute the best possible score, depending on the signal of `length_penalty`
999
+ else:
1000
+ # `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
1001
+ # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
1002
+ # its max this way
1003
+ if self.length_penalty > 0.0:
1004
+ if self.max_length <= decoder_prompt_len:
1005
+ raise ValueError("max_length is not larger than decoder prompt length")
1006
+ highest_attainable_score = (
1007
+ best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
1008
+ )
1009
+ # the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
1010
+ else:
1011
+ highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
1012
+ ret = self.worst_score >= highest_attainable_score
1013
+ return ret
.venv/lib/python3.11/site-packages/transformers/generation/candidate_generator.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import copy
17
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from ..utils import is_sklearn_available
23
+
24
+
25
+ if is_sklearn_available():
26
+ from sklearn.metrics import roc_curve
27
+
28
+ from ..cache_utils import DynamicCache
29
+ from ..pytorch_utils import isin_mps_friendly
30
+ from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from ..modeling_utils import PreTrainedModel
35
+ from ..tokenization_utils_base import PreTrainedTokenizerBase
36
+ from .configuration_utils import GenerationConfig
37
+
38
+
39
+ class CandidateGenerator:
40
+ """Abstract base class for all candidate generators that can be applied during assisted generation."""
41
+
42
+ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
43
+ """
44
+ Fetches the candidates to be tried for the current input.
45
+
46
+ Args:
47
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
48
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
49
+
50
+ Return:
51
+ `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
52
+ assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
53
+ vocabulary_size)` containing the logits associated to each candidate.
54
+ """
55
+ raise NotImplementedError(
56
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
57
+ )
58
+
59
+ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
60
+ """
61
+ Updates the candidate generation strategy based on the outcomes.
62
+
63
+ Args:
64
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
65
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
66
+ scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
67
+ Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
68
+ beam search or log softmax for each vocabulary token when using beam search
69
+ num_matches (`int`):
70
+ The number of matches between the candidate sequences and the model predictions.
71
+ """
72
+ raise NotImplementedError(
73
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can call "
74
+ "`update_candidate_strategy`."
75
+ )
76
+
77
+
78
+ class AssistedCandidateGenerator(CandidateGenerator):
79
+ """
80
+ `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
81
+ candidates through the use of a smaller model. Read the following blog post for more information:
82
+ https://huggingface.co/blog/assisted-generation
83
+
84
+ Args:
85
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
86
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
87
+ assistant_model (`PreTrainedModel`):
88
+ The model to be used for generating candidates. This model should be smaller than the main model.
89
+ generation_config (`~generation.GenerationConfig`, *optional*):
90
+ The generation configuration to be used as base parametrization for the generation call.
91
+ logits_processor (`LogitsProcessorList`):
92
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
93
+ used to modify the prediction scores of the language modeling head applied at each generation step.
94
+ model_kwargs (`Dict`):
95
+ The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
96
+ model as well.
97
+ inputs_tensor (`torch.Tensor`, *optional*):
98
+ The model input tensor. In encoder-decoder models, this is the encoder input.
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ input_ids: torch.LongTensor,
104
+ assistant_model: "PreTrainedModel",
105
+ generation_config: "GenerationConfig",
106
+ model_kwargs: Dict,
107
+ inputs_tensor: Optional[torch.Tensor] = None,
108
+ logits_processor: "LogitsProcessorList" = None,
109
+ ):
110
+ # Make sure all data at the same device as assistant model
111
+ device = assistant_model.device
112
+ input_ids = input_ids.to(device)
113
+ if inputs_tensor is not None:
114
+ inputs_tensor = inputs_tensor.to(device)
115
+
116
+ # Prepare the assistant and the starting number of candidate tokens
117
+ self.assistant_model = assistant_model
118
+ self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
119
+ self.assistant_confidence_threshold = assistant_model.generation_config.assistant_confidence_threshold
120
+
121
+ # Set eos in assistant same as in target model
122
+ self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id
123
+
124
+ # Prepare the kwargs for the assistant model
125
+ assistant_kwargs = {}
126
+ for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
127
+ if key not in ("encoder_outputs", "assistant_encoder_outputs", "past_key_values"):
128
+ assistant_kwargs[key] = (
129
+ value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
130
+ )
131
+
132
+ # Remove potential default "num_logits_to_keep" key
133
+ if "num_logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_num_logits_to_keep():
134
+ del assistant_kwargs["num_logits_to_keep"]
135
+
136
+ if "assistant_encoder_outputs" in model_kwargs:
137
+ assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
138
+ elif assistant_model.config.is_encoder_decoder:
139
+ inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs(
140
+ inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs
141
+ )
142
+ assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
143
+ inputs_tensor, assistant_kwargs, model_input_name, assistant_model.generation_config
144
+ )
145
+ elif "encoder_outputs" in model_kwargs:
146
+ assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"]
147
+ self.assistant_kwargs = assistant_kwargs
148
+
149
+ # Prepare assistant model's keys of inputs
150
+ if assistant_model.config.is_encoder_decoder:
151
+ # both are encoder-decoder
152
+ self.input_ids_key = "decoder_input_ids"
153
+ elif "encoder_outputs" in assistant_kwargs:
154
+ # special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
155
+ self.input_ids_key = "input_ids"
156
+ self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get(
157
+ "decoder_attention_mask",
158
+ torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
159
+ )
160
+ else:
161
+ # both are decoder-only
162
+ self.input_ids_key = "input_ids"
163
+
164
+ # Prepare generation-related options.
165
+ self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
166
+ self.generation_config = copy.deepcopy(generation_config)
167
+
168
+ self.generation_config.return_dict_in_generate = True
169
+ self.generation_config.output_scores = True
170
+ self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold
171
+ # this flag allow us set the confidence stopping criteria for assistant model generation.
172
+ self.generation_config.is_assistant = True
173
+
174
+ # avoid unnecessary warnings that min_length is larger than max_new_tokens
175
+ # remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
176
+ self.main_model_min_length = self.generation_config.min_length
177
+ self.generation_config.min_length = 0
178
+ self.generation_config.min_new_tokens = None
179
+ for processor in self.logits_processor:
180
+ if isinstance(processor, MinLengthLogitsProcessor):
181
+ raise ValueError(
182
+ "Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
183
+ "Please pass in `min_length` into `.generate()` instead"
184
+ )
185
+
186
+ # We need to roll back the cache in assisted generation, only DynamicCache is supported
187
+ self.generation_config.cache_implementation = None
188
+
189
+ if (
190
+ is_sklearn_available()
191
+ and self.assistant_model.generation_config.assistant_confidence_threshold
192
+ and type(self) is AssistedCandidateGenerator
193
+ ):
194
+ self.probs = []
195
+ self.matches = []
196
+
197
+ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
198
+ """
199
+ Fetches the candidates to be tried for the current input.
200
+
201
+ Args:
202
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
203
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
204
+
205
+ Return:
206
+ `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
207
+ assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
208
+ vocabulary_size)` containing the logits associated to each candidate.
209
+ """
210
+ input_ids = input_ids.to(self.assistant_model.device)
211
+ # Calculate new tokens to generate
212
+ min_new_tokens, max_new_tokens = self._calculate_new_tokens(input_ids)
213
+ if max_new_tokens == 0:
214
+ return input_ids, None
215
+ # Update past key values and masks
216
+ self._update_past_and_masks(input_ids)
217
+ # Generate candidates
218
+ generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens)
219
+ candidate_ids, candidate_logits = self._generate_candidates(generation_args)
220
+ return candidate_ids, candidate_logits
221
+
222
+ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
223
+ """
224
+ Updates the candidate generation strategy based on the outcomes.
225
+
226
+ Args:
227
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
228
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
229
+ scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
230
+ Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
231
+ beam search or log softmax for each vocabulary token when using beam search
232
+ num_matches (`int`):
233
+ The number of matches between the candidate sequences and the model predictions.
234
+ """
235
+ # Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
236
+ # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
237
+ # cost of forecasting incorrect assistant tokens.
238
+ if self.assistant_model.generation_config.num_assistant_tokens_schedule in {
239
+ "heuristic",
240
+ "heuristic_transient",
241
+ }:
242
+ # len(scores[0])-1 is the number of candidates according to the target tokenizer.
243
+ if num_matches == len(scores[0]) - 1:
244
+ self.num_assistant_tokens += 2.0
245
+ else:
246
+ self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
247
+
248
+ # The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes. The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target. A cost of 25% is assigned to false positives and 75% to false negatives.
249
+ # This adaptation is not compatible with UAG, as it relies on the number of matched tokens based on the draft vocabulary, which is unavailable in UAG.
250
+ if (
251
+ is_sklearn_available()
252
+ and self.assistant_model.generation_config.assistant_confidence_threshold
253
+ and type(self) is AssistedCandidateGenerator
254
+ ):
255
+ # update self.matches
256
+ self.matches.extend([1] * num_matches)
257
+ if len(self.probs) > len(self.matches):
258
+ self.matches.append(0)
259
+
260
+ # update self.probs
261
+ excess_length = len(self.probs) - len(self.matches)
262
+ if excess_length > 0:
263
+ del self.probs[-excess_length:]
264
+
265
+ if (
266
+ len(self.probs) > 5 and {0, 1}.issubset(self.matches)
267
+ ): # require at least 5 samples to calculate the ROC curve and at least one positive and one negative sample
268
+ fpr, tpr, thresholds = roc_curve(self.matches, self.probs)
269
+ fnr = 1 - tpr
270
+
271
+ # Calculate the cost for each threshold
272
+ costs = fpr + 3 * fnr
273
+
274
+ # Find the threshold that minimizes the cost
275
+ optimal_threshold_index = np.argmin(costs)
276
+ best_threshold = thresholds[optimal_threshold_index]
277
+
278
+ self.assistant_model.generation_config.assistant_confidence_threshold = best_threshold
279
+
280
+ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]:
281
+ """Calculate the minimum and maximum number of new tokens to generate."""
282
+ new_cur_len = input_ids.shape[-1]
283
+ max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
284
+ min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
285
+ return min_new_tokens, max_new_tokens
286
+
287
+ def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool:
288
+ """Update past key values and attention masks for subsequent generation rounds."""
289
+ has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
290
+ if has_past_key_values:
291
+ new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
292
+ self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
293
+ self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
294
+ )
295
+ self.assistant_kwargs = _prepare_attention_mask(
296
+ self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
297
+ )
298
+ self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
299
+ return has_past_key_values
300
+
301
+ def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
302
+ """Prepare arguments for the generation call."""
303
+ return {
304
+ self.input_ids_key: input_ids,
305
+ "min_new_tokens": min_new_tokens,
306
+ "max_new_tokens": max_new_tokens,
307
+ "generation_config": self.generation_config,
308
+ "logits_processor": self.logits_processor,
309
+ }
310
+
311
+ def _generate_candidates(self, generation_args: Dict) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
312
+ """Generate candidate sequences using the assistant model."""
313
+ assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
314
+ self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
315
+ if (
316
+ is_sklearn_available()
317
+ and self.assistant_model.generation_config.assistant_confidence_threshold
318
+ and type(self) is AssistedCandidateGenerator
319
+ ):
320
+ scores_tensor = torch.cat(assistant_output.scores, dim=0)
321
+ scores_softmax = torch.softmax(scores_tensor, dim=-1)
322
+ ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
323
+ p = scores_softmax[range(len(ids)), ids]
324
+ self.probs.extend(p.tolist())
325
+ candidate_logits = torch.stack(assistant_output.scores, dim=1)
326
+ candidate_ids = assistant_output.sequences
327
+ return candidate_ids, candidate_logits
328
+
329
+
330
+ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
331
+ """
332
+ `CandidateGenerator` class to be used for Universal Assisted Generation (UAD): assisted generation with different tokenizers
333
+ for the assistant and main models. This class generates candidates through the use of a smaller
334
+ model.
335
+
336
+ The main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
337
+ in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
338
+ The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
339
+ Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
340
+ to ensure the new tokens include the correct prompt suffix.
341
+
342
+ Args:
343
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
344
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
345
+ assistant_model (`PreTrainedModel`):
346
+ The model to be used for generating candidates. This model should be smaller than the main model.
347
+ target_tokenizer (`PreTrainedTokenizerBase`):
348
+ The tokenizer used for the target model.
349
+ assistant_tokenizer (`PreTrainedTokenizerBase`):
350
+ The tokenizer used for the assistant model.
351
+ generation_config (`~generation.GenerationConfig`, *optional*):
352
+ The generation configuration to be used as base parametrization for the generation call.
353
+ logits_processor (`LogitsProcessorList`):
354
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
355
+ used to modify the prediction scores of the language modeling head applied at each generation step.
356
+ model_kwargs (`Dict`):
357
+ The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
358
+ model as well.
359
+ inputs_tensor (`torch.Tensor`, *optional*):
360
+ The model input tensor. In encoder-decoder models, this is the encoder input.
361
+ """
362
+
363
+ def __init__(
364
+ self,
365
+ input_ids: torch.LongTensor,
366
+ assistant_model: "PreTrainedModel",
367
+ target_tokenizer: "PreTrainedTokenizerBase",
368
+ assistant_tokenizer: "PreTrainedTokenizerBase",
369
+ generation_config: "GenerationConfig",
370
+ model_kwargs: Dict,
371
+ inputs_tensor: Optional[torch.Tensor] = None,
372
+ logits_processor: "LogitsProcessorList" = None,
373
+ ):
374
+ super().__init__(input_ids, assistant_model, generation_config, model_kwargs, inputs_tensor, logits_processor)
375
+
376
+ self.target_tokenizer = target_tokenizer
377
+ self.assistant_tokenizer = assistant_tokenizer
378
+ self.prev_target_ids_len: Optional[int] = None
379
+ self.prev_assistant_ids = None
380
+ self.target_lookbehind = assistant_model.generation_config.target_lookbehind
381
+ self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
382
+
383
+ @staticmethod
384
+ def _get_longest_diag_dict(input_matrix, nonzero_idx):
385
+ """
386
+ Calculates the length of the longest diagonal sequence in a given matrix.
387
+ Args:
388
+ input_matrix (torch.Tensor): The input matrix.
389
+ nonzero_idx (torch.Tensor): The indices of the non-zero elements in the matrix.
390
+ Returns:
391
+ dict: A dictionary where the keys are the indices of the non-zero elements and the values are the lengths of the longest diagonal sequences starting from those indices.
392
+ """
393
+
394
+ visited = set()
395
+ diags = {}
396
+ for idx in nonzero_idx:
397
+ start_idx = torch.clone(idx)
398
+ tuple_start_idx = tuple(start_idx.tolist())
399
+
400
+ if tuple_start_idx in visited:
401
+ continue
402
+
403
+ visited.add(tuple_start_idx)
404
+ cur_diag_len = 1
405
+ start_idx += 1
406
+ while start_idx[0] < input_matrix.shape[0] and start_idx[1] < input_matrix.shape[1]:
407
+ tuple_start_idx = tuple(start_idx.tolist())
408
+ visited.add(tuple_start_idx)
409
+
410
+ if input_matrix[start_idx[0], start_idx[1]] == 1:
411
+ cur_diag_len += 1
412
+ start_idx += 1
413
+ else:
414
+ break
415
+
416
+ diags[idx] = cur_diag_len
417
+ return diags
418
+
419
+ @staticmethod
420
+ def _get_longest_diag_index(input_matrix):
421
+ """
422
+ Returns the start index and length of the longest diagonal in the given input.
423
+ Args:
424
+ input_matrix (numpy.ndarray): The input matrix.
425
+ Returns:
426
+ tuple: A tuple containing the start index and length of the longest diagonal.
427
+ """
428
+
429
+ diags = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_dict(
430
+ input_matrix, input_matrix.nonzero()
431
+ )
432
+ diags_values = list(diags.values())
433
+ diags_keys = list(diags.keys())
434
+ best_diag = np.argmax(diags_values)
435
+ diag_start_index = diags_keys[best_diag]
436
+ diag_start_length = diags_values[best_diag]
437
+ return diag_start_index, diag_start_length
438
+
439
+ @staticmethod
440
+ def _get_tokens_diag(prompt, prompt_plus_new_tokens):
441
+ """
442
+ Input:
443
+ prompt: 2D array of shape (batch_size, prompt_length), represents the original prompt tokens
444
+ prompt_plus_new_tokens: 2D array of shape (batch_size, prompt_length), represents the suffix of the original prompt, with additional new tokens.
445
+ Output:
446
+ discrepancy_length: int, represents the number of tokens that need to be replaced from prompt
447
+ new_tokens_only: 2D array of shape (batch_size, new_token_length), represents the new tokens that are not in prompt
448
+ discrepancy_only: 2D array of shape (batch_size, discrepancy_length), represents the new tokens that are in prompt but not in prompt_plus_new_tokens
449
+ """
450
+ compare_mat = prompt_plus_new_tokens.T == prompt
451
+ if not torch.is_tensor(compare_mat):
452
+ compare_mat = torch.tensor(compare_mat)
453
+
454
+ compare_mat_int = compare_mat.to(int)
455
+
456
+ if not compare_mat_int.any().item():
457
+ # empty intersection between prompt and prompt_plus_new_tokens
458
+ return None, None, None
459
+
460
+ longest_location, longest_diag_length = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_index(
461
+ compare_mat_int
462
+ )
463
+ new_token_start_index = longest_location[0] + longest_diag_length
464
+ discrepancy_with_old = longest_location[1] + longest_diag_length
465
+ discrepancy_length = (prompt.shape[1] - discrepancy_with_old).item()
466
+ new_tokens_only = prompt_plus_new_tokens[:, new_token_start_index + discrepancy_length :]
467
+ discrepancy_only = prompt_plus_new_tokens[
468
+ :, new_token_start_index : new_token_start_index + discrepancy_length
469
+ ]
470
+ return discrepancy_length, new_tokens_only, discrepancy_only
471
+
472
+ def convert_source_tokens_to_target_tokens(
473
+ self,
474
+ input_ids,
475
+ source_tokenizer,
476
+ destination_tokenizer,
477
+ ):
478
+ """
479
+ Convert token IDs from one tokenizer to another.
480
+ Args:
481
+ input_ids: The input token IDs.
482
+ source_tokenizer: The source tokenizer.
483
+ destination_tokenizer: The destination tokenizer.
484
+ Returns:
485
+ The converted token IDs.
486
+ """
487
+ text = source_tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
488
+ dest_ids = destination_tokenizer(text, add_special_tokens=True, return_tensors="pt")["input_ids"]
489
+ return dest_ids.to(input_ids.device)
490
+
491
+ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
492
+ """
493
+ Fetches the candidates to be tried for the current input.
494
+
495
+ Args:
496
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
497
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
498
+
499
+ Return:
500
+ `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
501
+ assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
502
+ vocabulary_size)` containing the logits associated to each candidate.
503
+ """
504
+ max_new_tokens = int(self.num_assistant_tokens)
505
+ if max_new_tokens == 0:
506
+ return input_ids, None
507
+
508
+ input_ids = input_ids.to(self.assistant_model.device)
509
+ remove_from_pkv = 0
510
+
511
+ assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids)
512
+ self.prev_assistant_ids = assistant_input_ids
513
+
514
+ min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0)
515
+
516
+ self._update_past_and_masks(assistant_input_ids, remove_from_pkv)
517
+ generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
518
+ self.assistant_kwargs.pop("attention_mask", None)
519
+
520
+ assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
521
+ new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences, assistant_input_ids)
522
+
523
+ # Update state
524
+ self.prev_target_ids_len = input_ids.shape[1]
525
+ self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
526
+ self.prev_assistant_ids = assistant_output.sequences
527
+
528
+ if self.prev_target_ids_len >= new_target_ids.shape[1]:
529
+ return input_ids, None
530
+
531
+ return new_target_ids, None
532
+
533
+ def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, int]:
534
+ """Converts target input IDs to assistant input IDs, handling discrepancies."""
535
+ convert_kwargs = {
536
+ "source_tokenizer": self.target_tokenizer,
537
+ "destination_tokenizer": self.assistant_tokenizer,
538
+ }
539
+ remove_from_pkv = 0
540
+
541
+ if self.prev_assistant_ids is not None and self.prev_target_ids_len > self.target_lookbehind:
542
+ # input_ids contains all target prompt input ids and some new target input ids
543
+ start_index_in_target_window = self.prev_target_ids_len - self.target_lookbehind
544
+
545
+ new_assistant_ids = self.convert_source_tokens_to_target_tokens(
546
+ input_ids[:, start_index_in_target_window:], **convert_kwargs
547
+ )
548
+ prompt_use_length = new_assistant_ids.shape[1]
549
+ prompt_use = self.prev_assistant_ids[:, -prompt_use_length:]
550
+
551
+ discrepancy_length, new_tokens_only, discrepancy_only = self._get_tokens_diag(
552
+ prompt_use, new_assistant_ids
553
+ )
554
+ assistant_input_ids = self.prev_assistant_ids
555
+
556
+ if new_tokens_only is not None:
557
+ if discrepancy_length > 0 and discrepancy_only.shape[1] > 0:
558
+ if discrepancy_length == discrepancy_only.shape[1]:
559
+ assistant_input_ids[:, -discrepancy_length:] = discrepancy_only
560
+
561
+ elif discrepancy_length > discrepancy_only.shape[1]:
562
+ discrepancy_length_diff = discrepancy_length - discrepancy_only.shape[1]
563
+ assistant_input_ids = assistant_input_ids[:, :-discrepancy_length_diff]
564
+ assistant_input_ids[:, -discrepancy_only.shape[1] :] = discrepancy_only
565
+
566
+ remove_from_pkv = discrepancy_length
567
+
568
+ if new_tokens_only.shape[1] > 0:
569
+ assistant_input_ids = torch.cat([assistant_input_ids, new_tokens_only], dim=-1)
570
+ else:
571
+ # edge case: in case of no intersection between prompt and new_assistant_ids
572
+ assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)
573
+ else:
574
+ assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
575
+ self.prev_target_ids_len = input_ids.shape[1]
576
+
577
+ return assistant_input_ids, remove_from_pkv
578
+
579
+ def _process_assistant_outputs(
580
+ self, input_ids: torch.LongTensor, assistant_sequences: torch.LongTensor, assistant_input_ids: torch.LongTensor
581
+ ) -> torch.LongTensor:
582
+ """Processes assistant outputs to obtain target input IDs."""
583
+ num_prev_assistant = self.prev_assistant_ids.shape[1]
584
+ start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
585
+
586
+ new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
587
+ assistant_sequences[:, start_assistant_look_index:],
588
+ source_tokenizer=self.assistant_tokenizer,
589
+ destination_tokenizer=self.target_tokenizer,
590
+ )
591
+ target_prompt_use_length = new_target_ids_from_window.shape[1]
592
+
593
+ target_prompt_use = input_ids[:, -target_prompt_use_length:]
594
+
595
+ _, target_new_tokens_only, _ = self._get_tokens_diag(target_prompt_use, new_target_ids_from_window)
596
+
597
+ new_target_ids = input_ids
598
+
599
+ if target_new_tokens_only is not None:
600
+ if target_new_tokens_only.shape[1] > 0:
601
+ new_target_ids = torch.cat([new_target_ids, target_new_tokens_only], dim=-1)
602
+ else:
603
+ # edge case: in case of no intersection between prompt and new_target_ids
604
+ new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1)
605
+
606
+ if hasattr(self.generation_config, "max_length"):
607
+ new_target_ids = new_target_ids[:, : self.generation_config.max_length]
608
+
609
+ return new_target_ids
610
+
611
+
612
+ class PromptLookupCandidateGenerator(CandidateGenerator):
613
+ """
614
+ `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
615
+ likely continuations in the provided prompt (input_ids) itself.
616
+ Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding
617
+
618
+ Args:
619
+ max_matching_ngram_size (`int`):
620
+ The maximum ngram size to be considered for matching in the prompt
621
+ num_output_tokens (`int`):
622
+ The number of tokens to be output as candidate tokens.
623
+ max_length (`int`):
624
+ The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
625
+ Defaults to 20, which is the max length used as default in generation config.
626
+ """
627
+
628
+ def __init__(
629
+ self,
630
+ eos_token_id: torch.Tensor = None,
631
+ num_output_tokens: int = 10,
632
+ max_matching_ngram_size: int = None,
633
+ max_length: int = 20,
634
+ ):
635
+ self.num_output_tokens = num_output_tokens
636
+ self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
637
+ self.max_length = max_length
638
+ self.eos_token_id = eos_token_id
639
+
640
+ if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
641
+ raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
642
+
643
+ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
644
+ """
645
+ Fetches the candidates to be tried for the current input.
646
+
647
+ Args:
648
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
649
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
650
+
651
+ Return:
652
+ `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
653
+ """
654
+ input_length = input_ids.size(1)
655
+
656
+ # Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
657
+ if self.max_length == input_length + 1:
658
+ return input_ids, None
659
+
660
+ chosen_ids = None
661
+ match_found = False
662
+ for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
663
+ # Create sliding windows of size ngram_size
664
+ windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
665
+
666
+ # Convert ngram to a tensor for comparison
667
+ ngram_tensor = input_ids[0, -ngram_size:]
668
+
669
+ # Find where the windows match the ngram
670
+ matches = (windows == ngram_tensor).all(dim=2)
671
+
672
+ # Get the indices of matches
673
+ match_indices = matches.nonzero(as_tuple=True)[1]
674
+
675
+ # Iterate through match indices to find a valid continuation
676
+ for idx in match_indices:
677
+ start_idx = idx + ngram_size
678
+ end_idx = start_idx + self.num_output_tokens
679
+ end_idx = min(end_idx, input_length, self.max_length)
680
+
681
+ if start_idx < end_idx:
682
+ chosen_ids = input_ids[0, start_idx:end_idx]
683
+ match_found = True
684
+
685
+ # remove remaining candidate ids if an "eos" token is found, otherwise the target model may
686
+ # accept eos and the rest as valid, thus not stopping generation after "eos"
687
+ # NOTE: below code is written based on the fact that assisted decoding supports only bs=1
688
+ mask = isin_mps_friendly(chosen_ids, self.eos_token_id)
689
+ match_indices_eos = torch.nonzero(mask)
690
+ if match_indices_eos.numel() > 0:
691
+ first_eos_index = match_indices_eos[0].item()
692
+ chosen_ids = chosen_ids[:first_eos_index]
693
+ break
694
+ if match_found:
695
+ break
696
+
697
+ if chosen_ids is None or len(chosen_ids) == 0:
698
+ # In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
699
+ return input_ids, None
700
+
701
+ # Now need extend input_ids with chosen_ids
702
+ chosen_ids = chosen_ids.unsqueeze(0)
703
+ candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
704
+ # assisted_generation expects logits as well, but we don't have those here, so returning None
705
+ return candidate_input_ids, None
706
+
707
+ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
708
+ """
709
+ Updates the candidate generation strategy based on the outcomes.
710
+
711
+ Args:
712
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
713
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
714
+ scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
715
+ Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
716
+ beam search or log softmax for each vocabulary token when using beam search
717
+ num_matches (`int`):
718
+ The number of matches between the candidate sequences and the model predictions.
719
+ """
720
+ # Currently does nothing
721
+ return
722
+
723
+
724
+ class EarlyExitCandidateGenerator(AssistedCandidateGenerator):
725
+ """
726
+ `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
727
+ candidates through the use of **the model itself**, exiting early. Can only be used with models that support early
728
+ exit, e.g., `facebook/layerskip-llama3.2-1B`.
729
+
730
+ Args:
731
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
732
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
733
+ assistant_model (`PreTrainedModel`):
734
+ The original model. This model must support early exit (i.e. is trained to compute logits in earlier
735
+ layers).
736
+ generation_config (`~generation.GenerationConfig`, *optional*):
737
+ The generation configuration to be used as base parametrization for the generation call.
738
+ logits_processor (`LogitsProcessorList`):
739
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
740
+ used to modify the prediction scores of the language modeling head applied at each generation step.
741
+ model_kwargs (`Dict`):
742
+ The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
743
+ model as well.
744
+ inputs_tensor (`torch.Tensor`, *optional*):
745
+ The model input tensor. In encoder-decoder models, this is the encoder input.
746
+ """
747
+
748
+ def __init__(
749
+ self,
750
+ input_ids: torch.LongTensor,
751
+ assistant_model: "PreTrainedModel",
752
+ generation_config: "GenerationConfig",
753
+ model_kwargs: Dict,
754
+ inputs_tensor: Optional[torch.Tensor] = None,
755
+ logits_processor: "LogitsProcessorList" = None,
756
+ ):
757
+ super().__init__(
758
+ input_ids=input_ids,
759
+ assistant_model=assistant_model,
760
+ generation_config=generation_config,
761
+ model_kwargs=model_kwargs,
762
+ inputs_tensor=inputs_tensor,
763
+ logits_processor=logits_processor,
764
+ )
765
+ # We have to move early exit out of the generation config, otherwise the assistant will also call `generate`
766
+ # with early exit
767
+ self.assistant_early_exit = self.generation_config.assistant_early_exit
768
+ self.generation_config.assistant_early_exit = None
769
+
770
+ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
771
+ # Temporarily sets the number of hidden layers to the early exit value
772
+ base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix)
773
+ original_num_hidden_layers = base_model.config.num_hidden_layers
774
+ base_model.config.num_hidden_layers = self.assistant_early_exit
775
+ candidate_ids, candidate_logits = super().get_candidates(input_ids)
776
+ base_model.config.num_hidden_layers = original_num_hidden_layers
777
+ return candidate_ids, candidate_logits
778
+
779
+
780
+ def _crop_past_key_values(model, past_key_values, max_length):
781
+ """Crops the past key values up to a certain maximum length."""
782
+ new_past = []
783
+ if model.config.is_encoder_decoder:
784
+ for idx in range(len(past_key_values)):
785
+ new_past.append(
786
+ (
787
+ past_key_values[idx][0][:, :, :max_length, :],
788
+ past_key_values[idx][1][:, :, :max_length, :],
789
+ past_key_values[idx][2],
790
+ past_key_values[idx][3],
791
+ )
792
+ )
793
+ past_key_values = tuple(new_past)
794
+ # gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model
795
+ elif "gptbigcode" in model.__class__.__name__.lower() or (
796
+ model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
797
+ ):
798
+ if model.config.multi_query:
799
+ for idx in range(len(past_key_values)):
800
+ past_key_values[idx] = past_key_values[idx][:, :max_length, :]
801
+ else:
802
+ for idx in range(len(past_key_values)):
803
+ past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
804
+ elif isinstance(past_key_values, DynamicCache):
805
+ past_key_values.crop(max_length)
806
+ elif past_key_values is not None:
807
+ for idx in range(len(past_key_values)):
808
+ if past_key_values[idx] != ([], []):
809
+ new_past.append(
810
+ (
811
+ past_key_values[idx][0][:, :, :max_length, :],
812
+ past_key_values[idx][1][:, :, :max_length, :],
813
+ )
814
+ )
815
+ else:
816
+ new_past.append((past_key_values[idx][0], past_key_values[idx][1]))
817
+ past_key_values = tuple(new_past)
818
+ return past_key_values
819
+
820
+
821
+ def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
822
+ """Expands or crops the model's mask for decoding purposes, to the defined length"""
823
+
824
+ mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
825
+ if mask_key not in model_kwargs:
826
+ return model_kwargs
827
+
828
+ mask = model_kwargs[mask_key]
829
+ mask_length_diff = new_length - mask.shape[1]
830
+
831
+ if mask_length_diff < 0:
832
+ model_kwargs[mask_key] = mask[:, :mask_length_diff]
833
+ elif mask_length_diff > 0:
834
+ model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
835
+
836
+ # Handle cross attention models
837
+ if "cross_attention_mask" in model_kwargs:
838
+ # Mllama case
839
+ cross_mask = model_kwargs["cross_attention_mask"]
840
+ if mask_length_diff < 0:
841
+ model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff]
842
+ elif mask_length_diff > 0:
843
+ new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1)
844
+ model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
845
+ elif "image_attention_mask" in model_kwargs:
846
+ # IDEFICS case
847
+ cross_mask = model_kwargs["image_attention_mask"]
848
+ if mask_length_diff < 0:
849
+ model_kwargs["image_attention_mask"] = cross_mask[:, :mask_length_diff]
850
+ elif mask_length_diff > 0:
851
+ new_mask = cross_mask[:, -1:, :].repeat(1, mask_length_diff, 1)
852
+ model_kwargs["image_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
853
+
854
+ return model_kwargs
855
+
856
+
857
+ def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
858
+ """Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
859
+ if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
860
+ return model_kwargs
861
+
862
+ token_type_ids = model_kwargs["token_type_ids"]
863
+ final_token_type = token_type_ids[:, -1].unsqueeze(-1)
864
+ type_length_diff = new_length - token_type_ids.shape[1]
865
+
866
+ if type_length_diff < 0:
867
+ token_type_ids = token_type_ids[:, :type_length_diff]
868
+ elif type_length_diff > 0:
869
+ token_type_copies = final_token_type.repeat(1, type_length_diff)
870
+ model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
871
+ return model_kwargs
.venv/lib/python3.11/site-packages/transformers/generation/configuration_utils.py ADDED
@@ -0,0 +1,1628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Generation configuration class and utilities."""
16
+
17
+ import copy
18
+ import json
19
+ import os
20
+ import warnings
21
+ from abc import ABC, abstractmethod
22
+ from dataclasses import dataclass, is_dataclass
23
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
24
+
25
+ from .. import __version__
26
+ from ..configuration_utils import PretrainedConfig
27
+ from ..utils import (
28
+ GENERATION_CONFIG_NAME,
29
+ ExplicitEnum,
30
+ PushToHubMixin,
31
+ cached_file,
32
+ download_url,
33
+ extract_commit_hash,
34
+ is_remote_url,
35
+ is_torch_available,
36
+ logging,
37
+ )
38
+
39
+
40
+ if TYPE_CHECKING:
41
+ from ..modeling_utils import PreTrainedModel
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+ METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
46
+ NEEDS_CACHE_CONFIG = {}
47
+ NEED_SETUP_CACHE_CLASSES_MAPPING = {}
48
+ QUANT_BACKEND_CLASSES_MAPPING = {}
49
+ ALL_CACHE_IMPLEMENTATIONS = []
50
+
51
+ if is_torch_available():
52
+ from ..cache_utils import (
53
+ HQQQuantizedCache,
54
+ HybridCache,
55
+ MambaCache,
56
+ OffloadedStaticCache,
57
+ QuantizedCacheConfig,
58
+ QuantoQuantizedCache,
59
+ SlidingWindowCache,
60
+ StaticCache,
61
+ StaticCacheConfig,
62
+ )
63
+ from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
64
+
65
+ NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
66
+ NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig
67
+ NEED_SETUP_CACHE_CLASSES_MAPPING = {
68
+ "static": StaticCache,
69
+ "offloaded_static": OffloadedStaticCache,
70
+ "sliding_window": SlidingWindowCache,
71
+ "hybrid": HybridCache,
72
+ "mamba": MambaCache,
73
+ }
74
+ QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
75
+ ALL_CACHE_IMPLEMENTATIONS = (
76
+ list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"]
77
+ )
78
+
79
+
80
+ class GenerationMode(ExplicitEnum):
81
+ """
82
+ Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
83
+ """
84
+
85
+ # Non-beam methods
86
+ CONTRASTIVE_SEARCH = "contrastive_search"
87
+ GREEDY_SEARCH = "greedy_search"
88
+ SAMPLE = "sample"
89
+ ASSISTED_GENERATION = "assisted_generation"
90
+ DOLA_GENERATION = "dola_generation"
91
+ # Beam methods
92
+ BEAM_SEARCH = "beam_search"
93
+ BEAM_SAMPLE = "beam_sample"
94
+ CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
95
+ GROUP_BEAM_SEARCH = "group_beam_search"
96
+
97
+
98
+ class GenerationConfig(PushToHubMixin):
99
+ # no-format
100
+ """
101
+ Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
102
+ for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
103
+
104
+ - *greedy decoding* if `num_beams=1` and `do_sample=False`
105
+ - *contrastive search* if `penalty_alpha>0.` and `top_k>1`
106
+ - *multinomial sampling* if `num_beams=1` and `do_sample=True`
107
+ - *beam-search decoding* if `num_beams>1` and `do_sample=False`
108
+ - *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
109
+ - *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
110
+ - *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
111
+ - *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
112
+ - *dola decoding* if `dola_layers` is passed to `.generate()`
113
+
114
+ To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
115
+
116
+ <Tip>
117
+
118
+ A large number of these flags control the logits or the stopping criteria of the generation. Make sure you check
119
+ the [generate-related classes](https://huggingface.co/docs/transformers/internal/generation_utils) for a full
120
+ description of the possible manipulations, as well as examples of their usage.
121
+
122
+ </Tip>
123
+
124
+ Arg:
125
+ > Parameters that control the length of the output
126
+
127
+ max_length (`int`, *optional*, defaults to 20):
128
+ The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
129
+ `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
130
+ max_new_tokens (`int`, *optional*):
131
+ The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
132
+ min_length (`int`, *optional*, defaults to 0):
133
+ The minimum length of the sequence to be generated. Corresponds to the length of the input prompt +
134
+ `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
135
+ min_new_tokens (`int`, *optional*):
136
+ The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
137
+ early_stopping (`bool` or `str`, *optional*, defaults to `False`):
138
+ Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
139
+ `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
140
+ heuristic is applied and the generation stops when is it very unlikely to find better candidates;
141
+ `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
142
+ beam search algorithm).
143
+ max_time (`float`, *optional*):
144
+ The maximum amount of time you allow the computation to run for in seconds. generation will still finish
145
+ the current pass after allocated time has been passed.
146
+ stop_strings (`str or List[str]`, *optional*):
147
+ A string or a list of strings that should terminate generation if the model outputs them.
148
+
149
+ > Parameters that control the generation strategy used
150
+
151
+ do_sample (`bool`, *optional*, defaults to `False`):
152
+ Whether or not to use sampling ; use greedy decoding otherwise.
153
+ num_beams (`int`, *optional*, defaults to 1):
154
+ Number of beams for beam search. 1 means no beam search.
155
+ num_beam_groups (`int`, *optional*, defaults to 1):
156
+ Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
157
+ [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
158
+ penalty_alpha (`float`, *optional*):
159
+ The values balance the model confidence and the degeneration penalty in contrastive search decoding.
160
+ dola_layers (`str` or `List[int]`, *optional*):
161
+ The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
162
+ be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
163
+ "low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
164
+ layers up to the last 20 layers.
165
+ If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
166
+ The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
167
+ `'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
168
+ or [the paper](https://arxiv.org/abs/2309.03883) for more details.
169
+
170
+ > Parameters that control the cache
171
+
172
+ use_cache (`bool`, *optional*, defaults to `True`):
173
+ Whether or not the model should use the past last key/values attentions (if applicable to the model) to
174
+ speed up decoding.
175
+ cache_implementation (`str`, *optional*, default to `None`):
176
+ Name of the cache class that will be instantiated in `generate`, for faster decoding. Possible values are:
177
+
178
+ - `"static"`: [`StaticCache`]
179
+ - `"offloaded_static"`: [`OffloadedStaticCache`]
180
+ - `"sliding_window"`: [`SlidingWindowCache`]
181
+ - `"hybrid"`: [`HybridCache`]
182
+ - `"mamba"`: [`MambaCache`]
183
+ - `"quantized"`: [`QuantizedCache`]
184
+
185
+ We support other cache types, but they must be manually instantiated and
186
+ passed to `generate` through the `past_key_values` argument. See our
187
+ [cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
188
+ cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
189
+ Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
190
+ it will be converted to its repsective `CacheConfig` internally.
191
+ Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
192
+ return_legacy_cache (`bool`, *optional*, default to `True`):
193
+ Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
194
+
195
+ > Parameters for manipulation of the model output logits
196
+
197
+ temperature (`float`, *optional*, defaults to 1.0):
198
+ The value used to module the next token probabilities. This value is set in a model's `generation_config.json` file. If it isn't set, the default value is 1.0
199
+ top_k (`int`, *optional*, defaults to 50):
200
+ The number of highest probability vocabulary tokens to keep for top-k-filtering. This value is set in a model's `generation_config.json` file. If it isn't set, the default value is 50.
201
+ top_p (`float`, *optional*, defaults to 1.0):
202
+ If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to
203
+ `top_p` or higher are kept for generation. This value is set in a model's `generation_config.json` file. If it isn't set, the default value is 1.0
204
+ min_p (`float`, *optional*):
205
+ Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
206
+ value between 0 and 1. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in
207
+ the 0.99-0.8 range (use the opposite of normal `top_p` values).
208
+ typical_p (`float`, *optional*, defaults to 1.0):
209
+ Local typicality measures how similar the conditional probability of predicting a target token next is to
210
+ the expected conditional probability of predicting a random token next, given the partial text already
211
+ generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that
212
+ add up to `typical_p` or higher are kept for generation. See [this
213
+ paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
214
+ epsilon_cutoff (`float`, *optional*, defaults to 0.0):
215
+ If set to float strictly between 0 and 1, only tokens with a conditional probability greater than
216
+ `epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the
217
+ size of the model. See [Truncation Sampling as Language Model
218
+ Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
219
+ eta_cutoff (`float`, *optional*, defaults to 0.0):
220
+ Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between
221
+ 0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) *
222
+ exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token
223
+ probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3,
224
+ depending on the size of the model. See [Truncation Sampling as Language Model
225
+ Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
226
+ diversity_penalty (`float`, *optional*, defaults to 0.0):
227
+ This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
228
+ particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.
229
+ repetition_penalty (`float`, *optional*, defaults to 1.0):
230
+ The parameter for repetition penalty. 1.0 means no penalty. See [this
231
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
232
+ encoder_repetition_penalty (`float`, *optional*, defaults to 1.0):
233
+ The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are not in the
234
+ original input. 1.0 means no penalty.
235
+ length_penalty (`float`, *optional*, defaults to 1.0):
236
+ Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
237
+ the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
238
+ likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
239
+ `length_penalty` < 0.0 encourages shorter sequences.
240
+ no_repeat_ngram_size (`int`, *optional*, defaults to 0):
241
+ If set to int > 0, all ngrams of that size can only occur once.
242
+ bad_words_ids (`List[List[int]]`, *optional*):
243
+ List of list of token ids that are not allowed to be generated. Check
244
+ [`~generation.NoBadWordsLogitsProcessor`] for further documentation and examples.
245
+ force_words_ids (`List[List[int]]` or `List[List[List[int]]]`, *optional*):
246
+ List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of
247
+ words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
248
+ triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
249
+ can allow different forms of each word.
250
+ renormalize_logits (`bool`, *optional*, defaults to `False`):
251
+ Whether to renormalize the logits after applying all the logits processors (including the custom
252
+ ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
253
+ are normalized but some logit processors break the normalization.
254
+ constraints (`List[Constraint]`, *optional*):
255
+ Custom constraints that can be added to the generation to ensure that the output will contain the use of
256
+ certain tokens as defined by `Constraint` objects, in the most sensible way possible.
257
+ forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
258
+ The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
259
+ multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
260
+ language token.
261
+ forced_eos_token_id (`int` or List[int]`, *optional*, defaults to `model.config.forced_eos_token_id`):
262
+ The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
263
+ list to set multiple *end-of-sequence* tokens.
264
+ remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
265
+ Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash.
266
+ Note that using `remove_invalid_values` can slow down generation.
267
+ exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
268
+ This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
269
+ generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where
270
+ penalty starts and `decay_factor` represents the factor of exponential decay
271
+ suppress_tokens (`List[int]`, *optional*):
272
+ A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their
273
+ log probs to `-inf` so that they are not sampled.
274
+ begin_suppress_tokens (`List[int]`, *optional*):
275
+ A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
276
+ processor will set their log probs to `-inf` so that they are not sampled.
277
+ forced_decoder_ids (`List[List[int]]`, *optional*):
278
+ A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
279
+ forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
280
+ of index 123.
281
+ sequence_bias (`Dict[Tuple[int], float]`, *optional*)):
282
+ Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
283
+ sequence being selected, while negative biases do the opposite. Check
284
+ [`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
285
+ token_healing (`bool`, *optional*, defaults to `False`):
286
+ Heal tail tokens of prompts by replacing them with their appropriate extensions.
287
+ This enhances the quality of completions for prompts affected by greedy tokenization bias.
288
+ guidance_scale (`float`, *optional*):
289
+ The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
290
+ Higher guidance scale encourages the model to generate samples that are more closely linked to the input
291
+ prompt, usually at the expense of poorer quality.
292
+ low_memory (`bool`, *optional*):
293
+ Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
294
+ Used with beam search and contrastive search.
295
+ watermarking_config (`BaseWatermarkingConfig` or `dict`, *optional*):
296
+ Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green"
297
+ tokens. See the docs of [`SynthIDTextWatermarkingConfig`] and [`WatermarkingConfig`] for more
298
+ details. If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
299
+
300
+ > Parameters that define the output variables of generate
301
+
302
+ num_return_sequences (`int`, *optional*, defaults to 1):
303
+ The number of independently computed returned sequences for each element in the batch.
304
+ output_attentions (`bool`, *optional*, defaults to `False`):
305
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
306
+ tensors for more details.
307
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
308
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
309
+ more details.
310
+ output_scores (`bool`, *optional*, defaults to `False`):
311
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
312
+ output_logits (`bool`, *optional*):
313
+ Whether or not to return the unprocessed prediction logit scores. See `logits` under returned tensors for
314
+ more details.
315
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
316
+ Whether or not to return a [`~utils.ModelOutput`], as opposed to returning exclusively the generated
317
+ sequence. This flag must be set to `True` to return the generation cache (when `use_cache` is `True`)
318
+ or optional outputs (see flags starting with `output_`)
319
+
320
+ > Special tokens that can be used at generation time
321
+
322
+ pad_token_id (`int`, *optional*):
323
+ The id of the *padding* token.
324
+ bos_token_id (`int`, *optional*):
325
+ The id of the *beginning-of-sequence* token.
326
+ eos_token_id (`Union[int, List[int]]`, *optional*):
327
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
328
+
329
+ > Generation parameters exclusive to encoder-decoder models
330
+
331
+ encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
332
+ If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
333
+ `decoder_input_ids`.
334
+ decoder_start_token_id (`int` or `List[int]`, *optional*):
335
+ If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token or a list of length
336
+ `batch_size`. Indicating a list enables different start ids for each element in the batch
337
+ (e.g. multilingual models with different target languages in one batch)
338
+
339
+ > Generation parameters exclusive to assistant generation
340
+ is_assistant (`bool`, *optional*, defaults to `False`):
341
+ Whether the model is an assistant (draft) model.
342
+ num_assistant_tokens (`int`, *optional*, defaults to 20):
343
+ Defines the number of _speculative tokens_ that shall be generated by the assistant model before being
344
+ checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation
345
+ more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant
346
+ model requires lots of corrections, lower speed-ups are reached.
347
+ num_assistant_tokens_schedule (`str`, *optional*, defaults to `"constant"`):
348
+ Defines the schedule at which max assistant tokens shall be changed during inference.
349
+ - `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else
350
+ reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
351
+ - `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
352
+ - `"constant"`: `num_assistant_tokens` stays unchanged during generation
353
+ assistant_confidence_threshold (`float`, *optional*, defaults to 0.4):
354
+ The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
355
+ than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
356
+ (defined by `num_assistant_tokens`) is not yet reached. The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes, biased towards avoiding false negatives.
357
+ `assistant_confidence_threshold` value is persistent over multiple generation calls with the same assistant model.
358
+ It is an unsupervised version of the dynamic speculation lookahead
359
+ from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://arxiv.org/abs/2405.04304>.
360
+ prompt_lookup_num_tokens (`int`, *optional*):
361
+ The number of tokens to be output as candidate tokens.
362
+ max_matching_ngram_size (`int`, *optional*):
363
+ The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
364
+ assistant_early_exit(`int`, *optional*):
365
+ If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
366
+ models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head).
367
+ assistant_lookbehind(`int`, *optional*, defaults to 10):
368
+ If set to a positive integer, the re-encodeing process will additionally consider the last `assistant_lookbehind` assistant tokens
369
+ to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
370
+ See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
371
+ target_lookbehind(`int`, *optional*, defaults to 10):
372
+ If set to a positive integer, the re-encodeing process will additionally consider the last `target_lookbehind` target tokens
373
+ to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
374
+ See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
375
+
376
+ > Parameters related to performances and compilation
377
+
378
+ compile_config (CompileConfig, *optional*):
379
+ If using a static cache, this controls how `generate` will `compile` the forward pass for performance
380
+ gains.
381
+
382
+ > Wild card
383
+
384
+ generation_kwargs:
385
+ Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not
386
+ present in `generate`'s signature will be used in the model forward pass.
387
+ """
388
+
389
+ extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits")
390
+
391
+ def __init__(self, **kwargs):
392
+ # Parameters that control the length of the output
393
+ self.max_length = kwargs.pop("max_length", 20)
394
+ self.max_new_tokens = kwargs.pop("max_new_tokens", None)
395
+ self.min_length = kwargs.pop("min_length", 0)
396
+ self.min_new_tokens = kwargs.pop("min_new_tokens", None)
397
+ self.early_stopping = kwargs.pop("early_stopping", False)
398
+ self.max_time = kwargs.pop("max_time", None)
399
+ self.stop_strings = kwargs.pop("stop_strings", None)
400
+
401
+ # Parameters that control the generation strategy used
402
+ self.do_sample = kwargs.pop("do_sample", False)
403
+ self.num_beams = kwargs.pop("num_beams", 1)
404
+ self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
405
+ self.penalty_alpha = kwargs.pop("penalty_alpha", None)
406
+ self.dola_layers = kwargs.pop("dola_layers", None)
407
+
408
+ # Parameters that control the cache
409
+ self.use_cache = kwargs.pop("use_cache", True)
410
+ self.cache_implementation = kwargs.pop("cache_implementation", None)
411
+ self.cache_config = kwargs.pop("cache_config", None)
412
+ if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
413
+ cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
414
+ if self.cache_config is None:
415
+ self.cache_config = cache_config_class()
416
+ elif isinstance(self.cache_config, dict):
417
+ self.cache_config = cache_config_class.from_dict(self.cache_config)
418
+ self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
419
+
420
+ # Parameters for manipulation of the model output logits
421
+ self.temperature = kwargs.pop("temperature", 1.0)
422
+ self.top_k = kwargs.pop("top_k", 50)
423
+ self.top_p = kwargs.pop("top_p", 1.0)
424
+ self.min_p = kwargs.pop("min_p", None)
425
+ self.typical_p = kwargs.pop("typical_p", 1.0)
426
+ self.epsilon_cutoff = kwargs.pop("epsilon_cutoff", 0.0)
427
+ self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0)
428
+ self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
429
+ self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
430
+ self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty", 1.0)
431
+ self.length_penalty = kwargs.pop("length_penalty", 1.0)
432
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
433
+ self.bad_words_ids = kwargs.pop("bad_words_ids", None)
434
+ self.force_words_ids = kwargs.pop("force_words_ids", None)
435
+ self.renormalize_logits = kwargs.pop("renormalize_logits", False)
436
+ self.constraints = kwargs.pop("constraints", None)
437
+ self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
438
+ self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
439
+ self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
440
+ self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
441
+ self.suppress_tokens = kwargs.pop("suppress_tokens", None)
442
+ self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
443
+ self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
444
+ self.sequence_bias = kwargs.pop("sequence_bias", None)
445
+ self.token_healing = kwargs.pop("token_healing", False)
446
+ self.guidance_scale = kwargs.pop("guidance_scale", None)
447
+ self.low_memory = kwargs.pop("low_memory", None)
448
+ watermarking_config = kwargs.pop("watermarking_config", None)
449
+ if watermarking_config is None:
450
+ self.watermarking_config = None
451
+ elif isinstance(watermarking_config, BaseWatermarkingConfig):
452
+ self.watermarking_config = watermarking_config
453
+ else:
454
+ self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config)
455
+
456
+ # Parameters that define the output variables of `generate`
457
+ self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
458
+ self.output_attentions = kwargs.pop("output_attentions", False)
459
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
460
+ self.output_scores = kwargs.pop("output_scores", False)
461
+ self.output_logits = kwargs.pop("output_logits", None)
462
+ self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
463
+
464
+ # Special tokens that can be used at generation time
465
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
466
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
467
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
468
+
469
+ # Generation parameters exclusive to encoder-decoder models
470
+ self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
471
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
472
+
473
+ # Assistant generation
474
+ self.is_assistant = False
475
+ self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20)
476
+ self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant")
477
+ self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4)
478
+ self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
479
+ self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
480
+ self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)
481
+ ## assistant generation for different tokenizers, the windows size for assistant/target model
482
+ self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
483
+ self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
484
+
485
+ # Performances
486
+ self.compile_config = kwargs.pop("compile_config", CompileConfig())
487
+
488
+ # Wild card
489
+ self.generation_kwargs = kwargs.pop("generation_kwargs", {})
490
+
491
+ # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
492
+ # interface.
493
+ self._from_model_config = kwargs.pop("_from_model_config", False)
494
+ self._commit_hash = kwargs.pop("_commit_hash", None)
495
+ self.transformers_version = kwargs.pop("transformers_version", __version__)
496
+
497
+ # Additional attributes without default values
498
+ if not self._from_model_config:
499
+ # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
500
+ # model's default configuration file
501
+ for key, value in kwargs.items():
502
+ try:
503
+ setattr(self, key, value)
504
+ except AttributeError as err:
505
+ logger.error(f"Can't set {key} with value {value} for {self}")
506
+ raise err
507
+
508
+ # Validate the values of the attributes
509
+ self.validate(is_init=True)
510
+
511
+ def __hash__(self):
512
+ return hash(self.to_json_string(ignore_metadata=True))
513
+
514
+ def __eq__(self, other):
515
+ if not isinstance(other, GenerationConfig):
516
+ return False
517
+
518
+ self_without_metadata = self.to_json_string(use_diff=False, ignore_metadata=True)
519
+ other_without_metadata = other.to_json_string(use_diff=False, ignore_metadata=True)
520
+ return self_without_metadata == other_without_metadata
521
+
522
+ def __repr__(self):
523
+ return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}"
524
+
525
+ def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = None) -> GenerationMode:
526
+ """
527
+ Returns the generation mode triggered by the [`GenerationConfig`] instance.
528
+
529
+ Arg:
530
+ assistant_model (`PreTrainedModel`, *optional*):
531
+ The assistant model to be used for assisted generation. If set, the generation mode will be
532
+ assisted generation.
533
+
534
+ Returns:
535
+ `GenerationMode`: The generation mode triggered by the instance.
536
+ """
537
+ # TODO joao: find out a way of not depending on external fields (e.g. `assistant_model`), then make this a
538
+ # property and part of the `__repr__`
539
+ if self.constraints is not None or self.force_words_ids is not None:
540
+ generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
541
+ elif self.num_beams == 1:
542
+ if self.do_sample is False:
543
+ if (
544
+ self.top_k is not None
545
+ and self.top_k > 1
546
+ and self.penalty_alpha is not None
547
+ and self.penalty_alpha > 0
548
+ ):
549
+ generation_mode = GenerationMode.CONTRASTIVE_SEARCH
550
+ else:
551
+ generation_mode = GenerationMode.GREEDY_SEARCH
552
+ else:
553
+ generation_mode = GenerationMode.SAMPLE
554
+ else:
555
+ if self.num_beam_groups > 1:
556
+ generation_mode = GenerationMode.GROUP_BEAM_SEARCH
557
+ elif self.do_sample is True:
558
+ generation_mode = GenerationMode.BEAM_SAMPLE
559
+ else:
560
+ generation_mode = GenerationMode.BEAM_SEARCH
561
+
562
+ # Assisted generation may extend some generation modes
563
+ if (
564
+ assistant_model is not None
565
+ or self.prompt_lookup_num_tokens is not None
566
+ or self.assistant_early_exit is not None
567
+ ):
568
+ if generation_mode in ("greedy_search", "sample"):
569
+ generation_mode = GenerationMode.ASSISTED_GENERATION
570
+ else:
571
+ raise ValueError(
572
+ "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
573
+ "is only supported with Greedy Search and Sample."
574
+ )
575
+
576
+ # DoLa generation may extend some generation modes
577
+ if self.dola_layers is not None:
578
+ if generation_mode in ("greedy_search", "sample"):
579
+ generation_mode = GenerationMode.DOLA_GENERATION
580
+ else:
581
+ raise ValueError(
582
+ "You've set `dola_layers`, which triggers DoLa generate. Currently, DoLa generate "
583
+ "is only supported with Greedy Search and Sample."
584
+ )
585
+ return generation_mode
586
+
587
+ def validate(self, is_init=False):
588
+ """
589
+ Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
590
+ of parameterization that can be detected as incorrect from the configuration instance alone.
591
+
592
+ Note that some parameters not validated here are best validated at generate runtime, as they may depend on
593
+ other inputs and/or the model, such as parameters related to the generation length.
594
+
595
+ Arg:
596
+ is_init (`bool`, *optional*, defaults to `False`):
597
+ Whether the validation is performed during the initialization of the instance.
598
+ """
599
+
600
+ # Validation of individual attributes
601
+ if self.early_stopping not in {True, False, "never"}:
602
+ raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
603
+ if self.max_new_tokens is not None and self.max_new_tokens <= 0:
604
+ raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
605
+ if self.pad_token_id is not None and self.pad_token_id < 0:
606
+ warnings.warn(
607
+ f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
608
+ "generating, if there is padding. Please set `pad_token_id` explicitly as "
609
+ "`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
610
+ )
611
+
612
+ # Validation of attribute relations:
613
+ fix_location = ""
614
+ if is_init:
615
+ fix_location = (
616
+ " This was detected when initializing the generation config instance, which means the corresponding "
617
+ "file may hold incorrect parameterization and should be fixed."
618
+ )
619
+
620
+ # 1. detect sampling-only parameterization when not in sampling mode
621
+ if self.do_sample is False:
622
+ greedy_wrong_parameter_msg = (
623
+ "`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
624
+ "used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
625
+ + fix_location
626
+ )
627
+ if self.temperature is not None and self.temperature != 1.0:
628
+ warnings.warn(
629
+ greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
630
+ UserWarning,
631
+ )
632
+ if self.top_p is not None and self.top_p != 1.0:
633
+ warnings.warn(
634
+ greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
635
+ UserWarning,
636
+ )
637
+ if self.min_p is not None:
638
+ warnings.warn(
639
+ greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p),
640
+ UserWarning,
641
+ )
642
+ if self.typical_p is not None and self.typical_p != 1.0:
643
+ warnings.warn(
644
+ greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
645
+ UserWarning,
646
+ )
647
+ if (
648
+ self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
649
+ ): # contrastive search uses top_k
650
+ warnings.warn(
651
+ greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
652
+ UserWarning,
653
+ )
654
+ if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
655
+ warnings.warn(
656
+ greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
657
+ UserWarning,
658
+ )
659
+ if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
660
+ warnings.warn(
661
+ greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
662
+ UserWarning,
663
+ )
664
+
665
+ # 2. detect beam-only parameterization when not in beam mode
666
+ if self.num_beams is None:
667
+ warnings.warn("`num_beams` is set to None - defaulting to 1.", UserWarning)
668
+ self.num_beams = 1
669
+
670
+ if self.num_beams == 1:
671
+ single_beam_wrong_parameter_msg = (
672
+ "`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
673
+ "in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location
674
+ )
675
+ if self.early_stopping is not False:
676
+ warnings.warn(
677
+ single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
678
+ UserWarning,
679
+ )
680
+ if self.num_beam_groups is not None and self.num_beam_groups != 1:
681
+ warnings.warn(
682
+ single_beam_wrong_parameter_msg.format(
683
+ flag_name="num_beam_groups", flag_value=self.num_beam_groups
684
+ ),
685
+ UserWarning,
686
+ )
687
+ if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
688
+ warnings.warn(
689
+ single_beam_wrong_parameter_msg.format(
690
+ flag_name="diversity_penalty", flag_value=self.diversity_penalty
691
+ ),
692
+ UserWarning,
693
+ )
694
+ if self.length_penalty is not None and self.length_penalty != 1.0:
695
+ warnings.warn(
696
+ single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
697
+ UserWarning,
698
+ )
699
+ if self.constraints is not None:
700
+ warnings.warn(
701
+ single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints),
702
+ UserWarning,
703
+ )
704
+
705
+ # 3. detect incorrect paramaterization specific to advanced beam modes
706
+ else:
707
+ # constrained beam search
708
+ if self.constraints is not None or self.force_words_ids is not None:
709
+ constrained_wrong_parameter_msg = (
710
+ "one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. However, "
711
+ "`{flag_name}` is set to `{flag_value}`, which is incompatible with this generation mode. Set "
712
+ "`constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + fix_location
713
+ )
714
+ if self.do_sample is True:
715
+ raise ValueError(
716
+ constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
717
+ )
718
+ if self.num_beam_groups is not None and self.num_beam_groups != 1:
719
+ raise ValueError(
720
+ constrained_wrong_parameter_msg.format(
721
+ flag_name="num_beam_groups", flag_value=self.num_beam_groups
722
+ )
723
+ )
724
+ # group beam search
725
+ if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
726
+ group_error_prefix = (
727
+ "`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
728
+ "this generation mode, "
729
+ )
730
+ if self.do_sample is True:
731
+ raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
732
+ if self.num_beams % self.num_beam_groups != 0:
733
+ raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
734
+ if self.diversity_penalty == 0.0:
735
+ raise ValueError(
736
+ group_error_prefix
737
+ + "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
738
+ )
739
+ # DoLa generation
740
+ if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
741
+ warnings.warn(
742
+ "`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
743
+ f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
744
+ "DoLa decoding is `repetition_penalty>=1.2`.",
745
+ UserWarning,
746
+ )
747
+
748
+ # 4. check `num_return_sequences`
749
+ if self.num_return_sequences != 1:
750
+ if self.num_beams == 1:
751
+ if self.do_sample is False:
752
+ raise ValueError(
753
+ "Greedy methods without beam search do not support `num_return_sequences` different than 1 "
754
+ f"(got {self.num_return_sequences})."
755
+ )
756
+ elif self.num_return_sequences > self.num_beams:
757
+ raise ValueError(
758
+ f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
759
+ f"({self.num_beams})."
760
+ )
761
+
762
+ # 5. check cache-related arguments
763
+ if self.cache_implementation is not None and self.cache_implementation not in ALL_CACHE_IMPLEMENTATIONS:
764
+ raise ValueError(
765
+ f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: "
766
+ f"{ALL_CACHE_IMPLEMENTATIONS}"
767
+ )
768
+ if self.cache_config is not None:
769
+ cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
770
+ if cache_class is None:
771
+ raise ValueError(
772
+ "You provided a `cache_config` but the cache implementation you are using "
773
+ f"({self.cache_implementation}) does not require any config. Make sure to use the "
774
+ "correct cache implementation matching your cache config."
775
+ )
776
+ if not isinstance(self.cache_config, cache_class):
777
+ self.cache_config = cache_class.from_dict(self.cache_config)
778
+ self.cache_config.validate()
779
+ if self.use_cache is False:
780
+ # In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
781
+ # passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
782
+ # (otherwise a user might need to overwrite several parameters).
783
+ no_cache_warning = (
784
+ "You have set `use_cache` to `False`, but {cache_arg} is set to {cache_arg_value}. {cache_arg} will "
785
+ "have no effect."
786
+ )
787
+ for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
788
+ if getattr(self, arg_name) is not None:
789
+ logger.warning_once(
790
+ no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)),
791
+ UserWarning,
792
+ )
793
+
794
+ # 6. check watermarking arguments
795
+ if self.watermarking_config is not None:
796
+ if not (
797
+ isinstance(self.watermarking_config, WatermarkingConfig)
798
+ or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
799
+ ):
800
+ warnings.warn(
801
+ "`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with "
802
+ "`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.",
803
+ FutureWarning,
804
+ )
805
+ self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
806
+ self.watermarking_config.validate()
807
+
808
+ # 7. performances arguments
809
+ if not isinstance(self.compile_config, CompileConfig):
810
+ raise ValueError(
811
+ f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an instance of `CompileConfig`."
812
+ )
813
+
814
+ # 8. other incorrect combinations
815
+ if self.return_dict_in_generate is not True:
816
+ for extra_output_flag in self.extra_output_flags:
817
+ if getattr(self, extra_output_flag) is True:
818
+ warnings.warn(
819
+ f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When "
820
+ f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored.",
821
+ UserWarning,
822
+ )
823
+
824
+ # 8. check common issue: passing `generate` arguments inside the generation config
825
+ generate_arguments = (
826
+ "logits_processor",
827
+ "stopping_criteria",
828
+ "prefix_allowed_tokens_fn",
829
+ "synced_gpus",
830
+ "assistant_model",
831
+ "streamer",
832
+ "negative_prompt_ids",
833
+ "negative_prompt_attention_mask",
834
+ )
835
+ for arg in generate_arguments:
836
+ if hasattr(self, arg):
837
+ raise ValueError(
838
+ f"Argument `{arg}` is not a valid argument of `GenerationConfig`. It should be passed to "
839
+ "`generate()` (or a pipeline) directly."
840
+ )
841
+
842
+ def save_pretrained(
843
+ self,
844
+ save_directory: Union[str, os.PathLike],
845
+ config_file_name: Optional[Union[str, os.PathLike]] = None,
846
+ push_to_hub: bool = False,
847
+ **kwargs,
848
+ ):
849
+ r"""
850
+ Save a generation configuration object to the directory `save_directory`, so that it can be re-loaded using the
851
+ [`~GenerationConfig.from_pretrained`] class method.
852
+
853
+ Args:
854
+ save_directory (`str` or `os.PathLike`):
855
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
856
+ config_file_name (`str` or `os.PathLike`, *optional*, defaults to `"generation_config.json"`):
857
+ Name of the generation configuration JSON file to be saved in `save_directory`.
858
+ push_to_hub (`bool`, *optional*, defaults to `False`):
859
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
860
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
861
+ namespace).
862
+ kwargs (`Dict[str, Any]`, *optional*):
863
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
864
+ """
865
+
866
+ # At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance.
867
+ # This strictness is enforced to prevent bad configurations from being saved and re-used.
868
+ try:
869
+ with warnings.catch_warnings(record=True) as caught_warnings:
870
+ self.validate()
871
+ if len(caught_warnings) > 0:
872
+ raise ValueError(str([w.message for w in caught_warnings]))
873
+ except ValueError as exc:
874
+ raise ValueError(
875
+ "The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
876
+ "Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc)
877
+ )
878
+
879
+ use_auth_token = kwargs.pop("use_auth_token", None)
880
+
881
+ if use_auth_token is not None:
882
+ warnings.warn(
883
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. "
884
+ "Please use `token` instead.",
885
+ FutureWarning,
886
+ )
887
+ if kwargs.get("token", None) is not None:
888
+ raise ValueError(
889
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
890
+ )
891
+ kwargs["token"] = use_auth_token
892
+
893
+ config_file_name = config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME
894
+
895
+ if os.path.isfile(save_directory):
896
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
897
+
898
+ os.makedirs(save_directory, exist_ok=True)
899
+
900
+ if push_to_hub:
901
+ commit_message = kwargs.pop("commit_message", None)
902
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
903
+ repo_id = self._create_repo(repo_id, **kwargs)
904
+ files_timestamps = self._get_files_timestamps(save_directory)
905
+
906
+ output_config_file = os.path.join(save_directory, config_file_name)
907
+
908
+ self.to_json_file(output_config_file, use_diff=True)
909
+ logger.info(f"Configuration saved in {output_config_file}")
910
+
911
+ if push_to_hub:
912
+ self._upload_modified_files(
913
+ save_directory,
914
+ repo_id,
915
+ files_timestamps,
916
+ commit_message=commit_message,
917
+ token=kwargs.get("token"),
918
+ )
919
+
920
+ @classmethod
921
+ def from_pretrained(
922
+ cls,
923
+ pretrained_model_name: Union[str, os.PathLike],
924
+ config_file_name: Optional[Union[str, os.PathLike]] = None,
925
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
926
+ force_download: bool = False,
927
+ local_files_only: bool = False,
928
+ token: Optional[Union[str, bool]] = None,
929
+ revision: str = "main",
930
+ **kwargs,
931
+ ) -> "GenerationConfig":
932
+ r"""
933
+ Instantiate a [`GenerationConfig`] from a generation configuration file.
934
+
935
+ Args:
936
+ pretrained_model_name (`str` or `os.PathLike`):
937
+ This can be either:
938
+
939
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
940
+ huggingface.co.
941
+ - a path to a *directory* containing a configuration file saved using the
942
+ [`~GenerationConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
943
+ config_file_name (`str` or `os.PathLike`, *optional*, defaults to `"generation_config.json"`):
944
+ Name of the generation configuration JSON file to be loaded from `pretrained_model_name`.
945
+ cache_dir (`str` or `os.PathLike`, *optional*):
946
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
947
+ standard cache should not be used.
948
+ force_download (`bool`, *optional*, defaults to `False`):
949
+ Whether or not to force to (re-)download the configuration files and override the cached versions if
950
+ they exist.
951
+ resume_download:
952
+ Deprecated and ignored. All downloads are now resumed by default when possible.
953
+ Will be removed in v5 of Transformers.
954
+ proxies (`Dict[str, str]`, *optional*):
955
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
956
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
957
+ token (`str` or `bool`, *optional*):
958
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
959
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
960
+ revision (`str`, *optional*, defaults to `"main"`):
961
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
962
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
963
+ identifier allowed by git.
964
+
965
+ <Tip>
966
+
967
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
968
+
969
+ </Tip>
970
+
971
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
972
+ If `False`, then this function returns just the final configuration object.
973
+
974
+ If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
975
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
976
+ part of `kwargs` which has not been used to update `config` and is otherwise ignored.
977
+ subfolder (`str`, *optional*, defaults to `""`):
978
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
979
+ specify the folder name here.
980
+ kwargs (`Dict[str, Any]`, *optional*):
981
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
982
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
983
+ by the `return_unused_kwargs` keyword parameter.
984
+
985
+ Returns:
986
+ [`GenerationConfig`]: The configuration object instantiated from this pretrained model.
987
+
988
+ Examples:
989
+
990
+ ```python
991
+ >>> from transformers import GenerationConfig
992
+
993
+ >>> # Download configuration from huggingface.co and cache.
994
+ >>> generation_config = GenerationConfig.from_pretrained("openai-community/gpt2")
995
+
996
+ >>> # E.g. config was saved using *save_pretrained('./test/saved_model/')*
997
+ >>> generation_config.save_pretrained("./test/saved_model/")
998
+ >>> generation_config = GenerationConfig.from_pretrained("./test/saved_model/")
999
+
1000
+ >>> # You can also specify configuration names to your generation configuration file
1001
+ >>> generation_config.save_pretrained("./test/saved_model/", config_file_name="my_configuration.json")
1002
+ >>> generation_config = GenerationConfig.from_pretrained("./test/saved_model/", "my_configuration.json")
1003
+
1004
+ >>> # If you'd like to try a minor variation to an existing configuration, you can also pass generation
1005
+ >>> # arguments to `.from_pretrained()`. Be mindful that typos and unused arguments will be ignored
1006
+ >>> generation_config, unused_kwargs = GenerationConfig.from_pretrained(
1007
+ ... "openai-community/gpt2", top_k=1, foo=False, do_sample=True, return_unused_kwargs=True
1008
+ ... )
1009
+ >>> generation_config.top_k
1010
+ 1
1011
+
1012
+ >>> unused_kwargs
1013
+ {'foo': False}
1014
+ ```"""
1015
+ config_file_name = config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME
1016
+
1017
+ resume_download = kwargs.pop("resume_download", None)
1018
+ proxies = kwargs.pop("proxies", None)
1019
+ use_auth_token = kwargs.pop("use_auth_token", None)
1020
+ subfolder = kwargs.pop("subfolder", "")
1021
+ from_pipeline = kwargs.pop("_from_pipeline", None)
1022
+ from_auto_class = kwargs.pop("_from_auto", False)
1023
+ commit_hash = kwargs.pop("_commit_hash", None)
1024
+
1025
+ if use_auth_token is not None:
1026
+ warnings.warn(
1027
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
1028
+ FutureWarning,
1029
+ )
1030
+ if token is not None:
1031
+ raise ValueError(
1032
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
1033
+ )
1034
+ token = use_auth_token
1035
+
1036
+ user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
1037
+ if from_pipeline is not None:
1038
+ user_agent["using_pipeline"] = from_pipeline
1039
+
1040
+ config_path = os.path.join(pretrained_model_name, config_file_name)
1041
+ config_path = str(config_path)
1042
+
1043
+ is_local = os.path.exists(config_path)
1044
+ if os.path.isfile(os.path.join(subfolder, config_path)):
1045
+ # Special case when config_path is a local file
1046
+ resolved_config_file = config_path
1047
+ is_local = True
1048
+ elif is_remote_url(config_path):
1049
+ configuration_file = config_path
1050
+ resolved_config_file = download_url(config_path)
1051
+ else:
1052
+ configuration_file = config_file_name
1053
+ try:
1054
+ # Load from local folder or from cache or download from model Hub and cache
1055
+ resolved_config_file = cached_file(
1056
+ pretrained_model_name,
1057
+ configuration_file,
1058
+ cache_dir=cache_dir,
1059
+ force_download=force_download,
1060
+ proxies=proxies,
1061
+ resume_download=resume_download,
1062
+ local_files_only=local_files_only,
1063
+ token=token,
1064
+ user_agent=user_agent,
1065
+ revision=revision,
1066
+ subfolder=subfolder,
1067
+ _commit_hash=commit_hash,
1068
+ )
1069
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
1070
+ except EnvironmentError:
1071
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
1072
+ # the original exception.
1073
+ raise
1074
+ except Exception:
1075
+ # For any other exception, we throw a generic error.
1076
+ raise EnvironmentError(
1077
+ f"Can't load the configuration of '{pretrained_model_name}'. If you were trying to load it"
1078
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
1079
+ f" name. Otherwise, make sure '{pretrained_model_name}' is the correct path to a directory"
1080
+ f" containing a {configuration_file} file"
1081
+ )
1082
+
1083
+ try:
1084
+ # Load config dict
1085
+ config_dict = cls._dict_from_json_file(resolved_config_file)
1086
+ config_dict["_commit_hash"] = commit_hash
1087
+ except (json.JSONDecodeError, UnicodeDecodeError):
1088
+ raise EnvironmentError(
1089
+ f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
1090
+ )
1091
+
1092
+ if is_local:
1093
+ logger.info(f"loading configuration file {resolved_config_file}")
1094
+ else:
1095
+ logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
1096
+
1097
+ if kwargs.get("return_unused_kwargs") is True:
1098
+ config, unused_kwargs = cls.from_dict(config_dict, **kwargs)
1099
+ config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
1100
+ return config, unused_kwargs
1101
+ else:
1102
+ config = cls.from_dict(config_dict, **kwargs)
1103
+ config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
1104
+ return config
1105
+
1106
+ @classmethod
1107
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
1108
+ with open(json_file, "r", encoding="utf-8") as reader:
1109
+ text = reader.read()
1110
+ return json.loads(text)
1111
+
1112
+ @classmethod
1113
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
1114
+ """
1115
+ Instantiates a [`GenerationConfig`] from a Python dictionary of parameters.
1116
+
1117
+ Args:
1118
+ config_dict (`Dict[str, Any]`):
1119
+ Dictionary that will be used to instantiate the configuration object.
1120
+ kwargs (`Dict[str, Any]`):
1121
+ Additional parameters from which to initialize the configuration object.
1122
+
1123
+ Returns:
1124
+ [`GenerationConfig`]: The configuration object instantiated from those parameters.
1125
+ """
1126
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
1127
+ # Those arguments may be passed along for our internal telemetry.
1128
+ # We remove them so they don't appear in `return_unused_kwargs`.
1129
+ kwargs.pop("_from_auto", None)
1130
+ kwargs.pop("_from_pipeline", None)
1131
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
1132
+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
1133
+ kwargs["_commit_hash"] = config_dict["_commit_hash"]
1134
+
1135
+ # The line below allows model-specific config to be loaded as well through kwargs, with safety checks.
1136
+ # See https://github.com/huggingface/transformers/pull/21269
1137
+ config = cls(**{**config_dict, **kwargs})
1138
+ unused_kwargs = config.update(**kwargs)
1139
+
1140
+ logger.info(f"Generate config {config}")
1141
+ if return_unused_kwargs:
1142
+ return config, unused_kwargs
1143
+ else:
1144
+ return config
1145
+
1146
+ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
1147
+ """
1148
+ Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
1149
+ converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
1150
+ string, which can then be stored in the json format.
1151
+ """
1152
+ if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
1153
+ d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
1154
+ for value in d.values():
1155
+ if isinstance(value, dict):
1156
+ self.dict_torch_dtype_to_str(value)
1157
+
1158
+ def to_diff_dict(self) -> Dict[str, Any]:
1159
+ """
1160
+ Removes all attributes from config which correspond to the default config attributes for better readability and
1161
+ serializes to a Python dictionary.
1162
+
1163
+ Returns:
1164
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
1165
+ """
1166
+ config_dict = self.to_dict()
1167
+
1168
+ # get the default config dict
1169
+ default_config_dict = GenerationConfig().to_dict()
1170
+
1171
+ serializable_config_dict = {}
1172
+
1173
+ # only serialize values that differ from the default config
1174
+ for key, value in config_dict.items():
1175
+ if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]:
1176
+ serializable_config_dict[key] = value
1177
+
1178
+ self.dict_torch_dtype_to_str(serializable_config_dict)
1179
+ return serializable_config_dict
1180
+
1181
+ def to_dict(self) -> Dict[str, Any]:
1182
+ """
1183
+ Serializes this instance to a Python dictionary.
1184
+
1185
+ Returns:
1186
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
1187
+ """
1188
+ output = copy.deepcopy(self.__dict__)
1189
+
1190
+ # Fields to ignore at serialization time
1191
+ if "_commit_hash" in output:
1192
+ del output["_commit_hash"]
1193
+ if "_original_object_hash" in output:
1194
+ del output["_original_object_hash"]
1195
+ if "compile_config" in output:
1196
+ del output["compile_config"]
1197
+
1198
+ # Transformers version when serializing this file
1199
+ output["transformers_version"] = __version__
1200
+
1201
+ self.dict_torch_dtype_to_str(output)
1202
+ return output
1203
+
1204
+ def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str:
1205
+ """
1206
+ Serializes this instance to a JSON string.
1207
+
1208
+ Args:
1209
+ use_diff (`bool`, *optional*, defaults to `True`):
1210
+ If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
1211
+ is serialized to JSON string.
1212
+ ignore_metadata (`bool`, *optional*, defaults to `False`):
1213
+ Whether to ignore the metadata fields present in the instance
1214
+
1215
+ Returns:
1216
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
1217
+ """
1218
+ if use_diff is True:
1219
+ config_dict = self.to_diff_dict()
1220
+ else:
1221
+ config_dict = self.to_dict()
1222
+
1223
+ if ignore_metadata:
1224
+ for metadata_field in METADATA_FIELDS:
1225
+ config_dict.pop(metadata_field, None)
1226
+
1227
+ def convert_keys_to_string(obj):
1228
+ if isinstance(obj, dict):
1229
+ return {str(key): convert_keys_to_string(value) for key, value in obj.items()}
1230
+ elif isinstance(obj, list):
1231
+ return [convert_keys_to_string(item) for item in obj]
1232
+ else:
1233
+ return obj
1234
+
1235
+ def convert_dataclass_to_dict(obj):
1236
+ if isinstance(obj, dict):
1237
+ return {key: convert_dataclass_to_dict(value) for key, value in obj.items()}
1238
+ elif is_dataclass(obj):
1239
+ return obj.to_dict()
1240
+ else:
1241
+ return obj
1242
+
1243
+ config_dict = convert_keys_to_string(config_dict)
1244
+ config_dict = convert_dataclass_to_dict(config_dict)
1245
+
1246
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
1247
+
1248
+ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
1249
+ """
1250
+ Save this instance to a JSON file.
1251
+
1252
+ Args:
1253
+ json_file_path (`str` or `os.PathLike`):
1254
+ Path to the JSON file in which this configuration instance's parameters will be saved.
1255
+ use_diff (`bool`, *optional*, defaults to `True`):
1256
+ If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
1257
+ is serialized to JSON file.
1258
+ """
1259
+ with open(json_file_path, "w", encoding="utf-8") as writer:
1260
+ writer.write(self.to_json_string(use_diff=use_diff))
1261
+
1262
+ @classmethod
1263
+ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
1264
+ """
1265
+ Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
1266
+ [`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
1267
+
1268
+ Args:
1269
+ model_config (`PretrainedConfig`):
1270
+ The model config that will be used to instantiate the generation config.
1271
+
1272
+ Returns:
1273
+ [`GenerationConfig`]: The configuration object instantiated from those parameters.
1274
+ """
1275
+ config_dict = model_config.to_dict()
1276
+ config_dict.pop("_from_model_config", None)
1277
+
1278
+ # Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
1279
+ config_dict = {key: value for key, value in config_dict.items() if value is not None}
1280
+
1281
+ generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
1282
+
1283
+ # Special case: some models have generation attributes set in the decoder. Use them if still unset in the
1284
+ # generation config (which in turn is defined from the outer attributes of model config).
1285
+ decoder_config = model_config.get_text_config(decoder=True)
1286
+ if decoder_config is not model_config:
1287
+ default_generation_config = GenerationConfig()
1288
+ decoder_config_dict = decoder_config.to_dict()
1289
+ for attr in generation_config.to_dict().keys():
1290
+ is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
1291
+ if attr in decoder_config_dict and is_unset:
1292
+ setattr(generation_config, attr, decoder_config_dict[attr])
1293
+
1294
+ # If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
1295
+ if generation_config.return_dict_in_generate is False:
1296
+ if any(
1297
+ getattr(generation_config, extra_output_flag, False)
1298
+ for extra_output_flag in generation_config.extra_output_flags
1299
+ ):
1300
+ generation_config.return_dict_in_generate = True
1301
+
1302
+ # Hash to detect whether the instance was modified
1303
+ generation_config._original_object_hash = hash(generation_config)
1304
+ return generation_config
1305
+
1306
+ def update(self, **kwargs):
1307
+ """
1308
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
1309
+ returning all the unused kwargs.
1310
+
1311
+ Args:
1312
+ kwargs (`Dict[str, Any]`):
1313
+ Dictionary of attributes to tentatively update this class.
1314
+
1315
+ Returns:
1316
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
1317
+ """
1318
+ to_remove = []
1319
+ for key, value in kwargs.items():
1320
+ if hasattr(self, key):
1321
+ setattr(self, key, value)
1322
+ to_remove.append(key)
1323
+
1324
+ # Confirm that the updated instance is still valid
1325
+ self.validate()
1326
+
1327
+ # Remove all the attributes that were updated, without modifying the input dict
1328
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
1329
+ return unused_kwargs
1330
+
1331
+
1332
+ @dataclass
1333
+ class BaseWatermarkingConfig(ABC):
1334
+ """Generic watermarking config"""
1335
+
1336
+ @classmethod
1337
+ def from_dict(cls, config_dict, **kwargs):
1338
+ """
1339
+ Constructs a BaseWatermarkingConfig instance from a dictionary of parameters.
1340
+
1341
+ Args:
1342
+ config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
1343
+ **kwargs: Additional keyword arguments to override dictionary values.
1344
+
1345
+ Returns:
1346
+ BaseWatermarkingConfig: Instance of BaseWatermarkingConfig constructed from the dictionary.
1347
+ """
1348
+ config = cls(**config_dict)
1349
+ to_remove = []
1350
+ for key, value in kwargs.items():
1351
+ if hasattr(config, key):
1352
+ setattr(config, key, value)
1353
+ to_remove.append(key)
1354
+ for key in to_remove:
1355
+ kwargs.pop(key, None)
1356
+ return config
1357
+
1358
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
1359
+ """
1360
+ Save this instance to a JSON file.
1361
+
1362
+ Args:
1363
+ json_file_path (Union[str, os.PathLike]): Path to the JSON file in which this configuration instance's parameters will be saved.
1364
+ """
1365
+ with open(json_file_path, "w", encoding="utf-8") as writer:
1366
+ config_dict = self.to_dict()
1367
+ json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
1368
+
1369
+ writer.write(json_string)
1370
+
1371
+ def to_dict(self) -> Dict[str, Any]:
1372
+ """
1373
+ Serializes this instance to a Python dictionary.
1374
+
1375
+ Returns:
1376
+ Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
1377
+ """
1378
+ output = copy.deepcopy(self.__dict__)
1379
+ return output
1380
+
1381
+ def __iter__(self):
1382
+ for attr, value in copy.deepcopy(self.__dict__).items():
1383
+ yield attr, value
1384
+
1385
+ def __repr__(self):
1386
+ return f"{self.__class__.__name__} {self.to_json_string()}"
1387
+
1388
+ def to_json_string(self):
1389
+ """
1390
+ Serializes this instance to a JSON formatted string.
1391
+
1392
+ Returns:
1393
+ str: JSON formatted string representing the configuration instance.
1394
+ """
1395
+ return json.dumps(self.__dict__, indent=2) + "\n"
1396
+
1397
+ def update(self, **kwargs):
1398
+ """
1399
+ Update the configuration attributes with new values.
1400
+
1401
+ Args:
1402
+ **kwargs: Keyword arguments representing configuration attributes and their new values.
1403
+ """
1404
+ for key, value in kwargs.items():
1405
+ if hasattr(self, key):
1406
+ setattr(self, key, value)
1407
+
1408
+ @abstractmethod
1409
+ def validate(self): ...
1410
+
1411
+ @abstractmethod
1412
+ def construct_processor(self, vocab_size): ...
1413
+
1414
+
1415
+ @dataclass
1416
+ class WatermarkingConfig(BaseWatermarkingConfig):
1417
+ """
1418
+ Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
1419
+ See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments.
1420
+
1421
+ Accepts the following keys:
1422
+ - greenlist_ratio (`float`):
1423
+ Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
1424
+ - bias (`float`):
1425
+ Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
1426
+ - hashing_key (`int`):
1427
+ Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
1428
+ - seeding_scheme (`str`):
1429
+ Algorithm to use for watermarking. Accepts values:
1430
+ - "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
1431
+ - "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
1432
+ The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
1433
+ - context_width(`int`):
1434
+ The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
1435
+ """
1436
+
1437
+ def __init__(
1438
+ self,
1439
+ greenlist_ratio: Optional[float] = 0.25,
1440
+ bias: Optional[float] = 2.0,
1441
+ hashing_key: Optional[int] = 15485863,
1442
+ seeding_scheme: Optional[str] = "lefthash",
1443
+ context_width: Optional[int] = 1,
1444
+ ):
1445
+ self.greenlist_ratio = greenlist_ratio
1446
+ self.bias = bias
1447
+ self.hashing_key = hashing_key
1448
+ self.seeding_scheme = seeding_scheme
1449
+ self.context_width = context_width
1450
+
1451
+ def validate(self):
1452
+ watermark_missing_arg_msg = (
1453
+ "Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
1454
+ "but found {found_value}"
1455
+ )
1456
+ if self.seeding_scheme not in ["selfhash", "lefthash"]:
1457
+ raise ValueError(
1458
+ watermark_missing_arg_msg.format(
1459
+ key="seeding_scheme",
1460
+ correct_value="[`selfhash`, `lefthash`]",
1461
+ found_value=self.seeding_scheme,
1462
+ ),
1463
+ )
1464
+ if not 0.0 <= self.greenlist_ratio <= 1.0:
1465
+ raise ValueError(
1466
+ watermark_missing_arg_msg.format(
1467
+ key="greenlist_ratio",
1468
+ correct_value="in range between 0.0 and 1.0",
1469
+ found_value=self.seeding_scheme,
1470
+ ),
1471
+ )
1472
+ if not self.context_width >= 1:
1473
+ raise ValueError(
1474
+ watermark_missing_arg_msg.format(
1475
+ key="context_width",
1476
+ correct_value="a positive integer",
1477
+ found_value=self.context_width,
1478
+ ),
1479
+ )
1480
+
1481
+ def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor":
1482
+ return WatermarkLogitsProcessor(
1483
+ vocab_size=vocab_size,
1484
+ device=device,
1485
+ greenlist_ratio=self.greenlist_ratio,
1486
+ bias=self.bias,
1487
+ hashing_key=self.hashing_key,
1488
+ seeding_scheme=self.seeding_scheme,
1489
+ context_width=self.context_width,
1490
+ )
1491
+
1492
+
1493
+ @dataclass
1494
+ class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig):
1495
+ """
1496
+ Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
1497
+ See [this paper](https://www.nature.com/articles/s41586-024-08025-4) for more details on the arguments.
1498
+
1499
+ Args:
1500
+ ngram_len (`int`):
1501
+ Ngram length.
1502
+ keys (`List[int]`):
1503
+ A sequence of watermarking keys, one for each depth.
1504
+ context_history_size (`int`, *optional*, defaults to 1024):
1505
+ Size of the tensor to keep track of seen contexts.
1506
+ sampling_table_seed (`int`, *optional*, defaults to 0):
1507
+ Random seed to generate the sampling table.
1508
+ sampling_table_size (`int`, *optional*, defaults to 65536):
1509
+ Size of the sampling table.
1510
+ skip_first_ngram_calls (`bool`, *optional*, defaults to `False`):
1511
+ Whether to skip first ngram calls.
1512
+ debug_mode (`bool`, optional, *optional*, defaults to `False`):
1513
+ Logits are modified to uniform one got before watermarking modification is applied. This is to test the
1514
+ implementation.
1515
+
1516
+ Examples:
1517
+ ```python
1518
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig
1519
+
1520
+ >>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b', padding_side="left")
1521
+ >>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b')
1522
+
1523
+ >>> # SynthID Text configuration
1524
+ >>> watermarking_config = SynthIDTextWatermarkingConfig(
1525
+ ... keys=[654, 400, 836, 123, 340, 443, 597, 160, 57],
1526
+ ... ngram_len=5,
1527
+ ... )
1528
+
1529
+ >>> # Generation with watermarking
1530
+ >>> tokenized_prompts = tokenizer(["Once upon a time, "], return_tensors="pt", padding=True)
1531
+ >>> output_sequences = model.generate(
1532
+ ... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, max_new_tokens=10
1533
+ ... )
1534
+ >>> watermarked_text = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
1535
+ ```
1536
+ """
1537
+
1538
+ def __init__(
1539
+ self,
1540
+ ngram_len: int,
1541
+ keys: List[int],
1542
+ context_history_size: int = 1024,
1543
+ sampling_table_seed: int = 0,
1544
+ sampling_table_size: int = 2**16,
1545
+ skip_first_ngram_calls: bool = False,
1546
+ debug_mode: bool = False,
1547
+ ):
1548
+ self.ngram_len = ngram_len
1549
+ self.keys = keys
1550
+ self.sampling_table_size = sampling_table_size
1551
+ self.sampling_table_seed = sampling_table_seed
1552
+ self.context_history_size = context_history_size
1553
+ self.skip_first_ngram_calls = skip_first_ngram_calls
1554
+ self.debug_mode = debug_mode
1555
+
1556
+ def validate(self):
1557
+ watermark_missing_arg_msg = (
1558
+ "Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
1559
+ "but found {found_value}"
1560
+ )
1561
+ if self.sampling_table_size > 2**24:
1562
+ raise ValueError(
1563
+ watermark_missing_arg_msg.format(
1564
+ key="sampling_table_size",
1565
+ correct_value="< 2**24",
1566
+ found_value=self.sampling_table_size,
1567
+ ),
1568
+ )
1569
+
1570
+ def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor":
1571
+ return SynthIDTextWatermarkLogitsProcessor(
1572
+ ngram_len=self.ngram_len,
1573
+ keys=self.keys,
1574
+ sampling_table_size=self.sampling_table_size,
1575
+ sampling_table_seed=self.sampling_table_seed,
1576
+ context_history_size=self.context_history_size,
1577
+ device=device,
1578
+ skip_first_ngram_calls=self.skip_first_ngram_calls,
1579
+ debug_mode=self.debug_mode,
1580
+ )
1581
+
1582
+
1583
+ @dataclass
1584
+ class CompileConfig(object):
1585
+ """
1586
+ Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`.
1587
+ See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.
1588
+
1589
+ Args:
1590
+ fullgraph (`bool`, *optional*, defaults to `True`):
1591
+ If `True`, requires that the whole forward be capturable in a single graph.
1592
+ dynamic (`bool` or `None`, *optional*):
1593
+ Whether to try to use dynamic shape graphs.
1594
+ backend (`str` or `Callable`, *optional*, defaults to `"inductor"`):
1595
+ Backend to be used.
1596
+ mode (`str`, *optional*, defaults to `"reduce-overhead"`):
1597
+ Controls balance between performance and overhead.
1598
+ options (`dict`, *optional*):
1599
+ A dictionary of options to pass to the backend.
1600
+
1601
+ Examples:
1602
+ ```python
1603
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig
1604
+
1605
+ >>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b')
1606
+ >>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b').cuda()
1607
+
1608
+ >>> # Automatic compile configuration, used with static cache
1609
+ >>> compile_config = CompileConfig(dynamic=True)
1610
+
1611
+ >>> # Generation with static cache and compile config
1612
+ >>> input = tokenizer.encode("Hello there, how", return_tensors="pt").cuda()
1613
+ >>> output = model.generate(
1614
+ ... input, do_sample=False, max_new_tokens=300, cache_implementation="static", compile_config=compile_config
1615
+ ... )
1616
+ >>> output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
1617
+ ```
1618
+ """
1619
+
1620
+ fullgraph: bool = True
1621
+ dynamic: Optional[bool] = None
1622
+ backend: Union[str, Callable] = "inductor"
1623
+ mode: str = "reduce-overhead"
1624
+ options: Optional[dict] = None
1625
+
1626
+ def to_dict(self) -> Dict[str, Any]:
1627
+ """Serializes this instance to a Python dictionary."""
1628
+ return copy.deepcopy(self.__dict__)
.venv/lib/python3.11/site-packages/transformers/generation/flax_logits_process.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+
18
+ import jax
19
+ import jax.lax as lax
20
+ import jax.numpy as jnp
21
+ from jax.experimental import sparse
22
+
23
+ from ..utils import add_start_docstrings
24
+ from ..utils.logging import get_logger
25
+
26
+
27
+ logger = get_logger(__name__)
28
+
29
+
30
+ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
31
+ Args:
32
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
33
+ Indices of input sequence tokens in the vocabulary.
34
+
35
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
36
+ [`PreTrainedTokenizer.__call__`] for details.
37
+
38
+ [What are input IDs?](../glossary#input-ids)
39
+ scores (`jnp.ndarray` of shape `(batch_size, config.vocab_size)`):
40
+ Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
41
+ search or log softmax for each vocabulary token when using beam search
42
+ kwargs (`Dict[str, Any]`, *optional*):
43
+ Additional logits processor specific kwargs.
44
+
45
+ Return:
46
+ `jnp.ndarray` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
47
+
48
+ """
49
+
50
+
51
+ class FlaxLogitsProcessor:
52
+ """Abstract base class for all logit processors that can be applied during generation."""
53
+
54
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
55
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
56
+ """Flax method for processing logits."""
57
+ raise NotImplementedError(
58
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
59
+ )
60
+
61
+
62
+ class FlaxLogitsWarper:
63
+ """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
64
+
65
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
66
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
67
+ """Flax method for warping logits."""
68
+ raise NotImplementedError(
69
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
70
+ )
71
+
72
+
73
+ class FlaxLogitsProcessorList(list):
74
+ """
75
+ This class can be used to create a list of [`FlaxLogitsProcessor`] or [`FlaxLogitsWarper`] to subsequently process
76
+ a `scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
77
+ [`FlaxLogitsProcessor`] or [`FlaxLogitsWarper`] to the inputs.
78
+ """
79
+
80
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
81
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int, **kwargs) -> jnp.ndarray:
82
+ for processor in self:
83
+ function_args = inspect.signature(processor.__call__).parameters
84
+ if len(function_args) > 3:
85
+ if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
86
+ raise ValueError(
87
+ f"Make sure that all the required parameters: {list(function_args.keys())} for "
88
+ f"{processor.__class__} are passed to the logits processor."
89
+ )
90
+ scores = processor(input_ids, scores, cur_len, **kwargs)
91
+ else:
92
+ scores = processor(input_ids, scores, cur_len)
93
+ return scores
94
+
95
+
96
+ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
97
+ r"""
98
+ [`FlaxLogitsWarper`] for temperature (exponential scaling output probability distribution).
99
+
100
+ Args:
101
+ temperature (`float`):
102
+ The value used to module the logits distribution.
103
+ """
104
+
105
+ def __init__(self, temperature: float):
106
+ if not isinstance(temperature, float) or not (temperature > 0):
107
+ raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
108
+
109
+ self.temperature = temperature
110
+
111
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
112
+ scores = scores / self.temperature
113
+ return scores
114
+
115
+
116
+ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
117
+ """
118
+ [`FlaxLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
119
+
120
+ Args:
121
+ top_p (`float`):
122
+ If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
123
+ higher are kept for generation.
124
+ filter_value (`float`, *optional*, defaults to -inf):
125
+ All filtered values will be set to this float value.
126
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
127
+ Minimum number of tokens that cannot be filtered.
128
+ """
129
+
130
+ def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
131
+ if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
132
+ raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
133
+ if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
134
+ raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
135
+
136
+ self.top_p = top_p
137
+ self.filter_value = filter_value
138
+ self.min_tokens_to_keep = min_tokens_to_keep
139
+
140
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
141
+ topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
142
+
143
+ mask_scores = jnp.full_like(scores, self.filter_value)
144
+ cumulative_probs = jax.nn.softmax(topk_scores, axis=-1).cumsum(axis=-1)
145
+ score_mask = cumulative_probs < self.top_p
146
+
147
+ # include the token that is higher than top_p as well
148
+ score_mask = jnp.roll(score_mask, 1)
149
+ score_mask |= score_mask.at[:, 0].set(True)
150
+
151
+ # min tokens to keep
152
+ score_mask = score_mask.at[:, : self.min_tokens_to_keep].set(True)
153
+
154
+ topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
155
+ next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]
156
+
157
+ return next_scores
158
+
159
+
160
+ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
161
+ r"""
162
+ [`FlaxLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
163
+
164
+ Args:
165
+ top_k (`int`):
166
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
167
+ filter_value (`float`, *optional*, defaults to -inf):
168
+ All filtered values will be set to this float value.
169
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
170
+ Minimum number of tokens that cannot be filtered.
171
+ """
172
+
173
+ def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
174
+ if not isinstance(top_k, int) or top_k <= 0:
175
+ raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
176
+
177
+ self.top_k = max(top_k, min_tokens_to_keep)
178
+ self.filter_value = filter_value
179
+
180
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
181
+ batch_size, vocab_size = scores.shape
182
+ next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
183
+
184
+ topk = min(self.top_k, scores.shape[-1]) # Safety check
185
+ topk_scores, topk_indices = lax.top_k(scores, topk)
186
+ shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten()
187
+ topk_scores_flat = topk_scores.flatten()
188
+ topk_indices_flat = topk_indices.flatten() + shift
189
+
190
+ next_scores_flat = next_scores_flat.at[topk_indices_flat].set(topk_scores_flat)
191
+ next_scores = next_scores_flat.reshape(batch_size, vocab_size)
192
+ return next_scores
193
+
194
+
195
+ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
196
+ r"""
197
+ [`FlaxLogitsProcessor`] that enforces the specified token as the first generated token.
198
+
199
+ Args:
200
+ bos_token_id (`int`):
201
+ The id of the token to force as the first generated token.
202
+ """
203
+
204
+ def __init__(self, bos_token_id: int):
205
+ self.bos_token_id = bos_token_id
206
+
207
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
208
+ new_scores = jnp.full(scores.shape, -float("inf"))
209
+
210
+ apply_penalty = 1 - jnp.bool_(cur_len - 1)
211
+
212
+ scores = jnp.where(apply_penalty, new_scores.at[:, self.bos_token_id].set(0), scores)
213
+
214
+ return scores
215
+
216
+
217
+ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
218
+ r"""
219
+ [`FlaxLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
220
+
221
+ Args:
222
+ max_length (`int`):
223
+ The maximum length of the sequence to be generated.
224
+ eos_token_id (`int`):
225
+ The id of the token to force as the last generated token when `max_length` is reached.
226
+ """
227
+
228
+ def __init__(self, max_length: int, eos_token_id: int):
229
+ self.max_length = max_length
230
+ self.eos_token_id = eos_token_id
231
+
232
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
233
+ new_scores = jnp.full(scores.shape, -float("inf"))
234
+
235
+ apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
236
+
237
+ scores = jnp.where(apply_penalty, new_scores.at[:, self.eos_token_id].set(0), scores)
238
+
239
+ return scores
240
+
241
+
242
+ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
243
+ r"""
244
+ [`FlaxLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
245
+
246
+ Args:
247
+ min_length (`int`):
248
+ The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
249
+ eos_token_id (`int`):
250
+ The id of the *end-of-sequence* token.
251
+ """
252
+
253
+ def __init__(self, min_length: int, eos_token_id: int):
254
+ if not isinstance(min_length, int) or min_length < 0:
255
+ raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
256
+
257
+ if not isinstance(eos_token_id, int) or eos_token_id < 0:
258
+ raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
259
+
260
+ self.min_length = min_length
261
+ self.eos_token_id = eos_token_id
262
+
263
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
264
+ # create boolean flag to decide if min length penalty should be applied
265
+ apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
266
+
267
+ scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores)
268
+
269
+ return scores
270
+
271
+
272
+ class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor):
273
+ r"""
274
+ [`FlaxLogitsProcessor`] supressing a list of tokens as soon as the `generate` function starts generating using
275
+ `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not sampled at the
276
+ beginning of the generation.
277
+
278
+ Args:
279
+ begin_suppress_tokens (`List[int]`):
280
+ Tokens to not sample.
281
+ begin_index (`int`):
282
+ Index where the tokens are suppressed.
283
+ """
284
+
285
+ def __init__(self, begin_suppress_tokens, begin_index):
286
+ self.begin_suppress_tokens = list(begin_suppress_tokens)
287
+ self.begin_index = begin_index
288
+
289
+ def __call__(self, input_ids, scores, cur_len: int):
290
+ apply_penalty = 1 - jnp.bool_(cur_len - self.begin_index)
291
+
292
+ scores = jnp.where(apply_penalty, scores.at[:, self.begin_suppress_tokens].set(-float("inf")), scores)
293
+
294
+ return scores
295
+
296
+
297
+ class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor):
298
+ r"""
299
+ [`FlaxLogitsProcessor`] suppressing a list of tokens at each decoding step. The processor will set their log probs
300
+ to be `-inf` so they are not sampled.
301
+
302
+ Args:
303
+ suppress_tokens (`list`):
304
+ Tokens to not sample.
305
+ """
306
+
307
+ def __init__(self, suppress_tokens: list):
308
+ self.suppress_tokens = list(suppress_tokens)
309
+
310
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
311
+ scores = scores.at[..., self.suppress_tokens].set(-float("inf"))
312
+
313
+ return scores
314
+
315
+
316
+ class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor):
317
+ r"""
318
+ [`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
319
+ token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
320
+ to `-inf` so that they are sampled at their corresponding index.
321
+
322
+ Args:
323
+ force_token_map (`list`):
324
+ Map giving token ids and indices where they will be forced to be sampled.
325
+ """
326
+
327
+ def __init__(self, force_token_map):
328
+ force_token_map = dict(force_token_map)
329
+ # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
330
+ # index of the array corresponds to the index of the token to be forced, for XLA compatibility.
331
+ # Indexes without forced tokens will have a negative value.
332
+ force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1
333
+ for index, token in force_token_map.items():
334
+ if token is not None:
335
+ force_token_array = force_token_array.at[index].set(token)
336
+ self.force_token_array = jnp.int32(force_token_array)
337
+
338
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
339
+ def _force_token(generation_idx):
340
+ batch_size = scores.shape[0]
341
+ current_token = self.force_token_array[generation_idx]
342
+
343
+ new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
344
+ updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
345
+ new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
346
+ return new_scores
347
+
348
+ scores = lax.cond(
349
+ cur_len >= self.force_token_array.shape[0],
350
+ # If the current length is geq than the length of force_token_array, the processor does nothing.
351
+ lambda: scores,
352
+ # Otherwise, it may force a certain token.
353
+ lambda: lax.cond(
354
+ self.force_token_array[cur_len] >= 0,
355
+ # Only valid (positive) tokens are forced
356
+ lambda: _force_token(cur_len),
357
+ # Otherwise, the processor does nothing.
358
+ lambda: scores,
359
+ ),
360
+ )
361
+ return scores
362
+
363
+
364
+ class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor):
365
+ r"""
366
+ Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
367
+ probs to `inf` so that they are sampled at their corresponding index.
368
+
369
+ Args:
370
+ generate_config (`GenerateConfig`):
371
+ The generate config used to generate the output. The following parameters are required:
372
+ eos_token_id (`int`, *optional*, defaults to 50257):
373
+ The id of the *end-of-sequence* token.
374
+ no_timestamps_token_id (`int`, *optional*, defaults to 50363):
375
+ The id of the `"<|notimestamps|>"` token.
376
+ max_initial_timestamp_index (`int`, *optional*, defaults to 1):
377
+ Used to set the maximum value of the initial timestamp. This is used to prevent the model from
378
+ predicting timestamps that are too far in the future.
379
+ """
380
+
381
+ def __init__(self, generate_config, model_config, decoder_input_length):
382
+ self.eos_token_id = generate_config.eos_token_id
383
+ self.no_timestamps_token_id = generate_config.no_timestamps_token_id
384
+ self.timestamp_begin = generate_config.no_timestamps_token_id + 1
385
+
386
+ self.begin_index = decoder_input_length + 1
387
+
388
+ if generate_config.is_multilingual:
389
+ # room for language token and task token
390
+ self.begin_index += 2
391
+ if hasattr(generate_config, "max_initial_timestamp_index"):
392
+ self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
393
+ else:
394
+ self.max_initial_timestamp_index = model_config.vocab_size
395
+ if self.max_initial_timestamp_index is None:
396
+ self.max_initial_timestamp_index = model_config.vocab_size
397
+
398
+ def __call__(self, input_ids, scores, cur_len):
399
+ # suppress <|notimestamps|> which is handled by without_timestamps
400
+ scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf"))
401
+
402
+ def handle_pairs(input_ids_k, scores_k):
403
+ last_was_timestamp = jnp.where((cur_len - self.begin_index) >= 1, True, False)
404
+ last_was_timestamp = jnp.where(
405
+ input_ids_k[cur_len - 1] >= self.timestamp_begin,
406
+ True and last_was_timestamp,
407
+ False,
408
+ )
409
+
410
+ penultimate_was_timestamp = jnp.where((cur_len - self.begin_index) < 2, True, False)
411
+ penultimate_was_timestamp = jnp.where(
412
+ input_ids_k[cur_len - 2] >= self.timestamp_begin,
413
+ True,
414
+ penultimate_was_timestamp,
415
+ )
416
+
417
+ return jnp.where(
418
+ last_was_timestamp,
419
+ jnp.where(
420
+ penultimate_was_timestamp > 0,
421
+ scores_k.at[self.timestamp_begin :].set(-float("inf")),
422
+ scores_k.at[: self.eos_token_id].set(-float("inf")),
423
+ ),
424
+ scores_k,
425
+ )
426
+
427
+ scores = jax.vmap(handle_pairs)(input_ids, scores)
428
+
429
+ apply_max_initial_timestamp = jnp.where(cur_len == self.begin_index, True, False)
430
+ apply_max_initial_timestamp = jnp.where(
431
+ self.max_initial_timestamp_index is not None,
432
+ True and apply_max_initial_timestamp,
433
+ False,
434
+ )
435
+
436
+ last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
437
+
438
+ scores = jnp.where(
439
+ apply_max_initial_timestamp,
440
+ scores.at[:, last_allowed + 1 :].set(-float("inf")),
441
+ scores,
442
+ )
443
+
444
+ # if sum of probability over timestamps is above any other token, sample timestamp
445
+ logprobs = jax.nn.log_softmax(scores, axis=-1)
446
+
447
+ def handle_cumulative_probs(logprobs_k, scores_k):
448
+ timestamp_logprob = jax.nn.logsumexp(logprobs_k[self.timestamp_begin :], axis=-1)
449
+ max_text_token_logprob = jnp.max(logprobs_k[: self.timestamp_begin])
450
+ return jnp.where(
451
+ timestamp_logprob > max_text_token_logprob,
452
+ scores_k.at[: self.timestamp_begin].set(-float("inf")),
453
+ scores_k,
454
+ )
455
+
456
+ scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)
457
+
458
+ return scores
459
+
460
+
461
+ class FlaxNoRepeatNGramLogitsProcessor(FlaxLogitsProcessor):
462
+ r"""
463
+ [`FlaxLogitsProcessor`] that enforces no repetition of n-grams. See
464
+ [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
465
+
466
+ Args:
467
+ ngram_size (`int`):
468
+ All ngrams of size `ngram_size` can only occur once.
469
+ """
470
+
471
+ def __init__(self, ngram_size: int):
472
+ if not isinstance(ngram_size, int) or ngram_size <= 0:
473
+ raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
474
+ self.ngram_size = ngram_size
475
+
476
+ def get_previous_ngrams(self, input_ids: jnp.ndarray, vocab_size: int, cur_len: int):
477
+ """
478
+ get a matrix of size (batch_size,) + (vocab_size,)*n (for n-grams) that
479
+ represent the n-grams that occurred previously.
480
+ The BCOO representation allow to store only the few non-zero entries, instead of the full (huge) matrix
481
+ """
482
+ batch_size, seq_len = input_ids.shape
483
+ # number of n-grams in the whole sequence
484
+ seq_ngrams = seq_len - (self.ngram_size - 1)
485
+ # number of n-grams in the currently generated sequence
486
+ cur_ngrams = cur_len - (self.ngram_size - 1)
487
+
488
+ def body_fun(i, val):
489
+ b = i % batch_size
490
+ pos = i // batch_size
491
+ return val.at[i].set(
492
+ jnp.array(
493
+ [
494
+ b,
495
+ ]
496
+ + [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)]
497
+ )
498
+ )
499
+
500
+ shape = (batch_size * seq_ngrams, self.ngram_size + 1)
501
+ all_update_indices = jax.lax.fori_loop(
502
+ 0, batch_size * cur_ngrams, body_fun, jnp.zeros(shape, dtype=input_ids.dtype)
503
+ )
504
+
505
+ # ignore the n-grams not yet generated
506
+ data = (jnp.arange(batch_size * seq_ngrams) < batch_size * cur_ngrams).astype("float32")
507
+
508
+ return sparse.BCOO((data, all_update_indices), shape=(batch_size,) + (vocab_size,) * self.ngram_size)
509
+
510
+ def get_banned_tokens_mask(self, latest_tokens: jnp.ndarray, previous_ngrams) -> jnp.ndarray:
511
+ """
512
+ Determines which tokens must be banned given latest tokens and the previously seen
513
+ ngrams.
514
+ """
515
+
516
+ @sparse.sparsify
517
+ @jax.vmap
518
+ def inner_fn(latest_tokens, previous_ngrams):
519
+ return previous_ngrams[tuple(latest_tokens)]
520
+
521
+ return sparse.bcoo_todense(inner_fn(latest_tokens, previous_ngrams))
522
+
523
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
524
+ def true_fn():
525
+ _, vocab_size = scores.shape
526
+ # store the previously seen n-grams
527
+ previous_ngrams = self.get_previous_ngrams(input_ids, vocab_size, cur_len)
528
+
529
+ # get the n-1 last tokens that prefix the n-gram being generated
530
+ latest_tokens = jnp.zeros((input_ids.shape[0], self.ngram_size - 1), dtype=input_ids.dtype)
531
+ latest_tokens = jax.lax.dynamic_update_slice(
532
+ latest_tokens,
533
+ jax.lax.dynamic_slice(
534
+ input_ids, (0, cur_len - (self.ngram_size - 1)), (input_ids.shape[0], (self.ngram_size - 1))
535
+ ),
536
+ (0, 0),
537
+ )
538
+
539
+ # compute the banned tokens, ie all the tokens that when added to the latest tokens lead to a n-gram that was previously generated
540
+ banned_tokens_indices_mask = self.get_banned_tokens_mask(latest_tokens, previous_ngrams).astype("bool")
541
+ return jnp.where(banned_tokens_indices_mask, -float("inf"), scores)
542
+
543
+ output = jax.lax.cond((cur_len >= self.ngram_size - 1), true_fn, lambda: scores)
544
+ return output
.venv/lib/python3.11/site-packages/transformers/generation/flax_utils.py ADDED
@@ -0,0 +1,1027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google AI Flax Team Authors, and The HuggingFace Inc. team.
3
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import copy
19
+ import inspect
20
+ import warnings
21
+ from functools import partial
22
+ from typing import Any, Dict, Optional, Union
23
+
24
+ import flax
25
+ import jax
26
+ import jax.numpy as jnp
27
+ import numpy as np
28
+ from jax import lax
29
+
30
+ from ..models.auto import (
31
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
32
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
33
+ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
34
+ )
35
+ from ..utils import ModelOutput, logging
36
+ from .configuration_utils import GenerationConfig
37
+ from .flax_logits_process import (
38
+ FlaxForcedBOSTokenLogitsProcessor,
39
+ FlaxForcedEOSTokenLogitsProcessor,
40
+ FlaxForceTokensLogitsProcessor,
41
+ FlaxLogitsProcessorList,
42
+ FlaxMinLengthLogitsProcessor,
43
+ FlaxNoRepeatNGramLogitsProcessor,
44
+ FlaxSuppressTokensAtBeginLogitsProcessor,
45
+ FlaxSuppressTokensLogitsProcessor,
46
+ FlaxTemperatureLogitsWarper,
47
+ FlaxTopKLogitsWarper,
48
+ FlaxTopPLogitsWarper,
49
+ )
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ @flax.struct.dataclass
56
+ class FlaxGreedySearchOutput(ModelOutput):
57
+ """
58
+ Flax Base class for outputs of decoder-only generation models using greedy search.
59
+
60
+
61
+ Args:
62
+ sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
63
+ The generated sequences.
64
+ """
65
+
66
+ sequences: jnp.ndarray = None
67
+
68
+
69
+ @flax.struct.dataclass
70
+ class FlaxSampleOutput(ModelOutput):
71
+ """
72
+ Flax Base class for outputs of decoder-only generation models using sampling.
73
+
74
+
75
+ Args:
76
+ sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
77
+ The generated sequences.
78
+ """
79
+
80
+ sequences: jnp.ndarray = None
81
+
82
+
83
+ @flax.struct.dataclass
84
+ class FlaxBeamSearchOutput(ModelOutput):
85
+ """
86
+ Flax Base class for outputs of decoder-only generation models using greedy search.
87
+
88
+
89
+ Args:
90
+ sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
91
+ The generated sequences.
92
+ scores (`jnp.ndarray` of shape `(batch_size,)`):
93
+ The scores (log probabilities) of the generated sequences.
94
+ """
95
+
96
+ sequences: jnp.ndarray = None
97
+ scores: jnp.ndarray = None
98
+
99
+
100
+ @flax.struct.dataclass
101
+ class GreedyState:
102
+ cur_len: jnp.ndarray
103
+ sequences: jnp.ndarray
104
+ running_token: jnp.ndarray
105
+ is_sent_finished: jnp.ndarray
106
+ model_kwargs: Dict[str, jnp.ndarray]
107
+
108
+
109
+ @flax.struct.dataclass
110
+ class SampleState:
111
+ cur_len: jnp.ndarray
112
+ sequences: jnp.ndarray
113
+ running_token: jnp.ndarray
114
+ is_sent_finished: jnp.ndarray
115
+ prng_key: jnp.ndarray
116
+ model_kwargs: Dict[str, jnp.ndarray]
117
+
118
+
119
+ @flax.struct.dataclass
120
+ class BeamSearchState:
121
+ cur_len: jnp.ndarray
122
+ running_sequences: jnp.ndarray
123
+ running_scores: jnp.ndarray
124
+ sequences: jnp.ndarray
125
+ scores: jnp.ndarray
126
+ is_sent_finished: jnp.ndarray
127
+ model_kwargs: Dict[str, jnp.ndarray]
128
+
129
+
130
+ class FlaxGenerationMixin:
131
+ """
132
+ A class containing all functions for auto-regressive text generation, to be used as a mixin in
133
+ [`FlaxPreTrainedModel`].
134
+
135
+ The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for:
136
+ - *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and
137
+ `do_sample=False`
138
+ - *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and
139
+ `do_sample=True`
140
+ - *beam-search decoding* by calling [`~generation.FlaxGenerationMixin._beam_search`] if `num_beams>1` and
141
+ `do_sample=False`
142
+
143
+ You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
144
+ learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
145
+ """
146
+
147
+ def prepare_inputs_for_generation(self, *args, **kwargs):
148
+ raise NotImplementedError(
149
+ "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
150
+ )
151
+
152
+ @staticmethod
153
+ def _run_loop_in_debug(cond_fn, body_fn, init_state):
154
+ """
155
+ Run generation in untraced mode. This should only be used for debugging purposes.
156
+ """
157
+ state = init_state
158
+ while cond_fn(state):
159
+ state = body_fn(state)
160
+ return state
161
+
162
+ def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs):
163
+ encoder_kwargs = {
164
+ argument: value
165
+ for argument, value in model_kwargs.items()
166
+ if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
167
+ }
168
+ model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
169
+ return model_kwargs
170
+
171
+ def _prepare_decoder_input_ids_for_generation(
172
+ self,
173
+ batch_size: int,
174
+ decoder_start_token_id: int = None,
175
+ bos_token_id: int = None,
176
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
177
+ ) -> jnp.ndarray:
178
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
179
+ # Only use this arg if not None, otherwise just remove from model_kwargs
180
+ decoder_input_ids = model_kwargs.pop("decoder_input_ids")
181
+ if decoder_input_ids is not None:
182
+ return decoder_input_ids
183
+ decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
184
+ return jnp.array(decoder_start_token_id, dtype="i4").reshape(1, -1).repeat(batch_size, axis=0)
185
+
186
+ def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
187
+ # retrieve decoder_start_token_id for encoder-decoder models
188
+ # fall back to bos_token_id if necessary
189
+ decoder_start_token_id = (
190
+ decoder_start_token_id
191
+ if decoder_start_token_id is not None
192
+ else self.generation_config.decoder_start_token_id
193
+ )
194
+ bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
195
+ if decoder_start_token_id is not None:
196
+ return decoder_start_token_id
197
+ elif (
198
+ hasattr(self.config, "decoder")
199
+ and hasattr(self.config.decoder, "decoder_start_token_id")
200
+ and self.config.decoder.decoder_start_token_id is not None
201
+ ):
202
+ return self.config.decoder.decoder_start_token_id
203
+ elif bos_token_id is not None:
204
+ return bos_token_id
205
+ elif (
206
+ hasattr(self.config, "decoder")
207
+ and hasattr(self.config.decoder, "bos_token_id")
208
+ and self.config.decoder.bos_token_id is not None
209
+ ):
210
+ return self.config.decoder.bos_token_id
211
+ raise ValueError(
212
+ "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
213
+ )
214
+
215
+ @staticmethod
216
+ def _expand_to_num_beams(tensor, num_beams):
217
+ return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
218
+
219
+ def _adapt_logits_for_beam_search(self, logits):
220
+ """
221
+ This function can be overwritten in the specific modeling_flax_<model-name>.py classes to allow for custom beam
222
+ search behavior. Note that the only model that overwrites this method is [`~transformes.FlaxMarianMTModel`].
223
+ """
224
+ return logits
225
+
226
+ def _validate_model_class(self):
227
+ """
228
+ Confirms that the model class is compatible with generation. If not, raises an exception that points to the
229
+ right class to use.
230
+ """
231
+ if not self.can_generate():
232
+ generate_compatible_mappings = [
233
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
234
+ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
235
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
236
+ ]
237
+ generate_compatible_classes = set()
238
+ for model_mapping in generate_compatible_mappings:
239
+ supported_models = model_mapping.get(type(self.config), default=None)
240
+ if supported_models is not None:
241
+ generate_compatible_classes.add(supported_models.__name__)
242
+ exception_message = (
243
+ f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
244
+ "it doesn't have a language model head."
245
+ )
246
+ if generate_compatible_classes:
247
+ exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
248
+ raise TypeError(exception_message)
249
+
250
+ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
251
+ """Validates model kwargs for generation. Generate argument typos will also be caught here."""
252
+ unused_model_args = []
253
+ model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
254
+ # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
255
+ # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
256
+ if "kwargs" in model_args or "model_kwargs" in model_args:
257
+ model_args |= set(inspect.signature(self.__call__).parameters)
258
+ for key, value in model_kwargs.items():
259
+ if value is not None and key not in model_args:
260
+ unused_model_args.append(key)
261
+
262
+ if unused_model_args:
263
+ raise ValueError(
264
+ f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
265
+ " generate arguments will also show up in this list)"
266
+ )
267
+
268
+ def generate(
269
+ self,
270
+ input_ids: jnp.ndarray,
271
+ generation_config: Optional[GenerationConfig] = None,
272
+ prng_key: Optional[jnp.ndarray] = None,
273
+ trace: bool = True,
274
+ params: Optional[Dict[str, jnp.ndarray]] = None,
275
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
276
+ **kwargs,
277
+ ):
278
+ r"""
279
+ Generates sequences of token ids for models with a language modeling head.
280
+
281
+ Parameters:
282
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
283
+ The sequence used as a prompt for the generation.
284
+ generation_config (`~generation.GenerationConfig`, *optional*):
285
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
286
+ passed to generate matching the attributes of `generation_config` will override them. If
287
+ `generation_config` is not provided, the default will be used, which had the following loading
288
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
289
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
290
+ default values, whose documentation should be checked to parameterize generation.
291
+ trace (`bool`, *optional*, defaults to `True`):
292
+ Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a
293
+ considerably slower runtime.
294
+ params (`Dict[str, jnp.ndarray]`, *optional*):
295
+ Optionally the model parameters can be passed. Can be useful for parallelized generation.
296
+ logits_processor (`FlaxLogitsProcessorList `, *optional*):
297
+ Custom logits processors that complement the default logits processors built from arguments and
298
+ generation config. If a logit processor is passed that is already created with the arguments or a
299
+ generation config an error is thrown. This feature is intended for advanced users.
300
+ kwargs (`Dict[str, Any]`, *optional*):
301
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
302
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
303
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
304
+
305
+ Return:
306
+ [`~utils.ModelOutput`].
307
+
308
+ """
309
+ # Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
310
+ self._validate_model_class()
311
+
312
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
313
+ if generation_config is None:
314
+ # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
315
+ # two conditions must be met
316
+ # 1) the generation config must have been created from the model config (`_from_model_config` field);
317
+ # 2) the generation config must have seen no modification since its creation (the hash is the same).
318
+ if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
319
+ self.generation_config
320
+ ):
321
+ new_generation_config = GenerationConfig.from_model_config(self.config)
322
+ if new_generation_config != self.generation_config:
323
+ warnings.warn(
324
+ "You have modified the pretrained model configuration to control generation. This is a"
325
+ " deprecated strategy to control generation and will be removed soon, in a future version."
326
+ " Please use and modify the model generation configuration (see"
327
+ " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
328
+ )
329
+ self.generation_config = new_generation_config
330
+ generation_config = self.generation_config
331
+
332
+ generation_config = copy.deepcopy(generation_config)
333
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
334
+ self._validate_model_kwargs(model_kwargs.copy())
335
+
336
+ logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()
337
+
338
+ # set init values
339
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
340
+
341
+ if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
342
+ if model_kwargs.get("attention_mask") is None:
343
+ logger.warning(
344
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
345
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
346
+ )
347
+ eos_token_id = generation_config.eos_token_id
348
+ if isinstance(eos_token_id, list):
349
+ eos_token_id = eos_token_id[0]
350
+ generation_config.pad_token_id = eos_token_id
351
+
352
+ if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:
353
+ raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
354
+
355
+ # decoder-only models should use left-padding for generation (can't be checked with `trace=True`)
356
+ if not self.config.is_encoder_decoder and not trace:
357
+ if (
358
+ generation_config.pad_token_id is not None
359
+ and jnp.sum(input_ids[:, -1] == generation_config.pad_token_id) > 0
360
+ ):
361
+ logger.warning(
362
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
363
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
364
+ )
365
+
366
+ batch_size = input_ids.shape[0]
367
+
368
+ if self.config.is_encoder_decoder:
369
+ # add encoder_outputs to model_kwargs
370
+ if model_kwargs.get("encoder_outputs") is None:
371
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
372
+ # prepare decoder_input_ids for generation
373
+ input_ids = self._prepare_decoder_input_ids_for_generation(
374
+ batch_size,
375
+ decoder_start_token_id=generation_config.decoder_start_token_id,
376
+ bos_token_id=generation_config.bos_token_id,
377
+ model_kwargs=model_kwargs,
378
+ )
379
+
380
+ # Prepare `max_length` depending on other stopping criteria.
381
+ input_ids_seq_length = input_ids.shape[-1]
382
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
383
+ if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
384
+ # 20 is the default max_length of the generation config
385
+ warnings.warn(
386
+ f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
387
+ "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.",
388
+ UserWarning,
389
+ )
390
+ elif generation_config.max_new_tokens is not None:
391
+ if not has_default_max_length and generation_config.max_length is not None:
392
+ logger.warning(
393
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
394
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
395
+ "Please refer to the documentation for more information. "
396
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
397
+ )
398
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
399
+ else: # by default let's always generate 10 new tokens
400
+ if generation_config.max_length == GenerationConfig().max_length:
401
+ generation_config.max_length = generation_config.max_length + input_ids_seq_length
402
+ max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
403
+ if max_position_embeddings is not None:
404
+ generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
405
+
406
+ if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
407
+ raise ValueError(
408
+ f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than"
409
+ f" the maximum length ({generation_config.max_length})"
410
+ )
411
+ if input_ids_seq_length >= generation_config.max_length:
412
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
413
+ logger.warning(
414
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
415
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
416
+ " increasing`max_new_tokens`."
417
+ )
418
+
419
+ logits_processor = self._get_logits_processor(
420
+ generation_config=generation_config,
421
+ input_ids_seq_length=input_ids_seq_length,
422
+ logits_processor=logits_processor,
423
+ )
424
+
425
+ if not generation_config.do_sample and generation_config.num_beams == 1:
426
+ return self._greedy_search(
427
+ input_ids,
428
+ generation_config.max_length,
429
+ generation_config.pad_token_id,
430
+ generation_config.eos_token_id,
431
+ logits_processor=logits_processor,
432
+ trace=trace,
433
+ params=params,
434
+ model_kwargs=model_kwargs,
435
+ )
436
+ elif generation_config.do_sample and generation_config.num_beams == 1:
437
+ logits_warper = self._get_logits_warper(generation_config=generation_config)
438
+ return self._sample(
439
+ input_ids,
440
+ generation_config.max_length,
441
+ generation_config.pad_token_id,
442
+ generation_config.eos_token_id,
443
+ prng_key,
444
+ logits_warper=logits_warper,
445
+ logits_processor=logits_processor,
446
+ trace=trace,
447
+ params=params,
448
+ model_kwargs=model_kwargs,
449
+ )
450
+ elif not generation_config.do_sample and generation_config.num_beams > 1:
451
+ # broadcast input_ids & encoder_outputs
452
+ input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams)
453
+
454
+ if "encoder_outputs" in model_kwargs:
455
+ model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
456
+ model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams
457
+ )
458
+
459
+ for kwarg in ["attention_mask", "decoder_attention_mask"]:
460
+ if kwarg in model_kwargs:
461
+ model_kwargs[kwarg] = self._expand_to_num_beams(
462
+ model_kwargs[kwarg], num_beams=generation_config.num_beams
463
+ )
464
+
465
+ return self._beam_search(
466
+ input_ids,
467
+ generation_config.max_length,
468
+ generation_config.pad_token_id,
469
+ generation_config.eos_token_id,
470
+ length_penalty=generation_config.length_penalty,
471
+ early_stopping=generation_config.early_stopping,
472
+ logits_processor=logits_processor,
473
+ trace=trace,
474
+ params=params,
475
+ num_return_sequences=generation_config.num_return_sequences,
476
+ model_kwargs=model_kwargs,
477
+ )
478
+ else:
479
+ raise NotImplementedError("`Beam sampling is currently not implemented.")
480
+
481
+ def _get_logits_warper(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList:
482
+ """
483
+ This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
484
+ instances used for multinomial sampling.
485
+ """
486
+ warpers = FlaxLogitsProcessorList()
487
+
488
+ if generation_config.temperature is not None and generation_config.temperature != 1.0:
489
+ warpers.append(FlaxTemperatureLogitsWarper(generation_config.temperature))
490
+ if generation_config.top_k is not None and generation_config.top_k != 0:
491
+ warpers.append(FlaxTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))
492
+ if generation_config.top_p is not None and generation_config.top_p < 1.0:
493
+ warpers.append(FlaxTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1))
494
+
495
+ return warpers
496
+
497
+ def _get_logits_processor(
498
+ self,
499
+ generation_config: GenerationConfig,
500
+ input_ids_seq_length: int,
501
+ logits_processor: Optional[FlaxLogitsProcessorList],
502
+ ) -> FlaxLogitsProcessorList:
503
+ """
504
+ This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
505
+ instances used to modify the scores of the language model head.
506
+ """
507
+ processors = FlaxLogitsProcessorList()
508
+
509
+ if (
510
+ generation_config.min_length is not None
511
+ and generation_config.eos_token_id is not None
512
+ and generation_config.min_length > -1
513
+ ):
514
+ processors.append(
515
+ FlaxMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)
516
+ )
517
+ if generation_config.forced_bos_token_id is not None:
518
+ processors.append(FlaxForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
519
+ if generation_config.forced_eos_token_id is not None:
520
+ processors.append(
521
+ FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
522
+ )
523
+ if generation_config.suppress_tokens is not None:
524
+ processors.append(FlaxSuppressTokensLogitsProcessor(generation_config.suppress_tokens))
525
+ if generation_config.begin_suppress_tokens is not None:
526
+ begin_index = input_ids_seq_length
527
+ begin_index = (
528
+ begin_index
529
+ if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
530
+ else begin_index + 1
531
+ )
532
+ if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0:
533
+ # generation starts after the last token that is forced
534
+ begin_index += generation_config.forced_decoder_ids[-1][0]
535
+ processors.append(
536
+ FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
537
+ )
538
+ if generation_config.forced_decoder_ids is not None:
539
+ forced_decoder_ids = [
540
+ [input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
541
+ ]
542
+ processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
543
+ if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
544
+ processors.append(FlaxNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
545
+ processors = self._merge_criteria_processor_list(processors, logits_processor)
546
+
547
+ return processors
548
+
549
+ def _merge_criteria_processor_list(
550
+ self,
551
+ default_list: FlaxLogitsProcessorList,
552
+ custom_list: FlaxLogitsProcessorList,
553
+ ) -> FlaxLogitsProcessorList:
554
+ if len(custom_list) == 0:
555
+ return default_list
556
+ for default in default_list:
557
+ for custom in custom_list:
558
+ if type(custom) is type(default):
559
+ object_type = "logits processor"
560
+ raise ValueError(
561
+ f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
562
+ f" `generate`, but it has already been created with the values {default}. {default} has been"
563
+ " created by passing the corresponding arguments to generate or by the model's config default"
564
+ f" values. If you just want to change the default values of {object_type} consider passing"
565
+ f" them as arguments to `generate` instead of using a custom {object_type}."
566
+ )
567
+ default_list.extend(custom_list)
568
+ return default_list
569
+
570
+ def _greedy_search(
571
+ self,
572
+ input_ids: None,
573
+ max_length: Optional[int] = None,
574
+ pad_token_id: Optional[int] = None,
575
+ eos_token_id: Optional[int] = None,
576
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
577
+ trace: bool = True,
578
+ params: Optional[Dict[str, jnp.ndarray]] = None,
579
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
580
+ ):
581
+ # init values
582
+ max_length = max_length if max_length is not None else self.generation_config.max_length
583
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
584
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
585
+
586
+ batch_size, cur_len = input_ids.shape
587
+
588
+ eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
589
+ pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
590
+ cur_len = jnp.array(cur_len)
591
+
592
+ # per batch-item holding current token in loop.
593
+ sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
594
+ sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
595
+
596
+ # per batch-item state bit indicating if sentence has finished.
597
+ is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
598
+
599
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
600
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
601
+ model = self.decode if self.config.is_encoder_decoder else self
602
+ # initialize model specific kwargs
603
+ model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
604
+
605
+ # initialize state
606
+ state = GreedyState(
607
+ cur_len=cur_len,
608
+ sequences=sequences,
609
+ running_token=input_ids,
610
+ is_sent_finished=is_sent_finished,
611
+ model_kwargs=model_kwargs,
612
+ )
613
+
614
+ def greedy_search_cond_fn(state):
615
+ """state termination condition fn."""
616
+ has_reached_max_length = state.cur_len == max_length
617
+ all_sequence_finished = jnp.all(state.is_sent_finished)
618
+ finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
619
+ return ~finish_generation
620
+
621
+ def greedy_search_body_fn(state):
622
+ """state update fn."""
623
+ model_outputs = model(state.running_token, params=params, **state.model_kwargs)
624
+ logits = model_outputs.logits[:, -1]
625
+
626
+ # apply min_length, ...
627
+ logits = logits_processor(state.sequences, logits, state.cur_len)
628
+
629
+ next_token = jnp.argmax(logits, axis=-1)
630
+
631
+ next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
632
+ next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
633
+ next_token = next_token[:, None]
634
+
635
+ next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
636
+ next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
637
+ return GreedyState(
638
+ cur_len=state.cur_len + 1,
639
+ sequences=next_sequences,
640
+ running_token=next_token,
641
+ is_sent_finished=next_is_sent_finished,
642
+ model_kwargs=next_model_kwargs,
643
+ )
644
+
645
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
646
+ if input_ids.shape[1] > 1:
647
+ state = greedy_search_body_fn(state)
648
+
649
+ if not trace:
650
+ state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state)
651
+ else:
652
+ state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state)
653
+
654
+ return FlaxGreedySearchOutput(sequences=state.sequences)
655
+
656
+ def _sample(
657
+ self,
658
+ input_ids: None,
659
+ max_length: Optional[int] = None,
660
+ pad_token_id: Optional[int] = None,
661
+ eos_token_id: Optional[int] = None,
662
+ prng_key: Optional[jnp.ndarray] = None,
663
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
664
+ logits_warper: Optional[FlaxLogitsProcessorList] = None,
665
+ trace: bool = True,
666
+ params: Optional[Dict[str, jnp.ndarray]] = None,
667
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
668
+ ):
669
+ # init values
670
+ max_length = max_length if max_length is not None else self.generation_config.max_length
671
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
672
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
673
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
674
+
675
+ batch_size, cur_len = input_ids.shape
676
+
677
+ eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
678
+ pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
679
+ cur_len = jnp.array(cur_len)
680
+
681
+ # per batch-item holding current token in loop.
682
+ sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
683
+ sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
684
+
685
+ # per batch-item state bit indicating if sentence has finished.
686
+ is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
687
+
688
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
689
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
690
+ model = self.decode if self.config.is_encoder_decoder else self
691
+
692
+ # initialize model specific kwargs
693
+ model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
694
+
695
+ # initialize state
696
+ state = SampleState(
697
+ cur_len=cur_len,
698
+ sequences=sequences,
699
+ running_token=input_ids,
700
+ is_sent_finished=is_sent_finished,
701
+ prng_key=prng_key,
702
+ model_kwargs=model_kwargs,
703
+ )
704
+
705
+ def sample_search_cond_fn(state):
706
+ """state termination condition fn."""
707
+ has_reached_max_length = state.cur_len == max_length
708
+ all_sequence_finished = jnp.all(state.is_sent_finished)
709
+ finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
710
+ return ~finish_generation
711
+
712
+ def sample_search_body_fn(state):
713
+ """state update fn."""
714
+ prng_key, prng_key_next = jax.random.split(state.prng_key)
715
+ model_outputs = model(state.running_token, params=params, **state.model_kwargs)
716
+
717
+ logits = model_outputs.logits[:, -1]
718
+
719
+ # apply min_length, ...
720
+ logits = logits_processor(state.sequences, logits, state.cur_len)
721
+ # apply top_p, top_k, temperature
722
+ logits = logits_warper(logits, logits, state.cur_len)
723
+
724
+ next_token = jax.random.categorical(prng_key, logits, axis=-1)
725
+
726
+ next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
727
+ next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
728
+ next_token = next_token[:, None]
729
+
730
+ next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
731
+ next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
732
+
733
+ return SampleState(
734
+ cur_len=state.cur_len + 1,
735
+ sequences=next_sequences,
736
+ running_token=next_token,
737
+ is_sent_finished=next_is_sent_finished,
738
+ model_kwargs=next_model_kwargs,
739
+ prng_key=prng_key_next,
740
+ )
741
+
742
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
743
+ if input_ids.shape[1] > 1:
744
+ state = sample_search_body_fn(state)
745
+
746
+ if not trace:
747
+ state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
748
+ else:
749
+ state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
750
+
751
+ return FlaxSampleOutput(sequences=state.sequences)
752
+
753
+ def _beam_search(
754
+ self,
755
+ input_ids: None,
756
+ max_length: Optional[int] = None,
757
+ pad_token_id: Optional[int] = None,
758
+ eos_token_id: Optional[int] = None,
759
+ length_penalty: Optional[float] = None,
760
+ early_stopping: Optional[Union[bool, str]] = None,
761
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
762
+ trace: bool = True,
763
+ params: Optional[Dict[str, jnp.ndarray]] = None,
764
+ num_return_sequences: Optional[int] = None,
765
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
766
+ ):
767
+ """
768
+ This beam search function is heavily inspired by Flax's official example:
769
+ https://github.com/google/flax/blob/main/examples/wmt/decode.py
770
+ """
771
+
772
+ def flatten_beam_dim(tensor):
773
+ """Flattens the first two dimensions of a non-scalar array."""
774
+ # ignore scalars (e.g. cache index)
775
+ if tensor.ndim == 0:
776
+ return tensor
777
+ return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
778
+
779
+ def unflatten_beam_dim(tensor, batch_size, num_beams):
780
+ """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
781
+ # ignore scalars (e.g. cache index)
782
+ if tensor.ndim == 0:
783
+ return tensor
784
+ return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
785
+
786
+ def gather_beams(nested, beam_indices, batch_size, new_num_beams):
787
+ """
788
+ Gathers the beam slices indexed by beam_indices into new beam array.
789
+ """
790
+ batch_indices = jnp.reshape(
791
+ jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
792
+ )
793
+
794
+ def gather_fn(tensor):
795
+ # ignore scalars (e.g. cache index)
796
+ if tensor.ndim == 0:
797
+ return tensor
798
+ else:
799
+ return tensor[batch_indices, beam_indices]
800
+
801
+ return jax.tree_util.tree_map(gather_fn, nested)
802
+
803
+ # init values
804
+ max_length = max_length if max_length is not None else self.generation_config.max_length
805
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
806
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
807
+ length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty
808
+ early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping
809
+ num_return_sequences = (
810
+ num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences
811
+ )
812
+
813
+ batch_size, num_beams, cur_len = input_ids.shape
814
+
815
+ eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
816
+ pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
817
+ cur_len = jnp.array(cur_len)
818
+
819
+ # record the prompt length of decoder
820
+ decoder_prompt_len = input_ids.shape[-1]
821
+
822
+ # per batch,beam-item holding current token in loop.
823
+ sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
824
+ running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
825
+ running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
826
+
827
+ # per batch,beam-item state bit indicating if sentence has finished.
828
+ is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
829
+
830
+ # per batch,beam-item score, logprobs
831
+ running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
832
+ scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
833
+
834
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
835
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
836
+ model = self.decode if self.config.is_encoder_decoder else self
837
+
838
+ # flatten beam dim
839
+ if "encoder_outputs" in model_kwargs:
840
+ model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
841
+ model_kwargs["encoder_outputs"]["last_hidden_state"]
842
+ )
843
+ for kwarg in ["attention_mask", "decoder_attention_mask"]:
844
+ if kwarg in model_kwargs:
845
+ model_kwargs[kwarg] = flatten_beam_dim(model_kwargs[kwarg])
846
+
847
+ # initialize model specific kwargs
848
+ model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
849
+
850
+ # initialize state
851
+ state = BeamSearchState(
852
+ cur_len=cur_len,
853
+ running_sequences=running_sequences,
854
+ running_scores=running_scores,
855
+ sequences=sequences,
856
+ scores=scores,
857
+ is_sent_finished=is_sent_finished,
858
+ model_kwargs=model_kwargs,
859
+ )
860
+
861
+ def beam_search_cond_fn(state):
862
+ """beam search state termination condition fn."""
863
+
864
+ # 1. is less than max length?
865
+ not_max_length_yet = state.cur_len < max_length
866
+
867
+ # 2. can the new beams still improve?
868
+ # early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
869
+ # below for more details.
870
+ # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
871
+ # early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
872
+ # length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
873
+ if early_stopping == "never" and length_penalty > 0.0:
874
+ best_running_score = state.running_scores[:, :1] / (
875
+ (max_length - decoder_prompt_len) ** length_penalty
876
+ )
877
+ else:
878
+ best_running_score = state.running_scores[:, :1] / (
879
+ (state.cur_len - decoder_prompt_len) ** length_penalty
880
+ )
881
+ worst_finished_score = jnp.where(
882
+ state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
883
+ )
884
+ improvement_still_possible = jnp.any(best_running_score > worst_finished_score)
885
+
886
+ # 3. is there still a beam that has not finished?
887
+ still_open_beam = ~(jnp.all(state.is_sent_finished) & (early_stopping is True))
888
+
889
+ return not_max_length_yet & still_open_beam & improvement_still_possible
890
+
891
+ def beam_search_body_fn(state, input_ids_length=1):
892
+ """beam search state update fn."""
893
+ # 1. Forward current tokens
894
+ # Collect the current position slice along length to feed the fast
895
+ # autoregressive decoder model. Flatten the beam dimension into batch
896
+ # dimension for feeding into the model.
897
+ # unflatten beam dimension
898
+ # Unflatten beam dimension in attention cache arrays
899
+ input_token = flatten_beam_dim(
900
+ lax.dynamic_slice(
901
+ state.running_sequences,
902
+ (0, 0, state.cur_len - input_ids_length),
903
+ (batch_size, num_beams, input_ids_length),
904
+ )
905
+ )
906
+ model_outputs = model(input_token, params=params, **state.model_kwargs)
907
+
908
+ logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
909
+ cache = jax.tree_util.tree_map(
910
+ lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
911
+ )
912
+
913
+ # adapt logits for FlaxMarianMTModel
914
+ logits = self._adapt_logits_for_beam_search(logits)
915
+
916
+ # 2. Compute log probs
917
+ # get log probabilities from logits,
918
+ # process logits with processors (*e.g.* min_length, ...), and
919
+ # add new logprobs to existing running logprobs scores.
920
+ log_probs = jax.nn.log_softmax(logits)
921
+ log_probs = logits_processor(
922
+ flatten_beam_dim(state.running_sequences), flatten_beam_dim(log_probs), state.cur_len
923
+ )
924
+ log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
925
+ log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
926
+ vocab_size = log_probs.shape[2]
927
+ log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
928
+
929
+ # 3. Retrieve top-K
930
+ # Each item in batch has num_beams * vocab_size candidate sequences.
931
+ # For each item, get the top 2*k candidates with the highest log-
932
+ # probabilities. We gather the top 2*K beams here so that even if the best
933
+ # K sequences reach EOS simultaneously, we have another K sequences
934
+ # remaining to continue the live beam search.
935
+ # Gather the top 2*K scores from _all_ beams.
936
+ # Gather 2*k top beams.
937
+ # Recover the beam index by floor division.
938
+ # Recover token id by modulo division and expand Id array for broadcasting.
939
+ # Update sequences for the 2*K top-k new sequences.
940
+ beams_to_keep = 2 * num_beams
941
+ topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
942
+ topk_beam_indices = topk_indices // vocab_size
943
+ topk_running_sequences = gather_beams(
944
+ state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
945
+ )
946
+ topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
947
+ topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
948
+
949
+ # 4. Check which sequences have ended
950
+ # Update current sequences:
951
+ # Did any of these sequences reach an end marker?
952
+ # To prevent these just finished sequences from being added to the current sequences
953
+ # set of active beam search sequences, set their log probs to a very large
954
+ # negative value.
955
+ did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
956
+ running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
957
+ # 5. Get running sequences scores for next
958
+ # Determine the top k beam indices (from top 2*k beams) from log probs
959
+ # and gather top k beams (from top 2*k beams).
960
+ next_topk_indices = lax.top_k(running_topk_log_probs, k=num_beams)[1]
961
+ next_running_sequences, next_running_scores = gather_beams(
962
+ [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
963
+ )
964
+
965
+ # 6. Process topk logits
966
+ # Further process log probs:
967
+ # - add length penalty
968
+ # - make sure no scores can be added anymore if beam is full
969
+ # - make sure still running sequences cannot be chosen as finalized beam
970
+ topk_log_probs = topk_log_probs / ((state.cur_len + 1 - decoder_prompt_len) ** length_penalty)
971
+ beams_in_batch_are_full = jnp.broadcast_to(
972
+ state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape
973
+ ) & (early_stopping is True)
974
+ add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
975
+ topk_log_probs += add_penalty * np.array(-1.0e7)
976
+
977
+ # 7. Get scores, sequences, is sentence finished for next.
978
+ # Combine sequences, scores, and flags along the beam dimension and compare
979
+ # new finished sequence scores to existing finished scores and select the
980
+ # best from the new set of beams
981
+ merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
982
+ merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
983
+ merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
984
+ topk_merged_indices = lax.top_k(merged_scores, k=num_beams)[1]
985
+ next_sequences, next_scores, next_is_sent_finished = gather_beams(
986
+ [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
987
+ )
988
+
989
+ # 8. Update model kwargs.
990
+ # Determine the top k beam indices from the original set of all beams.
991
+ # With these, gather the top k beam-associated caches.
992
+ next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
993
+ next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
994
+ model_outputs["past_key_values"] = jax.tree_util.tree_map(lambda x: flatten_beam_dim(x), next_cache)
995
+ next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
996
+
997
+ return BeamSearchState(
998
+ cur_len=state.cur_len + 1,
999
+ running_scores=next_running_scores,
1000
+ running_sequences=next_running_sequences,
1001
+ scores=next_scores,
1002
+ sequences=next_sequences,
1003
+ is_sent_finished=next_is_sent_finished,
1004
+ model_kwargs=next_model_kwargs,
1005
+ )
1006
+
1007
+ # Always run first iteration outside of `lax.while_loop` to avoid calling `beam_search_cond_fn`
1008
+ # when `state.cur_len` equals `decoder_prompt_len`. This also helps to comply with TPU when
1009
+ # the very first prompt has sequence length > 1.
1010
+ state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
1011
+
1012
+ if not trace:
1013
+ state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
1014
+ else:
1015
+ state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
1016
+
1017
+ # Account for the edge-case where there are no finished sequences for a
1018
+ # particular batch item. If so, return running sequences for that batch item.
1019
+ none_finished = jnp.any(state.is_sent_finished, axis=1)
1020
+ sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
1021
+ scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
1022
+
1023
+ # Take best beams for each batch (the score is sorted in descending order)
1024
+ sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
1025
+ scores = flatten_beam_dim(scores[:, :num_return_sequences])
1026
+
1027
+ return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
.venv/lib/python3.11/site-packages/transformers/generation/logits_process.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/transformers/generation/stopping_criteria.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import warnings
3
+ from abc import ABC
4
+ from collections import OrderedDict
5
+ from copy import deepcopy
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch.nn import functional as F
11
+
12
+ from ..pytorch_utils import isin_mps_friendly
13
+ from ..tokenization_utils_base import PreTrainedTokenizerBase
14
+ from ..utils import add_start_docstrings, logging
15
+
16
+
17
+ logger = logging.get_logger(__name__)
18
+ # We maintain a module-level cache of the embedding vectors for the stop string criterion
19
+ # because they are slow to compute
20
+ STOP_STRING_EMBEDDING_CACHE = OrderedDict()
21
+
22
+
23
+ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
24
+ Args:
25
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
26
+ Indices of input sequence tokens in the vocabulary.
27
+
28
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
29
+ [`PreTrainedTokenizer.__call__`] for details.
30
+
31
+ [What are input IDs?](../glossary#input-ids)
32
+ scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
33
+ Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
34
+ or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input,
35
+ make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`.
36
+ kwargs (`Dict[str, Any]`, *optional*):
37
+ Additional stopping criteria specific kwargs.
38
+
39
+ Return:
40
+ `torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation
41
+ for a particular row, `True` indicates we should continue.
42
+
43
+ """
44
+
45
+
46
+ class StoppingCriteria(ABC):
47
+ """Abstract base class for all stopping criteria that can be applied during generation.
48
+
49
+ If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True,
50
+ output_scores=True` to `generate`.
51
+ """
52
+
53
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
54
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
55
+ raise NotImplementedError("StoppingCriteria needs to be subclassed")
56
+
57
+
58
+ class MaxLengthCriteria(StoppingCriteria):
59
+ """
60
+ This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep
61
+ in mind for decoder-only type of transformers, this will include the initial prompted tokens.
62
+
63
+ Args:
64
+ max_length (`int`):
65
+ The maximum length that the output sequence can have in number of tokens.
66
+ max_position_embeddings (`int`, *optional*):
67
+ The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
68
+ """
69
+
70
+ def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None):
71
+ self.max_length = max_length
72
+ self.max_position_embeddings = max_position_embeddings
73
+
74
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
75
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
76
+ cur_len = input_ids.shape[-1]
77
+ is_done = cur_len >= self.max_length
78
+ if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
79
+ logger.warning_once(
80
+ "This is a friendly reminder - the current text generation call will exceed the model's predefined "
81
+ f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
82
+ "exceptions, performance degradation, or nothing at all."
83
+ )
84
+ return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
85
+
86
+
87
+ class MaxTimeCriteria(StoppingCriteria):
88
+ """
89
+ This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
90
+ time will start being counted when you initialize this function. You can override this by passing an
91
+ `initial_time`.
92
+
93
+ Args:
94
+ max_time (`float`):
95
+ The maximum allowed time in seconds for the generation.
96
+ initial_time (`float`, *optional*, defaults to `time.time()`):
97
+ The start of the generation allowed time.
98
+ """
99
+
100
+ def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
101
+ self.max_time = max_time
102
+ self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
103
+
104
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
105
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
106
+ is_done = time.time() - self.initial_timestamp > self.max_time
107
+ return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
108
+
109
+
110
+ class StopStringCriteria(StoppingCriteria):
111
+ """
112
+ This class can be used to stop generation whenever specific string sequences are generated. It preprocesses
113
+ the strings together with the tokenizer vocab to find positions where tokens can validly complete the stop strings.
114
+
115
+ Generation is stopped as soon as a token is generated that completes any of the stop strings.
116
+ We want to catch any instance in which the stop string would be present in the decoded output, which means
117
+ we must also catch cases with "overhangs" off one or both ends. To make this more concrete, for the stop string
118
+ "stop", any of the following token sequences would trigger the match:
119
+
120
+ - ["st", "op"]
121
+ - ["stop"]
122
+ - ["st", "opera"]
123
+ - ["sto", "pper"]
124
+ - ["las", "topper"]
125
+ - ["s", "to", "pped"]
126
+
127
+ Note that a match will only be triggered if the stop string is at the end of the generated sequence. In other
128
+ words, these sequences will not trigger a match:
129
+
130
+ - ["stop", "at"]
131
+ - ["st", "op", "at"]
132
+ - ["st", "opera", "tion"]
133
+
134
+ The reason these are not a match is that the stop string does not overlap with the final token. If you can remove
135
+ one or more tokens from the end of the sequence without destroying the stop string, then this criterion will not
136
+ match that stop string. This is by design; because this check is run after each token is generated, we can't miss a
137
+ valid stop string if one is generated, but we don't want to halt generation just because the stop string exists
138
+ somewhere in the past input_ids.
139
+
140
+ How is the match actually performed, though? We do it in quite a confusing way, because we want the entire match
141
+ process to be compilable with Torch or XLA, which means we cannot use standard string methods. However, it is possible,
142
+ with some work, to do string matching with pure tensor operations. We'll begin by describing the algorithm we use
143
+ with standard string operations, and then at the end we'll explain how this is converted to pure tensor operations.
144
+
145
+ The key to the algorithm is an observation: Because the stop string must overlap with the end of the token sequence, we can start at
146
+ the end of the sequence and work backwards. Specifically, we check that there is an overlap between the start of
147
+ the final token and the end of the stop_string, or to put it another way, stop_string[-i:] == token[:i] for
148
+ some i > 0. If you look at the positive examples above, you'll see the last token in all of them fulfills this
149
+ property:
150
+
151
+ - ["st", "op"] (overlap is "op", overlap length == 2)
152
+ - ["stop"] (overlap is "stop", overlap length == 4)
153
+ - ["st", "opera"] (overlap is "op", overlap length == 2)
154
+ - ["sto", "pper"] (overlap is "p", overlap length == 1)
155
+ - ["las", "topper"] (overlap is "top", overlap length == 3)
156
+ - ["s", "to", "pped"] (overlap is "p", overlap length == 1)
157
+
158
+ It's impossible to construct a matching sequence that does not have this property (feel free to verify this
159
+ yourself). However, although this overlap between the start of the final token and the end of the stop string is
160
+ necessary for a match, it is not sufficient. We also need to check that the rest of the token sequence is
161
+ consistent with the stop string.
162
+
163
+ How do we do that? Let's use ["s", "to", "pped"] as an example. We know that the final token, "pped", has an
164
+ overlap of 1 with the stop string, "stop". We then go back to the previous token, "to". Since we have already
165
+ matched 1 character from the stop string, the remainder to check is "sto". We check that the next token "to"
166
+ matches the end of the remainder, which it does. We have now matched 3 characters from the stop string, and the
167
+ remainder to match is "s". We go back to the previous token again, which is also "s". This is a match, and so
168
+ we have matched the entire stop string.
169
+
170
+ How does it work when the tokens run off the start of the stop string, though? Let's consider the example of
171
+ ["las", "topper"]. The final token, "topper", has an overlap of 3 with the stop string, "stop". Therefore,
172
+ the remaining stop string to match is "s". We go back to the previous token, "las". Because the remainder to
173
+ match is just "s", with length 1, we consider only the final 1 character from the token, which is "s". This
174
+ matches the stop string, and so the entire string is matched.
175
+
176
+ How do we compute these matches with tensor operations, though? Simply: we efficiently precompute the necessary
177
+ information for all tokens! For every token, we compute:
178
+ - Its overlap with the end of the stop string, if any
179
+ - The positions inside the stop string where the token matches, including matches that run off the start.
180
+ - The total length of the token
181
+
182
+ For example, for the token "pped", we would compute an end overlap of 1, no internal matching positions,
183
+ and a length of 4. For the token "to", we would compute no end overlap, a single internal matching position
184
+ of 1 (counting from the end), and a length of 2. For the token "s", we would compute no end overlap,
185
+ a single internal matching position of 3 (again counting from the end) and a length of 1.
186
+
187
+ As long as we have this information, we can execute the algorithm above without any string comparison
188
+ operations. We simply perform the following steps:
189
+ - Check if the final token has an end-overlap with the start string
190
+ - Continue backwards, keeping track of how much of the stop string we've matched so far
191
+ - At each point, check if the next token has the current position as one of its valid positions
192
+ - Continue until either a match fails, or we completely match the whole stop string
193
+
194
+ Again, consider ["s", "to", "pped"] as an example. "pped" has an end overlap of 1, so we can begin a match.
195
+ We have matched 1 character so far, so we check that the next token "to", has 1 as a valid position (again,
196
+ counting from the end). It does, so we add the length of "to" to our position tracker. We have now matched
197
+ 3 characters, so we check that the next token "s" has 3 as a valid position. It does, so we add its length
198
+ to the position tracker. The position tracker is now 4, which is the length of the stop string. We have matched the
199
+ entire stop string.
200
+
201
+ In the second case, ["las", "topper"], "topper" has an end overlap of 3, so we can begin a match. We have
202
+ matched 3 characters so far, so we check that the next token "las" has 3 as a valid position. It does, because we
203
+ allow tokens to match positions that run off the start of the stop string. We add its length to the position
204
+ tracker. The position tracker is now 6, which is greater than the length of the stop string! Don't panic, though -
205
+ this also counts as a match of the stop string. We have matched the entire stop string.
206
+
207
+
208
+ Args:
209
+ tokenizer (`PreTrainedTokenizer`):
210
+ The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences)
211
+ stop_strings (`Union[str, List[str]]`):
212
+ A list of strings that should end generation. If a string is passed, it will be treated like a
213
+ list with a single element.
214
+
215
+ Examples:
216
+
217
+ ```python
218
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer
219
+
220
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
221
+ >>> model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
222
+ >>> inputs = tokenizer("The biggest states in the USA by land area:", return_tensors="pt")
223
+
224
+ >>> gen_out = model.generate(**inputs)
225
+ >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
226
+ The biggest states in the USA by land area:
227
+ - Alaska
228
+ - Texas
229
+ - California
230
+
231
+ >>> # Passing one or more stop strings will halt generation after those strings are emitted
232
+ >>> # Note that generating with stop strings requires you to pass the tokenizer too
233
+ >>> gen_out = model.generate(**inputs, stop_strings=["Texas"], tokenizer=tokenizer)
234
+ >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
235
+ The biggest states in the USA by land area:
236
+ - Alaska
237
+ - Texas
238
+ ```
239
+ """
240
+
241
+ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, List[str]]):
242
+ if isinstance(stop_strings, str):
243
+ stop_strings = [stop_strings]
244
+ self.stop_strings: Tuple[str, ...] = tuple(stop_strings)
245
+ vocab = tokenizer.get_vocab()
246
+ token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values())
247
+ self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache(
248
+ token_list, token_indices, self.stop_strings, tokenizer
249
+ )
250
+
251
+ self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings])
252
+ self.num_stop_strings = len(self.stop_strings)
253
+ self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32)
254
+
255
+ def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_strings, tokenizer):
256
+ # We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality
257
+ if (token_list, token_indices, stop_strings) in STOP_STRING_EMBEDDING_CACHE:
258
+ embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[
259
+ (token_list, token_indices, self.stop_strings)
260
+ ]
261
+ STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, stop_strings))
262
+ else:
263
+ clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer)
264
+ embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec(
265
+ clean_token_list, clean_token_indices, stop_strings
266
+ )
267
+ STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, stop_strings)] = (
268
+ embedding_vec,
269
+ max_valid_positions,
270
+ max_valid_end_lens,
271
+ )
272
+ if len(STOP_STRING_EMBEDDING_CACHE) > 8:
273
+ STOP_STRING_EMBEDDING_CACHE.popitem(last=False) # Pop from the start, the least recently used item
274
+ return embedding_vec, max_valid_positions, max_valid_end_lens
275
+
276
+ @staticmethod
277
+ def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"):
278
+ """
279
+ This method turns a tokenizer vocab into a "clean" vocab where each token represents the actual string
280
+ it will yield, without any special prefixes like "##" or "Ġ". This is trickier than it looks - the method
281
+ tokenizer.convert_tokens_to_string() does not always return the correct string because of issues with prefix
282
+ space addition/removal. To work around this, we add a static prefix to the start of the token, then remove
283
+ it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string().
284
+ """
285
+ vocab = tokenizer.get_vocab()
286
+ clean_token_list = []
287
+ clean_token_indices = []
288
+ sentence_base = tokenizer(static_prefix, add_special_tokens=False)["input_ids"]
289
+ tokens_base = [tokenizer._convert_id_to_token(tok) for tok in sentence_base]
290
+ for token, token_idx in vocab.items():
291
+ token_string = tokenizer.convert_tokens_to_string(tokens_base + [token])
292
+ token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :]
293
+ clean_token_list.append(token_string)
294
+ clean_token_indices.append(token_idx)
295
+ return tuple(clean_token_list), tuple(clean_token_indices)
296
+
297
+ @staticmethod
298
+ def _stop_string_get_matching_positions(
299
+ token_list, token_indices, stop_strings
300
+ ) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]:
301
+ """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can
302
+ validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the
303
+ token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters
304
+ from the end of the stop string that overlap with the start of the token, which can have more than one value.
305
+
306
+ The reason for computing these may seem a bit cryptic - please see the docstring for StopStringCriteria for a full
307
+ explanation of what these values are for!"""
308
+
309
+ token_valid_positions = {}
310
+ token_end_overlaps = {}
311
+ for stop_string in stop_strings:
312
+ reversed_stop_string = stop_string[::-1]
313
+ token_valid_positions[stop_string] = {}
314
+ token_end_overlaps[stop_string] = {}
315
+ for token, tok_idx in zip(token_list, token_indices):
316
+ reversed_token = token[::-1]
317
+ matching_positions = []
318
+ possible_end_lengths = []
319
+ for i in range(1 - len(token), len(stop_string)):
320
+ if i < 0:
321
+ tok = reversed_token[-i:]
322
+ i = 0
323
+ else:
324
+ tok = reversed_token
325
+ stop = reversed_stop_string[i : i + len(tok)]
326
+ if tok.startswith(stop):
327
+ if i == 0:
328
+ possible_end_lengths.append(min(len(tok), len(stop)))
329
+ else:
330
+ matching_positions.append(i)
331
+
332
+ if matching_positions:
333
+ token_valid_positions[stop_string][tok_idx] = matching_positions
334
+ if possible_end_lengths:
335
+ token_end_overlaps[stop_string][tok_idx] = possible_end_lengths
336
+ return token_valid_positions, token_end_overlaps
337
+
338
+ @staticmethod
339
+ def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) -> Dict[str, torch.tensor]:
340
+ """This function precomputes everything needed for the run-time checks in StopStringCriteria, and packs
341
+ them into an embedding tensor that can be accessed with pure tensor operations. For the specifics of the values
342
+ that are precomputed and what they are used for, please refer to the StopStringCriteria docstring!"""
343
+ token_valid_positions, token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions(
344
+ token_list, token_indices, stop_strings
345
+ )
346
+ all_valid_positions = [len(val) for positions in token_valid_positions.values() for val in positions.values()]
347
+ # In some cases, tokens may have no valid internal positions (such as single-character stop strings), so
348
+ # we need a fallback to handle this case
349
+ max_valid_positions = max(all_valid_positions) if all_valid_positions else 1
350
+ # There should always be at least one valid end_len, however, so no fallback needed here
351
+ valid_end_lens = [len(val) for positions in token_end_overlaps.values() for val in positions.values()]
352
+ if not valid_end_lens:
353
+ raise ValueError(
354
+ "Stop string preprocessing was unable to identify tokens matching one or more of the "
355
+ "supplied stop string(s). This is most often caused by the stop "
356
+ "strings containing unusual characters that are not in the tokenizer vocabulary."
357
+ )
358
+ max_valid_end_lens = max(valid_end_lens)
359
+ vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1
360
+ gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1)
361
+
362
+ for i, stop_string in enumerate(stop_strings):
363
+ positions = token_valid_positions[stop_string]
364
+ end_lens = token_end_overlaps[stop_string]
365
+
366
+ # Since this is lots of very small assignments of lists, we build it with numpy rather
367
+ # than torch for speed + simplicity, then convert to torch at the end
368
+ for token_idx, valid_positions in positions.items():
369
+ gather_vec[token_idx, max_valid_positions * i : max_valid_positions * i + len(valid_positions)] = (
370
+ valid_positions
371
+ )
372
+ for token_idx, possible_end_lens in end_lens.items():
373
+ gather_vec[
374
+ token_idx,
375
+ max_valid_positions * len(stop_strings) + max_valid_end_lens * i : max_valid_positions
376
+ * len(stop_strings)
377
+ + max_valid_end_lens * i
378
+ + len(possible_end_lens),
379
+ ] = possible_end_lens
380
+ for token, token_idx in zip(token_list, token_indices):
381
+ gather_vec[token_idx, -1] = len(token)
382
+
383
+ gather_vec = torch.tensor(gather_vec, dtype=torch.int32)
384
+
385
+ return gather_vec, max_valid_positions, max_valid_end_lens
386
+
387
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
388
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.Tensor:
389
+ self.embedding_vec = self.embedding_vec.to(input_ids.device)
390
+ self.target_lens = self.target_lens.to(input_ids.device)
391
+ # The maximum length we need to consider is 1 token per character. Note that input_ids can also be
392
+ # *shorter* than the global max, and the code below should be ready for that
393
+ input_ids = input_ids[:, -self.maximum_token_len :]
394
+
395
+ # Flip input_ids because we're only matching strings at the end of the generated sequence
396
+ flipped_ids = torch.flip(input_ids, (1,))
397
+
398
+ # Size of the vector of positions a single token can match
399
+ max_valid_positions = self.max_valid_positions
400
+
401
+ # The embedding vec contains the valid positions, end_lengths and total lengths for each token
402
+ embedded = F.embedding(flipped_ids, self.embedding_vec)
403
+
404
+ # Now we split the embedding vector. valid_positions is the positions in the stop string the token can fit
405
+ valid_positions = embedded[:, 1:, : max_valid_positions * self.num_stop_strings].unflatten(
406
+ -1, (self.num_stop_strings, -1)
407
+ )
408
+ # end_lengths is the number of characters from the string, counting from the end, that the token
409
+ # contains. It can have multiple values if the same token can overlap different end lengths
410
+ end_lengths = embedded[:, :1, max_valid_positions * self.num_stop_strings : -1].unflatten(
411
+ -1, (self.num_stop_strings, -1)
412
+ )
413
+ # Lengths is the total length of each token. Unlike the others, it always has a single value
414
+ lengths = embedded[:, 1:, None, -1:] # Insert a dummy dimension for stop_strings even though lengths are const
415
+
416
+ # Concatenate lengths onto each possible end_lengths value
417
+ lengths = lengths.expand((-1, -1, end_lengths.shape[-2], end_lengths.shape[-1]))
418
+ lengths_with_ends = torch.cat([end_lengths, lengths], dim=1)
419
+
420
+ # cumsum() to get the number of matched characters in the stop string after each token
421
+ cumsum = lengths_with_ends.cumsum(dim=1) # B x maximum_token_len x num_stop_strings x max_valid_end_lens
422
+
423
+ # The calculation above assumes that all tokens are in valid positions. Now we mask the ones that are not.
424
+ # First, tokens match the start of the string if they have a positive value in the end_lengths vector
425
+ initial_match = end_lengths > 0
426
+
427
+ # Tokens continue the string if the cumsum() so far is one of the valid positions for that token
428
+ # Note that we're actually tracking one cumsum() for for each possible end_length
429
+ later_match = torch.any(cumsum[:, :-1, :, None] == valid_positions[:, :, :, :, None], axis=-2)
430
+
431
+ # The match vector is a boolean vector that indicates which positions have valid tokens
432
+ match = torch.cat([initial_match, later_match], dim=1)
433
+
434
+ # Once a single position does not match, all positions following that position are masked
435
+ mask = (~match).cumsum(dim=1, dtype=torch.int32)
436
+ mask = mask == 0
437
+
438
+ # The string is matched if we reached a cumsum equal to or greater than the length of the string
439
+ # before hitting the mask
440
+ string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :]
441
+
442
+ # We return a per-sample vector that is True if any stop string is matched for that sample
443
+ return torch.any(string_matches, dim=-1)
444
+
445
+
446
+ class EosTokenCriteria(StoppingCriteria):
447
+ """
448
+ This class can be used to stop generation whenever the "end-of-sequence" token is generated.
449
+ By default, it uses the `model.generation_config.eos_token_id`.
450
+
451
+ Args:
452
+ eos_token_id (`Union[int, List[int], torch.Tensor]`):
453
+ The id(s) of the *end-of-sequence* token.
454
+ """
455
+
456
+ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor]):
457
+ if not isinstance(eos_token_id, torch.Tensor):
458
+ if isinstance(eos_token_id, int):
459
+ eos_token_id = [eos_token_id]
460
+ eos_token_id = torch.tensor(eos_token_id)
461
+ self.eos_token_id = eos_token_id
462
+
463
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
464
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
465
+ self.eos_token_id = self.eos_token_id.to(input_ids.device)
466
+ is_done = isin_mps_friendly(input_ids[:, -1], self.eos_token_id)
467
+ return is_done
468
+
469
+
470
+ class ConfidenceCriteria(StoppingCriteria):
471
+ """
472
+ This class can be used to stop generation whenever assistant model's confidence in its prediction for the current token is lower than the threshold
473
+ `model.generation_config.assistant_confidence_threshold` even if the number of speculative tokens (defined by `num_assistant_tokens`) is not yet reached.
474
+
475
+ Args:
476
+ assistant_confidence_threshold (`float`):
477
+ The value of the threshold.
478
+ """
479
+
480
+ def __init__(self, assistant_confidence_threshold):
481
+ self.assistant_confidence_threshold = assistant_confidence_threshold
482
+
483
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
484
+ probs = scores[-1].softmax(-1)
485
+ p = probs[0, input_ids[0, -1]].item()
486
+ if p < self.assistant_confidence_threshold:
487
+ return True
488
+ return False
489
+
490
+
491
+ class StoppingCriteriaList(list):
492
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
493
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
494
+ is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool)
495
+ for criteria in self:
496
+ is_done = is_done | criteria(input_ids, scores, **kwargs)
497
+ return is_done
498
+
499
+ @property
500
+ def max_length(self) -> Optional[int]:
501
+ for stopping_criterium in self:
502
+ if isinstance(stopping_criterium, MaxLengthCriteria):
503
+ return stopping_criterium.max_length
504
+ return None
505
+
506
+
507
+ def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
508
+ stopping_max_length = stopping_criteria.max_length
509
+ new_stopping_criteria = deepcopy(stopping_criteria)
510
+ if stopping_max_length is not None and stopping_max_length != max_length:
511
+ warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
512
+ elif stopping_max_length is None:
513
+ new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
514
+ return new_stopping_criteria
.venv/lib/python3.11/site-packages/transformers/generation/streamers.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ from queue import Queue
20
+ from typing import TYPE_CHECKING, Optional
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ from ..models.auto import AutoTokenizer
25
+
26
+
27
+ class BaseStreamer:
28
+ """
29
+ Base class from which `.generate()` streamers should inherit.
30
+ """
31
+
32
+ def put(self, value):
33
+ """Function that is called by `.generate()` to push new tokens"""
34
+ raise NotImplementedError()
35
+
36
+ def end(self):
37
+ """Function that is called by `.generate()` to signal the end of generation"""
38
+ raise NotImplementedError()
39
+
40
+
41
+ class TextStreamer(BaseStreamer):
42
+ """
43
+ Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
44
+
45
+ <Tip warning={true}>
46
+
47
+ The API for the streamer classes is still under development and may change in the future.
48
+
49
+ </Tip>
50
+
51
+ Parameters:
52
+ tokenizer (`AutoTokenizer`):
53
+ The tokenized used to decode the tokens.
54
+ skip_prompt (`bool`, *optional*, defaults to `False`):
55
+ Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
56
+ decode_kwargs (`dict`, *optional*):
57
+ Additional keyword arguments to pass to the tokenizer's `decode` method.
58
+
59
+ Examples:
60
+
61
+ ```python
62
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
63
+
64
+ >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
65
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
66
+ >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
67
+ >>> streamer = TextStreamer(tok)
68
+
69
+ >>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
70
+ >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
71
+ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
72
+ ```
73
+ """
74
+
75
+ def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
76
+ self.tokenizer = tokenizer
77
+ self.skip_prompt = skip_prompt
78
+ self.decode_kwargs = decode_kwargs
79
+
80
+ # variables used in the streaming process
81
+ self.token_cache = []
82
+ self.print_len = 0
83
+ self.next_tokens_are_prompt = True
84
+
85
+ def put(self, value):
86
+ """
87
+ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
88
+ """
89
+ if len(value.shape) > 1 and value.shape[0] > 1:
90
+ raise ValueError("TextStreamer only supports batch size 1")
91
+ elif len(value.shape) > 1:
92
+ value = value[0]
93
+
94
+ if self.skip_prompt and self.next_tokens_are_prompt:
95
+ self.next_tokens_are_prompt = False
96
+ return
97
+
98
+ # Add the new token to the cache and decodes the entire thing.
99
+ self.token_cache.extend(value.tolist())
100
+ text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
101
+
102
+ # After the symbol for a new line, we flush the cache.
103
+ if text.endswith("\n"):
104
+ printable_text = text[self.print_len :]
105
+ self.token_cache = []
106
+ self.print_len = 0
107
+ # If the last token is a CJK character, we print the characters.
108
+ elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
109
+ printable_text = text[self.print_len :]
110
+ self.print_len += len(printable_text)
111
+ # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
112
+ # which may change with the subsequent token -- there are probably smarter ways to do this!)
113
+ else:
114
+ printable_text = text[self.print_len : text.rfind(" ") + 1]
115
+ self.print_len += len(printable_text)
116
+
117
+ self.on_finalized_text(printable_text)
118
+
119
+ def end(self):
120
+ """Flushes any remaining cache and prints a newline to stdout."""
121
+ # Flush the cache, if it exists
122
+ if len(self.token_cache) > 0:
123
+ text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
124
+ printable_text = text[self.print_len :]
125
+ self.token_cache = []
126
+ self.print_len = 0
127
+ else:
128
+ printable_text = ""
129
+
130
+ self.next_tokens_are_prompt = True
131
+ self.on_finalized_text(printable_text, stream_end=True)
132
+
133
+ def on_finalized_text(self, text: str, stream_end: bool = False):
134
+ """Prints the new text to stdout. If the stream is ending, also prints a newline."""
135
+ print(text, flush=True, end="" if not stream_end else None)
136
+
137
+ def _is_chinese_char(self, cp):
138
+ """Checks whether CP is the codepoint of a CJK character."""
139
+ # This defines a "chinese character" as anything in the CJK Unicode block:
140
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
141
+ #
142
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
143
+ # despite its name. The modern Korean Hangul alphabet is a different block,
144
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
145
+ # space-separated words, so they are not treated specially and handled
146
+ # like the all of the other languages.
147
+ if (
148
+ (cp >= 0x4E00 and cp <= 0x9FFF)
149
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
150
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
151
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
152
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
153
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
154
+ or (cp >= 0xF900 and cp <= 0xFAFF)
155
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
156
+ ): #
157
+ return True
158
+
159
+ return False
160
+
161
+
162
+ class TextIteratorStreamer(TextStreamer):
163
+ """
164
+ Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
165
+ useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive
166
+ Gradio demo).
167
+
168
+ <Tip warning={true}>
169
+
170
+ The API for the streamer classes is still under development and may change in the future.
171
+
172
+ </Tip>
173
+
174
+ Parameters:
175
+ tokenizer (`AutoTokenizer`):
176
+ The tokenized used to decode the tokens.
177
+ skip_prompt (`bool`, *optional*, defaults to `False`):
178
+ Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
179
+ timeout (`float`, *optional*):
180
+ The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
181
+ in `.generate()`, when it is called in a separate thread.
182
+ decode_kwargs (`dict`, *optional*):
183
+ Additional keyword arguments to pass to the tokenizer's `decode` method.
184
+
185
+ Examples:
186
+
187
+ ```python
188
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
189
+ >>> from threading import Thread
190
+
191
+ >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
192
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
193
+ >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
194
+ >>> streamer = TextIteratorStreamer(tok)
195
+
196
+ >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
197
+ >>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
198
+ >>> thread = Thread(target=model.generate, kwargs=generation_kwargs)
199
+ >>> thread.start()
200
+ >>> generated_text = ""
201
+ >>> for new_text in streamer:
202
+ ... generated_text += new_text
203
+ >>> generated_text
204
+ 'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,'
205
+ ```
206
+ """
207
+
208
+ def __init__(
209
+ self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs
210
+ ):
211
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
212
+ self.text_queue = Queue()
213
+ self.stop_signal = None
214
+ self.timeout = timeout
215
+
216
+ def on_finalized_text(self, text: str, stream_end: bool = False):
217
+ """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
218
+ self.text_queue.put(text, timeout=self.timeout)
219
+ if stream_end:
220
+ self.text_queue.put(self.stop_signal, timeout=self.timeout)
221
+
222
+ def __iter__(self):
223
+ return self
224
+
225
+ def __next__(self):
226
+ value = self.text_queue.get(timeout=self.timeout)
227
+ if value == self.stop_signal:
228
+ raise StopIteration()
229
+ else:
230
+ return value
231
+
232
+
233
+ class AsyncTextIteratorStreamer(TextStreamer):
234
+ """
235
+ Streamer that stores print-ready text in a queue, to be used by a downstream application as an async iterator.
236
+ This is useful for applications that benefit from acessing the generated text asynchronously (e.g. in an
237
+ interactive Gradio demo).
238
+
239
+ <Tip warning={true}>
240
+
241
+ The API for the streamer classes is still under development and may change in the future.
242
+
243
+ </Tip>
244
+
245
+ Parameters:
246
+ tokenizer (`AutoTokenizer`):
247
+ The tokenized used to decode the tokens.
248
+ skip_prompt (`bool`, *optional*, defaults to `False`):
249
+ Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
250
+ timeout (`float`, *optional*):
251
+ The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
252
+ in `.generate()`, when it is called in a separate thread.
253
+ decode_kwargs (`dict`, *optional*):
254
+ Additional keyword arguments to pass to the tokenizer's `decode` method.
255
+
256
+ Raises:
257
+ TimeoutError: If token generation time exceeds timeout value.
258
+
259
+ Examples:
260
+
261
+ ```python
262
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer, AsyncTextIteratorStreamer
263
+ >>> from threading import Thread
264
+ >>> import asyncio
265
+
266
+ >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
267
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
268
+ >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
269
+
270
+ >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
271
+ >>> async def main():
272
+ ... # Important: AsyncTextIteratorStreamer must be initialized inside a coroutine!
273
+ ... streamer = AsyncTextIteratorStreamer(tok)
274
+ ... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
275
+ ... thread = Thread(target=model.generate, kwargs=generation_kwargs)
276
+ ... thread.start()
277
+ ... generated_text = ""
278
+ ... async for new_text in streamer:
279
+ ... generated_text += new_text
280
+ >>> print(generated_text)
281
+ >>> asyncio.run(main())
282
+ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
283
+ ```
284
+ """
285
+
286
+ def __init__(
287
+ self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: float | None = None, **decode_kwargs
288
+ ):
289
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
290
+ self.text_queue = asyncio.Queue()
291
+ self.stop_signal = None
292
+ self.timeout = timeout
293
+ self.loop = asyncio.get_running_loop()
294
+ self.has_asyncio_timeout = hasattr(asyncio, "timeout")
295
+
296
+ def on_finalized_text(self, text: str, stream_end: bool = False):
297
+ """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
298
+ self.loop.call_soon_threadsafe(self.text_queue.put_nowait, text)
299
+ if stream_end:
300
+ self.loop.call_soon_threadsafe(self.text_queue.put_nowait, self.stop_signal)
301
+
302
+ def __aiter__(self):
303
+ return self
304
+
305
+ async def __anext__(self):
306
+ try:
307
+ if self.has_asyncio_timeout:
308
+ async with asyncio.timeout(self.timeout):
309
+ value = await self.text_queue.get()
310
+ else:
311
+ value = await asyncio.wait_for(self.text_queue.get(), timeout=self.timeout)
312
+ except asyncio.TimeoutError:
313
+ raise TimeoutError()
314
+ else:
315
+ if value == self.stop_signal:
316
+ raise StopAsyncIteration()
317
+ else:
318
+ return value
.venv/lib/python3.11/site-packages/transformers/generation/tf_logits_process.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from typing import List, Tuple
18
+
19
+ import numpy as np
20
+ import tensorflow as tf
21
+
22
+ from ..tf_utils import stable_softmax
23
+ from ..utils import add_start_docstrings
24
+ from ..utils.logging import get_logger
25
+
26
+
27
+ logger = get_logger(__name__)
28
+
29
+
30
+ TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
31
+ Args:
32
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
33
+ Indices of input sequence tokens in the vocabulary.
34
+
35
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
36
+ [`PreTrainedTokenizer.__call__`] for details.
37
+
38
+ [What are input IDs?](../glossary#input-ids)
39
+ scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
40
+ Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
41
+ search or log softmax for each vocabulary token when using beam search.
42
+ cur_len (`int`):
43
+ The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length
44
+ is the maximum length generate can produce, and we need to know which of its tokens are valid.
45
+ kwargs (`Dict[str, Any]`, *optional*):
46
+ Additional logits processor specific kwargs.
47
+
48
+ Return:
49
+ `tf.Tensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
50
+ """
51
+
52
+
53
+ class TFLogitsProcessor:
54
+ """Abstract base class for all logit processors that can be applied during generation."""
55
+
56
+ @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
57
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
58
+ """TF method for processing logits."""
59
+ raise NotImplementedError(
60
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
61
+ )
62
+
63
+
64
+ class TFLogitsWarper:
65
+ """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
66
+
67
+ @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
68
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
69
+ """TF method for warping logits."""
70
+ raise NotImplementedError(
71
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
72
+ )
73
+
74
+
75
+ class TFLogitsProcessorList(list):
76
+ """
77
+ This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor.
78
+ This class inherits from list and adds a specific *__call__* method to apply each [`TFLogitsProcessor`] to the
79
+ inputs.
80
+ """
81
+
82
+ @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
83
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor:
84
+ for processor in self:
85
+ function_args = inspect.signature(processor.__call__).parameters
86
+ if len(function_args) > 3:
87
+ if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
88
+ raise ValueError(
89
+ f"Make sure that all the required parameters: {list(function_args.keys())} for "
90
+ f"{processor.__class__} are passed to the logits processor."
91
+ )
92
+ scores = processor(input_ids, scores, cur_len, **kwargs)
93
+ else:
94
+ scores = processor(input_ids, scores, cur_len)
95
+ return scores
96
+
97
+
98
+ class TFTemperatureLogitsWarper(TFLogitsWarper):
99
+ r"""
100
+ [`TFLogitsWarper`] for temperature (exponential scaling output probability distribution).
101
+
102
+ Args:
103
+ temperature (`float`):
104
+ The value used to module the logits distribution.
105
+ """
106
+
107
+ def __init__(self, temperature: float):
108
+ if not isinstance(temperature, float) or not (temperature > 0):
109
+ raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
110
+
111
+ self.temperature = temperature
112
+
113
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
114
+ scores = scores / self.temperature
115
+ return scores
116
+
117
+
118
+ class TFTopKLogitsWarper(TFLogitsWarper):
119
+ r"""
120
+ [`TFLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
121
+
122
+ Args:
123
+ top_k (`int`):
124
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
125
+ filter_value (`float`, *optional*, defaults to -inf):
126
+ All filtered values will be set to this float value.
127
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
128
+ Minimum number of tokens that cannot be filtered.
129
+ """
130
+
131
+ def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
132
+ if not isinstance(top_k, int) or top_k <= 0:
133
+ raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
134
+
135
+ self.top_k = max(top_k, min_tokens_to_keep)
136
+ self.filter_value = filter_value
137
+
138
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
139
+ top_k = min(self.top_k, scores.shape[-1]) # Safety check
140
+ # Boolean mask containing all tokens with a probability less than the last token of the top-k
141
+ indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
142
+ next_scores = tf.where(indices_to_remove, self.filter_value, scores)
143
+ return next_scores
144
+
145
+
146
+ class TFTopPLogitsWarper(TFLogitsWarper):
147
+ """
148
+ [`TFLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to <= prob_cut_off.
149
+
150
+ Args:
151
+ top_p (`float`):
152
+ If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
153
+ higher are kept for generation.
154
+ filter_value (`float`, *optional*, defaults to -inf):
155
+ All filtered values will be set to this float value.
156
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
157
+ Minimum number of tokens that cannot be filtered.
158
+ """
159
+
160
+ def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
161
+ if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
162
+ raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
163
+ if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
164
+ raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
165
+
166
+ self.top_p = top_p
167
+ self.filter_value = filter_value
168
+ self.min_tokens_to_keep = min_tokens_to_keep
169
+
170
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
171
+ topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
172
+
173
+ mask_scores = tf.fill(scores.shape, self.filter_value)
174
+ cumulative_probs = tf.math.cumsum(stable_softmax(topk_scores, axis=-1), axis=-1)
175
+ score_mask = cumulative_probs < self.top_p
176
+
177
+ # Also include the token that is higher than top_p (the first false = shift and insert a True on the left)
178
+ score_mask = tf.concat((tf.ones([score_mask.shape[0], 1], dtype=tf.bool), score_mask[:, :-1]), axis=-1)
179
+
180
+ # Ensure min tokens to keep
181
+ score_mask = tf.concat(
182
+ (
183
+ tf.ones([score_mask.shape[0], self.min_tokens_to_keep], dtype=tf.bool),
184
+ score_mask[:, self.min_tokens_to_keep :],
185
+ ),
186
+ axis=-1,
187
+ )
188
+
189
+ # Mask the values that do not fit the criteria
190
+ topk_next_scores = tf.where(score_mask, topk_scores, mask_scores)
191
+
192
+ # Undo the topk sorting: converts the 2D matrix of per-row original indices of shape (batch_size, vocab_size)
193
+ # to a 3D tensor of shape (batch_size, vocab_size, 2) containing the original score coordinate, from which we
194
+ # can scatter (i.e. `scatter_indices[row, col, :]` is a tensor containing `[row, topk_indices[row, col]]`)
195
+ scatter_rows = tf.tile(tf.expand_dims(tf.range(topk_indices.shape[0]), axis=-1), [1, topk_indices.shape[-1]])
196
+ scatter_indices = tf.stack((scatter_rows, topk_indices), axis=-1)
197
+ next_scores = tf.scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape)
198
+
199
+ return next_scores
200
+
201
+
202
+ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
203
+ r"""
204
+ [`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
205
+
206
+ Args:
207
+ min_length (`int`):
208
+ The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
209
+ eos_token_id (`int`):
210
+ The id of the *end-of-sequence* token.
211
+ """
212
+
213
+ def __init__(self, min_length: int, eos_token_id: int):
214
+ if not isinstance(min_length, int) or min_length < 0:
215
+ raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
216
+
217
+ if not isinstance(eos_token_id, int) or eos_token_id < 0:
218
+ raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
219
+
220
+ self.min_length = min_length
221
+ self.eos_token_id = eos_token_id
222
+
223
+ def _apply_eos_token_mask(self, scores: tf.Tensor) -> tf.Tensor:
224
+ eos_token_id_mask = tf.range(scores.shape[-1]) == self.eos_token_id
225
+ scores = tf.where(eos_token_id_mask, float("-inf"), scores)
226
+ return scores
227
+
228
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
229
+ # applies eos token masking if the first argument is true
230
+ scores = tf.cond(
231
+ tf.less(cur_len, self.min_length),
232
+ lambda: self._apply_eos_token_mask(scores),
233
+ lambda: tf.identity(scores),
234
+ )
235
+ return scores
236
+
237
+
238
+ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
239
+ r"""
240
+ [`TFLogitsProcessor`] enforcing an exponential penalty on repeated sequences.
241
+
242
+ Args:
243
+ repetition_penalty (`float`):
244
+ The parameter for repetition penalty. 1.0 means no penalty. See [this
245
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
246
+ """
247
+
248
+ def __init__(self, penalty: float):
249
+ if not isinstance(penalty, float) or not (penalty > 0):
250
+ raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
251
+
252
+ self.penalty = penalty
253
+
254
+ def _create_score_penalties(self, input_ids: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
255
+ # We want to populate the penalties in the positions of `input_ids`. Since XLA can't handle shapes unknown
256
+ # before runtime, `tf.unique` can't be used. Therefore, we may have redundant updates, when a given row has
257
+ # the same token multiple times.
258
+
259
+ # Gathers the penalties to apply
260
+ logit_penalties = tf.gather(logits, input_ids, axis=1, batch_dims=1)
261
+ logit_penalties = tf.where(logit_penalties > 0, 1 / self.penalty, logit_penalties)
262
+ logit_penalties = tf.where(logit_penalties < 0, self.penalty, logit_penalties)
263
+
264
+ # Scatters the penalties
265
+ token_penalties = tf.ones(logits.shape)
266
+ batch_size = input_ids.shape[0]
267
+ seq_len = tf.shape(input_ids)[1] # the sequence length has dynamic size, hence the dynamic shape
268
+ indexable_prev_input_ids = tf.concat(
269
+ (
270
+ tf.expand_dims(tf.repeat(tf.range(batch_size), seq_len), axis=-1),
271
+ tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
272
+ ),
273
+ axis=1,
274
+ )
275
+ token_penalties = tf.tensor_scatter_nd_update(
276
+ token_penalties, indices=indexable_prev_input_ids, updates=tf.reshape(logit_penalties, [-1])
277
+ )
278
+ return token_penalties
279
+
280
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
281
+ score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores)
282
+
283
+ scores = tf.math.multiply(scores, score_penalties)
284
+
285
+ return scores
286
+
287
+
288
+ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
289
+ """
290
+ [`TFLogitsProcessor`] that enforces that specified sequences will never be sampled.
291
+
292
+ Args:
293
+ bad_words_ids (`List[List[int]]`):
294
+ List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
295
+ that should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing
296
+ the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space`
297
+ argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from
298
+ `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
299
+ eos_token_id (`int`):
300
+ The id of the *end-of-sequence* token.
301
+ """
302
+
303
+ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
304
+ if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
305
+ raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
306
+ if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
307
+ raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
308
+ if any(
309
+ any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
310
+ for bad_word_ids in bad_words_ids
311
+ ):
312
+ raise ValueError(
313
+ f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
314
+ )
315
+
316
+ # stores the information about bad words in three tensors:
317
+ # 1. a rectangular tensor with the forbidden sequences (padded with `-1`), for full data comparisons
318
+ self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1)
319
+ # 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons
320
+ bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids]
321
+ if any(word_len == 0 for word_len in bad_word_seqs_len):
322
+ raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
323
+ self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32)
324
+ # 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned
325
+ self.seq_forbidden_tokens = tf.convert_to_tensor([bad_words[-1] for bad_words in bad_words_ids])
326
+
327
+ def _calc_row_banned_bad_tokens(self, row_input_ids: tf.Tensor) -> tf.Tensor:
328
+ def _tokens_match(bad_word_seq_number):
329
+ def _len_one():
330
+ # If the bad sequence only has one token, always mask it
331
+ return tf.cond(
332
+ tf.math.equal(self.bad_word_seqs_len[bad_word_seq_number], 1),
333
+ lambda: tf.ones((), dtype=tf.bool),
334
+ _len_greater_than_cur_len,
335
+ )
336
+
337
+ def _len_greater_than_cur_len():
338
+ # Otherwise, if the bad sequence is longer than the current length they can't ever match
339
+ return tf.cond(
340
+ tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], tf.shape(row_input_ids)[0]),
341
+ lambda: tf.zeros((), dtype=tf.bool),
342
+ _match_found,
343
+ )
344
+
345
+ def _match_found():
346
+ # Finaly, runs the actual comparison. Can only be called if the previous comparisons do not yield
347
+ # an answer (otherwise we get indexing exceptions)
348
+ compare_len = self.bad_word_seqs_len[bad_word_seq_number] - 1
349
+ return tf.cond(
350
+ tf.math.reduce_all(
351
+ tf.math.equal(
352
+ row_input_ids[-compare_len:], self.bad_word_seqs_ids[bad_word_seq_number, :compare_len]
353
+ )
354
+ ),
355
+ lambda: tf.ones((), dtype=tf.bool),
356
+ lambda: tf.zeros((), dtype=tf.bool),
357
+ )
358
+
359
+ match = _len_one()
360
+ return match
361
+
362
+ # Compares the current row against all bad word sequences, obtaining a mask with the matches.
363
+ match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool)
364
+ row_banned_tokens = self.seq_forbidden_tokens[match_mask]
365
+ return row_banned_tokens
366
+
367
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
368
+ # We want to mask some banned tokens, at a score level. Since the banned tokens depend on the previous
369
+ # `input_ids`, they may have a different length for each row, and they may even be empty for some rows.
370
+ # To remain simple and XLA-compatible, we work on a per-row fashion.
371
+ # TODO (Joao): this function might trigger XLA retracing as `cur_len` increases. Fix it if it becomes
372
+ # a frequent choke point. (make `cur_len` a tensor?)
373
+ def _get_row_updated_score(row_inputs: Tuple[tf.Tensor]) -> tf.Tensor:
374
+ row_input_ids, row_score = row_inputs
375
+ banned_tokens = self._calc_row_banned_bad_tokens(row_input_ids[:cur_len])
376
+ banned_tokens_mask = tf.scatter_nd(
377
+ indices=tf.expand_dims(banned_tokens, axis=-1),
378
+ updates=tf.ones_like(banned_tokens, dtype=tf.bool),
379
+ shape=row_score.shape,
380
+ )
381
+ row_score = tf.where(banned_tokens_mask, -float("inf"), row_score)
382
+ return row_score
383
+
384
+ scores = tf.map_fn(_get_row_updated_score, (input_ids, scores), fn_output_signature=tf.float32)
385
+ return scores
386
+
387
+
388
+ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
389
+ r"""
390
+ [`TFLogitsProcessor`] that enforces no repetition of n-grams. See
391
+ [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
392
+
393
+ Args:
394
+ ngram_size (`int`):
395
+ All ngrams of size `ngram_size` can only occur once.
396
+ """
397
+
398
+ def __init__(self, ngram_size: int):
399
+ if not isinstance(ngram_size, int) or ngram_size <= 0:
400
+ raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
401
+ self.ngram_size = ngram_size
402
+
403
+ def calc_banned_ngram_tokens(self, input_ids, num_hypos, cur_len):
404
+ # Copied from fairseq for no_repeat_ngram in beam_search
405
+ if cur_len + 1 < self.ngram_size:
406
+ # return no banned tokens if we haven't generated ngram_size tokens yet
407
+ return [[] for _ in range(num_hypos)]
408
+ generated_ngrams = [{} for _ in range(num_hypos)]
409
+ prev_input_ids = input_ids[:, :cur_len]
410
+ for idx in range(num_hypos):
411
+ gen_tokens = prev_input_ids[idx].numpy().tolist()
412
+ generated_ngram = generated_ngrams[idx]
413
+ for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
414
+ prev_ngram_tuple = tuple(ngram[:-1])
415
+ generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
416
+
417
+ def _get_generated_ngrams(hypo_idx):
418
+ # Before decoding the next token, prevent decoding of ngrams that have already appeared
419
+ start_idx = cur_len + 1 - self.ngram_size
420
+ ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
421
+ return generated_ngrams[hypo_idx].get(ngram_idx, [])
422
+
423
+ banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
424
+
425
+ return banned_tokens
426
+
427
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
428
+ # TODO (joao): enable XLA on this logits processor. See discussion and attempts in
429
+ # https://github.com/huggingface/transformers/pull/16974
430
+ if not tf.executing_eagerly():
431
+ raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")
432
+
433
+ batch_size, vocab_size = scores.shape
434
+ banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)
435
+
436
+ # create banned_tokens boolean mask
437
+ banned_tokens_indices_mask = []
438
+ for banned_tokens_slice in banned_tokens:
439
+ banned_tokens_indices_mask.append(
440
+ [True if token in banned_tokens_slice else False for token in range(vocab_size)]
441
+ )
442
+
443
+ scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
444
+
445
+ return scores
446
+
447
+
448
+ class TFForcedBOSTokenLogitsProcessor(TFLogitsProcessor):
449
+ r"""
450
+ [`TFLogitsProcessor`] that enforces the specified token as the first generated token.
451
+
452
+ Args:
453
+ bos_token_id (`int`):
454
+ The id of the token to force as the first generated token.
455
+ """
456
+
457
+ def __init__(self, bos_token_id: int):
458
+ if bos_token_id < 0:
459
+ raise ValueError(f"The forced bos token id must be a non-negative integer, got {bos_token_id}")
460
+ self.bos_token_id = bos_token_id
461
+
462
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
463
+ if cur_len == 1:
464
+ batch_size, num_tokens = scores.shape
465
+ # sets the score to 0 in the bos_token_id column
466
+ scores = tf.zeros((batch_size, 1))
467
+ # sets the score to -inf everywhere else
468
+ if self.bos_token_id > 0:
469
+ scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.bos_token_id)), scores), axis=-1)
470
+ if self.bos_token_id < (num_tokens - 1):
471
+ scores = tf.concat(
472
+ (scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.bos_token_id))),
473
+ axis=-1,
474
+ )
475
+ return scores
476
+
477
+
478
+ class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor):
479
+ r"""
480
+ [`TFLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
481
+
482
+ Args:
483
+ max_length (`int`):
484
+ The maximum length of the sequence to be generated.
485
+ eos_token_id (`int`):
486
+ The id of the token to force as the last generated token when `max_length` is reached.
487
+ """
488
+
489
+ def __init__(self, max_length: int, eos_token_id: int):
490
+ self.max_length = max_length
491
+ if eos_token_id < 0:
492
+ raise ValueError(f"The forced eos token id must be a non-negative integer, got {eos_token_id}")
493
+ self.eos_token_id = eos_token_id
494
+
495
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
496
+ if cur_len == self.max_length - 1:
497
+ batch_size, num_tokens = scores.shape
498
+ # sets the score to 0 in the eos_token_id column
499
+ scores = tf.zeros((batch_size, 1))
500
+ # sets the score to -inf everywhere else
501
+ if self.eos_token_id > 0:
502
+ scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.eos_token_id)), scores), axis=-1)
503
+ if self.eos_token_id < (num_tokens - 1):
504
+ scores = tf.concat(
505
+ (scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.eos_token_id))),
506
+ axis=-1,
507
+ )
508
+ return scores
509
+
510
+
511
+ class TFSuppressTokensAtBeginLogitsProcessor(TFLogitsProcessor):
512
+ r"""
513
+ [`TFSuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts
514
+ generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not
515
+ sampled at the beginning of the generation.
516
+ """
517
+
518
+ def __init__(self, begin_suppress_tokens, begin_index):
519
+ self.begin_suppress_tokens = list(begin_suppress_tokens)
520
+ self.begin_index = begin_index
521
+
522
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
523
+ suppressed_indices = []
524
+ for token in self.begin_suppress_tokens:
525
+ if token < scores.shape[-1]: # to ensure we don't go beyond the vocab size
526
+ suppressed_indices.extend([[i, token] for i in range(scores.shape[0])])
527
+
528
+ if len(suppressed_indices) > 0:
529
+ scores = tf.cond(
530
+ tf.equal(cur_len, self.begin_index),
531
+ lambda: tf.tensor_scatter_nd_update(
532
+ scores,
533
+ indices=suppressed_indices,
534
+ updates=[-float("inf") for _ in range(scores.shape[0] * len(self.begin_suppress_tokens))],
535
+ ),
536
+ lambda: scores,
537
+ )
538
+ return scores
539
+
540
+
541
+ class TFSuppressTokensLogitsProcessor(TFLogitsProcessor):
542
+ r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
543
+ are not sampled."""
544
+
545
+ def __init__(self, suppress_tokens):
546
+ self.suppress_tokens = list(suppress_tokens)
547
+
548
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
549
+ suppressed_indices = []
550
+ for token in self.suppress_tokens:
551
+ if token < scores.shape[-1]: # to ensure we don't go beyond the vocab size
552
+ suppressed_indices.extend([[i, token] for i in range(scores.shape[0])])
553
+
554
+ if len(suppressed_indices) > 0:
555
+ scores = tf.tensor_scatter_nd_update(
556
+ scores,
557
+ indices=[[i, token] for i in range(scores.shape[0]) for token in self.suppress_tokens],
558
+ updates=[-float("inf") for _ in range(scores.shape[0] * len(self.suppress_tokens))],
559
+ )
560
+ return scores
561
+
562
+
563
+ class TFForceTokensLogitsProcessor(TFLogitsProcessor):
564
+ r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
565
+ indices that will be forced before sampling. The processor will set their log probs to `0` and all other tokens to
566
+ `-inf` so that they are sampled at their corresponding index."""
567
+
568
+ def __init__(self, force_token_map: List[List[int]]):
569
+ force_token_map = dict(force_token_map)
570
+ # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
571
+ # index of the array corresponds to the index of the token to be forced, for XLA compatibility.
572
+ # Indexes without forced tokens will have an negative value.
573
+ force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1
574
+ for index, token in force_token_map.items():
575
+ if token is not None:
576
+ force_token_array[index] = token
577
+ self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32)
578
+
579
+ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
580
+ def _force_token(generation_idx):
581
+ batch_size = scores.shape[0]
582
+ current_token = self.force_token_array[generation_idx]
583
+
584
+ new_scores = tf.zeros_like(scores, dtype=scores.dtype) + tf.constant([scores.dtype.min])
585
+ indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)
586
+ updates = tf.zeros((batch_size,), dtype=scores.dtype)
587
+ new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)
588
+ return new_scores
589
+
590
+ scores = tf.cond(
591
+ tf.greater_equal(cur_len, tf.shape(self.force_token_array)[0]),
592
+ # If the current length is geq than the length of force_token_array, the processor does nothing.
593
+ lambda: tf.identity(scores),
594
+ # Otherwise, it may force a certain token.
595
+ lambda: tf.cond(
596
+ tf.greater_equal(self.force_token_array[cur_len], 0),
597
+ # Only valid (positive) tokens are forced
598
+ lambda: _force_token(cur_len),
599
+ # Otherwise, the processor does nothing.
600
+ lambda: scores,
601
+ ),
602
+ )
603
+ return scores
.venv/lib/python3.11/site-packages/transformers/generation/tf_utils.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/transformers/generation/utils.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/transformers/generation/watermarking.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team and Google DeepMind.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import collections
17
+ from dataclasses import dataclass
18
+ from functools import lru_cache
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import BCELoss
25
+
26
+ from ..modeling_utils import PreTrainedModel
27
+ from ..utils import ModelOutput, is_torch_available, logging
28
+ from .configuration_utils import PretrainedConfig, WatermarkingConfig
29
+
30
+
31
+ if is_torch_available():
32
+ import torch
33
+
34
+ from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ @dataclass
41
+ class WatermarkDetectorOutput:
42
+ """
43
+ Outputs of a watermark detector.
44
+
45
+ Args:
46
+ num_tokens_scored (np.array of shape (batch_size)):
47
+ Array containing the number of tokens scored for each element in the batch.
48
+ num_green_tokens (np.array of shape (batch_size)):
49
+ Array containing the number of green tokens for each element in the batch.
50
+ green_fraction (np.array of shape (batch_size)):
51
+ Array containing the fraction of green tokens for each element in the batch.
52
+ z_score (np.array of shape (batch_size)):
53
+ Array containing the z-score for each element in the batch. Z-score here shows
54
+ how many standard deviations away is the green token count in the input text
55
+ from the expected green token count for machine-generated text.
56
+ p_value (np.array of shape (batch_size)):
57
+ Array containing the p-value for each batch obtained from z-scores.
58
+ prediction (np.array of shape (batch_size)), *optional*:
59
+ Array containing boolean predictions whether a text is machine-generated for each element in the batch.
60
+ confidence (np.array of shape (batch_size)), *optional*:
61
+ Array containing confidence scores of a text being machine-generated for each element in the batch.
62
+ """
63
+
64
+ num_tokens_scored: np.array = None
65
+ num_green_tokens: np.array = None
66
+ green_fraction: np.array = None
67
+ z_score: np.array = None
68
+ p_value: np.array = None
69
+ prediction: Optional[np.array] = None
70
+ confidence: Optional[np.array] = None
71
+
72
+
73
+ class WatermarkDetector:
74
+ r"""
75
+ Detector for detection of watermark generated text. The detector needs to be given the exact same settings that were
76
+ given during text generation to replicate the watermark greenlist generation and so detect the watermark. This includes
77
+ the correct device that was used during text generation, the correct watermarking arguments and the correct tokenizer vocab size.
78
+ The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main).
79
+
80
+ See [the paper](https://arxiv.org/abs/2306.04634) for more information.
81
+
82
+ Args:
83
+ model_config (`PretrainedConfig`):
84
+ The model config that will be used to get model specific arguments used when generating.
85
+ device (`str`):
86
+ The device which was used during watermarked text generation.
87
+ watermarking_config (Union[`WatermarkingConfig`, `Dict`]):
88
+ The exact same watermarking config and arguments used when generating text.
89
+ ignore_repeated_ngrams (`bool`, *optional*, defaults to `False`):
90
+ Whether to count every unique ngram only once or not.
91
+ max_cache_size (`int`, *optional*, defaults to 128):
92
+ The max size to be used for LRU caching of seeding/sampling algorithms called for every token.
93
+
94
+ Examples:
95
+
96
+ ```python
97
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig
98
+
99
+ >>> model_id = "openai-community/gpt2"
100
+ >>> model = AutoModelForCausalLM.from_pretrained(model_id)
101
+ >>> tok = AutoTokenizer.from_pretrained(model_id)
102
+ >>> tok.pad_token_id = tok.eos_token_id
103
+ >>> tok.padding_side = "left"
104
+
105
+ >>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
106
+ >>> input_len = inputs["input_ids"].shape[-1]
107
+
108
+ >>> # first generate text with watermark and without
109
+ >>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
110
+ >>> out_watermarked = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20)
111
+ >>> out = model.generate(**inputs, do_sample=False, max_length=20)
112
+
113
+ >>> # now we can instantiate the detector and check the generated text
114
+ >>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config)
115
+ >>> detection_out_watermarked = detector(out_watermarked, return_dict=True)
116
+ >>> detection_out = detector(out, return_dict=True)
117
+ >>> detection_out_watermarked.prediction
118
+ array([ True, True])
119
+
120
+ >>> detection_out.prediction
121
+ array([False, False])
122
+ ```
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ model_config: PretrainedConfig,
128
+ device: str,
129
+ watermarking_config: Union[WatermarkingConfig, Dict],
130
+ ignore_repeated_ngrams: bool = False,
131
+ max_cache_size: int = 128,
132
+ ):
133
+ if isinstance(watermarking_config, WatermarkingConfig):
134
+ watermarking_config = watermarking_config.to_dict()
135
+
136
+ self.bos_token_id = (
137
+ model_config.bos_token_id if not model_config.is_encoder_decoder else model_config.decoder_start_token_id
138
+ )
139
+ self.greenlist_ratio = watermarking_config["greenlist_ratio"]
140
+ self.ignore_repeated_ngrams = ignore_repeated_ngrams
141
+ self.processor = WatermarkLogitsProcessor(
142
+ vocab_size=model_config.vocab_size, device=device, **watermarking_config
143
+ )
144
+
145
+ # Expensive re-seeding and sampling is cached.
146
+ self._get_ngram_score_cached = lru_cache(maxsize=max_cache_size)(self._get_ngram_score)
147
+
148
+ def _get_ngram_score(self, prefix: torch.LongTensor, target: int):
149
+ greenlist_ids = self.processor._get_greenlist_ids(prefix)
150
+ return target in greenlist_ids
151
+
152
+ def _score_ngrams_in_passage(self, input_ids: torch.LongTensor):
153
+ batch_size, seq_length = input_ids.shape
154
+ selfhash = int(self.processor.seeding_scheme == "selfhash")
155
+ n = self.processor.context_width + 1 - selfhash
156
+ indices = torch.arange(n).unsqueeze(0) + torch.arange(seq_length - n + 1).unsqueeze(1)
157
+ ngram_tensors = input_ids[:, indices]
158
+
159
+ num_tokens_scored_batch = np.zeros(batch_size)
160
+ green_token_count_batch = np.zeros(batch_size)
161
+ for batch_idx in range(ngram_tensors.shape[0]):
162
+ frequencies_table = collections.Counter(ngram_tensors[batch_idx])
163
+ ngram_to_watermark_lookup = {}
164
+ for ngram_example in frequencies_table.keys():
165
+ prefix = ngram_example if selfhash else ngram_example[:-1]
166
+ target = ngram_example[-1]
167
+ ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target)
168
+
169
+ if self.ignore_repeated_ngrams:
170
+ # counts a green/red hit once per unique ngram.
171
+ # num total tokens scored becomes the number unique ngrams.
172
+ num_tokens_scored_batch[batch_idx] = len(frequencies_table.keys())
173
+ green_token_count_batch[batch_idx] = sum(ngram_to_watermark_lookup.values())
174
+ else:
175
+ num_tokens_scored_batch[batch_idx] = sum(frequencies_table.values())
176
+ green_token_count_batch[batch_idx] = sum(
177
+ freq * outcome
178
+ for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values())
179
+ )
180
+ return num_tokens_scored_batch, green_token_count_batch
181
+
182
+ def _compute_z_score(self, green_token_count: np.array, total_num_tokens: np.array) -> np.array:
183
+ expected_count = self.greenlist_ratio
184
+ numer = green_token_count - expected_count * total_num_tokens
185
+ denom = np.sqrt(total_num_tokens * expected_count * (1 - expected_count))
186
+ z = numer / denom
187
+ return z
188
+
189
+ def _compute_pval(self, x, loc=0, scale=1):
190
+ z = (x - loc) / scale
191
+ return 1 - (0.5 * (1 + np.sign(z) * (1 - np.exp(-2 * z**2 / np.pi))))
192
+
193
+ def __call__(
194
+ self,
195
+ input_ids: torch.LongTensor,
196
+ z_threshold: float = 3.0,
197
+ return_dict: bool = False,
198
+ ) -> Union[WatermarkDetectorOutput, np.array]:
199
+ """
200
+ Args:
201
+ input_ids (`torch.LongTensor`):
202
+ The watermark generated text. It is advised to remove the prompt, which can affect the detection.
203
+ z_threshold (`Dict`, *optional*, defaults to `3.0`):
204
+ Changing this threshold will change the sensitivity of the detector. Higher z threshold gives less
205
+ sensitivity and vice versa for lower z threshold.
206
+ return_dict (`bool`, *optional*, defaults to `False`):
207
+ Whether to return `~generation.WatermarkDetectorOutput` or not. If not it will return boolean predictions,
208
+ ma
209
+ Return:
210
+ [`~generation.WatermarkDetectorOutput`] or `np.array`: A [`~generation.WatermarkDetectorOutput`]
211
+ if `return_dict=True` otherwise a `np.array`.
212
+
213
+ """
214
+
215
+ # Let's assume that if one batch start with `bos`, all batched also do
216
+ if input_ids[0, 0] == self.bos_token_id:
217
+ input_ids = input_ids[:, 1:]
218
+
219
+ if input_ids.shape[-1] - self.processor.context_width < 1:
220
+ raise ValueError(
221
+ f"Must have at least `1` token to score after the first "
222
+ f"min_prefix_len={self.processor.context_width} tokens required by the seeding scheme."
223
+ )
224
+
225
+ num_tokens_scored, green_token_count = self._score_ngrams_in_passage(input_ids)
226
+ z_score = self._compute_z_score(green_token_count, num_tokens_scored)
227
+ prediction = z_score > z_threshold
228
+
229
+ if return_dict:
230
+ p_value = self._compute_pval(z_score)
231
+ confidence = 1 - p_value
232
+
233
+ return WatermarkDetectorOutput(
234
+ num_tokens_scored=num_tokens_scored,
235
+ num_green_tokens=green_token_count,
236
+ green_fraction=green_token_count / num_tokens_scored,
237
+ z_score=z_score,
238
+ p_value=p_value,
239
+ prediction=prediction,
240
+ confidence=confidence,
241
+ )
242
+ return prediction
243
+
244
+
245
+ class BayesianDetectorConfig(PretrainedConfig):
246
+ """
247
+ This is the configuration class to store the configuration of a [`BayesianDetectorModel`]. It is used to
248
+ instantiate a Bayesian Detector model according to the specified arguments.
249
+
250
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
251
+ documentation from [`PretrainedConfig`] for more information.
252
+
253
+ Args:
254
+ watermarking_depth (`int`, *optional*):
255
+ The number of tournament layers.
256
+ base_rate (`float1`, *optional*, defaults to 0.5):
257
+ Prior probability P(w) that a text is watermarked.
258
+ """
259
+
260
+ def __init__(self, watermarking_depth: int = None, base_rate: float = 0.5, **kwargs):
261
+ self.watermarking_depth = watermarking_depth
262
+ self.base_rate = base_rate
263
+ # These can be set later to store information about this detector.
264
+ self.model_name = None
265
+ self.watermarking_config = None
266
+
267
+ super().__init__(**kwargs)
268
+
269
+ def set_detector_information(self, model_name, watermarking_config):
270
+ self.model_name = model_name
271
+ self.watermarking_config = watermarking_config
272
+
273
+
274
+ @dataclass
275
+ class BayesianWatermarkDetectorModelOutput(ModelOutput):
276
+ """
277
+ Base class for outputs of models predicting if the text is watermarked.
278
+
279
+ Args:
280
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
281
+ Language modeling loss.
282
+ posterior_probabilities (`torch.FloatTensor` of shape `(1,)`):
283
+ Multiple choice classification loss.
284
+ """
285
+
286
+ loss: Optional[torch.FloatTensor] = None
287
+ posterior_probabilities: Optional[torch.FloatTensor] = None
288
+
289
+
290
+ class BayesianDetectorWatermarkedLikelihood(nn.Module):
291
+ """Watermarked likelihood model for binary-valued g-values.
292
+
293
+ This takes in g-values and returns p(g_values|watermarked).
294
+ """
295
+
296
+ def __init__(self, watermarking_depth: int):
297
+ """Initializes the model parameters."""
298
+ super().__init__()
299
+ self.watermarking_depth = watermarking_depth
300
+ self.beta = torch.nn.Parameter(-2.5 + 0.001 * torch.randn(1, 1, watermarking_depth))
301
+ self.delta = torch.nn.Parameter(0.001 * torch.randn(1, 1, self.watermarking_depth, watermarking_depth))
302
+
303
+ def _compute_latents(self, g_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
304
+ """Computes the unique token probability distribution given g-values.
305
+
306
+ Args:
307
+ g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`):
308
+ PRF values.
309
+
310
+ Returns:
311
+ p_one_unique_token and p_two_unique_tokens, both of shape
312
+ [batch_size, seq_len, watermarking_depth]. p_one_unique_token[i,t,l]
313
+ gives the probability of there being one unique token in a tournament
314
+ match on layer l, on timestep t, for batch item i.
315
+ p_one_unique_token[i,t,l] + p_two_unique_token[i,t,l] = 1.
316
+ """
317
+ # Tile g-values to produce feature vectors for predicting the latents
318
+ # for each layer in the tournament; our model for the latents psi is a
319
+ # logistic regression model psi = sigmoid(delta * x + beta).
320
+
321
+ # [batch_size, seq_len, watermarking_depth, watermarking_depth]
322
+ x = torch.repeat_interleave(torch.unsqueeze(g_values, dim=-2), self.watermarking_depth, axis=-2)
323
+
324
+ # mask all elements above -1 diagonal for autoregressive factorization
325
+ x = torch.tril(x, diagonal=-1)
326
+
327
+ # [batch_size, seq_len, watermarking_depth]
328
+ # (i, j, k, l) x (i, j, k, l) -> (i, j, k) einsum equivalent
329
+ logits = (self.delta[..., None, :] @ x.type(self.delta.dtype)[..., None]).squeeze() + self.beta
330
+
331
+ p_two_unique_tokens = torch.sigmoid(logits)
332
+ p_one_unique_token = 1 - p_two_unique_tokens
333
+ return p_one_unique_token, p_two_unique_tokens
334
+
335
+ def forward(self, g_values: torch.Tensor) -> torch.Tensor:
336
+ """Computes the likelihoods P(g_values|watermarked).
337
+
338
+ Args:
339
+ g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`):
340
+ g-values (values 0 or 1)
341
+
342
+ Returns:
343
+ p(g_values|watermarked) of shape [batch_size, seq_len, watermarking_depth].
344
+ """
345
+ p_one_unique_token, p_two_unique_tokens = self._compute_latents(g_values)
346
+
347
+ # P(g_tl | watermarked) is equal to
348
+ # 0.5 * [ (g_tl+0.5) * p_two_unique_tokens + p_one_unique_token].
349
+ return 0.5 * ((g_values + 0.5) * p_two_unique_tokens + p_one_unique_token)
350
+
351
+
352
+ class BayesianDetectorModel(PreTrainedModel):
353
+ r"""
354
+ Bayesian classifier for watermark detection.
355
+
356
+ This detector uses Bayes' rule to compute a watermarking score, which is the sigmoid of the log of ratio of the
357
+ posterior probabilities P(watermarked|g_values) and P(unwatermarked|g_values). Please see the section on
358
+ BayesianScore in the paper for further details.
359
+ Paper URL: https://www.nature.com/articles/s41586-024-08025-4
360
+
361
+ Note that this detector only works with non-distortionary Tournament-based watermarking using the Bernoulli(0.5)
362
+ g-value distribution.
363
+
364
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
365
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
366
+ etc.)
367
+
368
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
369
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
370
+ and behavior.
371
+
372
+ Parameters:
373
+ config ([`BayesianDetectorConfig`]): Model configuration class with all the parameters of the model.
374
+ Initializing with a config file does not load the weights associated with the model, only the
375
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
376
+ """
377
+
378
+ config_class = BayesianDetectorConfig
379
+ base_model_prefix = "model"
380
+
381
+ def __init__(self, config):
382
+ super().__init__(config)
383
+
384
+ self.watermarking_depth = config.watermarking_depth
385
+ self.base_rate = config.base_rate
386
+ self.likelihood_model_watermarked = BayesianDetectorWatermarkedLikelihood(
387
+ watermarking_depth=self.watermarking_depth
388
+ )
389
+ self.prior = torch.nn.Parameter(torch.tensor([self.base_rate]))
390
+
391
+ def _init_weights(self, module):
392
+ """Initialize the weights."""
393
+ if isinstance(module, nn.Parameter):
394
+ module.weight.data.normal_(mean=0.0, std=0.02)
395
+
396
+ def _compute_posterior(
397
+ self,
398
+ likelihoods_watermarked: torch.Tensor,
399
+ likelihoods_unwatermarked: torch.Tensor,
400
+ mask: torch.Tensor,
401
+ prior: float,
402
+ ) -> torch.Tensor:
403
+ """
404
+ Compute posterior P(w|g) given likelihoods, mask and prior.
405
+
406
+ Args:
407
+ likelihoods_watermarked (`torch.Tensor` of shape `(batch, length, depth)`):
408
+ Likelihoods P(g_values|watermarked) of g-values under watermarked model.
409
+ likelihoods_unwatermarked (`torch.Tensor` of shape `(batch, length, depth)`):
410
+ Likelihoods P(g_values|unwatermarked) of g-values under unwatermarked model.
411
+ mask (`torch.Tensor` of shape `(batch, length)`):
412
+ A binary array indicating which g-values should be used. g-values with mask value 0 are discarded.
413
+ prior (`float`):
414
+ the prior probability P(w) that the text is watermarked.
415
+
416
+ Returns:
417
+ Posterior probability P(watermarked|g_values), shape [batch].
418
+ """
419
+ mask = torch.unsqueeze(mask, dim=-1)
420
+ prior = torch.clamp(prior, min=1e-5, max=1 - 1e-5)
421
+ log_likelihoods_watermarked = torch.log(torch.clamp(likelihoods_watermarked, min=1e-30, max=float("inf")))
422
+ log_likelihoods_unwatermarked = torch.log(torch.clamp(likelihoods_unwatermarked, min=1e-30, max=float("inf")))
423
+ log_odds = log_likelihoods_watermarked - log_likelihoods_unwatermarked
424
+
425
+ # Sum relative surprisals (log odds) across all token positions and layers.
426
+ relative_surprisal_likelihood = torch.einsum("i...->i", log_odds * mask)
427
+
428
+ # Compute the relative surprisal prior
429
+ relative_surprisal_prior = torch.log(prior) - torch.log(1 - prior)
430
+
431
+ # Combine prior and likelihood.
432
+ # [batch_size]
433
+ relative_surprisal = relative_surprisal_prior + relative_surprisal_likelihood
434
+
435
+ # Compute the posterior probability P(w|g) = sigmoid(relative_surprisal).
436
+ return torch.sigmoid(relative_surprisal)
437
+
438
+ def forward(
439
+ self,
440
+ g_values: torch.Tensor,
441
+ mask: torch.Tensor,
442
+ labels: Optional[torch.Tensor] = None,
443
+ loss_batch_weight=1,
444
+ return_dict=False,
445
+ ) -> BayesianWatermarkDetectorModelOutput:
446
+ """
447
+ Computes the watermarked posterior P(watermarked|g_values).
448
+
449
+ Args:
450
+ g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth, ...)`):
451
+ g-values (with values 0 or 1)
452
+ mask:
453
+ A binary array shape [batch_size, seq_len] indicating which g-values should be used. g-values with mask
454
+ value 0 are discarded.
455
+
456
+ Returns:
457
+ p(watermarked | g_values), of shape [batch_size].
458
+ """
459
+
460
+ likelihoods_watermarked = self.likelihood_model_watermarked(g_values)
461
+ likelihoods_unwatermarked = 0.5 * torch.ones_like(g_values)
462
+ out = self._compute_posterior(
463
+ likelihoods_watermarked=likelihoods_watermarked,
464
+ likelihoods_unwatermarked=likelihoods_unwatermarked,
465
+ mask=mask,
466
+ prior=self.prior,
467
+ )
468
+
469
+ loss = None
470
+ if labels is not None:
471
+ loss_fct = BCELoss()
472
+ loss_unwweight = torch.sum(self.likelihood_model_watermarked.delta**2)
473
+ loss_weight = loss_unwweight * loss_batch_weight
474
+ loss = loss_fct(torch.clamp(out, 1e-5, 1 - 1e-5), labels) + loss_weight
475
+
476
+ if not return_dict:
477
+ return (out,) if loss is None else (out, loss)
478
+
479
+ return BayesianWatermarkDetectorModelOutput(loss=loss, posterior_probabilities=out)
480
+
481
+
482
+ class SynthIDTextWatermarkDetector:
483
+ r"""
484
+ SynthID text watermark detector class.
485
+
486
+ This class has to be initialized with the trained bayesian detector module check script
487
+ in examples/synthid_text/detector_training.py for example in training/saving/loading this
488
+ detector module. The folder also showcases example use case of this detector.
489
+
490
+ Parameters:
491
+ detector_module ([`BayesianDetectorModel`]):
492
+ Bayesian detector module object initialized with parameters.
493
+ Check examples/research_projects/synthid_text/detector_training.py for usage.
494
+ logits_processor (`SynthIDTextWatermarkLogitsProcessor`):
495
+ The logits processor used for watermarking.
496
+ tokenizer (`Any`):
497
+ The tokenizer used for the model.
498
+
499
+ Examples:
500
+ ```python
501
+ >>> from transformers import (
502
+ ... AutoTokenizer, BayesianDetectorModel, SynthIDTextWatermarkLogitsProcessor, SynthIDTextWatermarkDetector
503
+ ... )
504
+
505
+ >>> # Load the detector. See examples/research_projects/synthid_text for training a detector.
506
+ >>> detector_model = BayesianDetectorModel.from_pretrained("joaogante/dummy_synthid_detector")
507
+ >>> logits_processor = SynthIDTextWatermarkLogitsProcessor(
508
+ ... **detector_model.config.watermarking_config, device="cpu"
509
+ ... )
510
+ >>> tokenizer = AutoTokenizer.from_pretrained(detector_model.config.model_name)
511
+ >>> detector = SynthIDTextWatermarkDetector(detector_model, logits_processor, tokenizer)
512
+
513
+ >>> # Test whether a certain string is watermarked
514
+ >>> test_input = tokenizer(["This is a test input"], return_tensors="pt")
515
+ >>> is_watermarked = detector(test_input.input_ids)
516
+ ```
517
+ """
518
+
519
+ def __init__(
520
+ self,
521
+ detector_module: BayesianDetectorModel,
522
+ logits_processor: SynthIDTextWatermarkLogitsProcessor,
523
+ tokenizer: Any,
524
+ ):
525
+ self.detector_module = detector_module
526
+ self.logits_processor = logits_processor
527
+ self.tokenizer = tokenizer
528
+
529
+ def __call__(self, tokenized_outputs: torch.Tensor):
530
+ # eos mask is computed, skip first ngram_len - 1 tokens
531
+ # eos_mask will be of shape [batch_size, output_len]
532
+ eos_token_mask = self.logits_processor.compute_eos_token_mask(
533
+ input_ids=tokenized_outputs,
534
+ eos_token_id=self.tokenizer.eos_token_id,
535
+ )[:, self.logits_processor.ngram_len - 1 :]
536
+
537
+ # context repetition mask is computed
538
+ context_repetition_mask = self.logits_processor.compute_context_repetition_mask(
539
+ input_ids=tokenized_outputs,
540
+ )
541
+ # context repitition mask shape [batch_size, output_len - (ngram_len - 1)]
542
+
543
+ combined_mask = context_repetition_mask * eos_token_mask
544
+
545
+ g_values = self.logits_processor.compute_g_values(
546
+ input_ids=tokenized_outputs,
547
+ )
548
+ # g values shape [batch_size, output_len - (ngram_len - 1), depth]
549
+ return self.detector_module(g_values, combined_mask)