Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__pycache__/more.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/transformers/data/__init__.py +45 -0
- .venv/lib/python3.11/site-packages/transformers/data/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/__pycache__/data_collator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/data_collator.py +1656 -0
- .venv/lib/python3.11/site-packages/transformers/data/datasets/__init__.py +23 -0
- .venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/glue.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/language_modeling.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/squad.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/datasets/glue.py +161 -0
- .venv/lib/python3.11/site-packages/transformers/data/datasets/language_modeling.py +530 -0
- .venv/lib/python3.11/site-packages/transformers/data/datasets/squad.py +229 -0
- .venv/lib/python3.11/site-packages/transformers/data/metrics/__init__.py +98 -0
- .venv/lib/python3.11/site-packages/transformers/data/metrics/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/metrics/__pycache__/squad_metrics.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/metrics/squad_metrics.py +779 -0
- .venv/lib/python3.11/site-packages/transformers/data/processors/__init__.py +18 -0
- .venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/glue.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/squad.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/xnli.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/data/processors/glue.py +643 -0
- .venv/lib/python3.11/site-packages/transformers/data/processors/squad.py +845 -0
- .venv/lib/python3.11/site-packages/transformers/data/processors/utils.py +349 -0
- .venv/lib/python3.11/site-packages/transformers/data/processors/xnli.py +96 -0
- .venv/lib/python3.11/site-packages/transformers/generation/__init__.py +352 -0
- .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/flax_logits_process.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/flax_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/streamers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/tf_logits_process.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/watermarking.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/beam_constraints.py +524 -0
- .venv/lib/python3.11/site-packages/transformers/generation/beam_search.py +1013 -0
- .venv/lib/python3.11/site-packages/transformers/generation/candidate_generator.py +871 -0
- .venv/lib/python3.11/site-packages/transformers/generation/configuration_utils.py +1628 -0
- .venv/lib/python3.11/site-packages/transformers/generation/flax_logits_process.py +544 -0
- .venv/lib/python3.11/site-packages/transformers/generation/flax_utils.py +1027 -0
- .venv/lib/python3.11/site-packages/transformers/generation/logits_process.py +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/stopping_criteria.py +514 -0
- .venv/lib/python3.11/site-packages/transformers/generation/streamers.py +318 -0
- .venv/lib/python3.11/site-packages/transformers/generation/tf_logits_process.py +603 -0
- .venv/lib/python3.11/site-packages/transformers/generation/tf_utils.py +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/utils.py +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/watermarking.py +549 -0
.gitattributes
CHANGED
|
@@ -421,3 +421,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
|
|
| 421 |
.venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 422 |
.venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 423 |
.venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 421 |
.venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 422 |
.venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 423 |
.venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 424 |
+
.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__pycache__/more.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/pkg_resources/_vendor/more_itertools/__pycache__/more.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2915b1a6eb03bcc7f50087391a236782ed9422d5e5eb9ed0a937ca85ce5e73a
|
| 3 |
+
size 167956
|
.venv/lib/python3.11/site-packages/transformers/data/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .data_collator import (
|
| 16 |
+
DataCollatorForLanguageModeling,
|
| 17 |
+
DataCollatorForPermutationLanguageModeling,
|
| 18 |
+
DataCollatorForSeq2Seq,
|
| 19 |
+
DataCollatorForSOP,
|
| 20 |
+
DataCollatorForTokenClassification,
|
| 21 |
+
DataCollatorForWholeWordMask,
|
| 22 |
+
DataCollatorWithFlattening,
|
| 23 |
+
DataCollatorWithPadding,
|
| 24 |
+
DefaultDataCollator,
|
| 25 |
+
default_data_collator,
|
| 26 |
+
)
|
| 27 |
+
from .metrics import glue_compute_metrics, xnli_compute_metrics
|
| 28 |
+
from .processors import (
|
| 29 |
+
DataProcessor,
|
| 30 |
+
InputExample,
|
| 31 |
+
InputFeatures,
|
| 32 |
+
SingleSentenceClassificationProcessor,
|
| 33 |
+
SquadExample,
|
| 34 |
+
SquadFeatures,
|
| 35 |
+
SquadV1Processor,
|
| 36 |
+
SquadV2Processor,
|
| 37 |
+
glue_convert_examples_to_features,
|
| 38 |
+
glue_output_modes,
|
| 39 |
+
glue_processors,
|
| 40 |
+
glue_tasks_num_labels,
|
| 41 |
+
squad_convert_examples_to_features,
|
| 42 |
+
xnli_output_modes,
|
| 43 |
+
xnli_processors,
|
| 44 |
+
xnli_tasks_num_labels,
|
| 45 |
+
)
|
.venv/lib/python3.11/site-packages/transformers/data/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/__pycache__/data_collator.cpython-311.pyc
ADDED
|
Binary file (94.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/data_collator.py
ADDED
|
@@ -0,0 +1,1656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import random
|
| 16 |
+
import warnings
|
| 17 |
+
from collections.abc import Mapping
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from random import randint
|
| 20 |
+
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
from ..models.bert import BertTokenizer, BertTokenizerFast
|
| 25 |
+
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
| 26 |
+
from ..utils import PaddingStrategy
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
InputDataClass = NewType("InputDataClass", Any)
|
| 30 |
+
|
| 31 |
+
"""
|
| 32 |
+
A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary
|
| 33 |
+
of PyTorch/TensorFlow tensors or NumPy arrays.
|
| 34 |
+
"""
|
| 35 |
+
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, Any]])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class DataCollatorMixin:
|
| 39 |
+
def __call__(self, features, return_tensors=None):
|
| 40 |
+
if return_tensors is None:
|
| 41 |
+
return_tensors = self.return_tensors
|
| 42 |
+
if return_tensors == "tf":
|
| 43 |
+
return self.tf_call(features)
|
| 44 |
+
elif return_tensors == "pt":
|
| 45 |
+
return self.torch_call(features)
|
| 46 |
+
elif return_tensors == "np":
|
| 47 |
+
return self.numpy_call(features)
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError(f"Framework '{return_tensors}' not recognized!")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
|
| 53 |
+
"""
|
| 54 |
+
Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
# To avoid errors when using Feature extractors
|
| 58 |
+
if not hasattr(tokenizer, "deprecation_warnings"):
|
| 59 |
+
return tokenizer.pad(*pad_args, **pad_kwargs)
|
| 60 |
+
|
| 61 |
+
# Save the state of the warning, then disable it
|
| 62 |
+
warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
|
| 63 |
+
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
padded = tokenizer.pad(*pad_args, **pad_kwargs)
|
| 67 |
+
finally:
|
| 68 |
+
# Restore the state of the warning.
|
| 69 |
+
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state
|
| 70 |
+
|
| 71 |
+
return padded
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def default_data_collator(features: List[InputDataClass], return_tensors="pt") -> Dict[str, Any]:
|
| 75 |
+
"""
|
| 76 |
+
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
|
| 77 |
+
potential keys named:
|
| 78 |
+
|
| 79 |
+
- `label`: handles a single value (int or float) per object
|
| 80 |
+
- `label_ids`: handles a list of values per object
|
| 81 |
+
|
| 82 |
+
Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
|
| 83 |
+
to the model. See glue and ner for example of how it's useful.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
# In this function we'll make the assumption that all `features` in the batch
|
| 87 |
+
# have the same attributes.
|
| 88 |
+
# So we will look at the first element as a proxy for what attributes exist
|
| 89 |
+
# on the whole batch.
|
| 90 |
+
|
| 91 |
+
if return_tensors == "pt":
|
| 92 |
+
return torch_default_data_collator(features)
|
| 93 |
+
elif return_tensors == "tf":
|
| 94 |
+
return tf_default_data_collator(features)
|
| 95 |
+
elif return_tensors == "np":
|
| 96 |
+
return numpy_default_data_collator(features)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@dataclass
|
| 100 |
+
class DefaultDataCollator(DataCollatorMixin):
|
| 101 |
+
"""
|
| 102 |
+
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
|
| 103 |
+
potential keys named:
|
| 104 |
+
|
| 105 |
+
- `label`: handles a single value (int or float) per object
|
| 106 |
+
- `label_ids`: handles a list of values per object
|
| 107 |
+
|
| 108 |
+
Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
|
| 109 |
+
to the model. See glue and ner for example of how it's useful.
|
| 110 |
+
|
| 111 |
+
This is an object (like other data collators) rather than a pure function like default_data_collator. This can be
|
| 112 |
+
helpful if you need to set a return_tensors value at initialization.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
| 116 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
return_tensors: str = "pt"
|
| 120 |
+
|
| 121 |
+
def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, Any]:
|
| 122 |
+
if return_tensors is None:
|
| 123 |
+
return_tensors = self.return_tensors
|
| 124 |
+
return default_data_collator(features, return_tensors)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
| 128 |
+
import torch
|
| 129 |
+
|
| 130 |
+
if not isinstance(features[0], Mapping):
|
| 131 |
+
features = [vars(f) for f in features]
|
| 132 |
+
first = features[0]
|
| 133 |
+
batch = {}
|
| 134 |
+
|
| 135 |
+
# Special handling for labels.
|
| 136 |
+
# Ensure that tensor is created with the correct type
|
| 137 |
+
# (it should be automatically the case, but let's make sure of it.)
|
| 138 |
+
if "label" in first and first["label"] is not None:
|
| 139 |
+
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
|
| 140 |
+
dtype = torch.long if isinstance(label, int) else torch.float
|
| 141 |
+
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
| 142 |
+
elif "label_ids" in first and first["label_ids"] is not None:
|
| 143 |
+
if isinstance(first["label_ids"], torch.Tensor):
|
| 144 |
+
batch["labels"] = torch.stack([f["label_ids"] for f in features])
|
| 145 |
+
else:
|
| 146 |
+
dtype = torch.long if isinstance(first["label_ids"][0], int) else torch.float
|
| 147 |
+
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
|
| 148 |
+
|
| 149 |
+
# Handling of all other possible keys.
|
| 150 |
+
# Again, we will use the first element to figure out which key/values are not None for this model.
|
| 151 |
+
for k, v in first.items():
|
| 152 |
+
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
| 153 |
+
if isinstance(v, torch.Tensor):
|
| 154 |
+
batch[k] = torch.stack([f[k] for f in features])
|
| 155 |
+
elif isinstance(v, np.ndarray):
|
| 156 |
+
batch[k] = torch.from_numpy(np.stack([f[k] for f in features]))
|
| 157 |
+
else:
|
| 158 |
+
batch[k] = torch.tensor([f[k] for f in features])
|
| 159 |
+
|
| 160 |
+
return batch
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
| 164 |
+
import tensorflow as tf
|
| 165 |
+
|
| 166 |
+
if not isinstance(features[0], Mapping):
|
| 167 |
+
features = [vars(f) for f in features]
|
| 168 |
+
first = features[0]
|
| 169 |
+
batch = {}
|
| 170 |
+
|
| 171 |
+
# Special handling for labels.
|
| 172 |
+
# Ensure that tensor is created with the correct type
|
| 173 |
+
# (it should be automatically the case, but let's make sure of it.)
|
| 174 |
+
if "label" in first and first["label"] is not None:
|
| 175 |
+
label_col_name = "label"
|
| 176 |
+
elif "label_ids" in first and first["label_ids"] is not None:
|
| 177 |
+
label_col_name = "label_ids"
|
| 178 |
+
elif "labels" in first and first["labels"] is not None:
|
| 179 |
+
label_col_name = "labels"
|
| 180 |
+
else:
|
| 181 |
+
label_col_name = None
|
| 182 |
+
if label_col_name is not None:
|
| 183 |
+
if isinstance(first[label_col_name], tf.Tensor):
|
| 184 |
+
dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32
|
| 185 |
+
elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
|
| 186 |
+
dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
|
| 187 |
+
elif isinstance(first[label_col_name], (tuple, list)):
|
| 188 |
+
dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32
|
| 189 |
+
else:
|
| 190 |
+
dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32
|
| 191 |
+
batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype)
|
| 192 |
+
# Handling of all other possible keys.
|
| 193 |
+
# Again, we will use the first element to figure out which key/values are not None for this model.
|
| 194 |
+
for k, v in first.items():
|
| 195 |
+
if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str):
|
| 196 |
+
if isinstance(v, (tf.Tensor, np.ndarray)):
|
| 197 |
+
batch[k] = tf.stack([f[k] for f in features])
|
| 198 |
+
else:
|
| 199 |
+
batch[k] = tf.convert_to_tensor([f[k] for f in features])
|
| 200 |
+
|
| 201 |
+
return batch
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
| 205 |
+
if not isinstance(features[0], Mapping):
|
| 206 |
+
features = [vars(f) for f in features]
|
| 207 |
+
first = features[0]
|
| 208 |
+
batch = {}
|
| 209 |
+
|
| 210 |
+
# Special handling for labels.
|
| 211 |
+
# Ensure that tensor is created with the correct type
|
| 212 |
+
# (it should be automatically the case, but let's make sure of it.)
|
| 213 |
+
if "label" in first and first["label"] is not None:
|
| 214 |
+
label = first["label"].item() if isinstance(first["label"], np.ndarray) else first["label"]
|
| 215 |
+
dtype = np.int64 if isinstance(label, int) else np.float32
|
| 216 |
+
batch["labels"] = np.array([f["label"] for f in features], dtype=dtype)
|
| 217 |
+
elif "label_ids" in first and first["label_ids"] is not None:
|
| 218 |
+
if isinstance(first["label_ids"], np.ndarray):
|
| 219 |
+
batch["labels"] = np.stack([f["label_ids"] for f in features])
|
| 220 |
+
else:
|
| 221 |
+
dtype = np.int64 if isinstance(first["label_ids"][0], int) else np.float32
|
| 222 |
+
batch["labels"] = np.array([f["label_ids"] for f in features], dtype=dtype)
|
| 223 |
+
|
| 224 |
+
# Handling of all other possible keys.
|
| 225 |
+
# Again, we will use the first element to figure out which key/values are not None for this model.
|
| 226 |
+
for k, v in first.items():
|
| 227 |
+
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
| 228 |
+
if isinstance(v, np.ndarray):
|
| 229 |
+
batch[k] = np.stack([f[k] for f in features])
|
| 230 |
+
else:
|
| 231 |
+
batch[k] = np.array([f[k] for f in features])
|
| 232 |
+
|
| 233 |
+
return batch
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@dataclass
|
| 237 |
+
class DataCollatorWithPadding:
|
| 238 |
+
"""
|
| 239 |
+
Data collator that will dynamically pad the inputs received.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
| 243 |
+
The tokenizer used for encoding the data.
|
| 244 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 245 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
| 246 |
+
among:
|
| 247 |
+
|
| 248 |
+
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
| 249 |
+
sequence is provided).
|
| 250 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 251 |
+
acceptable input length for the model if that argument is not provided.
|
| 252 |
+
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
| 253 |
+
max_length (`int`, *optional*):
|
| 254 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 255 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 256 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 257 |
+
|
| 258 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
| 259 |
+
7.0 (Volta).
|
| 260 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
| 261 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
tokenizer: PreTrainedTokenizerBase
|
| 265 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
| 266 |
+
max_length: Optional[int] = None
|
| 267 |
+
pad_to_multiple_of: Optional[int] = None
|
| 268 |
+
return_tensors: str = "pt"
|
| 269 |
+
|
| 270 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 271 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 272 |
+
self.tokenizer,
|
| 273 |
+
features,
|
| 274 |
+
padding=self.padding,
|
| 275 |
+
max_length=self.max_length,
|
| 276 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 277 |
+
return_tensors=self.return_tensors,
|
| 278 |
+
)
|
| 279 |
+
if "label" in batch:
|
| 280 |
+
batch["labels"] = batch["label"]
|
| 281 |
+
del batch["label"]
|
| 282 |
+
if "label_ids" in batch:
|
| 283 |
+
batch["labels"] = batch["label_ids"]
|
| 284 |
+
del batch["label_ids"]
|
| 285 |
+
return batch
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
@dataclass
|
| 289 |
+
class DataCollatorForTokenClassification(DataCollatorMixin):
|
| 290 |
+
"""
|
| 291 |
+
Data collator that will dynamically pad the inputs received, as well as the labels.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
| 295 |
+
The tokenizer used for encoding the data.
|
| 296 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 297 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
| 298 |
+
among:
|
| 299 |
+
|
| 300 |
+
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
| 301 |
+
sequence is provided).
|
| 302 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 303 |
+
acceptable input length for the model if that argument is not provided.
|
| 304 |
+
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
| 305 |
+
max_length (`int`, *optional*):
|
| 306 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 307 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 308 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 309 |
+
|
| 310 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
| 311 |
+
7.0 (Volta).
|
| 312 |
+
label_pad_token_id (`int`, *optional*, defaults to -100):
|
| 313 |
+
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
| 314 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
| 315 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
tokenizer: PreTrainedTokenizerBase
|
| 319 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
| 320 |
+
max_length: Optional[int] = None
|
| 321 |
+
pad_to_multiple_of: Optional[int] = None
|
| 322 |
+
label_pad_token_id: int = -100
|
| 323 |
+
return_tensors: str = "pt"
|
| 324 |
+
|
| 325 |
+
def torch_call(self, features):
|
| 326 |
+
import torch
|
| 327 |
+
|
| 328 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
| 329 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
| 330 |
+
|
| 331 |
+
no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
|
| 332 |
+
|
| 333 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 334 |
+
self.tokenizer,
|
| 335 |
+
no_labels_features,
|
| 336 |
+
padding=self.padding,
|
| 337 |
+
max_length=self.max_length,
|
| 338 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 339 |
+
return_tensors="pt",
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
if labels is None:
|
| 343 |
+
return batch
|
| 344 |
+
|
| 345 |
+
sequence_length = batch["input_ids"].shape[1]
|
| 346 |
+
padding_side = self.tokenizer.padding_side
|
| 347 |
+
|
| 348 |
+
def to_list(tensor_or_iterable):
|
| 349 |
+
if isinstance(tensor_or_iterable, torch.Tensor):
|
| 350 |
+
return tensor_or_iterable.tolist()
|
| 351 |
+
return list(tensor_or_iterable)
|
| 352 |
+
|
| 353 |
+
if padding_side == "right":
|
| 354 |
+
batch[label_name] = [
|
| 355 |
+
to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
| 356 |
+
]
|
| 357 |
+
else:
|
| 358 |
+
batch[label_name] = [
|
| 359 |
+
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
|
| 360 |
+
]
|
| 361 |
+
|
| 362 |
+
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
|
| 363 |
+
return batch
|
| 364 |
+
|
| 365 |
+
def tf_call(self, features):
|
| 366 |
+
import tensorflow as tf
|
| 367 |
+
|
| 368 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
| 369 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
| 370 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 371 |
+
self.tokenizer,
|
| 372 |
+
features,
|
| 373 |
+
padding=self.padding,
|
| 374 |
+
max_length=self.max_length,
|
| 375 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 376 |
+
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
|
| 377 |
+
return_tensors="tf" if labels is None else None,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
if labels is None:
|
| 381 |
+
return batch
|
| 382 |
+
|
| 383 |
+
sequence_length = tf.convert_to_tensor(batch["input_ids"]).shape[1]
|
| 384 |
+
padding_side = self.tokenizer.padding_side
|
| 385 |
+
if padding_side == "right":
|
| 386 |
+
batch["labels"] = [
|
| 387 |
+
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
| 388 |
+
]
|
| 389 |
+
else:
|
| 390 |
+
batch["labels"] = [
|
| 391 |
+
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
|
| 392 |
+
]
|
| 393 |
+
|
| 394 |
+
batch = {k: tf.convert_to_tensor(v, dtype=tf.int64) for k, v in batch.items()}
|
| 395 |
+
return batch
|
| 396 |
+
|
| 397 |
+
def numpy_call(self, features):
|
| 398 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
| 399 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
| 400 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 401 |
+
self.tokenizer,
|
| 402 |
+
features,
|
| 403 |
+
padding=self.padding,
|
| 404 |
+
max_length=self.max_length,
|
| 405 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 406 |
+
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
|
| 407 |
+
return_tensors="np" if labels is None else None,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
if labels is None:
|
| 411 |
+
return batch
|
| 412 |
+
|
| 413 |
+
sequence_length = np.array(batch["input_ids"]).shape[1]
|
| 414 |
+
padding_side = self.tokenizer.padding_side
|
| 415 |
+
if padding_side == "right":
|
| 416 |
+
batch["labels"] = [
|
| 417 |
+
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
| 418 |
+
]
|
| 419 |
+
else:
|
| 420 |
+
batch["labels"] = [
|
| 421 |
+
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
|
| 422 |
+
]
|
| 423 |
+
|
| 424 |
+
batch = {k: np.array(v, dtype=np.int64) for k, v in batch.items()}
|
| 425 |
+
return batch
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
| 429 |
+
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
| 430 |
+
import torch
|
| 431 |
+
|
| 432 |
+
# Tensorize if necessary.
|
| 433 |
+
if isinstance(examples[0], (list, tuple, np.ndarray)):
|
| 434 |
+
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
|
| 435 |
+
|
| 436 |
+
length_of_first = examples[0].size(0)
|
| 437 |
+
|
| 438 |
+
# Check if padding is necessary.
|
| 439 |
+
|
| 440 |
+
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
| 441 |
+
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
| 442 |
+
if not isinstance(examples, torch.Tensor):
|
| 443 |
+
return torch.stack(examples, dim=0)
|
| 444 |
+
|
| 445 |
+
# If yes, check if we have a `pad_token`.
|
| 446 |
+
if tokenizer.pad_token is None:
|
| 447 |
+
raise ValueError(
|
| 448 |
+
"You are attempting to pad samples but the tokenizer you are using"
|
| 449 |
+
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Creating the full tensor and filling it with our data.
|
| 453 |
+
max_length = max(x.size(0) for x in examples)
|
| 454 |
+
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
| 455 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 456 |
+
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
| 457 |
+
for i, example in enumerate(examples):
|
| 458 |
+
if tokenizer.padding_side == "right":
|
| 459 |
+
result[i, : example.shape[0]] = example
|
| 460 |
+
else:
|
| 461 |
+
result[i, -example.shape[0] :] = example
|
| 462 |
+
return result
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def _tf_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
| 466 |
+
import tensorflow as tf
|
| 467 |
+
|
| 468 |
+
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
| 469 |
+
# Tensorize if necessary.
|
| 470 |
+
if isinstance(examples[0], (list, tuple)):
|
| 471 |
+
examples = [tf.convert_to_tensor(e, dtype=tf.int64) for e in examples]
|
| 472 |
+
|
| 473 |
+
# Check if padding is necessary.
|
| 474 |
+
length_of_first = len(examples[0])
|
| 475 |
+
are_tensors_same_length = all(len(x) == length_of_first for x in examples)
|
| 476 |
+
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
| 477 |
+
return tf.stack(examples, axis=0)
|
| 478 |
+
|
| 479 |
+
# If yes, check if we have a `pad_token`.
|
| 480 |
+
if tokenizer.pad_token is None:
|
| 481 |
+
raise ValueError(
|
| 482 |
+
"You are attempting to pad samples but the tokenizer you are using"
|
| 483 |
+
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# Creating the full tensor and filling it with our data.
|
| 487 |
+
max_length = max(len(x) for x in examples)
|
| 488 |
+
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
| 489 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 490 |
+
# result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
| 491 |
+
result = []
|
| 492 |
+
rank = tf.rank(examples[0])
|
| 493 |
+
paddings = np.zeros((rank, 2), dtype=np.int32)
|
| 494 |
+
for example in examples:
|
| 495 |
+
if tokenizer.padding_side == "right":
|
| 496 |
+
paddings[0, 1] = max_length - len(example)
|
| 497 |
+
else:
|
| 498 |
+
paddings[0, 0] = max_length - len(example)
|
| 499 |
+
result.append(tf.pad(example, paddings, constant_values=tokenizer.pad_token_id))
|
| 500 |
+
return tf.stack(result, axis=0)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
| 504 |
+
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
| 505 |
+
# Tensorize if necessary.
|
| 506 |
+
if isinstance(examples[0], (list, tuple)):
|
| 507 |
+
examples = [np.array(e, dtype=np.int64) for e in examples]
|
| 508 |
+
|
| 509 |
+
# Check if padding is necessary.
|
| 510 |
+
length_of_first = len(examples[0])
|
| 511 |
+
are_tensors_same_length = all(len(x) == length_of_first for x in examples)
|
| 512 |
+
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
| 513 |
+
return np.stack(examples, axis=0)
|
| 514 |
+
|
| 515 |
+
# If yes, check if we have a `pad_token`.
|
| 516 |
+
if tokenizer.pad_token is None:
|
| 517 |
+
raise ValueError(
|
| 518 |
+
"You are attempting to pad samples but the tokenizer you are using"
|
| 519 |
+
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# Creating the full tensor and filling it with our data.
|
| 523 |
+
max_length = max(len(x) for x in examples)
|
| 524 |
+
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
| 525 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 526 |
+
result = np.full(shape=(len(examples), max_length), fill_value=tokenizer.pad_token_id, dtype=examples[0].dtype)
|
| 527 |
+
for i, example in enumerate(examples):
|
| 528 |
+
if tokenizer.padding_side == "right":
|
| 529 |
+
result[i, : example.shape[0]] = example
|
| 530 |
+
else:
|
| 531 |
+
result[i, -example.shape[0] :] = example
|
| 532 |
+
return result
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def tolist(x):
|
| 536 |
+
if isinstance(x, list):
|
| 537 |
+
return x
|
| 538 |
+
elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
|
| 539 |
+
x = x.numpy()
|
| 540 |
+
return x.tolist()
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
@dataclass
|
| 544 |
+
class DataCollatorForSeq2Seq:
|
| 545 |
+
"""
|
| 546 |
+
Data collator that will dynamically pad the inputs received, as well as the labels.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
| 550 |
+
The tokenizer used for encoding the data.
|
| 551 |
+
model ([`PreTrainedModel`], *optional*):
|
| 552 |
+
The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
|
| 553 |
+
prepare the *decoder_input_ids*
|
| 554 |
+
|
| 555 |
+
This is useful when using *label_smoothing* to avoid calculating loss twice.
|
| 556 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 557 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
| 558 |
+
among:
|
| 559 |
+
|
| 560 |
+
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
| 561 |
+
sequence is provided).
|
| 562 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 563 |
+
acceptable input length for the model if that argument is not provided.
|
| 564 |
+
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
| 565 |
+
max_length (`int`, *optional*):
|
| 566 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 567 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 568 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 569 |
+
|
| 570 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
| 571 |
+
7.0 (Volta).
|
| 572 |
+
label_pad_token_id (`int`, *optional*, defaults to -100):
|
| 573 |
+
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
| 574 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
| 575 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
tokenizer: PreTrainedTokenizerBase
|
| 579 |
+
model: Optional[Any] = None
|
| 580 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
| 581 |
+
max_length: Optional[int] = None
|
| 582 |
+
pad_to_multiple_of: Optional[int] = None
|
| 583 |
+
label_pad_token_id: int = -100
|
| 584 |
+
return_tensors: str = "pt"
|
| 585 |
+
|
| 586 |
+
def __call__(self, features, return_tensors=None):
|
| 587 |
+
if return_tensors is None:
|
| 588 |
+
return_tensors = self.return_tensors
|
| 589 |
+
|
| 590 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
| 591 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
| 592 |
+
# reconvert list[None] to None if necessary
|
| 593 |
+
# this might occur when we pass {..., "labels": None}
|
| 594 |
+
if labels is not None and all(label is None for label in labels):
|
| 595 |
+
labels = None
|
| 596 |
+
non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
|
| 597 |
+
|
| 598 |
+
# run through tokenizer without labels to ensure no side effects
|
| 599 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 600 |
+
self.tokenizer,
|
| 601 |
+
non_labels_features,
|
| 602 |
+
padding=self.padding,
|
| 603 |
+
max_length=self.max_length,
|
| 604 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 605 |
+
return_tensors=return_tensors,
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
# we have to pad the labels manually as we cannot rely on `tokenizer.pad` and we need them to be of the same length to return tensors
|
| 609 |
+
no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD
|
| 610 |
+
if labels is not None:
|
| 611 |
+
if no_padding:
|
| 612 |
+
if isinstance(features[0][label_name], list):
|
| 613 |
+
batch["labels"] = list(labels)
|
| 614 |
+
else:
|
| 615 |
+
batch["labels"] = [np.concatenate([label, []]) for label in labels]
|
| 616 |
+
else:
|
| 617 |
+
max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
|
| 618 |
+
max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
|
| 619 |
+
if self.pad_to_multiple_of is not None:
|
| 620 |
+
max_label_length = (
|
| 621 |
+
(max_label_length + self.pad_to_multiple_of - 1)
|
| 622 |
+
// self.pad_to_multiple_of
|
| 623 |
+
* self.pad_to_multiple_of
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
padding_side = self.tokenizer.padding_side
|
| 627 |
+
if isinstance(features[0][label_name], list):
|
| 628 |
+
batch["labels"] = [
|
| 629 |
+
label + [self.label_pad_token_id] * (max_label_length - len(label))
|
| 630 |
+
if padding_side == "right"
|
| 631 |
+
else [self.label_pad_token_id] * (max_label_length - len(label)) + label
|
| 632 |
+
for label in labels
|
| 633 |
+
]
|
| 634 |
+
else:
|
| 635 |
+
batch["labels"] = [
|
| 636 |
+
np.concatenate(
|
| 637 |
+
[
|
| 638 |
+
label,
|
| 639 |
+
np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
|
| 640 |
+
]
|
| 641 |
+
)
|
| 642 |
+
if padding_side == "right"
|
| 643 |
+
else np.concatenate(
|
| 644 |
+
[
|
| 645 |
+
np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
|
| 646 |
+
label,
|
| 647 |
+
]
|
| 648 |
+
)
|
| 649 |
+
for label in labels
|
| 650 |
+
]
|
| 651 |
+
|
| 652 |
+
# reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument
|
| 653 |
+
if batch.get("labels", None) is not None:
|
| 654 |
+
if return_tensors == "pt":
|
| 655 |
+
import torch
|
| 656 |
+
|
| 657 |
+
batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
|
| 658 |
+
elif return_tensors == "tf":
|
| 659 |
+
import tensorflow as tf
|
| 660 |
+
|
| 661 |
+
batch["labels"] = tf.constant(batch["labels"], dtype=tf.int64)
|
| 662 |
+
else:
|
| 663 |
+
batch["labels"] = np.array(batch["labels"], dtype=np.int64)
|
| 664 |
+
else:
|
| 665 |
+
batch["labels"] = None
|
| 666 |
+
|
| 667 |
+
# prepare decoder_input_ids
|
| 668 |
+
if (
|
| 669 |
+
labels is not None
|
| 670 |
+
and self.model is not None
|
| 671 |
+
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
| 672 |
+
):
|
| 673 |
+
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
|
| 674 |
+
batch["decoder_input_ids"] = decoder_input_ids
|
| 675 |
+
|
| 676 |
+
return batch
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
@dataclass
|
| 680 |
+
class DataCollatorForLanguageModeling(DataCollatorMixin):
|
| 681 |
+
"""
|
| 682 |
+
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
| 683 |
+
are not all of the same length.
|
| 684 |
+
|
| 685 |
+
Args:
|
| 686 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
| 687 |
+
The tokenizer used for encoding the data.
|
| 688 |
+
mlm (`bool`, *optional*, defaults to `True`):
|
| 689 |
+
Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
|
| 690 |
+
with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
|
| 691 |
+
tokens and the value to predict for the masked token.
|
| 692 |
+
mlm_probability (`float`, *optional*, defaults to 0.15):
|
| 693 |
+
The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
|
| 694 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 695 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 696 |
+
|
| 697 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
| 698 |
+
7.0 (Volta).
|
| 699 |
+
return_tensors (`str`):
|
| 700 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
| 701 |
+
|
| 702 |
+
<Tip>
|
| 703 |
+
|
| 704 |
+
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
| 705 |
+
BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
|
| 706 |
+
[`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
|
| 707 |
+
|
| 708 |
+
</Tip>"""
|
| 709 |
+
|
| 710 |
+
tokenizer: PreTrainedTokenizerBase
|
| 711 |
+
mlm: bool = True
|
| 712 |
+
mlm_probability: float = 0.15
|
| 713 |
+
pad_to_multiple_of: Optional[int] = None
|
| 714 |
+
tf_experimental_compile: bool = False
|
| 715 |
+
return_tensors: str = "pt"
|
| 716 |
+
|
| 717 |
+
def __post_init__(self):
|
| 718 |
+
if self.mlm and self.tokenizer.mask_token is None:
|
| 719 |
+
raise ValueError(
|
| 720 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
| 721 |
+
"You should pass `mlm=False` to train on causal language modeling instead."
|
| 722 |
+
)
|
| 723 |
+
if self.tf_experimental_compile:
|
| 724 |
+
import tensorflow as tf
|
| 725 |
+
|
| 726 |
+
self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True)
|
| 727 |
+
|
| 728 |
+
@staticmethod
|
| 729 |
+
def tf_bernoulli(shape, probability):
|
| 730 |
+
import tensorflow as tf
|
| 731 |
+
|
| 732 |
+
prob_matrix = tf.fill(shape, probability)
|
| 733 |
+
return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
|
| 734 |
+
|
| 735 |
+
def tf_mask_tokens(
|
| 736 |
+
self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None
|
| 737 |
+
) -> Tuple[Any, Any]:
|
| 738 |
+
"""
|
| 739 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
| 740 |
+
"""
|
| 741 |
+
import tensorflow as tf
|
| 742 |
+
|
| 743 |
+
mask_token_id = tf.cast(mask_token_id, inputs.dtype)
|
| 744 |
+
|
| 745 |
+
input_shape = tf.shape(inputs)
|
| 746 |
+
# 1 for a special token, 0 for a normal token in the special tokens mask
|
| 747 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
| 748 |
+
masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability) & ~special_tokens_mask
|
| 749 |
+
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
|
| 750 |
+
labels = tf.where(masked_indices, inputs, -100)
|
| 751 |
+
|
| 752 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 753 |
+
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
|
| 754 |
+
|
| 755 |
+
inputs = tf.where(indices_replaced, mask_token_id, inputs)
|
| 756 |
+
|
| 757 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 758 |
+
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
|
| 759 |
+
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
|
| 760 |
+
|
| 761 |
+
inputs = tf.where(indices_random, random_words, inputs)
|
| 762 |
+
|
| 763 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 764 |
+
return inputs, labels
|
| 765 |
+
|
| 766 |
+
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 767 |
+
import tensorflow as tf
|
| 768 |
+
|
| 769 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
| 770 |
+
if isinstance(examples[0], Mapping):
|
| 771 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 772 |
+
self.tokenizer, examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of
|
| 773 |
+
)
|
| 774 |
+
else:
|
| 775 |
+
batch = {
|
| 776 |
+
"input_ids": _tf_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 777 |
+
}
|
| 778 |
+
|
| 779 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
| 780 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
| 781 |
+
if self.mlm:
|
| 782 |
+
if special_tokens_mask is None:
|
| 783 |
+
special_tokens_mask = [
|
| 784 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
|
| 785 |
+
for val in batch["input_ids"].numpy().tolist()
|
| 786 |
+
]
|
| 787 |
+
# Cannot directly create as bool
|
| 788 |
+
special_tokens_mask = tf.cast(tf.convert_to_tensor(special_tokens_mask, dtype=tf.int64), tf.bool)
|
| 789 |
+
else:
|
| 790 |
+
special_tokens_mask = tf.cast(special_tokens_mask, tf.bool)
|
| 791 |
+
batch["input_ids"], batch["labels"] = self.tf_mask_tokens(
|
| 792 |
+
tf.cast(batch["input_ids"], tf.int64),
|
| 793 |
+
special_tokens_mask=special_tokens_mask,
|
| 794 |
+
mask_token_id=self.tokenizer.mask_token_id,
|
| 795 |
+
vocab_size=len(self.tokenizer),
|
| 796 |
+
)
|
| 797 |
+
else:
|
| 798 |
+
labels = batch["input_ids"]
|
| 799 |
+
if self.tokenizer.pad_token_id is not None:
|
| 800 |
+
# Replace self.tokenizer.pad_token_id with -100
|
| 801 |
+
labels = tf.where(labels == self.tokenizer.pad_token_id, -100, labels)
|
| 802 |
+
else:
|
| 803 |
+
labels = tf.identity(labels) # Makes a copy, just in case
|
| 804 |
+
batch["labels"] = labels
|
| 805 |
+
return batch
|
| 806 |
+
|
| 807 |
+
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 808 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
| 809 |
+
if isinstance(examples[0], Mapping):
|
| 810 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 811 |
+
self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
|
| 812 |
+
)
|
| 813 |
+
else:
|
| 814 |
+
batch = {
|
| 815 |
+
"input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 816 |
+
}
|
| 817 |
+
|
| 818 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
| 819 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
| 820 |
+
if self.mlm:
|
| 821 |
+
batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
|
| 822 |
+
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
| 823 |
+
)
|
| 824 |
+
else:
|
| 825 |
+
labels = batch["input_ids"].clone()
|
| 826 |
+
if self.tokenizer.pad_token_id is not None:
|
| 827 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 828 |
+
batch["labels"] = labels
|
| 829 |
+
return batch
|
| 830 |
+
|
| 831 |
+
def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
|
| 832 |
+
"""
|
| 833 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
| 834 |
+
"""
|
| 835 |
+
import torch
|
| 836 |
+
|
| 837 |
+
labels = inputs.clone()
|
| 838 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
| 839 |
+
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
| 840 |
+
if special_tokens_mask is None:
|
| 841 |
+
special_tokens_mask = [
|
| 842 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 843 |
+
]
|
| 844 |
+
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
| 845 |
+
else:
|
| 846 |
+
special_tokens_mask = special_tokens_mask.bool()
|
| 847 |
+
|
| 848 |
+
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
| 849 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
| 850 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 851 |
+
|
| 852 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 853 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
| 854 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
| 855 |
+
|
| 856 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 857 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 858 |
+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
| 859 |
+
inputs[indices_random] = random_words[indices_random]
|
| 860 |
+
|
| 861 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 862 |
+
return inputs, labels
|
| 863 |
+
|
| 864 |
+
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 865 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
| 866 |
+
if isinstance(examples[0], Mapping):
|
| 867 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 868 |
+
self.tokenizer, examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of
|
| 869 |
+
)
|
| 870 |
+
else:
|
| 871 |
+
batch = {
|
| 872 |
+
"input_ids": _numpy_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 873 |
+
}
|
| 874 |
+
|
| 875 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
| 876 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
| 877 |
+
if self.mlm:
|
| 878 |
+
batch["input_ids"], batch["labels"] = self.numpy_mask_tokens(
|
| 879 |
+
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
| 880 |
+
)
|
| 881 |
+
else:
|
| 882 |
+
labels = np.copy(batch["input_ids"])
|
| 883 |
+
if self.tokenizer.pad_token_id is not None:
|
| 884 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 885 |
+
batch["labels"] = labels
|
| 886 |
+
return batch
|
| 887 |
+
|
| 888 |
+
def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
|
| 889 |
+
"""
|
| 890 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
| 891 |
+
"""
|
| 892 |
+
labels = np.copy(inputs)
|
| 893 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
| 894 |
+
probability_matrix = np.full(labels.shape, self.mlm_probability)
|
| 895 |
+
if special_tokens_mask is None:
|
| 896 |
+
special_tokens_mask = [
|
| 897 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 898 |
+
]
|
| 899 |
+
special_tokens_mask = np.array(special_tokens_mask, dtype=bool)
|
| 900 |
+
else:
|
| 901 |
+
special_tokens_mask = special_tokens_mask.astype(bool)
|
| 902 |
+
|
| 903 |
+
probability_matrix[special_tokens_mask] = 0
|
| 904 |
+
# Numpy doesn't have bernoulli, so we use a binomial with 1 trial
|
| 905 |
+
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
|
| 906 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 907 |
+
|
| 908 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 909 |
+
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
|
| 910 |
+
inputs[indices_replaced] = self.tokenizer.mask_token_id
|
| 911 |
+
|
| 912 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 913 |
+
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 914 |
+
indices_random = (
|
| 915 |
+
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
|
| 916 |
+
)
|
| 917 |
+
random_words = np.random.randint(
|
| 918 |
+
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
|
| 919 |
+
)
|
| 920 |
+
inputs[indices_random] = random_words
|
| 921 |
+
|
| 922 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 923 |
+
return inputs, labels
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
@dataclass
|
| 927 |
+
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
| 928 |
+
"""
|
| 929 |
+
Data collator used for language modeling that masks entire words.
|
| 930 |
+
|
| 931 |
+
- collates batches of tensors, honoring their tokenizer's pad_token
|
| 932 |
+
- preprocesses batches for masked language modeling
|
| 933 |
+
|
| 934 |
+
<Tip>
|
| 935 |
+
|
| 936 |
+
This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically
|
| 937 |
+
that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will
|
| 938 |
+
produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`].
|
| 939 |
+
|
| 940 |
+
</Tip>"""
|
| 941 |
+
|
| 942 |
+
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 943 |
+
if isinstance(examples[0], Mapping):
|
| 944 |
+
input_ids = [e["input_ids"] for e in examples]
|
| 945 |
+
else:
|
| 946 |
+
input_ids = examples
|
| 947 |
+
examples = [{"input_ids": e} for e in examples]
|
| 948 |
+
|
| 949 |
+
batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 950 |
+
|
| 951 |
+
mask_labels = []
|
| 952 |
+
for e in examples:
|
| 953 |
+
ref_tokens = []
|
| 954 |
+
for id in tolist(e["input_ids"]):
|
| 955 |
+
token = self.tokenizer._convert_id_to_token(id)
|
| 956 |
+
ref_tokens.append(token)
|
| 957 |
+
|
| 958 |
+
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
| 959 |
+
if "chinese_ref" in e:
|
| 960 |
+
ref_pos = tolist(e["chinese_ref"])
|
| 961 |
+
len_seq = len(e["input_ids"])
|
| 962 |
+
for i in range(len_seq):
|
| 963 |
+
if i in ref_pos:
|
| 964 |
+
ref_tokens[i] = "##" + ref_tokens[i]
|
| 965 |
+
mask_labels.append(self._whole_word_mask(ref_tokens))
|
| 966 |
+
batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 967 |
+
inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
|
| 968 |
+
return {"input_ids": inputs, "labels": labels}
|
| 969 |
+
|
| 970 |
+
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 971 |
+
import tensorflow as tf
|
| 972 |
+
|
| 973 |
+
if isinstance(examples[0], Mapping):
|
| 974 |
+
input_ids = [e["input_ids"] for e in examples]
|
| 975 |
+
else:
|
| 976 |
+
input_ids = examples
|
| 977 |
+
examples = [{"input_ids": e} for e in examples]
|
| 978 |
+
|
| 979 |
+
batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 980 |
+
|
| 981 |
+
mask_labels = []
|
| 982 |
+
for e in examples:
|
| 983 |
+
ref_tokens = []
|
| 984 |
+
for id in tolist(e["input_ids"]):
|
| 985 |
+
token = self.tokenizer._convert_id_to_token(id)
|
| 986 |
+
ref_tokens.append(token)
|
| 987 |
+
|
| 988 |
+
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
| 989 |
+
if "chinese_ref" in e:
|
| 990 |
+
ref_pos = tolist(e["chinese_ref"])
|
| 991 |
+
len_seq = len(e["input_ids"])
|
| 992 |
+
for i in range(len_seq):
|
| 993 |
+
if i in ref_pos:
|
| 994 |
+
ref_tokens[i] = "##" + ref_tokens[i]
|
| 995 |
+
mask_labels.append(self._whole_word_mask(ref_tokens))
|
| 996 |
+
batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 997 |
+
inputs, labels = self.tf_mask_tokens(tf.cast(batch_input, tf.int64), batch_mask)
|
| 998 |
+
return {"input_ids": inputs, "labels": labels}
|
| 999 |
+
|
| 1000 |
+
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 1001 |
+
if isinstance(examples[0], Mapping):
|
| 1002 |
+
input_ids = [e["input_ids"] for e in examples]
|
| 1003 |
+
else:
|
| 1004 |
+
input_ids = examples
|
| 1005 |
+
examples = [{"input_ids": e} for e in examples]
|
| 1006 |
+
|
| 1007 |
+
batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 1008 |
+
|
| 1009 |
+
mask_labels = []
|
| 1010 |
+
for e in examples:
|
| 1011 |
+
ref_tokens = []
|
| 1012 |
+
for id in tolist(e["input_ids"]):
|
| 1013 |
+
token = self.tokenizer._convert_id_to_token(id)
|
| 1014 |
+
ref_tokens.append(token)
|
| 1015 |
+
|
| 1016 |
+
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
| 1017 |
+
if "chinese_ref" in e:
|
| 1018 |
+
ref_pos = tolist(e["chinese_ref"])
|
| 1019 |
+
len_seq = len(e["input_ids"])
|
| 1020 |
+
for i in range(len_seq):
|
| 1021 |
+
if i in ref_pos:
|
| 1022 |
+
ref_tokens[i] = "##" + ref_tokens[i]
|
| 1023 |
+
mask_labels.append(self._whole_word_mask(ref_tokens))
|
| 1024 |
+
batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 1025 |
+
inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
|
| 1026 |
+
return {"input_ids": inputs, "labels": labels}
|
| 1027 |
+
|
| 1028 |
+
def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
|
| 1029 |
+
"""
|
| 1030 |
+
Get 0/1 labels for masked tokens with whole word mask proxy
|
| 1031 |
+
"""
|
| 1032 |
+
if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
|
| 1033 |
+
warnings.warn(
|
| 1034 |
+
"DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
|
| 1035 |
+
"Please refer to the documentation for more information."
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
cand_indexes = []
|
| 1039 |
+
for i, token in enumerate(input_tokens):
|
| 1040 |
+
if token == "[CLS]" or token == "[SEP]":
|
| 1041 |
+
continue
|
| 1042 |
+
|
| 1043 |
+
if len(cand_indexes) >= 1 and token.startswith("##"):
|
| 1044 |
+
cand_indexes[-1].append(i)
|
| 1045 |
+
else:
|
| 1046 |
+
cand_indexes.append([i])
|
| 1047 |
+
|
| 1048 |
+
random.shuffle(cand_indexes)
|
| 1049 |
+
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
|
| 1050 |
+
masked_lms = []
|
| 1051 |
+
covered_indexes = set()
|
| 1052 |
+
for index_set in cand_indexes:
|
| 1053 |
+
if len(masked_lms) >= num_to_predict:
|
| 1054 |
+
break
|
| 1055 |
+
# If adding a whole-word mask would exceed the maximum number of
|
| 1056 |
+
# predictions, then just skip this candidate.
|
| 1057 |
+
if len(masked_lms) + len(index_set) > num_to_predict:
|
| 1058 |
+
continue
|
| 1059 |
+
is_any_index_covered = False
|
| 1060 |
+
for index in index_set:
|
| 1061 |
+
if index in covered_indexes:
|
| 1062 |
+
is_any_index_covered = True
|
| 1063 |
+
break
|
| 1064 |
+
if is_any_index_covered:
|
| 1065 |
+
continue
|
| 1066 |
+
for index in index_set:
|
| 1067 |
+
covered_indexes.add(index)
|
| 1068 |
+
masked_lms.append(index)
|
| 1069 |
+
|
| 1070 |
+
if len(covered_indexes) != len(masked_lms):
|
| 1071 |
+
raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
|
| 1072 |
+
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
|
| 1073 |
+
return mask_labels
|
| 1074 |
+
|
| 1075 |
+
def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
|
| 1076 |
+
"""
|
| 1077 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
| 1078 |
+
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
| 1079 |
+
"""
|
| 1080 |
+
import torch
|
| 1081 |
+
|
| 1082 |
+
if self.tokenizer.mask_token is None:
|
| 1083 |
+
raise ValueError(
|
| 1084 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
| 1085 |
+
" --mlm flag if you want to use this tokenizer."
|
| 1086 |
+
)
|
| 1087 |
+
labels = inputs.clone()
|
| 1088 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
| 1089 |
+
|
| 1090 |
+
probability_matrix = mask_labels
|
| 1091 |
+
|
| 1092 |
+
special_tokens_mask = [
|
| 1093 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 1094 |
+
]
|
| 1095 |
+
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
| 1096 |
+
if self.tokenizer.pad_token is not None:
|
| 1097 |
+
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
| 1098 |
+
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
| 1099 |
+
|
| 1100 |
+
masked_indices = probability_matrix.bool()
|
| 1101 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 1102 |
+
|
| 1103 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 1104 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
| 1105 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
| 1106 |
+
|
| 1107 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 1108 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 1109 |
+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
| 1110 |
+
inputs[indices_random] = random_words[indices_random]
|
| 1111 |
+
|
| 1112 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 1113 |
+
return inputs, labels
|
| 1114 |
+
|
| 1115 |
+
def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
|
| 1116 |
+
"""
|
| 1117 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
| 1118 |
+
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
| 1119 |
+
"""
|
| 1120 |
+
import tensorflow as tf
|
| 1121 |
+
|
| 1122 |
+
input_shape = tf.shape(inputs)
|
| 1123 |
+
if self.tokenizer.mask_token is None:
|
| 1124 |
+
raise ValueError(
|
| 1125 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
| 1126 |
+
" --mlm flag if you want to use this tokenizer."
|
| 1127 |
+
)
|
| 1128 |
+
labels = tf.identity(inputs)
|
| 1129 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
| 1130 |
+
|
| 1131 |
+
masked_indices = tf.cast(mask_labels, tf.bool)
|
| 1132 |
+
|
| 1133 |
+
special_tokens_mask = [
|
| 1134 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
|
| 1135 |
+
]
|
| 1136 |
+
masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
|
| 1137 |
+
if self.tokenizer.pad_token is not None:
|
| 1138 |
+
padding_mask = inputs == self.tokenizer.pad_token_id
|
| 1139 |
+
masked_indices = masked_indices & ~padding_mask
|
| 1140 |
+
|
| 1141 |
+
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
|
| 1142 |
+
labels = tf.where(masked_indices, inputs, -100)
|
| 1143 |
+
|
| 1144 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 1145 |
+
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
|
| 1146 |
+
|
| 1147 |
+
inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)
|
| 1148 |
+
|
| 1149 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 1150 |
+
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
|
| 1151 |
+
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
|
| 1152 |
+
inputs = tf.where(indices_random, random_words, inputs)
|
| 1153 |
+
|
| 1154 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 1155 |
+
return inputs, labels
|
| 1156 |
+
|
| 1157 |
+
def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
|
| 1158 |
+
"""
|
| 1159 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
| 1160 |
+
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
| 1161 |
+
"""
|
| 1162 |
+
if self.tokenizer.mask_token is None:
|
| 1163 |
+
raise ValueError(
|
| 1164 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
| 1165 |
+
" --mlm flag if you want to use this tokenizer."
|
| 1166 |
+
)
|
| 1167 |
+
labels = np.copy(inputs)
|
| 1168 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
| 1169 |
+
|
| 1170 |
+
masked_indices = mask_labels.astype(bool)
|
| 1171 |
+
|
| 1172 |
+
special_tokens_mask = [
|
| 1173 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 1174 |
+
]
|
| 1175 |
+
masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
|
| 1176 |
+
if self.tokenizer.pad_token is not None:
|
| 1177 |
+
padding_mask = labels == self.tokenizer.pad_token_id
|
| 1178 |
+
masked_indices[padding_mask] = 0
|
| 1179 |
+
|
| 1180 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 1181 |
+
|
| 1182 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 1183 |
+
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
|
| 1184 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
| 1185 |
+
|
| 1186 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 1187 |
+
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 1188 |
+
indices_random = (
|
| 1189 |
+
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
|
| 1190 |
+
)
|
| 1191 |
+
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
|
| 1192 |
+
inputs[indices_random] = random_words[indices_random]
|
| 1193 |
+
|
| 1194 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 1195 |
+
return inputs, labels
|
| 1196 |
+
|
| 1197 |
+
|
| 1198 |
+
@dataclass
|
| 1199 |
+
class DataCollatorForSOP(DataCollatorForLanguageModeling):
|
| 1200 |
+
"""
|
| 1201 |
+
Data collator used for sentence order prediction task.
|
| 1202 |
+
|
| 1203 |
+
- collates batches of tensors, honoring their tokenizer's pad_token
|
| 1204 |
+
- preprocesses batches for both masked language modeling and sentence order prediction
|
| 1205 |
+
"""
|
| 1206 |
+
|
| 1207 |
+
def __init__(self, *args, **kwargs):
|
| 1208 |
+
warnings.warn(
|
| 1209 |
+
"DataCollatorForSOP is deprecated and will be removed in a future version, you can now use "
|
| 1210 |
+
"DataCollatorForLanguageModeling instead.",
|
| 1211 |
+
FutureWarning,
|
| 1212 |
+
)
|
| 1213 |
+
|
| 1214 |
+
def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 1215 |
+
import torch
|
| 1216 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 1217 |
+
|
| 1218 |
+
input_ids = [example["input_ids"] for example in examples]
|
| 1219 |
+
input_ids = _torch_collate_batch(input_ids, self.tokenizer)
|
| 1220 |
+
input_ids, labels, attention_mask = self.mask_tokens(input_ids)
|
| 1221 |
+
|
| 1222 |
+
token_type_ids = [example["token_type_ids"] for example in examples]
|
| 1223 |
+
# size of segment_ids varied because randomness, padding zero to the end as the original implementation
|
| 1224 |
+
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
| 1225 |
+
|
| 1226 |
+
sop_label_list = [example["sentence_order_label"] for example in examples]
|
| 1227 |
+
sentence_order_label = torch.stack(sop_label_list)
|
| 1228 |
+
|
| 1229 |
+
return {
|
| 1230 |
+
"input_ids": input_ids,
|
| 1231 |
+
"labels": labels,
|
| 1232 |
+
"attention_mask": attention_mask,
|
| 1233 |
+
"token_type_ids": token_type_ids,
|
| 1234 |
+
"sentence_order_label": sentence_order_label,
|
| 1235 |
+
}
|
| 1236 |
+
|
| 1237 |
+
def mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any]:
|
| 1238 |
+
"""
|
| 1239 |
+
Prepare masked tokens inputs/labels/attention_mask for masked language modeling: 80% MASK, 10% random, 10%
|
| 1240 |
+
original. N-gram not applied yet.
|
| 1241 |
+
"""
|
| 1242 |
+
import torch
|
| 1243 |
+
|
| 1244 |
+
if self.tokenizer.mask_token is None:
|
| 1245 |
+
raise ValueError(
|
| 1246 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
| 1247 |
+
" --mlm flag if you want to use this tokenizer."
|
| 1248 |
+
)
|
| 1249 |
+
|
| 1250 |
+
labels = inputs.clone()
|
| 1251 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
| 1252 |
+
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
| 1253 |
+
special_tokens_mask = [
|
| 1254 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 1255 |
+
]
|
| 1256 |
+
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
| 1257 |
+
if self.tokenizer.pad_token is not None:
|
| 1258 |
+
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
| 1259 |
+
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
| 1260 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
| 1261 |
+
# probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
|
| 1262 |
+
attention_mask = (~masked_indices).float()
|
| 1263 |
+
if self.tokenizer.pad_token is not None:
|
| 1264 |
+
attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
| 1265 |
+
attention_mask.masked_fill_(attention_padding_mask, value=1.0)
|
| 1266 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
|
| 1267 |
+
|
| 1268 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 1269 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
| 1270 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
| 1271 |
+
|
| 1272 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 1273 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 1274 |
+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
| 1275 |
+
inputs[indices_random] = random_words[indices_random]
|
| 1276 |
+
|
| 1277 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 1278 |
+
return inputs, labels, attention_mask
|
| 1279 |
+
|
| 1280 |
+
|
| 1281 |
+
@dataclass
|
| 1282 |
+
class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
|
| 1283 |
+
"""
|
| 1284 |
+
Data collator used for permutation language modeling.
|
| 1285 |
+
|
| 1286 |
+
- collates batches of tensors, honoring their tokenizer's pad_token
|
| 1287 |
+
- preprocesses batches for permutation language modeling with procedures specific to XLNet
|
| 1288 |
+
"""
|
| 1289 |
+
|
| 1290 |
+
tokenizer: PreTrainedTokenizerBase
|
| 1291 |
+
plm_probability: float = 1 / 6
|
| 1292 |
+
max_span_length: int = 5 # maximum length of a span of masked tokens
|
| 1293 |
+
return_tensors: str = "pt"
|
| 1294 |
+
|
| 1295 |
+
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 1296 |
+
if isinstance(examples[0], Mapping):
|
| 1297 |
+
examples = [e["input_ids"] for e in examples]
|
| 1298 |
+
batch = _torch_collate_batch(examples, self.tokenizer)
|
| 1299 |
+
inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
|
| 1300 |
+
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
| 1301 |
+
|
| 1302 |
+
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 1303 |
+
if isinstance(examples[0], Mapping):
|
| 1304 |
+
examples = [e["input_ids"] for e in examples]
|
| 1305 |
+
batch = _tf_collate_batch(examples, self.tokenizer)
|
| 1306 |
+
inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch)
|
| 1307 |
+
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
| 1308 |
+
|
| 1309 |
+
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 1310 |
+
if isinstance(examples[0], Mapping):
|
| 1311 |
+
examples = [e["input_ids"] for e in examples]
|
| 1312 |
+
batch = _numpy_collate_batch(examples, self.tokenizer)
|
| 1313 |
+
inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
|
| 1314 |
+
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
| 1315 |
+
|
| 1316 |
+
def torch_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
|
| 1317 |
+
"""
|
| 1318 |
+
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
| 1319 |
+
|
| 1320 |
+
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
| 1321 |
+
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
| 1322 |
+
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
|
| 1323 |
+
masked
|
| 1324 |
+
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
|
| 1325 |
+
span_length]` and mask tokens `start_index:start_index + span_length`
|
| 1326 |
+
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
|
| 1327 |
+
sequence to be processed), repeat from Step 1.
|
| 1328 |
+
"""
|
| 1329 |
+
import torch
|
| 1330 |
+
|
| 1331 |
+
if self.tokenizer.mask_token is None:
|
| 1332 |
+
raise ValueError(
|
| 1333 |
+
"This tokenizer does not have a mask token which is necessary for permutation language modeling."
|
| 1334 |
+
" Please add a mask token if you want to use this tokenizer."
|
| 1335 |
+
)
|
| 1336 |
+
|
| 1337 |
+
if inputs.size(1) % 2 != 0:
|
| 1338 |
+
raise ValueError(
|
| 1339 |
+
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
|
| 1340 |
+
" relevant comments in source code for details."
|
| 1341 |
+
)
|
| 1342 |
+
|
| 1343 |
+
labels = inputs.clone()
|
| 1344 |
+
# Creating the mask and target_mapping tensors
|
| 1345 |
+
masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
|
| 1346 |
+
target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
| 1347 |
+
|
| 1348 |
+
for i in range(labels.size(0)):
|
| 1349 |
+
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
| 1350 |
+
cur_len = 0
|
| 1351 |
+
max_len = labels.size(1)
|
| 1352 |
+
|
| 1353 |
+
while cur_len < max_len:
|
| 1354 |
+
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
| 1355 |
+
span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
|
| 1356 |
+
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
| 1357 |
+
context_length = int(span_length / self.plm_probability)
|
| 1358 |
+
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
| 1359 |
+
start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
|
| 1360 |
+
masked_indices[i, start_index : start_index + span_length] = 1
|
| 1361 |
+
# Set `cur_len = cur_len + context_length`
|
| 1362 |
+
cur_len += context_length
|
| 1363 |
+
|
| 1364 |
+
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
| 1365 |
+
# the i-th predict corresponds to the i-th token.
|
| 1366 |
+
target_mapping[i] = torch.eye(labels.size(1))
|
| 1367 |
+
|
| 1368 |
+
special_tokens_mask = torch.tensor(
|
| 1369 |
+
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
|
| 1370 |
+
dtype=torch.bool,
|
| 1371 |
+
)
|
| 1372 |
+
masked_indices.masked_fill_(special_tokens_mask, value=0.0)
|
| 1373 |
+
if self.tokenizer.pad_token is not None:
|
| 1374 |
+
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
| 1375 |
+
masked_indices.masked_fill_(padding_mask, value=0.0)
|
| 1376 |
+
|
| 1377 |
+
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
| 1378 |
+
non_func_mask = ~(padding_mask | special_tokens_mask)
|
| 1379 |
+
|
| 1380 |
+
inputs[masked_indices] = self.tokenizer.mask_token_id
|
| 1381 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 1382 |
+
|
| 1383 |
+
perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
| 1384 |
+
|
| 1385 |
+
for i in range(labels.size(0)):
|
| 1386 |
+
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
| 1387 |
+
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
| 1388 |
+
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
| 1389 |
+
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
| 1390 |
+
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
| 1391 |
+
# This requires that the sequence length be even.
|
| 1392 |
+
|
| 1393 |
+
# Create a linear factorisation order
|
| 1394 |
+
perm_index = torch.arange(labels.size(1))
|
| 1395 |
+
# Split this into two halves, assuming that half the sequence is reused each time
|
| 1396 |
+
perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
|
| 1397 |
+
# Permute the two halves such that they do not cross over
|
| 1398 |
+
perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
|
| 1399 |
+
# Flatten this out into the desired permuted factorisation order
|
| 1400 |
+
perm_index = torch.flatten(perm_index.transpose(0, 1))
|
| 1401 |
+
# Set the permutation indices of non-masked (non-functional) tokens to the
|
| 1402 |
+
# smallest index (-1) so that:
|
| 1403 |
+
# (1) They can be seen by all other positions
|
| 1404 |
+
# (2) They cannot see masked positions, so there won't be information leak
|
| 1405 |
+
perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
|
| 1406 |
+
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
| 1407 |
+
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
| 1408 |
+
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
| 1409 |
+
perm_mask[i] = (
|
| 1410 |
+
perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
|
| 1411 |
+
) & masked_indices[i]
|
| 1412 |
+
|
| 1413 |
+
return inputs.long(), perm_mask, target_mapping, labels.long()
|
| 1414 |
+
|
| 1415 |
+
def tf_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
|
| 1416 |
+
"""
|
| 1417 |
+
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
| 1418 |
+
|
| 1419 |
+
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
| 1420 |
+
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
| 1421 |
+
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
|
| 1422 |
+
masked
|
| 1423 |
+
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
|
| 1424 |
+
span_length]` and mask tokens `start_index:start_index + span_length`
|
| 1425 |
+
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
|
| 1426 |
+
sequence to be processed), repeat from Step 1.
|
| 1427 |
+
"""
|
| 1428 |
+
import tensorflow as tf
|
| 1429 |
+
|
| 1430 |
+
if self.tokenizer.mask_token is None:
|
| 1431 |
+
raise ValueError(
|
| 1432 |
+
"This tokenizer does not have a mask token which is necessary for permutation language modeling."
|
| 1433 |
+
" Please add a mask token if you want to use this tokenizer."
|
| 1434 |
+
)
|
| 1435 |
+
|
| 1436 |
+
if tf.shape(inputs)[1] % 2 != 0:
|
| 1437 |
+
raise ValueError(
|
| 1438 |
+
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
|
| 1439 |
+
" relevant comments in source code for details."
|
| 1440 |
+
)
|
| 1441 |
+
|
| 1442 |
+
labels = tf.identity(inputs)
|
| 1443 |
+
# Creating the mask and target_mapping tensors
|
| 1444 |
+
masked_indices = np.full(labels.shape.as_list(), 0, dtype=bool)
|
| 1445 |
+
labels_shape = tf.shape(labels)
|
| 1446 |
+
target_mapping = np.zeros((labels_shape[0], labels_shape[1], labels_shape[1]), dtype=np.float32)
|
| 1447 |
+
|
| 1448 |
+
for i in range(len(labels)):
|
| 1449 |
+
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
| 1450 |
+
cur_len = 0
|
| 1451 |
+
max_len = tf.shape(labels)[1]
|
| 1452 |
+
|
| 1453 |
+
while cur_len < max_len:
|
| 1454 |
+
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
| 1455 |
+
span_length = randint(1, self.max_span_length + 1)
|
| 1456 |
+
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
| 1457 |
+
context_length = int(span_length / self.plm_probability)
|
| 1458 |
+
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
| 1459 |
+
start_index = cur_len + randint(0, context_length - span_length + 1)
|
| 1460 |
+
masked_indices[i, start_index : start_index + span_length] = 1
|
| 1461 |
+
# Set `cur_len = cur_len + context_length`
|
| 1462 |
+
cur_len += context_length
|
| 1463 |
+
|
| 1464 |
+
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
| 1465 |
+
# the i-th predict corresponds to the i-th token.
|
| 1466 |
+
target_mapping[i] = np.eye(labels_shape[1])
|
| 1467 |
+
masked_indices = tf.cast(tf.convert_to_tensor(masked_indices), dtype=tf.bool)
|
| 1468 |
+
target_mapping = tf.convert_to_tensor(target_mapping)
|
| 1469 |
+
special_tokens_mask = tf.convert_to_tensor(
|
| 1470 |
+
[
|
| 1471 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
|
| 1472 |
+
for val in labels.numpy().tolist()
|
| 1473 |
+
],
|
| 1474 |
+
)
|
| 1475 |
+
special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool)
|
| 1476 |
+
masked_indices = masked_indices & ~special_tokens_mask
|
| 1477 |
+
if self.tokenizer.pad_token is not None:
|
| 1478 |
+
padding_mask = labels == self.tokenizer.pad_token_id
|
| 1479 |
+
masked_indices = masked_indices & ~padding_mask
|
| 1480 |
+
|
| 1481 |
+
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
| 1482 |
+
non_func_mask = ~(padding_mask | special_tokens_mask)
|
| 1483 |
+
|
| 1484 |
+
inputs = tf.where(masked_indices, self.tokenizer.mask_token_id, inputs)
|
| 1485 |
+
labels = tf.where(masked_indices, labels, -100) # We only compute loss on masked tokens
|
| 1486 |
+
|
| 1487 |
+
perm_mask = []
|
| 1488 |
+
|
| 1489 |
+
for i in range(len(labels)):
|
| 1490 |
+
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
| 1491 |
+
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
| 1492 |
+
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
| 1493 |
+
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
| 1494 |
+
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
| 1495 |
+
# This requires that the sequence length be even.
|
| 1496 |
+
|
| 1497 |
+
# Create a linear factorisation order
|
| 1498 |
+
# tf.range is the equivalent of torch.arange
|
| 1499 |
+
perm_index = tf.range(labels_shape[1])
|
| 1500 |
+
# Split this into two halves, assuming that half the sequence is reused each time
|
| 1501 |
+
perm_index = tf.transpose(tf.reshape(perm_index, (-1, labels_shape[1] // 2)))
|
| 1502 |
+
# Permute the two halves such that they do not cross over
|
| 1503 |
+
perm_index = tf.random.shuffle(perm_index) # Shuffles along the first dimension
|
| 1504 |
+
# Flatten this out into the desired permuted factorisation order
|
| 1505 |
+
perm_index = tf.reshape(tf.transpose(perm_index), (-1,))
|
| 1506 |
+
# Set the permutation indices of non-masked (non-functional) tokens to the
|
| 1507 |
+
# smallest index (-1) so that:
|
| 1508 |
+
# (1) They can be seen by all other positions
|
| 1509 |
+
# (2) They cannot see masked positions, so there won't be information leak
|
| 1510 |
+
perm_index = tf.where(~masked_indices[i] & non_func_mask[i], -1, perm_index)
|
| 1511 |
+
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
| 1512 |
+
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
| 1513 |
+
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
| 1514 |
+
perm_mask.append(
|
| 1515 |
+
(tf.reshape(perm_index, (labels_shape[1], 1)) <= tf.reshape(perm_index, (1, labels_shape[1])))
|
| 1516 |
+
& masked_indices[i]
|
| 1517 |
+
)
|
| 1518 |
+
perm_mask = tf.stack(perm_mask, axis=0)
|
| 1519 |
+
|
| 1520 |
+
return tf.cast(inputs, tf.int64), tf.cast(perm_mask, tf.float32), target_mapping, tf.cast(labels, tf.int64)
|
| 1521 |
+
|
| 1522 |
+
def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
|
| 1523 |
+
"""
|
| 1524 |
+
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
| 1525 |
+
|
| 1526 |
+
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
| 1527 |
+
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
| 1528 |
+
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
|
| 1529 |
+
masked
|
| 1530 |
+
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
|
| 1531 |
+
span_length]` and mask tokens `start_index:start_index + span_length`
|
| 1532 |
+
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
|
| 1533 |
+
sequence to be processed), repeat from Step 1.
|
| 1534 |
+
"""
|
| 1535 |
+
if self.tokenizer.mask_token is None:
|
| 1536 |
+
raise ValueError(
|
| 1537 |
+
"This tokenizer does not have a mask token which is necessary for permutation language modeling."
|
| 1538 |
+
" Please add a mask token if you want to use this tokenizer."
|
| 1539 |
+
)
|
| 1540 |
+
|
| 1541 |
+
if inputs.shape[1] % 2 != 0:
|
| 1542 |
+
raise ValueError(
|
| 1543 |
+
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
|
| 1544 |
+
" relevant comments in source code for details."
|
| 1545 |
+
)
|
| 1546 |
+
|
| 1547 |
+
labels = np.copy(inputs)
|
| 1548 |
+
# Creating the mask and target_mapping tensors
|
| 1549 |
+
masked_indices = np.full(labels.shape, 0, dtype=bool)
|
| 1550 |
+
target_mapping = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
|
| 1551 |
+
|
| 1552 |
+
for i in range(labels.shape[0]):
|
| 1553 |
+
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
| 1554 |
+
cur_len = 0
|
| 1555 |
+
max_len = labels.shape[1]
|
| 1556 |
+
|
| 1557 |
+
while cur_len < max_len:
|
| 1558 |
+
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
| 1559 |
+
span_length = randint(1, self.max_span_length + 1)
|
| 1560 |
+
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
| 1561 |
+
context_length = int(span_length / self.plm_probability)
|
| 1562 |
+
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
| 1563 |
+
start_index = cur_len + randint(0, context_length - span_length + 1)
|
| 1564 |
+
masked_indices[i, start_index : start_index + span_length] = 1
|
| 1565 |
+
# Set `cur_len = cur_len + context_length`
|
| 1566 |
+
cur_len += context_length
|
| 1567 |
+
|
| 1568 |
+
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
| 1569 |
+
# the i-th predict corresponds to the i-th token.
|
| 1570 |
+
target_mapping[i] = np.eye(labels.shape[1])
|
| 1571 |
+
|
| 1572 |
+
special_tokens_mask = np.array(
|
| 1573 |
+
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
|
| 1574 |
+
dtype=bool,
|
| 1575 |
+
)
|
| 1576 |
+
masked_indices[special_tokens_mask] = 0
|
| 1577 |
+
if self.tokenizer.pad_token is not None:
|
| 1578 |
+
padding_mask = labels == self.tokenizer.pad_token_id
|
| 1579 |
+
masked_indices[padding_mask] = 0.0
|
| 1580 |
+
|
| 1581 |
+
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
| 1582 |
+
non_func_mask = ~(padding_mask | special_tokens_mask)
|
| 1583 |
+
|
| 1584 |
+
inputs[masked_indices] = self.tokenizer.mask_token_id
|
| 1585 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 1586 |
+
|
| 1587 |
+
perm_mask = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
|
| 1588 |
+
|
| 1589 |
+
for i in range(labels.shape[0]):
|
| 1590 |
+
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
| 1591 |
+
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
| 1592 |
+
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
| 1593 |
+
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
| 1594 |
+
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
| 1595 |
+
# This requires that the sequence length be even.
|
| 1596 |
+
|
| 1597 |
+
# Create a linear factorisation order
|
| 1598 |
+
perm_index = np.arange(labels.shape[1])
|
| 1599 |
+
# Split this into two halves, assuming that half the sequence is reused each time
|
| 1600 |
+
perm_index = perm_index.reshape((-1, labels.shape[1] // 2)).T
|
| 1601 |
+
# Permute the two halves such that they do not cross over
|
| 1602 |
+
np.random.shuffle(perm_index)
|
| 1603 |
+
# Flatten this out into the desired permuted factorisation order
|
| 1604 |
+
perm_index = perm_index.T.flatten()
|
| 1605 |
+
# Set the permutation indices of non-masked (non-functional) tokens to the
|
| 1606 |
+
# smallest index (-1) so that:
|
| 1607 |
+
# (1) They can be seen by all other positions
|
| 1608 |
+
# (2) They cannot see masked positions, so there won't be information leak
|
| 1609 |
+
perm_index[~masked_indices[i] & non_func_mask[i]] = -1
|
| 1610 |
+
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
| 1611 |
+
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
| 1612 |
+
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
| 1613 |
+
perm_mask[i] = (
|
| 1614 |
+
perm_index.reshape((labels.shape[1], 1)) <= perm_index.reshape((1, labels.shape[1]))
|
| 1615 |
+
) & masked_indices[i]
|
| 1616 |
+
|
| 1617 |
+
return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)
|
| 1618 |
+
|
| 1619 |
+
|
| 1620 |
+
@dataclass
|
| 1621 |
+
class DataCollatorWithFlattening(DefaultDataCollator):
|
| 1622 |
+
"""
|
| 1623 |
+
Data collator used for padding free approach. Does the following:
|
| 1624 |
+
|
| 1625 |
+
- concatate the entire mini batch into single long sequence [1, total_tokens]
|
| 1626 |
+
- uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
|
| 1627 |
+
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
|
| 1628 |
+
"""
|
| 1629 |
+
|
| 1630 |
+
def __init__(self, *args, return_position_ids=True, separator_id=-100, **kwargs):
|
| 1631 |
+
super().__init__(*args, **kwargs)
|
| 1632 |
+
self.return_position_ids = return_position_ids
|
| 1633 |
+
self.separator_id = separator_id
|
| 1634 |
+
warnings.warn(
|
| 1635 |
+
"Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence."
|
| 1636 |
+
"Make sure your attention computation is able to handle it!"
|
| 1637 |
+
)
|
| 1638 |
+
|
| 1639 |
+
def __call__(self, features, return_tensors=None, separator_id=None):
|
| 1640 |
+
if return_tensors is None:
|
| 1641 |
+
return_tensors = self.return_tensors
|
| 1642 |
+
if separator_id is None:
|
| 1643 |
+
separator_id = self.separator_id
|
| 1644 |
+
is_labels_provided = "labels" in features[0]
|
| 1645 |
+
ret = {"input_ids": [], "labels": []}
|
| 1646 |
+
if self.return_position_ids:
|
| 1647 |
+
ret.update({"position_ids": []})
|
| 1648 |
+
for idx in range(0, len(features)):
|
| 1649 |
+
ret["input_ids"] += features[idx]["input_ids"]
|
| 1650 |
+
if is_labels_provided:
|
| 1651 |
+
ret["labels"] += [separator_id] + features[idx]["labels"][1:]
|
| 1652 |
+
else:
|
| 1653 |
+
ret["labels"] += [separator_id] + features[idx]["input_ids"][1:]
|
| 1654 |
+
if self.return_position_ids:
|
| 1655 |
+
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
|
| 1656 |
+
return default_data_collator([ret], return_tensors)
|
.venv/lib/python3.11/site-packages/transformers/data/datasets/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .glue import GlueDataset, GlueDataTrainingArguments
|
| 16 |
+
from .language_modeling import (
|
| 17 |
+
LineByLineTextDataset,
|
| 18 |
+
LineByLineWithRefDataset,
|
| 19 |
+
LineByLineWithSOPTextDataset,
|
| 20 |
+
TextDataset,
|
| 21 |
+
TextDatasetForNextSentencePrediction,
|
| 22 |
+
)
|
| 23 |
+
from .squad import SquadDataset, SquadDataTrainingArguments
|
.venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (674 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/glue.cpython-311.pyc
ADDED
|
Binary file (7.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/language_modeling.cpython-311.pyc
ADDED
|
Binary file (27.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/datasets/__pycache__/squad.cpython-311.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/datasets/glue.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
import warnings
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from enum import Enum
|
| 20 |
+
from typing import List, Optional, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from filelock import FileLock
|
| 24 |
+
from torch.utils.data import Dataset
|
| 25 |
+
|
| 26 |
+
from ...tokenization_utils_base import PreTrainedTokenizerBase
|
| 27 |
+
from ...utils import logging
|
| 28 |
+
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
|
| 29 |
+
from ..processors.utils import InputFeatures
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class GlueDataTrainingArguments:
|
| 37 |
+
"""
|
| 38 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 39 |
+
|
| 40 |
+
Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
|
| 41 |
+
line.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
|
| 45 |
+
data_dir: str = field(
|
| 46 |
+
metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
|
| 47 |
+
)
|
| 48 |
+
max_seq_length: int = field(
|
| 49 |
+
default=128,
|
| 50 |
+
metadata={
|
| 51 |
+
"help": (
|
| 52 |
+
"The maximum total input sequence length after tokenization. Sequences longer "
|
| 53 |
+
"than this will be truncated, sequences shorter will be padded."
|
| 54 |
+
)
|
| 55 |
+
},
|
| 56 |
+
)
|
| 57 |
+
overwrite_cache: bool = field(
|
| 58 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def __post_init__(self):
|
| 62 |
+
self.task_name = self.task_name.lower()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Split(Enum):
|
| 66 |
+
train = "train"
|
| 67 |
+
dev = "dev"
|
| 68 |
+
test = "test"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class GlueDataset(Dataset):
|
| 72 |
+
"""
|
| 73 |
+
This will be superseded by a framework-agnostic approach soon.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
args: GlueDataTrainingArguments
|
| 77 |
+
output_mode: str
|
| 78 |
+
features: List[InputFeatures]
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
args: GlueDataTrainingArguments,
|
| 83 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 84 |
+
limit_length: Optional[int] = None,
|
| 85 |
+
mode: Union[str, Split] = Split.train,
|
| 86 |
+
cache_dir: Optional[str] = None,
|
| 87 |
+
):
|
| 88 |
+
warnings.warn(
|
| 89 |
+
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
| 90 |
+
"library. You can have a look at this example script for pointers: "
|
| 91 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py",
|
| 92 |
+
FutureWarning,
|
| 93 |
+
)
|
| 94 |
+
self.args = args
|
| 95 |
+
self.processor = glue_processors[args.task_name]()
|
| 96 |
+
self.output_mode = glue_output_modes[args.task_name]
|
| 97 |
+
if isinstance(mode, str):
|
| 98 |
+
try:
|
| 99 |
+
mode = Split[mode]
|
| 100 |
+
except KeyError:
|
| 101 |
+
raise KeyError("mode is not a valid split name")
|
| 102 |
+
# Load data features from cache or dataset file
|
| 103 |
+
cached_features_file = os.path.join(
|
| 104 |
+
cache_dir if cache_dir is not None else args.data_dir,
|
| 105 |
+
f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{args.task_name}",
|
| 106 |
+
)
|
| 107 |
+
label_list = self.processor.get_labels()
|
| 108 |
+
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in (
|
| 109 |
+
"RobertaTokenizer",
|
| 110 |
+
"RobertaTokenizerFast",
|
| 111 |
+
"XLMRobertaTokenizer",
|
| 112 |
+
"BartTokenizer",
|
| 113 |
+
"BartTokenizerFast",
|
| 114 |
+
):
|
| 115 |
+
# HACK(label indices are swapped in RoBERTa pretrained model)
|
| 116 |
+
label_list[1], label_list[2] = label_list[2], label_list[1]
|
| 117 |
+
self.label_list = label_list
|
| 118 |
+
|
| 119 |
+
# Make sure only the first process in distributed training processes the dataset,
|
| 120 |
+
# and the others will use the cache.
|
| 121 |
+
lock_path = cached_features_file + ".lock"
|
| 122 |
+
with FileLock(lock_path):
|
| 123 |
+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
| 124 |
+
start = time.time()
|
| 125 |
+
self.features = torch.load(cached_features_file)
|
| 126 |
+
logger.info(
|
| 127 |
+
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
logger.info(f"Creating features from dataset file at {args.data_dir}")
|
| 131 |
+
|
| 132 |
+
if mode == Split.dev:
|
| 133 |
+
examples = self.processor.get_dev_examples(args.data_dir)
|
| 134 |
+
elif mode == Split.test:
|
| 135 |
+
examples = self.processor.get_test_examples(args.data_dir)
|
| 136 |
+
else:
|
| 137 |
+
examples = self.processor.get_train_examples(args.data_dir)
|
| 138 |
+
if limit_length is not None:
|
| 139 |
+
examples = examples[:limit_length]
|
| 140 |
+
self.features = glue_convert_examples_to_features(
|
| 141 |
+
examples,
|
| 142 |
+
tokenizer,
|
| 143 |
+
max_length=args.max_seq_length,
|
| 144 |
+
label_list=label_list,
|
| 145 |
+
output_mode=self.output_mode,
|
| 146 |
+
)
|
| 147 |
+
start = time.time()
|
| 148 |
+
torch.save(self.features, cached_features_file)
|
| 149 |
+
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
|
| 150 |
+
logger.info(
|
| 151 |
+
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def __len__(self):
|
| 155 |
+
return len(self.features)
|
| 156 |
+
|
| 157 |
+
def __getitem__(self, i) -> InputFeatures:
|
| 158 |
+
return self.features[i]
|
| 159 |
+
|
| 160 |
+
def get_labels(self):
|
| 161 |
+
return self.label_list
|
.venv/lib/python3.11/site-packages/transformers/data/datasets/language_modeling.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import pickle
|
| 18 |
+
import random
|
| 19 |
+
import time
|
| 20 |
+
import warnings
|
| 21 |
+
from typing import Dict, List, Optional
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from filelock import FileLock
|
| 25 |
+
from torch.utils.data import Dataset
|
| 26 |
+
|
| 27 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 28 |
+
from ...utils import logging
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
DEPRECATION_WARNING = (
|
| 35 |
+
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
| 36 |
+
"library. You can have a look at this example script for pointers: {0}"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class TextDataset(Dataset):
|
| 41 |
+
"""
|
| 42 |
+
This will be superseded by a framework-agnostic approach soon.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
tokenizer: PreTrainedTokenizer,
|
| 48 |
+
file_path: str,
|
| 49 |
+
block_size: int,
|
| 50 |
+
overwrite_cache=False,
|
| 51 |
+
cache_dir: Optional[str] = None,
|
| 52 |
+
):
|
| 53 |
+
warnings.warn(
|
| 54 |
+
DEPRECATION_WARNING.format(
|
| 55 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
|
| 56 |
+
),
|
| 57 |
+
FutureWarning,
|
| 58 |
+
)
|
| 59 |
+
if os.path.isfile(file_path) is False:
|
| 60 |
+
raise ValueError(f"Input file path {file_path} not found")
|
| 61 |
+
|
| 62 |
+
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
|
| 63 |
+
|
| 64 |
+
directory, filename = os.path.split(file_path)
|
| 65 |
+
cached_features_file = os.path.join(
|
| 66 |
+
cache_dir if cache_dir is not None else directory,
|
| 67 |
+
f"cached_lm_{tokenizer.__class__.__name__}_{block_size}_{filename}",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Make sure only the first process in distributed training processes the dataset,
|
| 71 |
+
# and the others will use the cache.
|
| 72 |
+
lock_path = cached_features_file + ".lock"
|
| 73 |
+
with FileLock(lock_path):
|
| 74 |
+
if os.path.exists(cached_features_file) and not overwrite_cache:
|
| 75 |
+
start = time.time()
|
| 76 |
+
with open(cached_features_file, "rb") as handle:
|
| 77 |
+
self.examples = pickle.load(handle)
|
| 78 |
+
logger.info(
|
| 79 |
+
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
else:
|
| 83 |
+
logger.info(f"Creating features from dataset file at {directory}")
|
| 84 |
+
|
| 85 |
+
self.examples = []
|
| 86 |
+
with open(file_path, encoding="utf-8") as f:
|
| 87 |
+
text = f.read()
|
| 88 |
+
|
| 89 |
+
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
| 90 |
+
|
| 91 |
+
for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size
|
| 92 |
+
self.examples.append(
|
| 93 |
+
tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size])
|
| 94 |
+
)
|
| 95 |
+
# Note that we are losing the last truncated example here for the sake of simplicity (no padding)
|
| 96 |
+
# If your dataset is small, first you should look for a bigger one :-) and second you
|
| 97 |
+
# can change this behavior by adding (model specific) padding.
|
| 98 |
+
|
| 99 |
+
start = time.time()
|
| 100 |
+
with open(cached_features_file, "wb") as handle:
|
| 101 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
| 102 |
+
logger.info(
|
| 103 |
+
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def __len__(self):
|
| 107 |
+
return len(self.examples)
|
| 108 |
+
|
| 109 |
+
def __getitem__(self, i) -> torch.Tensor:
|
| 110 |
+
return torch.tensor(self.examples[i], dtype=torch.long)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class LineByLineTextDataset(Dataset):
|
| 114 |
+
"""
|
| 115 |
+
This will be superseded by a framework-agnostic approach soon.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
|
| 119 |
+
warnings.warn(
|
| 120 |
+
DEPRECATION_WARNING.format(
|
| 121 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
|
| 122 |
+
),
|
| 123 |
+
FutureWarning,
|
| 124 |
+
)
|
| 125 |
+
if os.path.isfile(file_path) is False:
|
| 126 |
+
raise ValueError(f"Input file path {file_path} not found")
|
| 127 |
+
# Here, we do not cache the features, operating under the assumption
|
| 128 |
+
# that we will soon use fast multithreaded tokenizers from the
|
| 129 |
+
# `tokenizers` repo everywhere =)
|
| 130 |
+
logger.info(f"Creating features from dataset file at {file_path}")
|
| 131 |
+
|
| 132 |
+
with open(file_path, encoding="utf-8") as f:
|
| 133 |
+
lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
| 134 |
+
|
| 135 |
+
batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)
|
| 136 |
+
self.examples = batch_encoding["input_ids"]
|
| 137 |
+
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
|
| 138 |
+
|
| 139 |
+
def __len__(self):
|
| 140 |
+
return len(self.examples)
|
| 141 |
+
|
| 142 |
+
def __getitem__(self, i) -> Dict[str, torch.tensor]:
|
| 143 |
+
return self.examples[i]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class LineByLineWithRefDataset(Dataset):
|
| 147 |
+
"""
|
| 148 |
+
This will be superseded by a framework-agnostic approach soon.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
|
| 152 |
+
warnings.warn(
|
| 153 |
+
DEPRECATION_WARNING.format(
|
| 154 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm_wwm.py"
|
| 155 |
+
),
|
| 156 |
+
FutureWarning,
|
| 157 |
+
)
|
| 158 |
+
if os.path.isfile(file_path) is False:
|
| 159 |
+
raise ValueError(f"Input file path {file_path} not found")
|
| 160 |
+
if os.path.isfile(ref_path) is False:
|
| 161 |
+
raise ValueError(f"Ref file path {file_path} not found")
|
| 162 |
+
# Here, we do not cache the features, operating under the assumption
|
| 163 |
+
# that we will soon use fast multithreaded tokenizers from the
|
| 164 |
+
# `tokenizers` repo everywhere =)
|
| 165 |
+
logger.info(f"Creating features from dataset file at {file_path}")
|
| 166 |
+
logger.info(f"Use ref segment results at {ref_path}")
|
| 167 |
+
with open(file_path, encoding="utf-8") as f:
|
| 168 |
+
data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line
|
| 169 |
+
data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
|
| 170 |
+
# Get ref inf from file
|
| 171 |
+
with open(ref_path, encoding="utf-8") as f:
|
| 172 |
+
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
| 173 |
+
if len(data) != len(ref):
|
| 174 |
+
raise ValueError(
|
| 175 |
+
f"Length of Input file should be equal to Ref file. But the length of {file_path} is {len(data)} "
|
| 176 |
+
f"while length of {ref_path} is {len(ref)}"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
|
| 180 |
+
self.examples = batch_encoding["input_ids"]
|
| 181 |
+
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
|
| 182 |
+
|
| 183 |
+
n = len(self.examples)
|
| 184 |
+
for i in range(n):
|
| 185 |
+
self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
|
| 186 |
+
|
| 187 |
+
def __len__(self):
|
| 188 |
+
return len(self.examples)
|
| 189 |
+
|
| 190 |
+
def __getitem__(self, i) -> Dict[str, torch.tensor]:
|
| 191 |
+
return self.examples[i]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class LineByLineWithSOPTextDataset(Dataset):
|
| 195 |
+
"""
|
| 196 |
+
Dataset for sentence order prediction task, prepare sentence pairs for SOP task
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
|
| 200 |
+
warnings.warn(
|
| 201 |
+
DEPRECATION_WARNING.format(
|
| 202 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
|
| 203 |
+
),
|
| 204 |
+
FutureWarning,
|
| 205 |
+
)
|
| 206 |
+
if os.path.isdir(file_dir) is False:
|
| 207 |
+
raise ValueError(f"{file_dir} is not a directory")
|
| 208 |
+
logger.info(f"Creating features from dataset file folder at {file_dir}")
|
| 209 |
+
self.examples = []
|
| 210 |
+
# TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
|
| 211 |
+
# file path looks like ./dataset/wiki_1, ./dataset/wiki_2
|
| 212 |
+
for file_name in os.listdir(file_dir):
|
| 213 |
+
file_path = os.path.join(file_dir, file_name)
|
| 214 |
+
if os.path.isfile(file_path) is False:
|
| 215 |
+
raise ValueError(f"{file_path} is not a file")
|
| 216 |
+
article_open = False
|
| 217 |
+
with open(file_path, encoding="utf-8") as f:
|
| 218 |
+
original_lines = f.readlines()
|
| 219 |
+
article_lines = []
|
| 220 |
+
for line in original_lines:
|
| 221 |
+
if "<doc id=" in line:
|
| 222 |
+
article_open = True
|
| 223 |
+
elif "</doc>" in line:
|
| 224 |
+
article_open = False
|
| 225 |
+
document = [
|
| 226 |
+
tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))
|
| 227 |
+
for line in article_lines[1:]
|
| 228 |
+
if (len(line) > 0 and not line.isspace())
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
examples = self.create_examples_from_document(document, block_size, tokenizer)
|
| 232 |
+
self.examples.extend(examples)
|
| 233 |
+
article_lines = []
|
| 234 |
+
else:
|
| 235 |
+
if article_open:
|
| 236 |
+
article_lines.append(line)
|
| 237 |
+
|
| 238 |
+
logger.info("Dataset parse finished.")
|
| 239 |
+
|
| 240 |
+
def create_examples_from_document(self, document, block_size, tokenizer, short_seq_prob=0.1):
|
| 241 |
+
"""Creates examples for a single document."""
|
| 242 |
+
|
| 243 |
+
# Account for special tokens
|
| 244 |
+
max_num_tokens = block_size - tokenizer.num_special_tokens_to_add(pair=True)
|
| 245 |
+
|
| 246 |
+
# We *usually* want to fill up the entire sequence since we are padding
|
| 247 |
+
# to `block_size` anyways, so short sequences are generally wasted
|
| 248 |
+
# computation. However, we *sometimes*
|
| 249 |
+
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
| 250 |
+
# sequences to minimize the mismatch between pretraining and fine-tuning.
|
| 251 |
+
# The `target_seq_length` is just a rough target however, whereas
|
| 252 |
+
# `block_size` is a hard limit.
|
| 253 |
+
target_seq_length = max_num_tokens
|
| 254 |
+
if random.random() < short_seq_prob:
|
| 255 |
+
target_seq_length = random.randint(2, max_num_tokens)
|
| 256 |
+
|
| 257 |
+
# We DON'T just concatenate all of the tokens from a document into a long
|
| 258 |
+
# sequence and choose an arbitrary split point because this would make the
|
| 259 |
+
# next sentence prediction task too easy. Instead, we split the input into
|
| 260 |
+
# segments "A" and "B" based on the actual "sentences" provided by the user
|
| 261 |
+
# input.
|
| 262 |
+
examples = []
|
| 263 |
+
current_chunk = [] # a buffer stored current working segments
|
| 264 |
+
current_length = 0
|
| 265 |
+
i = 0
|
| 266 |
+
while i < len(document):
|
| 267 |
+
segment = document[i] # get a segment
|
| 268 |
+
if not segment:
|
| 269 |
+
i += 1
|
| 270 |
+
continue
|
| 271 |
+
current_chunk.append(segment) # add a segment to current chunk
|
| 272 |
+
current_length += len(segment) # overall token length
|
| 273 |
+
# if current length goes to the target length or reaches the end of file, start building token a and b
|
| 274 |
+
if i == len(document) - 1 or current_length >= target_seq_length:
|
| 275 |
+
if current_chunk:
|
| 276 |
+
# `a_end` is how many segments from `current_chunk` go into the `A` (first) sentence.
|
| 277 |
+
a_end = 1
|
| 278 |
+
# if current chunk has more than 2 sentences, pick part of it `A` (first) sentence
|
| 279 |
+
if len(current_chunk) >= 2:
|
| 280 |
+
a_end = random.randint(1, len(current_chunk) - 1)
|
| 281 |
+
# token a
|
| 282 |
+
tokens_a = []
|
| 283 |
+
for j in range(a_end):
|
| 284 |
+
tokens_a.extend(current_chunk[j])
|
| 285 |
+
|
| 286 |
+
# token b
|
| 287 |
+
tokens_b = []
|
| 288 |
+
for j in range(a_end, len(current_chunk)):
|
| 289 |
+
tokens_b.extend(current_chunk[j])
|
| 290 |
+
|
| 291 |
+
if len(tokens_a) == 0 or len(tokens_b) == 0:
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
# switch tokens_a and tokens_b randomly
|
| 295 |
+
if random.random() < 0.5:
|
| 296 |
+
is_next = False
|
| 297 |
+
tokens_a, tokens_b = tokens_b, tokens_a
|
| 298 |
+
else:
|
| 299 |
+
is_next = True
|
| 300 |
+
|
| 301 |
+
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
|
| 302 |
+
"""Truncates a pair of sequences to a maximum sequence length."""
|
| 303 |
+
while True:
|
| 304 |
+
total_length = len(tokens_a) + len(tokens_b)
|
| 305 |
+
if total_length <= max_num_tokens:
|
| 306 |
+
break
|
| 307 |
+
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
| 308 |
+
if not (len(trunc_tokens) >= 1):
|
| 309 |
+
raise ValueError("Sequence length to be truncated must be no less than one")
|
| 310 |
+
# We want to sometimes truncate from the front and sometimes from the
|
| 311 |
+
# back to add more randomness and avoid biases.
|
| 312 |
+
if random.random() < 0.5:
|
| 313 |
+
del trunc_tokens[0]
|
| 314 |
+
else:
|
| 315 |
+
trunc_tokens.pop()
|
| 316 |
+
|
| 317 |
+
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
|
| 318 |
+
if not (len(tokens_a) >= 1):
|
| 319 |
+
raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
|
| 320 |
+
if not (len(tokens_b) >= 1):
|
| 321 |
+
raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
|
| 322 |
+
|
| 323 |
+
# add special tokens
|
| 324 |
+
input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
| 325 |
+
# add token type ids, 0 for sentence a, 1 for sentence b
|
| 326 |
+
token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
|
| 327 |
+
|
| 328 |
+
example = {
|
| 329 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 330 |
+
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
|
| 331 |
+
"sentence_order_label": torch.tensor(0 if is_next else 1, dtype=torch.long),
|
| 332 |
+
}
|
| 333 |
+
examples.append(example)
|
| 334 |
+
current_chunk = [] # clear current chunk
|
| 335 |
+
current_length = 0 # reset current text length
|
| 336 |
+
i += 1 # go to next line
|
| 337 |
+
return examples
|
| 338 |
+
|
| 339 |
+
def __len__(self):
|
| 340 |
+
return len(self.examples)
|
| 341 |
+
|
| 342 |
+
def __getitem__(self, i) -> Dict[str, torch.tensor]:
|
| 343 |
+
return self.examples[i]
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class TextDatasetForNextSentencePrediction(Dataset):
|
| 347 |
+
"""
|
| 348 |
+
This will be superseded by a framework-agnostic approach soon.
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
def __init__(
|
| 352 |
+
self,
|
| 353 |
+
tokenizer: PreTrainedTokenizer,
|
| 354 |
+
file_path: str,
|
| 355 |
+
block_size: int,
|
| 356 |
+
overwrite_cache=False,
|
| 357 |
+
short_seq_probability=0.1,
|
| 358 |
+
nsp_probability=0.5,
|
| 359 |
+
):
|
| 360 |
+
warnings.warn(
|
| 361 |
+
DEPRECATION_WARNING.format(
|
| 362 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
|
| 363 |
+
),
|
| 364 |
+
FutureWarning,
|
| 365 |
+
)
|
| 366 |
+
if not os.path.isfile(file_path):
|
| 367 |
+
raise ValueError(f"Input file path {file_path} not found")
|
| 368 |
+
|
| 369 |
+
self.short_seq_probability = short_seq_probability
|
| 370 |
+
self.nsp_probability = nsp_probability
|
| 371 |
+
|
| 372 |
+
directory, filename = os.path.split(file_path)
|
| 373 |
+
cached_features_file = os.path.join(
|
| 374 |
+
directory,
|
| 375 |
+
f"cached_nsp_{tokenizer.__class__.__name__}_{block_size}_{filename}",
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
self.tokenizer = tokenizer
|
| 379 |
+
|
| 380 |
+
# Make sure only the first process in distributed training processes the dataset,
|
| 381 |
+
# and the others will use the cache.
|
| 382 |
+
lock_path = cached_features_file + ".lock"
|
| 383 |
+
|
| 384 |
+
# Input file format:
|
| 385 |
+
# (1) One sentence per line. These should ideally be actual sentences, not
|
| 386 |
+
# entire paragraphs or arbitrary spans of text. (Because we use the
|
| 387 |
+
# sentence boundaries for the "next sentence prediction" task).
|
| 388 |
+
# (2) Blank lines between documents. Document boundaries are needed so
|
| 389 |
+
# that the "next sentence prediction" task doesn't span between documents.
|
| 390 |
+
#
|
| 391 |
+
# Example:
|
| 392 |
+
# I am very happy.
|
| 393 |
+
# Here is the second sentence.
|
| 394 |
+
#
|
| 395 |
+
# A new document.
|
| 396 |
+
|
| 397 |
+
with FileLock(lock_path):
|
| 398 |
+
if os.path.exists(cached_features_file) and not overwrite_cache:
|
| 399 |
+
start = time.time()
|
| 400 |
+
with open(cached_features_file, "rb") as handle:
|
| 401 |
+
self.examples = pickle.load(handle)
|
| 402 |
+
logger.info(
|
| 403 |
+
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
| 404 |
+
)
|
| 405 |
+
else:
|
| 406 |
+
logger.info(f"Creating features from dataset file at {directory}")
|
| 407 |
+
|
| 408 |
+
self.documents = [[]]
|
| 409 |
+
with open(file_path, encoding="utf-8") as f:
|
| 410 |
+
while True:
|
| 411 |
+
line = f.readline()
|
| 412 |
+
if not line:
|
| 413 |
+
break
|
| 414 |
+
line = line.strip()
|
| 415 |
+
|
| 416 |
+
# Empty lines are used as document delimiters
|
| 417 |
+
if not line and len(self.documents[-1]) != 0:
|
| 418 |
+
self.documents.append([])
|
| 419 |
+
tokens = tokenizer.tokenize(line)
|
| 420 |
+
tokens = tokenizer.convert_tokens_to_ids(tokens)
|
| 421 |
+
if tokens:
|
| 422 |
+
self.documents[-1].append(tokens)
|
| 423 |
+
|
| 424 |
+
logger.info(f"Creating examples from {len(self.documents)} documents.")
|
| 425 |
+
self.examples = []
|
| 426 |
+
for doc_index, document in enumerate(self.documents):
|
| 427 |
+
self.create_examples_from_document(document, doc_index, block_size)
|
| 428 |
+
|
| 429 |
+
start = time.time()
|
| 430 |
+
with open(cached_features_file, "wb") as handle:
|
| 431 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
| 432 |
+
logger.info(
|
| 433 |
+
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
def create_examples_from_document(self, document: List[List[int]], doc_index: int, block_size: int):
|
| 437 |
+
"""Creates examples for a single document."""
|
| 438 |
+
|
| 439 |
+
max_num_tokens = block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
|
| 440 |
+
|
| 441 |
+
# We *usually* want to fill up the entire sequence since we are padding
|
| 442 |
+
# to `block_size` anyways, so short sequences are generally wasted
|
| 443 |
+
# computation. However, we *sometimes*
|
| 444 |
+
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
| 445 |
+
# sequences to minimize the mismatch between pretraining and fine-tuning.
|
| 446 |
+
# The `target_seq_length` is just a rough target however, whereas
|
| 447 |
+
# `block_size` is a hard limit.
|
| 448 |
+
target_seq_length = max_num_tokens
|
| 449 |
+
if random.random() < self.short_seq_probability:
|
| 450 |
+
target_seq_length = random.randint(2, max_num_tokens)
|
| 451 |
+
|
| 452 |
+
current_chunk = [] # a buffer stored current working segments
|
| 453 |
+
current_length = 0
|
| 454 |
+
i = 0
|
| 455 |
+
|
| 456 |
+
while i < len(document):
|
| 457 |
+
segment = document[i]
|
| 458 |
+
current_chunk.append(segment)
|
| 459 |
+
current_length += len(segment)
|
| 460 |
+
if i == len(document) - 1 or current_length >= target_seq_length:
|
| 461 |
+
if current_chunk:
|
| 462 |
+
# `a_end` is how many segments from `current_chunk` go into the `A`
|
| 463 |
+
# (first) sentence.
|
| 464 |
+
a_end = 1
|
| 465 |
+
if len(current_chunk) >= 2:
|
| 466 |
+
a_end = random.randint(1, len(current_chunk) - 1)
|
| 467 |
+
|
| 468 |
+
tokens_a = []
|
| 469 |
+
for j in range(a_end):
|
| 470 |
+
tokens_a.extend(current_chunk[j])
|
| 471 |
+
|
| 472 |
+
tokens_b = []
|
| 473 |
+
|
| 474 |
+
if len(current_chunk) == 1 or random.random() < self.nsp_probability:
|
| 475 |
+
is_random_next = True
|
| 476 |
+
target_b_length = target_seq_length - len(tokens_a)
|
| 477 |
+
|
| 478 |
+
# This should rarely go for more than one iteration for large
|
| 479 |
+
# corpora. However, just to be careful, we try to make sure that
|
| 480 |
+
# the random document is not the same as the document
|
| 481 |
+
# we're processing.
|
| 482 |
+
for _ in range(10):
|
| 483 |
+
random_document_index = random.randint(0, len(self.documents) - 1)
|
| 484 |
+
if random_document_index != doc_index:
|
| 485 |
+
break
|
| 486 |
+
|
| 487 |
+
random_document = self.documents[random_document_index]
|
| 488 |
+
random_start = random.randint(0, len(random_document) - 1)
|
| 489 |
+
for j in range(random_start, len(random_document)):
|
| 490 |
+
tokens_b.extend(random_document[j])
|
| 491 |
+
if len(tokens_b) >= target_b_length:
|
| 492 |
+
break
|
| 493 |
+
# We didn't actually use these segments so we "put them back" so
|
| 494 |
+
# they don't go to waste.
|
| 495 |
+
num_unused_segments = len(current_chunk) - a_end
|
| 496 |
+
i -= num_unused_segments
|
| 497 |
+
# Actual next
|
| 498 |
+
else:
|
| 499 |
+
is_random_next = False
|
| 500 |
+
for j in range(a_end, len(current_chunk)):
|
| 501 |
+
tokens_b.extend(current_chunk[j])
|
| 502 |
+
|
| 503 |
+
if not (len(tokens_a) >= 1):
|
| 504 |
+
raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
|
| 505 |
+
if not (len(tokens_b) >= 1):
|
| 506 |
+
raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
|
| 507 |
+
|
| 508 |
+
# add special tokens
|
| 509 |
+
input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
| 510 |
+
# add token type ids, 0 for sentence a, 1 for sentence b
|
| 511 |
+
token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
|
| 512 |
+
|
| 513 |
+
example = {
|
| 514 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 515 |
+
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
|
| 516 |
+
"next_sentence_label": torch.tensor(1 if is_random_next else 0, dtype=torch.long),
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
self.examples.append(example)
|
| 520 |
+
|
| 521 |
+
current_chunk = []
|
| 522 |
+
current_length = 0
|
| 523 |
+
|
| 524 |
+
i += 1
|
| 525 |
+
|
| 526 |
+
def __len__(self):
|
| 527 |
+
return len(self.examples)
|
| 528 |
+
|
| 529 |
+
def __getitem__(self, i):
|
| 530 |
+
return self.examples[i]
|
.venv/lib/python3.11/site-packages/transformers/data/datasets/squad.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from enum import Enum
|
| 19 |
+
from typing import Dict, List, Optional, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from filelock import FileLock
|
| 23 |
+
from torch.utils.data import Dataset
|
| 24 |
+
|
| 25 |
+
from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
| 26 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 27 |
+
from ...utils import logging
|
| 28 |
+
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
|
| 34 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class SquadDataTrainingArguments:
|
| 39 |
+
"""
|
| 40 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
model_type: str = field(
|
| 44 |
+
default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)}
|
| 45 |
+
)
|
| 46 |
+
data_dir: str = field(
|
| 47 |
+
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
|
| 48 |
+
)
|
| 49 |
+
max_seq_length: int = field(
|
| 50 |
+
default=128,
|
| 51 |
+
metadata={
|
| 52 |
+
"help": (
|
| 53 |
+
"The maximum total input sequence length after tokenization. Sequences longer "
|
| 54 |
+
"than this will be truncated, sequences shorter will be padded."
|
| 55 |
+
)
|
| 56 |
+
},
|
| 57 |
+
)
|
| 58 |
+
doc_stride: int = field(
|
| 59 |
+
default=128,
|
| 60 |
+
metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
|
| 61 |
+
)
|
| 62 |
+
max_query_length: int = field(
|
| 63 |
+
default=64,
|
| 64 |
+
metadata={
|
| 65 |
+
"help": (
|
| 66 |
+
"The maximum number of tokens for the question. Questions longer than this will "
|
| 67 |
+
"be truncated to this length."
|
| 68 |
+
)
|
| 69 |
+
},
|
| 70 |
+
)
|
| 71 |
+
max_answer_length: int = field(
|
| 72 |
+
default=30,
|
| 73 |
+
metadata={
|
| 74 |
+
"help": (
|
| 75 |
+
"The maximum length of an answer that can be generated. This is needed because the start "
|
| 76 |
+
"and end predictions are not conditioned on one another."
|
| 77 |
+
)
|
| 78 |
+
},
|
| 79 |
+
)
|
| 80 |
+
overwrite_cache: bool = field(
|
| 81 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 82 |
+
)
|
| 83 |
+
version_2_with_negative: bool = field(
|
| 84 |
+
default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."}
|
| 85 |
+
)
|
| 86 |
+
null_score_diff_threshold: float = field(
|
| 87 |
+
default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
|
| 88 |
+
)
|
| 89 |
+
n_best_size: int = field(
|
| 90 |
+
default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
|
| 91 |
+
)
|
| 92 |
+
lang_id: int = field(
|
| 93 |
+
default=0,
|
| 94 |
+
metadata={
|
| 95 |
+
"help": (
|
| 96 |
+
"language id of input for language-specific xlm models (see"
|
| 97 |
+
" tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
|
| 98 |
+
)
|
| 99 |
+
},
|
| 100 |
+
)
|
| 101 |
+
threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"})
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Split(Enum):
|
| 105 |
+
train = "train"
|
| 106 |
+
dev = "dev"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SquadDataset(Dataset):
|
| 110 |
+
"""
|
| 111 |
+
This will be superseded by a framework-agnostic approach soon.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
args: SquadDataTrainingArguments
|
| 115 |
+
features: List[SquadFeatures]
|
| 116 |
+
mode: Split
|
| 117 |
+
is_language_sensitive: bool
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
args: SquadDataTrainingArguments,
|
| 122 |
+
tokenizer: PreTrainedTokenizer,
|
| 123 |
+
limit_length: Optional[int] = None,
|
| 124 |
+
mode: Union[str, Split] = Split.train,
|
| 125 |
+
is_language_sensitive: Optional[bool] = False,
|
| 126 |
+
cache_dir: Optional[str] = None,
|
| 127 |
+
dataset_format: Optional[str] = "pt",
|
| 128 |
+
):
|
| 129 |
+
self.args = args
|
| 130 |
+
self.is_language_sensitive = is_language_sensitive
|
| 131 |
+
self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
| 132 |
+
if isinstance(mode, str):
|
| 133 |
+
try:
|
| 134 |
+
mode = Split[mode]
|
| 135 |
+
except KeyError:
|
| 136 |
+
raise KeyError("mode is not a valid split name")
|
| 137 |
+
self.mode = mode
|
| 138 |
+
# Load data features from cache or dataset file
|
| 139 |
+
version_tag = "v2" if args.version_2_with_negative else "v1"
|
| 140 |
+
cached_features_file = os.path.join(
|
| 141 |
+
cache_dir if cache_dir is not None else args.data_dir,
|
| 142 |
+
f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{version_tag}",
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Make sure only the first process in distributed training processes the dataset,
|
| 146 |
+
# and the others will use the cache.
|
| 147 |
+
lock_path = cached_features_file + ".lock"
|
| 148 |
+
with FileLock(lock_path):
|
| 149 |
+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
| 150 |
+
start = time.time()
|
| 151 |
+
self.old_features = torch.load(cached_features_file)
|
| 152 |
+
|
| 153 |
+
# Legacy cache files have only features, while new cache files
|
| 154 |
+
# will have dataset and examples also.
|
| 155 |
+
self.features = self.old_features["features"]
|
| 156 |
+
self.dataset = self.old_features.get("dataset", None)
|
| 157 |
+
self.examples = self.old_features.get("examples", None)
|
| 158 |
+
logger.info(
|
| 159 |
+
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
if self.dataset is None or self.examples is None:
|
| 163 |
+
logger.warning(
|
| 164 |
+
f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in"
|
| 165 |
+
" future run"
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
if mode == Split.dev:
|
| 169 |
+
self.examples = self.processor.get_dev_examples(args.data_dir)
|
| 170 |
+
else:
|
| 171 |
+
self.examples = self.processor.get_train_examples(args.data_dir)
|
| 172 |
+
|
| 173 |
+
self.features, self.dataset = squad_convert_examples_to_features(
|
| 174 |
+
examples=self.examples,
|
| 175 |
+
tokenizer=tokenizer,
|
| 176 |
+
max_seq_length=args.max_seq_length,
|
| 177 |
+
doc_stride=args.doc_stride,
|
| 178 |
+
max_query_length=args.max_query_length,
|
| 179 |
+
is_training=mode == Split.train,
|
| 180 |
+
threads=args.threads,
|
| 181 |
+
return_dataset=dataset_format,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
start = time.time()
|
| 185 |
+
torch.save(
|
| 186 |
+
{"features": self.features, "dataset": self.dataset, "examples": self.examples},
|
| 187 |
+
cached_features_file,
|
| 188 |
+
)
|
| 189 |
+
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
|
| 190 |
+
logger.info(
|
| 191 |
+
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def __len__(self):
|
| 195 |
+
return len(self.features)
|
| 196 |
+
|
| 197 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 198 |
+
# Convert to Tensors and build dataset
|
| 199 |
+
feature = self.features[i]
|
| 200 |
+
|
| 201 |
+
input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
|
| 202 |
+
attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long)
|
| 203 |
+
token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long)
|
| 204 |
+
cls_index = torch.tensor(feature.cls_index, dtype=torch.long)
|
| 205 |
+
p_mask = torch.tensor(feature.p_mask, dtype=torch.float)
|
| 206 |
+
is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float)
|
| 207 |
+
|
| 208 |
+
inputs = {
|
| 209 |
+
"input_ids": input_ids,
|
| 210 |
+
"attention_mask": attention_mask,
|
| 211 |
+
"token_type_ids": token_type_ids,
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
|
| 215 |
+
del inputs["token_type_ids"]
|
| 216 |
+
|
| 217 |
+
if self.args.model_type in ["xlnet", "xlm"]:
|
| 218 |
+
inputs.update({"cls_index": cls_index, "p_mask": p_mask})
|
| 219 |
+
if self.args.version_2_with_negative:
|
| 220 |
+
inputs.update({"is_impossible": is_impossible})
|
| 221 |
+
if self.is_language_sensitive:
|
| 222 |
+
inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)})
|
| 223 |
+
|
| 224 |
+
if self.mode == Split.train:
|
| 225 |
+
start_positions = torch.tensor(feature.start_position, dtype=torch.long)
|
| 226 |
+
end_positions = torch.tensor(feature.end_position, dtype=torch.long)
|
| 227 |
+
inputs.update({"start_positions": start_positions, "end_positions": end_positions})
|
| 228 |
+
|
| 229 |
+
return inputs
|
.venv/lib/python3.11/site-packages/transformers/data/metrics/__init__.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
from ...utils import is_sklearn_available, requires_backends
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if is_sklearn_available():
|
| 19 |
+
from scipy.stats import pearsonr, spearmanr
|
| 20 |
+
from sklearn.metrics import f1_score, matthews_corrcoef
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
DEPRECATION_WARNING = (
|
| 24 |
+
"This metric will be removed from the library soon, metrics should be handled with the 🤗 Evaluate "
|
| 25 |
+
"library. You can have a look at this example script for pointers: "
|
| 26 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def simple_accuracy(preds, labels):
|
| 31 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
| 32 |
+
requires_backends(simple_accuracy, "sklearn")
|
| 33 |
+
return (preds == labels).mean()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def acc_and_f1(preds, labels):
|
| 37 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
| 38 |
+
requires_backends(acc_and_f1, "sklearn")
|
| 39 |
+
acc = simple_accuracy(preds, labels)
|
| 40 |
+
f1 = f1_score(y_true=labels, y_pred=preds)
|
| 41 |
+
return {
|
| 42 |
+
"acc": acc,
|
| 43 |
+
"f1": f1,
|
| 44 |
+
"acc_and_f1": (acc + f1) / 2,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def pearson_and_spearman(preds, labels):
|
| 49 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
| 50 |
+
requires_backends(pearson_and_spearman, "sklearn")
|
| 51 |
+
pearson_corr = pearsonr(preds, labels)[0]
|
| 52 |
+
spearman_corr = spearmanr(preds, labels)[0]
|
| 53 |
+
return {
|
| 54 |
+
"pearson": pearson_corr,
|
| 55 |
+
"spearmanr": spearman_corr,
|
| 56 |
+
"corr": (pearson_corr + spearman_corr) / 2,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def glue_compute_metrics(task_name, preds, labels):
|
| 61 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
| 62 |
+
requires_backends(glue_compute_metrics, "sklearn")
|
| 63 |
+
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
|
| 64 |
+
if task_name == "cola":
|
| 65 |
+
return {"mcc": matthews_corrcoef(labels, preds)}
|
| 66 |
+
elif task_name == "sst-2":
|
| 67 |
+
return {"acc": simple_accuracy(preds, labels)}
|
| 68 |
+
elif task_name == "mrpc":
|
| 69 |
+
return acc_and_f1(preds, labels)
|
| 70 |
+
elif task_name == "sts-b":
|
| 71 |
+
return pearson_and_spearman(preds, labels)
|
| 72 |
+
elif task_name == "qqp":
|
| 73 |
+
return acc_and_f1(preds, labels)
|
| 74 |
+
elif task_name == "mnli":
|
| 75 |
+
return {"mnli/acc": simple_accuracy(preds, labels)}
|
| 76 |
+
elif task_name == "mnli-mm":
|
| 77 |
+
return {"mnli-mm/acc": simple_accuracy(preds, labels)}
|
| 78 |
+
elif task_name == "qnli":
|
| 79 |
+
return {"acc": simple_accuracy(preds, labels)}
|
| 80 |
+
elif task_name == "rte":
|
| 81 |
+
return {"acc": simple_accuracy(preds, labels)}
|
| 82 |
+
elif task_name == "wnli":
|
| 83 |
+
return {"acc": simple_accuracy(preds, labels)}
|
| 84 |
+
elif task_name == "hans":
|
| 85 |
+
return {"acc": simple_accuracy(preds, labels)}
|
| 86 |
+
else:
|
| 87 |
+
raise KeyError(task_name)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def xnli_compute_metrics(task_name, preds, labels):
|
| 91 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
| 92 |
+
requires_backends(xnli_compute_metrics, "sklearn")
|
| 93 |
+
if len(preds) != len(labels):
|
| 94 |
+
raise ValueError(f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}")
|
| 95 |
+
if task_name == "xnli":
|
| 96 |
+
return {"acc": simple_accuracy(preds, labels)}
|
| 97 |
+
else:
|
| 98 |
+
raise KeyError(task_name)
|
.venv/lib/python3.11/site-packages/transformers/data/metrics/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (4.64 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/metrics/__pycache__/squad_metrics.cpython-311.pyc
ADDED
|
Binary file (31 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/metrics/squad_metrics.py
ADDED
|
@@ -0,0 +1,779 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was modified by XLNet authors to
|
| 16 |
+
update `find_best_threshold` scripts for SQuAD V2.0
|
| 17 |
+
|
| 18 |
+
In addition to basic functionality, we also compute additional statistics and plot precision-recall curves if an
|
| 19 |
+
additional na_prob.json file is provided. This file is expected to map question ID's to the model's predicted
|
| 20 |
+
probability that a question is unanswerable.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import collections
|
| 24 |
+
import json
|
| 25 |
+
import math
|
| 26 |
+
import re
|
| 27 |
+
import string
|
| 28 |
+
|
| 29 |
+
from ...models.bert import BasicTokenizer
|
| 30 |
+
from ...utils import logging
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def normalize_answer(s):
|
| 37 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
| 38 |
+
|
| 39 |
+
def remove_articles(text):
|
| 40 |
+
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
| 41 |
+
return re.sub(regex, " ", text)
|
| 42 |
+
|
| 43 |
+
def white_space_fix(text):
|
| 44 |
+
return " ".join(text.split())
|
| 45 |
+
|
| 46 |
+
def remove_punc(text):
|
| 47 |
+
exclude = set(string.punctuation)
|
| 48 |
+
return "".join(ch for ch in text if ch not in exclude)
|
| 49 |
+
|
| 50 |
+
def lower(text):
|
| 51 |
+
return text.lower()
|
| 52 |
+
|
| 53 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_tokens(s):
|
| 57 |
+
if not s:
|
| 58 |
+
return []
|
| 59 |
+
return normalize_answer(s).split()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def compute_exact(a_gold, a_pred):
|
| 63 |
+
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def compute_f1(a_gold, a_pred):
|
| 67 |
+
gold_toks = get_tokens(a_gold)
|
| 68 |
+
pred_toks = get_tokens(a_pred)
|
| 69 |
+
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
| 70 |
+
num_same = sum(common.values())
|
| 71 |
+
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
| 72 |
+
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
| 73 |
+
return int(gold_toks == pred_toks)
|
| 74 |
+
if num_same == 0:
|
| 75 |
+
return 0
|
| 76 |
+
precision = 1.0 * num_same / len(pred_toks)
|
| 77 |
+
recall = 1.0 * num_same / len(gold_toks)
|
| 78 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
| 79 |
+
return f1
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_raw_scores(examples, preds):
|
| 83 |
+
"""
|
| 84 |
+
Computes the exact and f1 scores from the examples and the model predictions
|
| 85 |
+
"""
|
| 86 |
+
exact_scores = {}
|
| 87 |
+
f1_scores = {}
|
| 88 |
+
|
| 89 |
+
for example in examples:
|
| 90 |
+
qas_id = example.qas_id
|
| 91 |
+
gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
|
| 92 |
+
|
| 93 |
+
if not gold_answers:
|
| 94 |
+
# For unanswerable questions, only correct answer is empty string
|
| 95 |
+
gold_answers = [""]
|
| 96 |
+
|
| 97 |
+
if qas_id not in preds:
|
| 98 |
+
print(f"Missing prediction for {qas_id}")
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
prediction = preds[qas_id]
|
| 102 |
+
exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers)
|
| 103 |
+
f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers)
|
| 104 |
+
|
| 105 |
+
return exact_scores, f1_scores
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
|
| 109 |
+
new_scores = {}
|
| 110 |
+
for qid, s in scores.items():
|
| 111 |
+
pred_na = na_probs[qid] > na_prob_thresh
|
| 112 |
+
if pred_na:
|
| 113 |
+
new_scores[qid] = float(not qid_to_has_ans[qid])
|
| 114 |
+
else:
|
| 115 |
+
new_scores[qid] = s
|
| 116 |
+
return new_scores
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
| 120 |
+
if not qid_list:
|
| 121 |
+
total = len(exact_scores)
|
| 122 |
+
return collections.OrderedDict(
|
| 123 |
+
[
|
| 124 |
+
("exact", 100.0 * sum(exact_scores.values()) / total),
|
| 125 |
+
("f1", 100.0 * sum(f1_scores.values()) / total),
|
| 126 |
+
("total", total),
|
| 127 |
+
]
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
total = len(qid_list)
|
| 131 |
+
return collections.OrderedDict(
|
| 132 |
+
[
|
| 133 |
+
("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
| 134 |
+
("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
| 135 |
+
("total", total),
|
| 136 |
+
]
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def merge_eval(main_eval, new_eval, prefix):
|
| 141 |
+
for k in new_eval:
|
| 142 |
+
main_eval[f"{prefix}_{k}"] = new_eval[k]
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
| 146 |
+
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
| 147 |
+
cur_score = num_no_ans
|
| 148 |
+
best_score = cur_score
|
| 149 |
+
best_thresh = 0.0
|
| 150 |
+
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
| 151 |
+
for i, qid in enumerate(qid_list):
|
| 152 |
+
if qid not in scores:
|
| 153 |
+
continue
|
| 154 |
+
if qid_to_has_ans[qid]:
|
| 155 |
+
diff = scores[qid]
|
| 156 |
+
else:
|
| 157 |
+
if preds[qid]:
|
| 158 |
+
diff = -1
|
| 159 |
+
else:
|
| 160 |
+
diff = 0
|
| 161 |
+
cur_score += diff
|
| 162 |
+
if cur_score > best_score:
|
| 163 |
+
best_score = cur_score
|
| 164 |
+
best_thresh = na_probs[qid]
|
| 165 |
+
|
| 166 |
+
has_ans_score, has_ans_cnt = 0, 0
|
| 167 |
+
for qid in qid_list:
|
| 168 |
+
if not qid_to_has_ans[qid]:
|
| 169 |
+
continue
|
| 170 |
+
has_ans_cnt += 1
|
| 171 |
+
|
| 172 |
+
if qid not in scores:
|
| 173 |
+
continue
|
| 174 |
+
has_ans_score += scores[qid]
|
| 175 |
+
|
| 176 |
+
return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
| 180 |
+
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
|
| 181 |
+
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
|
| 182 |
+
main_eval["best_exact"] = best_exact
|
| 183 |
+
main_eval["best_exact_thresh"] = exact_thresh
|
| 184 |
+
main_eval["best_f1"] = best_f1
|
| 185 |
+
main_eval["best_f1_thresh"] = f1_thresh
|
| 186 |
+
main_eval["has_ans_exact"] = has_ans_exact
|
| 187 |
+
main_eval["has_ans_f1"] = has_ans_f1
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
| 191 |
+
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
| 192 |
+
cur_score = num_no_ans
|
| 193 |
+
best_score = cur_score
|
| 194 |
+
best_thresh = 0.0
|
| 195 |
+
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
| 196 |
+
for _, qid in enumerate(qid_list):
|
| 197 |
+
if qid not in scores:
|
| 198 |
+
continue
|
| 199 |
+
if qid_to_has_ans[qid]:
|
| 200 |
+
diff = scores[qid]
|
| 201 |
+
else:
|
| 202 |
+
if preds[qid]:
|
| 203 |
+
diff = -1
|
| 204 |
+
else:
|
| 205 |
+
diff = 0
|
| 206 |
+
cur_score += diff
|
| 207 |
+
if cur_score > best_score:
|
| 208 |
+
best_score = cur_score
|
| 209 |
+
best_thresh = na_probs[qid]
|
| 210 |
+
return 100.0 * best_score / len(scores), best_thresh
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
| 214 |
+
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
| 215 |
+
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
| 216 |
+
|
| 217 |
+
main_eval["best_exact"] = best_exact
|
| 218 |
+
main_eval["best_exact_thresh"] = exact_thresh
|
| 219 |
+
main_eval["best_f1"] = best_f1
|
| 220 |
+
main_eval["best_f1_thresh"] = f1_thresh
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
|
| 224 |
+
qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}
|
| 225 |
+
has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer]
|
| 226 |
+
no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer]
|
| 227 |
+
|
| 228 |
+
if no_answer_probs is None:
|
| 229 |
+
no_answer_probs = {k: 0.0 for k in preds}
|
| 230 |
+
|
| 231 |
+
exact, f1 = get_raw_scores(examples, preds)
|
| 232 |
+
|
| 233 |
+
exact_threshold = apply_no_ans_threshold(
|
| 234 |
+
exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
|
| 235 |
+
)
|
| 236 |
+
f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
|
| 237 |
+
|
| 238 |
+
evaluation = make_eval_dict(exact_threshold, f1_threshold)
|
| 239 |
+
|
| 240 |
+
if has_answer_qids:
|
| 241 |
+
has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
|
| 242 |
+
merge_eval(evaluation, has_ans_eval, "HasAns")
|
| 243 |
+
|
| 244 |
+
if no_answer_qids:
|
| 245 |
+
no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
|
| 246 |
+
merge_eval(evaluation, no_ans_eval, "NoAns")
|
| 247 |
+
|
| 248 |
+
if no_answer_probs:
|
| 249 |
+
find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
|
| 250 |
+
|
| 251 |
+
return evaluation
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
| 255 |
+
"""Project the tokenized prediction back to the original text."""
|
| 256 |
+
|
| 257 |
+
# When we created the data, we kept track of the alignment between original
|
| 258 |
+
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
|
| 259 |
+
# now `orig_text` contains the span of our original text corresponding to the
|
| 260 |
+
# span that we predicted.
|
| 261 |
+
#
|
| 262 |
+
# However, `orig_text` may contain extra characters that we don't want in
|
| 263 |
+
# our prediction.
|
| 264 |
+
#
|
| 265 |
+
# For example, let's say:
|
| 266 |
+
# pred_text = steve smith
|
| 267 |
+
# orig_text = Steve Smith's
|
| 268 |
+
#
|
| 269 |
+
# We don't want to return `orig_text` because it contains the extra "'s".
|
| 270 |
+
#
|
| 271 |
+
# We don't want to return `pred_text` because it's already been normalized
|
| 272 |
+
# (the SQuAD eval script also does punctuation stripping/lower casing but
|
| 273 |
+
# our tokenizer does additional normalization like stripping accent
|
| 274 |
+
# characters).
|
| 275 |
+
#
|
| 276 |
+
# What we really want to return is "Steve Smith".
|
| 277 |
+
#
|
| 278 |
+
# Therefore, we have to apply a semi-complicated alignment heuristic between
|
| 279 |
+
# `pred_text` and `orig_text` to get a character-to-character alignment. This
|
| 280 |
+
# can fail in certain cases in which case we just return `orig_text`.
|
| 281 |
+
|
| 282 |
+
def _strip_spaces(text):
|
| 283 |
+
ns_chars = []
|
| 284 |
+
ns_to_s_map = collections.OrderedDict()
|
| 285 |
+
for i, c in enumerate(text):
|
| 286 |
+
if c == " ":
|
| 287 |
+
continue
|
| 288 |
+
ns_to_s_map[len(ns_chars)] = i
|
| 289 |
+
ns_chars.append(c)
|
| 290 |
+
ns_text = "".join(ns_chars)
|
| 291 |
+
return (ns_text, ns_to_s_map)
|
| 292 |
+
|
| 293 |
+
# We first tokenize `orig_text`, strip whitespace from the result
|
| 294 |
+
# and `pred_text`, and check if they are the same length. If they are
|
| 295 |
+
# NOT the same length, the heuristic has failed. If they are the same
|
| 296 |
+
# length, we assume the characters are one-to-one aligned.
|
| 297 |
+
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
| 298 |
+
|
| 299 |
+
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
| 300 |
+
|
| 301 |
+
start_position = tok_text.find(pred_text)
|
| 302 |
+
if start_position == -1:
|
| 303 |
+
if verbose_logging:
|
| 304 |
+
logger.info(f"Unable to find text: '{pred_text}' in '{orig_text}'")
|
| 305 |
+
return orig_text
|
| 306 |
+
end_position = start_position + len(pred_text) - 1
|
| 307 |
+
|
| 308 |
+
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
| 309 |
+
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
| 310 |
+
|
| 311 |
+
if len(orig_ns_text) != len(tok_ns_text):
|
| 312 |
+
if verbose_logging:
|
| 313 |
+
logger.info(f"Length not equal after stripping spaces: '{orig_ns_text}' vs '{tok_ns_text}'")
|
| 314 |
+
return orig_text
|
| 315 |
+
|
| 316 |
+
# We then project the characters in `pred_text` back to `orig_text` using
|
| 317 |
+
# the character-to-character alignment.
|
| 318 |
+
tok_s_to_ns_map = {}
|
| 319 |
+
for i, tok_index in tok_ns_to_s_map.items():
|
| 320 |
+
tok_s_to_ns_map[tok_index] = i
|
| 321 |
+
|
| 322 |
+
orig_start_position = None
|
| 323 |
+
if start_position in tok_s_to_ns_map:
|
| 324 |
+
ns_start_position = tok_s_to_ns_map[start_position]
|
| 325 |
+
if ns_start_position in orig_ns_to_s_map:
|
| 326 |
+
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
| 327 |
+
|
| 328 |
+
if orig_start_position is None:
|
| 329 |
+
if verbose_logging:
|
| 330 |
+
logger.info("Couldn't map start position")
|
| 331 |
+
return orig_text
|
| 332 |
+
|
| 333 |
+
orig_end_position = None
|
| 334 |
+
if end_position in tok_s_to_ns_map:
|
| 335 |
+
ns_end_position = tok_s_to_ns_map[end_position]
|
| 336 |
+
if ns_end_position in orig_ns_to_s_map:
|
| 337 |
+
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
| 338 |
+
|
| 339 |
+
if orig_end_position is None:
|
| 340 |
+
if verbose_logging:
|
| 341 |
+
logger.info("Couldn't map end position")
|
| 342 |
+
return orig_text
|
| 343 |
+
|
| 344 |
+
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
| 345 |
+
return output_text
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def _get_best_indexes(logits, n_best_size):
|
| 349 |
+
"""Get the n-best logits from a list."""
|
| 350 |
+
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
| 351 |
+
|
| 352 |
+
best_indexes = []
|
| 353 |
+
for i in range(len(index_and_score)):
|
| 354 |
+
if i >= n_best_size:
|
| 355 |
+
break
|
| 356 |
+
best_indexes.append(index_and_score[i][0])
|
| 357 |
+
return best_indexes
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def _compute_softmax(scores):
|
| 361 |
+
"""Compute softmax probability over raw logits."""
|
| 362 |
+
if not scores:
|
| 363 |
+
return []
|
| 364 |
+
|
| 365 |
+
max_score = None
|
| 366 |
+
for score in scores:
|
| 367 |
+
if max_score is None or score > max_score:
|
| 368 |
+
max_score = score
|
| 369 |
+
|
| 370 |
+
exp_scores = []
|
| 371 |
+
total_sum = 0.0
|
| 372 |
+
for score in scores:
|
| 373 |
+
x = math.exp(score - max_score)
|
| 374 |
+
exp_scores.append(x)
|
| 375 |
+
total_sum += x
|
| 376 |
+
|
| 377 |
+
probs = []
|
| 378 |
+
for score in exp_scores:
|
| 379 |
+
probs.append(score / total_sum)
|
| 380 |
+
return probs
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def compute_predictions_logits(
|
| 384 |
+
all_examples,
|
| 385 |
+
all_features,
|
| 386 |
+
all_results,
|
| 387 |
+
n_best_size,
|
| 388 |
+
max_answer_length,
|
| 389 |
+
do_lower_case,
|
| 390 |
+
output_prediction_file,
|
| 391 |
+
output_nbest_file,
|
| 392 |
+
output_null_log_odds_file,
|
| 393 |
+
verbose_logging,
|
| 394 |
+
version_2_with_negative,
|
| 395 |
+
null_score_diff_threshold,
|
| 396 |
+
tokenizer,
|
| 397 |
+
):
|
| 398 |
+
"""Write final predictions to the json file and log-odds of null if needed."""
|
| 399 |
+
if output_prediction_file:
|
| 400 |
+
logger.info(f"Writing predictions to: {output_prediction_file}")
|
| 401 |
+
if output_nbest_file:
|
| 402 |
+
logger.info(f"Writing nbest to: {output_nbest_file}")
|
| 403 |
+
if output_null_log_odds_file and version_2_with_negative:
|
| 404 |
+
logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}")
|
| 405 |
+
|
| 406 |
+
example_index_to_features = collections.defaultdict(list)
|
| 407 |
+
for feature in all_features:
|
| 408 |
+
example_index_to_features[feature.example_index].append(feature)
|
| 409 |
+
|
| 410 |
+
unique_id_to_result = {}
|
| 411 |
+
for result in all_results:
|
| 412 |
+
unique_id_to_result[result.unique_id] = result
|
| 413 |
+
|
| 414 |
+
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
| 415 |
+
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
all_predictions = collections.OrderedDict()
|
| 419 |
+
all_nbest_json = collections.OrderedDict()
|
| 420 |
+
scores_diff_json = collections.OrderedDict()
|
| 421 |
+
|
| 422 |
+
for example_index, example in enumerate(all_examples):
|
| 423 |
+
features = example_index_to_features[example_index]
|
| 424 |
+
|
| 425 |
+
prelim_predictions = []
|
| 426 |
+
# keep track of the minimum score of null start+end of position 0
|
| 427 |
+
score_null = 1000000 # large and positive
|
| 428 |
+
min_null_feature_index = 0 # the paragraph slice with min null score
|
| 429 |
+
null_start_logit = 0 # the start logit at the slice with min null score
|
| 430 |
+
null_end_logit = 0 # the end logit at the slice with min null score
|
| 431 |
+
for feature_index, feature in enumerate(features):
|
| 432 |
+
result = unique_id_to_result[feature.unique_id]
|
| 433 |
+
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
| 434 |
+
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
| 435 |
+
# if we could have irrelevant answers, get the min score of irrelevant
|
| 436 |
+
if version_2_with_negative:
|
| 437 |
+
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
| 438 |
+
if feature_null_score < score_null:
|
| 439 |
+
score_null = feature_null_score
|
| 440 |
+
min_null_feature_index = feature_index
|
| 441 |
+
null_start_logit = result.start_logits[0]
|
| 442 |
+
null_end_logit = result.end_logits[0]
|
| 443 |
+
for start_index in start_indexes:
|
| 444 |
+
for end_index in end_indexes:
|
| 445 |
+
# We could hypothetically create invalid predictions, e.g., predict
|
| 446 |
+
# that the start of the span is in the question. We throw out all
|
| 447 |
+
# invalid predictions.
|
| 448 |
+
if start_index >= len(feature.tokens):
|
| 449 |
+
continue
|
| 450 |
+
if end_index >= len(feature.tokens):
|
| 451 |
+
continue
|
| 452 |
+
if start_index not in feature.token_to_orig_map:
|
| 453 |
+
continue
|
| 454 |
+
if end_index not in feature.token_to_orig_map:
|
| 455 |
+
continue
|
| 456 |
+
if not feature.token_is_max_context.get(start_index, False):
|
| 457 |
+
continue
|
| 458 |
+
if end_index < start_index:
|
| 459 |
+
continue
|
| 460 |
+
length = end_index - start_index + 1
|
| 461 |
+
if length > max_answer_length:
|
| 462 |
+
continue
|
| 463 |
+
prelim_predictions.append(
|
| 464 |
+
_PrelimPrediction(
|
| 465 |
+
feature_index=feature_index,
|
| 466 |
+
start_index=start_index,
|
| 467 |
+
end_index=end_index,
|
| 468 |
+
start_logit=result.start_logits[start_index],
|
| 469 |
+
end_logit=result.end_logits[end_index],
|
| 470 |
+
)
|
| 471 |
+
)
|
| 472 |
+
if version_2_with_negative:
|
| 473 |
+
prelim_predictions.append(
|
| 474 |
+
_PrelimPrediction(
|
| 475 |
+
feature_index=min_null_feature_index,
|
| 476 |
+
start_index=0,
|
| 477 |
+
end_index=0,
|
| 478 |
+
start_logit=null_start_logit,
|
| 479 |
+
end_logit=null_end_logit,
|
| 480 |
+
)
|
| 481 |
+
)
|
| 482 |
+
prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
|
| 483 |
+
|
| 484 |
+
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
| 485 |
+
"NbestPrediction", ["text", "start_logit", "end_logit"]
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
seen_predictions = {}
|
| 489 |
+
nbest = []
|
| 490 |
+
for pred in prelim_predictions:
|
| 491 |
+
if len(nbest) >= n_best_size:
|
| 492 |
+
break
|
| 493 |
+
feature = features[pred.feature_index]
|
| 494 |
+
if pred.start_index > 0: # this is a non-null prediction
|
| 495 |
+
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
| 496 |
+
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
| 497 |
+
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
| 498 |
+
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
| 499 |
+
|
| 500 |
+
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
| 501 |
+
|
| 502 |
+
# tok_text = " ".join(tok_tokens)
|
| 503 |
+
#
|
| 504 |
+
# # De-tokenize WordPieces that have been split off.
|
| 505 |
+
# tok_text = tok_text.replace(" ##", "")
|
| 506 |
+
# tok_text = tok_text.replace("##", "")
|
| 507 |
+
|
| 508 |
+
# Clean whitespace
|
| 509 |
+
tok_text = tok_text.strip()
|
| 510 |
+
tok_text = " ".join(tok_text.split())
|
| 511 |
+
orig_text = " ".join(orig_tokens)
|
| 512 |
+
|
| 513 |
+
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
| 514 |
+
if final_text in seen_predictions:
|
| 515 |
+
continue
|
| 516 |
+
|
| 517 |
+
seen_predictions[final_text] = True
|
| 518 |
+
else:
|
| 519 |
+
final_text = ""
|
| 520 |
+
seen_predictions[final_text] = True
|
| 521 |
+
|
| 522 |
+
nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
|
| 523 |
+
# if we didn't include the empty option in the n-best, include it
|
| 524 |
+
if version_2_with_negative:
|
| 525 |
+
if "" not in seen_predictions:
|
| 526 |
+
nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
|
| 527 |
+
|
| 528 |
+
# In very rare edge cases we could only have single null prediction.
|
| 529 |
+
# So we just create a nonce prediction in this case to avoid failure.
|
| 530 |
+
if len(nbest) == 1:
|
| 531 |
+
nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
| 532 |
+
|
| 533 |
+
# In very rare edge cases we could have no valid predictions. So we
|
| 534 |
+
# just create a nonce prediction in this case to avoid failure.
|
| 535 |
+
if not nbest:
|
| 536 |
+
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
| 537 |
+
|
| 538 |
+
if len(nbest) < 1:
|
| 539 |
+
raise ValueError("No valid predictions")
|
| 540 |
+
|
| 541 |
+
total_scores = []
|
| 542 |
+
best_non_null_entry = None
|
| 543 |
+
for entry in nbest:
|
| 544 |
+
total_scores.append(entry.start_logit + entry.end_logit)
|
| 545 |
+
if not best_non_null_entry:
|
| 546 |
+
if entry.text:
|
| 547 |
+
best_non_null_entry = entry
|
| 548 |
+
|
| 549 |
+
probs = _compute_softmax(total_scores)
|
| 550 |
+
|
| 551 |
+
nbest_json = []
|
| 552 |
+
for i, entry in enumerate(nbest):
|
| 553 |
+
output = collections.OrderedDict()
|
| 554 |
+
output["text"] = entry.text
|
| 555 |
+
output["probability"] = probs[i]
|
| 556 |
+
output["start_logit"] = entry.start_logit
|
| 557 |
+
output["end_logit"] = entry.end_logit
|
| 558 |
+
nbest_json.append(output)
|
| 559 |
+
|
| 560 |
+
if len(nbest_json) < 1:
|
| 561 |
+
raise ValueError("No valid predictions")
|
| 562 |
+
|
| 563 |
+
if not version_2_with_negative:
|
| 564 |
+
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
| 565 |
+
else:
|
| 566 |
+
# predict "" iff the null score - the score of best non-null > threshold
|
| 567 |
+
score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
|
| 568 |
+
scores_diff_json[example.qas_id] = score_diff
|
| 569 |
+
if score_diff > null_score_diff_threshold:
|
| 570 |
+
all_predictions[example.qas_id] = ""
|
| 571 |
+
else:
|
| 572 |
+
all_predictions[example.qas_id] = best_non_null_entry.text
|
| 573 |
+
all_nbest_json[example.qas_id] = nbest_json
|
| 574 |
+
|
| 575 |
+
if output_prediction_file:
|
| 576 |
+
with open(output_prediction_file, "w") as writer:
|
| 577 |
+
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
| 578 |
+
|
| 579 |
+
if output_nbest_file:
|
| 580 |
+
with open(output_nbest_file, "w") as writer:
|
| 581 |
+
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
| 582 |
+
|
| 583 |
+
if output_null_log_odds_file and version_2_with_negative:
|
| 584 |
+
with open(output_null_log_odds_file, "w") as writer:
|
| 585 |
+
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
| 586 |
+
|
| 587 |
+
return all_predictions
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def compute_predictions_log_probs(
|
| 591 |
+
all_examples,
|
| 592 |
+
all_features,
|
| 593 |
+
all_results,
|
| 594 |
+
n_best_size,
|
| 595 |
+
max_answer_length,
|
| 596 |
+
output_prediction_file,
|
| 597 |
+
output_nbest_file,
|
| 598 |
+
output_null_log_odds_file,
|
| 599 |
+
start_n_top,
|
| 600 |
+
end_n_top,
|
| 601 |
+
version_2_with_negative,
|
| 602 |
+
tokenizer,
|
| 603 |
+
verbose_logging,
|
| 604 |
+
):
|
| 605 |
+
"""
|
| 606 |
+
XLNet write prediction logic (more complex than Bert's). Write final predictions to the json file and log-odds of
|
| 607 |
+
null if needed.
|
| 608 |
+
|
| 609 |
+
Requires utils_squad_evaluate.py
|
| 610 |
+
"""
|
| 611 |
+
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
| 612 |
+
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
| 616 |
+
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
logger.info(f"Writing predictions to: {output_prediction_file}")
|
| 620 |
+
|
| 621 |
+
example_index_to_features = collections.defaultdict(list)
|
| 622 |
+
for feature in all_features:
|
| 623 |
+
example_index_to_features[feature.example_index].append(feature)
|
| 624 |
+
|
| 625 |
+
unique_id_to_result = {}
|
| 626 |
+
for result in all_results:
|
| 627 |
+
unique_id_to_result[result.unique_id] = result
|
| 628 |
+
|
| 629 |
+
all_predictions = collections.OrderedDict()
|
| 630 |
+
all_nbest_json = collections.OrderedDict()
|
| 631 |
+
scores_diff_json = collections.OrderedDict()
|
| 632 |
+
|
| 633 |
+
for example_index, example in enumerate(all_examples):
|
| 634 |
+
features = example_index_to_features[example_index]
|
| 635 |
+
|
| 636 |
+
prelim_predictions = []
|
| 637 |
+
# keep track of the minimum score of null start+end of position 0
|
| 638 |
+
score_null = 1000000 # large and positive
|
| 639 |
+
|
| 640 |
+
for feature_index, feature in enumerate(features):
|
| 641 |
+
result = unique_id_to_result[feature.unique_id]
|
| 642 |
+
|
| 643 |
+
cur_null_score = result.cls_logits
|
| 644 |
+
|
| 645 |
+
# if we could have irrelevant answers, get the min score of irrelevant
|
| 646 |
+
score_null = min(score_null, cur_null_score)
|
| 647 |
+
|
| 648 |
+
for i in range(start_n_top):
|
| 649 |
+
for j in range(end_n_top):
|
| 650 |
+
start_log_prob = result.start_logits[i]
|
| 651 |
+
start_index = result.start_top_index[i]
|
| 652 |
+
|
| 653 |
+
j_index = i * end_n_top + j
|
| 654 |
+
|
| 655 |
+
end_log_prob = result.end_logits[j_index]
|
| 656 |
+
end_index = result.end_top_index[j_index]
|
| 657 |
+
|
| 658 |
+
# We could hypothetically create invalid predictions, e.g., predict
|
| 659 |
+
# that the start of the span is in the question. We throw out all
|
| 660 |
+
# invalid predictions.
|
| 661 |
+
if start_index >= feature.paragraph_len - 1:
|
| 662 |
+
continue
|
| 663 |
+
if end_index >= feature.paragraph_len - 1:
|
| 664 |
+
continue
|
| 665 |
+
|
| 666 |
+
if not feature.token_is_max_context.get(start_index, False):
|
| 667 |
+
continue
|
| 668 |
+
if end_index < start_index:
|
| 669 |
+
continue
|
| 670 |
+
length = end_index - start_index + 1
|
| 671 |
+
if length > max_answer_length:
|
| 672 |
+
continue
|
| 673 |
+
|
| 674 |
+
prelim_predictions.append(
|
| 675 |
+
_PrelimPrediction(
|
| 676 |
+
feature_index=feature_index,
|
| 677 |
+
start_index=start_index,
|
| 678 |
+
end_index=end_index,
|
| 679 |
+
start_log_prob=start_log_prob,
|
| 680 |
+
end_log_prob=end_log_prob,
|
| 681 |
+
)
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
prelim_predictions = sorted(
|
| 685 |
+
prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
seen_predictions = {}
|
| 689 |
+
nbest = []
|
| 690 |
+
for pred in prelim_predictions:
|
| 691 |
+
if len(nbest) >= n_best_size:
|
| 692 |
+
break
|
| 693 |
+
feature = features[pred.feature_index]
|
| 694 |
+
|
| 695 |
+
# XLNet un-tokenizer
|
| 696 |
+
# Let's keep it simple for now and see if we need all this later.
|
| 697 |
+
#
|
| 698 |
+
# tok_start_to_orig_index = feature.tok_start_to_orig_index
|
| 699 |
+
# tok_end_to_orig_index = feature.tok_end_to_orig_index
|
| 700 |
+
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
| 701 |
+
# end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
| 702 |
+
# paragraph_text = example.paragraph_text
|
| 703 |
+
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
| 704 |
+
|
| 705 |
+
# Previously used Bert untokenizer
|
| 706 |
+
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
| 707 |
+
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
| 708 |
+
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
| 709 |
+
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
| 710 |
+
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
| 711 |
+
|
| 712 |
+
# Clean whitespace
|
| 713 |
+
tok_text = tok_text.strip()
|
| 714 |
+
tok_text = " ".join(tok_text.split())
|
| 715 |
+
orig_text = " ".join(orig_tokens)
|
| 716 |
+
|
| 717 |
+
if hasattr(tokenizer, "do_lower_case"):
|
| 718 |
+
do_lower_case = tokenizer.do_lower_case
|
| 719 |
+
else:
|
| 720 |
+
do_lower_case = tokenizer.do_lowercase_and_remove_accent
|
| 721 |
+
|
| 722 |
+
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
| 723 |
+
|
| 724 |
+
if final_text in seen_predictions:
|
| 725 |
+
continue
|
| 726 |
+
|
| 727 |
+
seen_predictions[final_text] = True
|
| 728 |
+
|
| 729 |
+
nbest.append(
|
| 730 |
+
_NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# In very rare edge cases we could have no valid predictions. So we
|
| 734 |
+
# just create a nonce prediction in this case to avoid failure.
|
| 735 |
+
if not nbest:
|
| 736 |
+
nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
|
| 737 |
+
|
| 738 |
+
total_scores = []
|
| 739 |
+
best_non_null_entry = None
|
| 740 |
+
for entry in nbest:
|
| 741 |
+
total_scores.append(entry.start_log_prob + entry.end_log_prob)
|
| 742 |
+
if not best_non_null_entry:
|
| 743 |
+
best_non_null_entry = entry
|
| 744 |
+
|
| 745 |
+
probs = _compute_softmax(total_scores)
|
| 746 |
+
|
| 747 |
+
nbest_json = []
|
| 748 |
+
for i, entry in enumerate(nbest):
|
| 749 |
+
output = collections.OrderedDict()
|
| 750 |
+
output["text"] = entry.text
|
| 751 |
+
output["probability"] = probs[i]
|
| 752 |
+
output["start_log_prob"] = entry.start_log_prob
|
| 753 |
+
output["end_log_prob"] = entry.end_log_prob
|
| 754 |
+
nbest_json.append(output)
|
| 755 |
+
|
| 756 |
+
if len(nbest_json) < 1:
|
| 757 |
+
raise ValueError("No valid predictions")
|
| 758 |
+
if best_non_null_entry is None:
|
| 759 |
+
raise ValueError("No valid predictions")
|
| 760 |
+
|
| 761 |
+
score_diff = score_null
|
| 762 |
+
scores_diff_json[example.qas_id] = score_diff
|
| 763 |
+
# note(zhiliny): always predict best_non_null_entry
|
| 764 |
+
# and the evaluation script will search for the best threshold
|
| 765 |
+
all_predictions[example.qas_id] = best_non_null_entry.text
|
| 766 |
+
|
| 767 |
+
all_nbest_json[example.qas_id] = nbest_json
|
| 768 |
+
|
| 769 |
+
with open(output_prediction_file, "w") as writer:
|
| 770 |
+
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
| 771 |
+
|
| 772 |
+
with open(output_nbest_file, "w") as writer:
|
| 773 |
+
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
| 774 |
+
|
| 775 |
+
if version_2_with_negative:
|
| 776 |
+
with open(output_null_log_odds_file, "w") as writer:
|
| 777 |
+
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
| 778 |
+
|
| 779 |
+
return all_predictions
|
.venv/lib/python3.11/site-packages/transformers/data/processors/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels
|
| 16 |
+
from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
|
| 17 |
+
from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor
|
| 18 |
+
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
.venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (893 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/glue.cpython-311.pyc
ADDED
|
Binary file (35.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/squad.cpython-311.pyc
ADDED
|
Binary file (34.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (19.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/processors/__pycache__/xnli.cpython-311.pyc
ADDED
|
Binary file (4.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/data/processors/glue.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""GLUE processors and helpers"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import warnings
|
| 20 |
+
from dataclasses import asdict
|
| 21 |
+
from enum import Enum
|
| 22 |
+
from typing import List, Optional, Union
|
| 23 |
+
|
| 24 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 25 |
+
from ...utils import is_tf_available, logging
|
| 26 |
+
from .utils import DataProcessor, InputExample, InputFeatures
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if is_tf_available():
|
| 30 |
+
import tensorflow as tf
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
DEPRECATION_WARNING = (
|
| 35 |
+
"This {0} will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
| 36 |
+
"library. You can have a look at this example script for pointers: "
|
| 37 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def glue_convert_examples_to_features(
|
| 42 |
+
examples: Union[List[InputExample], "tf.data.Dataset"],
|
| 43 |
+
tokenizer: PreTrainedTokenizer,
|
| 44 |
+
max_length: Optional[int] = None,
|
| 45 |
+
task=None,
|
| 46 |
+
label_list=None,
|
| 47 |
+
output_mode=None,
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Loads a data file into a list of `InputFeatures`
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
examples: List of `InputExamples` or `tf.data.Dataset` containing the examples.
|
| 54 |
+
tokenizer: Instance of a tokenizer that will tokenize the examples
|
| 55 |
+
max_length: Maximum example length. Defaults to the tokenizer's max_len
|
| 56 |
+
task: GLUE task
|
| 57 |
+
label_list: List of labels. Can be obtained from the processor using the `processor.get_labels()` method
|
| 58 |
+
output_mode: String indicating the output mode. Either `regression` or `classification`
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the task-specific
|
| 62 |
+
features. If the input is a list of `InputExamples`, will return a list of task-specific `InputFeatures` which
|
| 63 |
+
can be fed to the model.
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
warnings.warn(DEPRECATION_WARNING.format("function"), FutureWarning)
|
| 67 |
+
if is_tf_available() and isinstance(examples, tf.data.Dataset):
|
| 68 |
+
if task is None:
|
| 69 |
+
raise ValueError("When calling glue_convert_examples_to_features from TF, the task parameter is required.")
|
| 70 |
+
return _tf_glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
|
| 71 |
+
return _glue_convert_examples_to_features(
|
| 72 |
+
examples, tokenizer, max_length=max_length, task=task, label_list=label_list, output_mode=output_mode
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if is_tf_available():
|
| 77 |
+
|
| 78 |
+
def _tf_glue_convert_examples_to_features(
|
| 79 |
+
examples: tf.data.Dataset,
|
| 80 |
+
tokenizer: PreTrainedTokenizer,
|
| 81 |
+
task=str,
|
| 82 |
+
max_length: Optional[int] = None,
|
| 83 |
+
) -> tf.data.Dataset:
|
| 84 |
+
"""
|
| 85 |
+
Returns:
|
| 86 |
+
A `tf.data.Dataset` containing the task-specific features.
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
processor = glue_processors[task]()
|
| 90 |
+
examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
|
| 91 |
+
features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
|
| 92 |
+
label_type = tf.float32 if task == "sts-b" else tf.int64
|
| 93 |
+
|
| 94 |
+
def gen():
|
| 95 |
+
for ex in features:
|
| 96 |
+
d = {k: v for k, v in asdict(ex).items() if v is not None}
|
| 97 |
+
label = d.pop("label")
|
| 98 |
+
yield (d, label)
|
| 99 |
+
|
| 100 |
+
input_names = tokenizer.model_input_names
|
| 101 |
+
|
| 102 |
+
return tf.data.Dataset.from_generator(
|
| 103 |
+
gen,
|
| 104 |
+
({k: tf.int32 for k in input_names}, label_type),
|
| 105 |
+
({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _glue_convert_examples_to_features(
|
| 110 |
+
examples: List[InputExample],
|
| 111 |
+
tokenizer: PreTrainedTokenizer,
|
| 112 |
+
max_length: Optional[int] = None,
|
| 113 |
+
task=None,
|
| 114 |
+
label_list=None,
|
| 115 |
+
output_mode=None,
|
| 116 |
+
):
|
| 117 |
+
if max_length is None:
|
| 118 |
+
max_length = tokenizer.model_max_length
|
| 119 |
+
|
| 120 |
+
if task is not None:
|
| 121 |
+
processor = glue_processors[task]()
|
| 122 |
+
if label_list is None:
|
| 123 |
+
label_list = processor.get_labels()
|
| 124 |
+
logger.info(f"Using label list {label_list} for task {task}")
|
| 125 |
+
if output_mode is None:
|
| 126 |
+
output_mode = glue_output_modes[task]
|
| 127 |
+
logger.info(f"Using output mode {output_mode} for task {task}")
|
| 128 |
+
|
| 129 |
+
label_map = {label: i for i, label in enumerate(label_list)}
|
| 130 |
+
|
| 131 |
+
def label_from_example(example: InputExample) -> Union[int, float, None]:
|
| 132 |
+
if example.label is None:
|
| 133 |
+
return None
|
| 134 |
+
if output_mode == "classification":
|
| 135 |
+
return label_map[example.label]
|
| 136 |
+
elif output_mode == "regression":
|
| 137 |
+
return float(example.label)
|
| 138 |
+
raise KeyError(output_mode)
|
| 139 |
+
|
| 140 |
+
labels = [label_from_example(example) for example in examples]
|
| 141 |
+
|
| 142 |
+
batch_encoding = tokenizer(
|
| 143 |
+
[(example.text_a, example.text_b) for example in examples],
|
| 144 |
+
max_length=max_length,
|
| 145 |
+
padding="max_length",
|
| 146 |
+
truncation=True,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
features = []
|
| 150 |
+
for i in range(len(examples)):
|
| 151 |
+
inputs = {k: batch_encoding[k][i] for k in batch_encoding}
|
| 152 |
+
|
| 153 |
+
feature = InputFeatures(**inputs, label=labels[i])
|
| 154 |
+
features.append(feature)
|
| 155 |
+
|
| 156 |
+
for i, example in enumerate(examples[:5]):
|
| 157 |
+
logger.info("*** Example ***")
|
| 158 |
+
logger.info(f"guid: {example.guid}")
|
| 159 |
+
logger.info(f"features: {features[i]}")
|
| 160 |
+
|
| 161 |
+
return features
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class OutputMode(Enum):
|
| 165 |
+
classification = "classification"
|
| 166 |
+
regression = "regression"
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class MrpcProcessor(DataProcessor):
|
| 170 |
+
"""Processor for the MRPC data set (GLUE version)."""
|
| 171 |
+
|
| 172 |
+
def __init__(self, *args, **kwargs):
|
| 173 |
+
super().__init__(*args, **kwargs)
|
| 174 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 175 |
+
|
| 176 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 177 |
+
"""See base class."""
|
| 178 |
+
return InputExample(
|
| 179 |
+
tensor_dict["idx"].numpy(),
|
| 180 |
+
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
| 181 |
+
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
| 182 |
+
str(tensor_dict["label"].numpy()),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def get_train_examples(self, data_dir):
|
| 186 |
+
"""See base class."""
|
| 187 |
+
logger.info(f"LOOKING AT {os.path.join(data_dir, 'train.tsv')}")
|
| 188 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 189 |
+
|
| 190 |
+
def get_dev_examples(self, data_dir):
|
| 191 |
+
"""See base class."""
|
| 192 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 193 |
+
|
| 194 |
+
def get_test_examples(self, data_dir):
|
| 195 |
+
"""See base class."""
|
| 196 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 197 |
+
|
| 198 |
+
def get_labels(self):
|
| 199 |
+
"""See base class."""
|
| 200 |
+
return ["0", "1"]
|
| 201 |
+
|
| 202 |
+
def _create_examples(self, lines, set_type):
|
| 203 |
+
"""Creates examples for the training, dev and test sets."""
|
| 204 |
+
examples = []
|
| 205 |
+
for i, line in enumerate(lines):
|
| 206 |
+
if i == 0:
|
| 207 |
+
continue
|
| 208 |
+
guid = f"{set_type}-{i}"
|
| 209 |
+
text_a = line[3]
|
| 210 |
+
text_b = line[4]
|
| 211 |
+
label = None if set_type == "test" else line[0]
|
| 212 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 213 |
+
return examples
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class MnliProcessor(DataProcessor):
|
| 217 |
+
"""Processor for the MultiNLI data set (GLUE version)."""
|
| 218 |
+
|
| 219 |
+
def __init__(self, *args, **kwargs):
|
| 220 |
+
super().__init__(*args, **kwargs)
|
| 221 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 222 |
+
|
| 223 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 224 |
+
"""See base class."""
|
| 225 |
+
return InputExample(
|
| 226 |
+
tensor_dict["idx"].numpy(),
|
| 227 |
+
tensor_dict["premise"].numpy().decode("utf-8"),
|
| 228 |
+
tensor_dict["hypothesis"].numpy().decode("utf-8"),
|
| 229 |
+
str(tensor_dict["label"].numpy()),
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def get_train_examples(self, data_dir):
|
| 233 |
+
"""See base class."""
|
| 234 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 235 |
+
|
| 236 |
+
def get_dev_examples(self, data_dir):
|
| 237 |
+
"""See base class."""
|
| 238 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
|
| 239 |
+
|
| 240 |
+
def get_test_examples(self, data_dir):
|
| 241 |
+
"""See base class."""
|
| 242 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
|
| 243 |
+
|
| 244 |
+
def get_labels(self):
|
| 245 |
+
"""See base class."""
|
| 246 |
+
return ["contradiction", "entailment", "neutral"]
|
| 247 |
+
|
| 248 |
+
def _create_examples(self, lines, set_type):
|
| 249 |
+
"""Creates examples for the training, dev and test sets."""
|
| 250 |
+
examples = []
|
| 251 |
+
for i, line in enumerate(lines):
|
| 252 |
+
if i == 0:
|
| 253 |
+
continue
|
| 254 |
+
guid = f"{set_type}-{line[0]}"
|
| 255 |
+
text_a = line[8]
|
| 256 |
+
text_b = line[9]
|
| 257 |
+
label = None if set_type.startswith("test") else line[-1]
|
| 258 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 259 |
+
return examples
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class MnliMismatchedProcessor(MnliProcessor):
|
| 263 |
+
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
|
| 264 |
+
|
| 265 |
+
def __init__(self, *args, **kwargs):
|
| 266 |
+
super().__init__(*args, **kwargs)
|
| 267 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 268 |
+
|
| 269 |
+
def get_dev_examples(self, data_dir):
|
| 270 |
+
"""See base class."""
|
| 271 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
|
| 272 |
+
|
| 273 |
+
def get_test_examples(self, data_dir):
|
| 274 |
+
"""See base class."""
|
| 275 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class ColaProcessor(DataProcessor):
|
| 279 |
+
"""Processor for the CoLA data set (GLUE version)."""
|
| 280 |
+
|
| 281 |
+
def __init__(self, *args, **kwargs):
|
| 282 |
+
super().__init__(*args, **kwargs)
|
| 283 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 284 |
+
|
| 285 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 286 |
+
"""See base class."""
|
| 287 |
+
return InputExample(
|
| 288 |
+
tensor_dict["idx"].numpy(),
|
| 289 |
+
tensor_dict["sentence"].numpy().decode("utf-8"),
|
| 290 |
+
None,
|
| 291 |
+
str(tensor_dict["label"].numpy()),
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
def get_train_examples(self, data_dir):
|
| 295 |
+
"""See base class."""
|
| 296 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 297 |
+
|
| 298 |
+
def get_dev_examples(self, data_dir):
|
| 299 |
+
"""See base class."""
|
| 300 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 301 |
+
|
| 302 |
+
def get_test_examples(self, data_dir):
|
| 303 |
+
"""See base class."""
|
| 304 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 305 |
+
|
| 306 |
+
def get_labels(self):
|
| 307 |
+
"""See base class."""
|
| 308 |
+
return ["0", "1"]
|
| 309 |
+
|
| 310 |
+
def _create_examples(self, lines, set_type):
|
| 311 |
+
"""Creates examples for the training, dev and test sets."""
|
| 312 |
+
test_mode = set_type == "test"
|
| 313 |
+
if test_mode:
|
| 314 |
+
lines = lines[1:]
|
| 315 |
+
text_index = 1 if test_mode else 3
|
| 316 |
+
examples = []
|
| 317 |
+
for i, line in enumerate(lines):
|
| 318 |
+
guid = f"{set_type}-{i}"
|
| 319 |
+
text_a = line[text_index]
|
| 320 |
+
label = None if test_mode else line[1]
|
| 321 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
| 322 |
+
return examples
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class Sst2Processor(DataProcessor):
|
| 326 |
+
"""Processor for the SST-2 data set (GLUE version)."""
|
| 327 |
+
|
| 328 |
+
def __init__(self, *args, **kwargs):
|
| 329 |
+
super().__init__(*args, **kwargs)
|
| 330 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 331 |
+
|
| 332 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 333 |
+
"""See base class."""
|
| 334 |
+
return InputExample(
|
| 335 |
+
tensor_dict["idx"].numpy(),
|
| 336 |
+
tensor_dict["sentence"].numpy().decode("utf-8"),
|
| 337 |
+
None,
|
| 338 |
+
str(tensor_dict["label"].numpy()),
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def get_train_examples(self, data_dir):
|
| 342 |
+
"""See base class."""
|
| 343 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 344 |
+
|
| 345 |
+
def get_dev_examples(self, data_dir):
|
| 346 |
+
"""See base class."""
|
| 347 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 348 |
+
|
| 349 |
+
def get_test_examples(self, data_dir):
|
| 350 |
+
"""See base class."""
|
| 351 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 352 |
+
|
| 353 |
+
def get_labels(self):
|
| 354 |
+
"""See base class."""
|
| 355 |
+
return ["0", "1"]
|
| 356 |
+
|
| 357 |
+
def _create_examples(self, lines, set_type):
|
| 358 |
+
"""Creates examples for the training, dev and test sets."""
|
| 359 |
+
examples = []
|
| 360 |
+
text_index = 1 if set_type == "test" else 0
|
| 361 |
+
for i, line in enumerate(lines):
|
| 362 |
+
if i == 0:
|
| 363 |
+
continue
|
| 364 |
+
guid = f"{set_type}-{i}"
|
| 365 |
+
text_a = line[text_index]
|
| 366 |
+
label = None if set_type == "test" else line[1]
|
| 367 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
| 368 |
+
return examples
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class StsbProcessor(DataProcessor):
|
| 372 |
+
"""Processor for the STS-B data set (GLUE version)."""
|
| 373 |
+
|
| 374 |
+
def __init__(self, *args, **kwargs):
|
| 375 |
+
super().__init__(*args, **kwargs)
|
| 376 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 377 |
+
|
| 378 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 379 |
+
"""See base class."""
|
| 380 |
+
return InputExample(
|
| 381 |
+
tensor_dict["idx"].numpy(),
|
| 382 |
+
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
| 383 |
+
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
| 384 |
+
str(tensor_dict["label"].numpy()),
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
def get_train_examples(self, data_dir):
|
| 388 |
+
"""See base class."""
|
| 389 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 390 |
+
|
| 391 |
+
def get_dev_examples(self, data_dir):
|
| 392 |
+
"""See base class."""
|
| 393 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 394 |
+
|
| 395 |
+
def get_test_examples(self, data_dir):
|
| 396 |
+
"""See base class."""
|
| 397 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 398 |
+
|
| 399 |
+
def get_labels(self):
|
| 400 |
+
"""See base class."""
|
| 401 |
+
return [None]
|
| 402 |
+
|
| 403 |
+
def _create_examples(self, lines, set_type):
|
| 404 |
+
"""Creates examples for the training, dev and test sets."""
|
| 405 |
+
examples = []
|
| 406 |
+
for i, line in enumerate(lines):
|
| 407 |
+
if i == 0:
|
| 408 |
+
continue
|
| 409 |
+
guid = f"{set_type}-{line[0]}"
|
| 410 |
+
text_a = line[7]
|
| 411 |
+
text_b = line[8]
|
| 412 |
+
label = None if set_type == "test" else line[-1]
|
| 413 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 414 |
+
return examples
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class QqpProcessor(DataProcessor):
|
| 418 |
+
"""Processor for the QQP data set (GLUE version)."""
|
| 419 |
+
|
| 420 |
+
def __init__(self, *args, **kwargs):
|
| 421 |
+
super().__init__(*args, **kwargs)
|
| 422 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 423 |
+
|
| 424 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 425 |
+
"""See base class."""
|
| 426 |
+
return InputExample(
|
| 427 |
+
tensor_dict["idx"].numpy(),
|
| 428 |
+
tensor_dict["question1"].numpy().decode("utf-8"),
|
| 429 |
+
tensor_dict["question2"].numpy().decode("utf-8"),
|
| 430 |
+
str(tensor_dict["label"].numpy()),
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
def get_train_examples(self, data_dir):
|
| 434 |
+
"""See base class."""
|
| 435 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 436 |
+
|
| 437 |
+
def get_dev_examples(self, data_dir):
|
| 438 |
+
"""See base class."""
|
| 439 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 440 |
+
|
| 441 |
+
def get_test_examples(self, data_dir):
|
| 442 |
+
"""See base class."""
|
| 443 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 444 |
+
|
| 445 |
+
def get_labels(self):
|
| 446 |
+
"""See base class."""
|
| 447 |
+
return ["0", "1"]
|
| 448 |
+
|
| 449 |
+
def _create_examples(self, lines, set_type):
|
| 450 |
+
"""Creates examples for the training, dev and test sets."""
|
| 451 |
+
test_mode = set_type == "test"
|
| 452 |
+
q1_index = 1 if test_mode else 3
|
| 453 |
+
q2_index = 2 if test_mode else 4
|
| 454 |
+
examples = []
|
| 455 |
+
for i, line in enumerate(lines):
|
| 456 |
+
if i == 0:
|
| 457 |
+
continue
|
| 458 |
+
guid = f"{set_type}-{line[0]}"
|
| 459 |
+
try:
|
| 460 |
+
text_a = line[q1_index]
|
| 461 |
+
text_b = line[q2_index]
|
| 462 |
+
label = None if test_mode else line[5]
|
| 463 |
+
except IndexError:
|
| 464 |
+
continue
|
| 465 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 466 |
+
return examples
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class QnliProcessor(DataProcessor):
|
| 470 |
+
"""Processor for the QNLI data set (GLUE version)."""
|
| 471 |
+
|
| 472 |
+
def __init__(self, *args, **kwargs):
|
| 473 |
+
super().__init__(*args, **kwargs)
|
| 474 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 475 |
+
|
| 476 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 477 |
+
"""See base class."""
|
| 478 |
+
return InputExample(
|
| 479 |
+
tensor_dict["idx"].numpy(),
|
| 480 |
+
tensor_dict["question"].numpy().decode("utf-8"),
|
| 481 |
+
tensor_dict["sentence"].numpy().decode("utf-8"),
|
| 482 |
+
str(tensor_dict["label"].numpy()),
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
def get_train_examples(self, data_dir):
|
| 486 |
+
"""See base class."""
|
| 487 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 488 |
+
|
| 489 |
+
def get_dev_examples(self, data_dir):
|
| 490 |
+
"""See base class."""
|
| 491 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 492 |
+
|
| 493 |
+
def get_test_examples(self, data_dir):
|
| 494 |
+
"""See base class."""
|
| 495 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 496 |
+
|
| 497 |
+
def get_labels(self):
|
| 498 |
+
"""See base class."""
|
| 499 |
+
return ["entailment", "not_entailment"]
|
| 500 |
+
|
| 501 |
+
def _create_examples(self, lines, set_type):
|
| 502 |
+
"""Creates examples for the training, dev and test sets."""
|
| 503 |
+
examples = []
|
| 504 |
+
for i, line in enumerate(lines):
|
| 505 |
+
if i == 0:
|
| 506 |
+
continue
|
| 507 |
+
guid = f"{set_type}-{line[0]}"
|
| 508 |
+
text_a = line[1]
|
| 509 |
+
text_b = line[2]
|
| 510 |
+
label = None if set_type == "test" else line[-1]
|
| 511 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 512 |
+
return examples
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class RteProcessor(DataProcessor):
|
| 516 |
+
"""Processor for the RTE data set (GLUE version)."""
|
| 517 |
+
|
| 518 |
+
def __init__(self, *args, **kwargs):
|
| 519 |
+
super().__init__(*args, **kwargs)
|
| 520 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 521 |
+
|
| 522 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 523 |
+
"""See base class."""
|
| 524 |
+
return InputExample(
|
| 525 |
+
tensor_dict["idx"].numpy(),
|
| 526 |
+
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
| 527 |
+
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
| 528 |
+
str(tensor_dict["label"].numpy()),
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
def get_train_examples(self, data_dir):
|
| 532 |
+
"""See base class."""
|
| 533 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 534 |
+
|
| 535 |
+
def get_dev_examples(self, data_dir):
|
| 536 |
+
"""See base class."""
|
| 537 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 538 |
+
|
| 539 |
+
def get_test_examples(self, data_dir):
|
| 540 |
+
"""See base class."""
|
| 541 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 542 |
+
|
| 543 |
+
def get_labels(self):
|
| 544 |
+
"""See base class."""
|
| 545 |
+
return ["entailment", "not_entailment"]
|
| 546 |
+
|
| 547 |
+
def _create_examples(self, lines, set_type):
|
| 548 |
+
"""Creates examples for the training, dev and test sets."""
|
| 549 |
+
examples = []
|
| 550 |
+
for i, line in enumerate(lines):
|
| 551 |
+
if i == 0:
|
| 552 |
+
continue
|
| 553 |
+
guid = f"{set_type}-{line[0]}"
|
| 554 |
+
text_a = line[1]
|
| 555 |
+
text_b = line[2]
|
| 556 |
+
label = None if set_type == "test" else line[-1]
|
| 557 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 558 |
+
return examples
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
class WnliProcessor(DataProcessor):
|
| 562 |
+
"""Processor for the WNLI data set (GLUE version)."""
|
| 563 |
+
|
| 564 |
+
def __init__(self, *args, **kwargs):
|
| 565 |
+
super().__init__(*args, **kwargs)
|
| 566 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 567 |
+
|
| 568 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 569 |
+
"""See base class."""
|
| 570 |
+
return InputExample(
|
| 571 |
+
tensor_dict["idx"].numpy(),
|
| 572 |
+
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
| 573 |
+
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
| 574 |
+
str(tensor_dict["label"].numpy()),
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
def get_train_examples(self, data_dir):
|
| 578 |
+
"""See base class."""
|
| 579 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 580 |
+
|
| 581 |
+
def get_dev_examples(self, data_dir):
|
| 582 |
+
"""See base class."""
|
| 583 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 584 |
+
|
| 585 |
+
def get_test_examples(self, data_dir):
|
| 586 |
+
"""See base class."""
|
| 587 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 588 |
+
|
| 589 |
+
def get_labels(self):
|
| 590 |
+
"""See base class."""
|
| 591 |
+
return ["0", "1"]
|
| 592 |
+
|
| 593 |
+
def _create_examples(self, lines, set_type):
|
| 594 |
+
"""Creates examples for the training, dev and test sets."""
|
| 595 |
+
examples = []
|
| 596 |
+
for i, line in enumerate(lines):
|
| 597 |
+
if i == 0:
|
| 598 |
+
continue
|
| 599 |
+
guid = f"{set_type}-{line[0]}"
|
| 600 |
+
text_a = line[1]
|
| 601 |
+
text_b = line[2]
|
| 602 |
+
label = None if set_type == "test" else line[-1]
|
| 603 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 604 |
+
return examples
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
glue_tasks_num_labels = {
|
| 608 |
+
"cola": 2,
|
| 609 |
+
"mnli": 3,
|
| 610 |
+
"mrpc": 2,
|
| 611 |
+
"sst-2": 2,
|
| 612 |
+
"sts-b": 1,
|
| 613 |
+
"qqp": 2,
|
| 614 |
+
"qnli": 2,
|
| 615 |
+
"rte": 2,
|
| 616 |
+
"wnli": 2,
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
glue_processors = {
|
| 620 |
+
"cola": ColaProcessor,
|
| 621 |
+
"mnli": MnliProcessor,
|
| 622 |
+
"mnli-mm": MnliMismatchedProcessor,
|
| 623 |
+
"mrpc": MrpcProcessor,
|
| 624 |
+
"sst-2": Sst2Processor,
|
| 625 |
+
"sts-b": StsbProcessor,
|
| 626 |
+
"qqp": QqpProcessor,
|
| 627 |
+
"qnli": QnliProcessor,
|
| 628 |
+
"rte": RteProcessor,
|
| 629 |
+
"wnli": WnliProcessor,
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
glue_output_modes = {
|
| 633 |
+
"cola": "classification",
|
| 634 |
+
"mnli": "classification",
|
| 635 |
+
"mnli-mm": "classification",
|
| 636 |
+
"mrpc": "classification",
|
| 637 |
+
"sst-2": "classification",
|
| 638 |
+
"sts-b": "regression",
|
| 639 |
+
"qqp": "classification",
|
| 640 |
+
"qnli": "classification",
|
| 641 |
+
"rte": "classification",
|
| 642 |
+
"wnli": "classification",
|
| 643 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/data/processors/squad.py
ADDED
|
@@ -0,0 +1,845 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
from functools import partial
|
| 18 |
+
from multiprocessing import Pool, cpu_count
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
from ...models.bert.tokenization_bert import whitespace_tokenize
|
| 24 |
+
from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy
|
| 25 |
+
from ...utils import is_tf_available, is_torch_available, logging
|
| 26 |
+
from .utils import DataProcessor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Store the tokenizers which insert 2 separators tokens
|
| 30 |
+
MULTI_SEP_TOKENS_TOKENIZERS_SET = {"roberta", "camembert", "bart", "mpnet"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if is_torch_available():
|
| 34 |
+
import torch
|
| 35 |
+
from torch.utils.data import TensorDataset
|
| 36 |
+
|
| 37 |
+
if is_tf_available():
|
| 38 |
+
import tensorflow as tf
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
|
| 44 |
+
"""Returns tokenized answer spans that better match the annotated answer."""
|
| 45 |
+
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
| 46 |
+
|
| 47 |
+
for new_start in range(input_start, input_end + 1):
|
| 48 |
+
for new_end in range(input_end, new_start - 1, -1):
|
| 49 |
+
text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
|
| 50 |
+
if text_span == tok_answer_text:
|
| 51 |
+
return (new_start, new_end)
|
| 52 |
+
|
| 53 |
+
return (input_start, input_end)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _check_is_max_context(doc_spans, cur_span_index, position):
|
| 57 |
+
"""Check if this is the 'max context' doc span for the token."""
|
| 58 |
+
best_score = None
|
| 59 |
+
best_span_index = None
|
| 60 |
+
for span_index, doc_span in enumerate(doc_spans):
|
| 61 |
+
end = doc_span.start + doc_span.length - 1
|
| 62 |
+
if position < doc_span.start:
|
| 63 |
+
continue
|
| 64 |
+
if position > end:
|
| 65 |
+
continue
|
| 66 |
+
num_left_context = position - doc_span.start
|
| 67 |
+
num_right_context = end - position
|
| 68 |
+
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
| 69 |
+
if best_score is None or score > best_score:
|
| 70 |
+
best_score = score
|
| 71 |
+
best_span_index = span_index
|
| 72 |
+
|
| 73 |
+
return cur_span_index == best_span_index
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _new_check_is_max_context(doc_spans, cur_span_index, position):
|
| 77 |
+
"""Check if this is the 'max context' doc span for the token."""
|
| 78 |
+
# if len(doc_spans) == 1:
|
| 79 |
+
# return True
|
| 80 |
+
best_score = None
|
| 81 |
+
best_span_index = None
|
| 82 |
+
for span_index, doc_span in enumerate(doc_spans):
|
| 83 |
+
end = doc_span["start"] + doc_span["length"] - 1
|
| 84 |
+
if position < doc_span["start"]:
|
| 85 |
+
continue
|
| 86 |
+
if position > end:
|
| 87 |
+
continue
|
| 88 |
+
num_left_context = position - doc_span["start"]
|
| 89 |
+
num_right_context = end - position
|
| 90 |
+
score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"]
|
| 91 |
+
if best_score is None or score > best_score:
|
| 92 |
+
best_score = score
|
| 93 |
+
best_span_index = span_index
|
| 94 |
+
|
| 95 |
+
return cur_span_index == best_span_index
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _is_whitespace(c):
|
| 99 |
+
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
| 100 |
+
return True
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def squad_convert_example_to_features(
|
| 105 |
+
example, max_seq_length, doc_stride, max_query_length, padding_strategy, is_training
|
| 106 |
+
):
|
| 107 |
+
features = []
|
| 108 |
+
if is_training and not example.is_impossible:
|
| 109 |
+
# Get start and end position
|
| 110 |
+
start_position = example.start_position
|
| 111 |
+
end_position = example.end_position
|
| 112 |
+
|
| 113 |
+
# If the answer cannot be found in the text, then skip this example.
|
| 114 |
+
actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
|
| 115 |
+
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
|
| 116 |
+
if actual_text.find(cleaned_answer_text) == -1:
|
| 117 |
+
logger.warning(f"Could not find answer: '{actual_text}' vs. '{cleaned_answer_text}'")
|
| 118 |
+
return []
|
| 119 |
+
|
| 120 |
+
tok_to_orig_index = []
|
| 121 |
+
orig_to_tok_index = []
|
| 122 |
+
all_doc_tokens = []
|
| 123 |
+
for i, token in enumerate(example.doc_tokens):
|
| 124 |
+
orig_to_tok_index.append(len(all_doc_tokens))
|
| 125 |
+
if tokenizer.__class__.__name__ in [
|
| 126 |
+
"RobertaTokenizer",
|
| 127 |
+
"LongformerTokenizer",
|
| 128 |
+
"BartTokenizer",
|
| 129 |
+
"RobertaTokenizerFast",
|
| 130 |
+
"LongformerTokenizerFast",
|
| 131 |
+
"BartTokenizerFast",
|
| 132 |
+
]:
|
| 133 |
+
sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
|
| 134 |
+
else:
|
| 135 |
+
sub_tokens = tokenizer.tokenize(token)
|
| 136 |
+
for sub_token in sub_tokens:
|
| 137 |
+
tok_to_orig_index.append(i)
|
| 138 |
+
all_doc_tokens.append(sub_token)
|
| 139 |
+
|
| 140 |
+
if is_training and not example.is_impossible:
|
| 141 |
+
tok_start_position = orig_to_tok_index[example.start_position]
|
| 142 |
+
if example.end_position < len(example.doc_tokens) - 1:
|
| 143 |
+
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
| 144 |
+
else:
|
| 145 |
+
tok_end_position = len(all_doc_tokens) - 1
|
| 146 |
+
|
| 147 |
+
(tok_start_position, tok_end_position) = _improve_answer_span(
|
| 148 |
+
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
spans = []
|
| 152 |
+
|
| 153 |
+
truncated_query = tokenizer.encode(
|
| 154 |
+
example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Tokenizers who insert 2 SEP tokens in-between <context> & <question> need to have special handling
|
| 158 |
+
# in the way they compute mask of added tokens.
|
| 159 |
+
tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
|
| 160 |
+
sequence_added_tokens = (
|
| 161 |
+
tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1
|
| 162 |
+
if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
|
| 163 |
+
else tokenizer.model_max_length - tokenizer.max_len_single_sentence
|
| 164 |
+
)
|
| 165 |
+
sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair
|
| 166 |
+
|
| 167 |
+
span_doc_tokens = all_doc_tokens
|
| 168 |
+
while len(spans) * doc_stride < len(all_doc_tokens):
|
| 169 |
+
# Define the side we want to truncate / pad and the text/pair sorting
|
| 170 |
+
if tokenizer.padding_side == "right":
|
| 171 |
+
texts = truncated_query
|
| 172 |
+
pairs = span_doc_tokens
|
| 173 |
+
truncation = TruncationStrategy.ONLY_SECOND.value
|
| 174 |
+
else:
|
| 175 |
+
texts = span_doc_tokens
|
| 176 |
+
pairs = truncated_query
|
| 177 |
+
truncation = TruncationStrategy.ONLY_FIRST.value
|
| 178 |
+
|
| 179 |
+
encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
|
| 180 |
+
texts,
|
| 181 |
+
pairs,
|
| 182 |
+
truncation=truncation,
|
| 183 |
+
padding=padding_strategy,
|
| 184 |
+
max_length=max_seq_length,
|
| 185 |
+
return_overflowing_tokens=True,
|
| 186 |
+
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
| 187 |
+
return_token_type_ids=True,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
paragraph_len = min(
|
| 191 |
+
len(all_doc_tokens) - len(spans) * doc_stride,
|
| 192 |
+
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if tokenizer.pad_token_id in encoded_dict["input_ids"]:
|
| 196 |
+
if tokenizer.padding_side == "right":
|
| 197 |
+
non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
|
| 198 |
+
else:
|
| 199 |
+
last_padding_id_position = (
|
| 200 |
+
len(encoded_dict["input_ids"]) - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id)
|
| 201 |
+
)
|
| 202 |
+
non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :]
|
| 203 |
+
|
| 204 |
+
else:
|
| 205 |
+
non_padded_ids = encoded_dict["input_ids"]
|
| 206 |
+
|
| 207 |
+
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
| 208 |
+
|
| 209 |
+
token_to_orig_map = {}
|
| 210 |
+
for i in range(paragraph_len):
|
| 211 |
+
index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
|
| 212 |
+
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
|
| 213 |
+
|
| 214 |
+
encoded_dict["paragraph_len"] = paragraph_len
|
| 215 |
+
encoded_dict["tokens"] = tokens
|
| 216 |
+
encoded_dict["token_to_orig_map"] = token_to_orig_map
|
| 217 |
+
encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens
|
| 218 |
+
encoded_dict["token_is_max_context"] = {}
|
| 219 |
+
encoded_dict["start"] = len(spans) * doc_stride
|
| 220 |
+
encoded_dict["length"] = paragraph_len
|
| 221 |
+
|
| 222 |
+
spans.append(encoded_dict)
|
| 223 |
+
|
| 224 |
+
if "overflowing_tokens" not in encoded_dict or (
|
| 225 |
+
"overflowing_tokens" in encoded_dict and len(encoded_dict["overflowing_tokens"]) == 0
|
| 226 |
+
):
|
| 227 |
+
break
|
| 228 |
+
span_doc_tokens = encoded_dict["overflowing_tokens"]
|
| 229 |
+
|
| 230 |
+
for doc_span_index in range(len(spans)):
|
| 231 |
+
for j in range(spans[doc_span_index]["paragraph_len"]):
|
| 232 |
+
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
|
| 233 |
+
index = (
|
| 234 |
+
j
|
| 235 |
+
if tokenizer.padding_side == "left"
|
| 236 |
+
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
| 237 |
+
)
|
| 238 |
+
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
|
| 239 |
+
|
| 240 |
+
for span in spans:
|
| 241 |
+
# Identify the position of the CLS token
|
| 242 |
+
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
|
| 243 |
+
|
| 244 |
+
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
| 245 |
+
# Original TF implementation also keep the classification token (set to 0)
|
| 246 |
+
p_mask = np.ones_like(span["token_type_ids"])
|
| 247 |
+
if tokenizer.padding_side == "right":
|
| 248 |
+
p_mask[len(truncated_query) + sequence_added_tokens :] = 0
|
| 249 |
+
else:
|
| 250 |
+
p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
|
| 251 |
+
|
| 252 |
+
pad_token_indices = np.where(np.atleast_1d(span["input_ids"] == tokenizer.pad_token_id))
|
| 253 |
+
special_token_indices = np.asarray(
|
| 254 |
+
tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
|
| 255 |
+
).nonzero()
|
| 256 |
+
|
| 257 |
+
p_mask[pad_token_indices] = 1
|
| 258 |
+
p_mask[special_token_indices] = 1
|
| 259 |
+
|
| 260 |
+
# Set the cls index to 0: the CLS index can be used for impossible answers
|
| 261 |
+
p_mask[cls_index] = 0
|
| 262 |
+
|
| 263 |
+
span_is_impossible = example.is_impossible
|
| 264 |
+
start_position = 0
|
| 265 |
+
end_position = 0
|
| 266 |
+
if is_training and not span_is_impossible:
|
| 267 |
+
# For training, if our document chunk does not contain an annotation
|
| 268 |
+
# we throw it out, since there is nothing to predict.
|
| 269 |
+
doc_start = span["start"]
|
| 270 |
+
doc_end = span["start"] + span["length"] - 1
|
| 271 |
+
out_of_span = False
|
| 272 |
+
|
| 273 |
+
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
|
| 274 |
+
out_of_span = True
|
| 275 |
+
|
| 276 |
+
if out_of_span:
|
| 277 |
+
start_position = cls_index
|
| 278 |
+
end_position = cls_index
|
| 279 |
+
span_is_impossible = True
|
| 280 |
+
else:
|
| 281 |
+
if tokenizer.padding_side == "left":
|
| 282 |
+
doc_offset = 0
|
| 283 |
+
else:
|
| 284 |
+
doc_offset = len(truncated_query) + sequence_added_tokens
|
| 285 |
+
|
| 286 |
+
start_position = tok_start_position - doc_start + doc_offset
|
| 287 |
+
end_position = tok_end_position - doc_start + doc_offset
|
| 288 |
+
|
| 289 |
+
features.append(
|
| 290 |
+
SquadFeatures(
|
| 291 |
+
span["input_ids"],
|
| 292 |
+
span["attention_mask"],
|
| 293 |
+
span["token_type_ids"],
|
| 294 |
+
cls_index,
|
| 295 |
+
p_mask.tolist(),
|
| 296 |
+
example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
|
| 297 |
+
unique_id=0,
|
| 298 |
+
paragraph_len=span["paragraph_len"],
|
| 299 |
+
token_is_max_context=span["token_is_max_context"],
|
| 300 |
+
tokens=span["tokens"],
|
| 301 |
+
token_to_orig_map=span["token_to_orig_map"],
|
| 302 |
+
start_position=start_position,
|
| 303 |
+
end_position=end_position,
|
| 304 |
+
is_impossible=span_is_impossible,
|
| 305 |
+
qas_id=example.qas_id,
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
+
return features
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def squad_convert_example_to_features_init(tokenizer_for_convert: PreTrainedTokenizerBase):
|
| 312 |
+
global tokenizer
|
| 313 |
+
tokenizer = tokenizer_for_convert
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def squad_convert_examples_to_features(
|
| 317 |
+
examples,
|
| 318 |
+
tokenizer,
|
| 319 |
+
max_seq_length,
|
| 320 |
+
doc_stride,
|
| 321 |
+
max_query_length,
|
| 322 |
+
is_training,
|
| 323 |
+
padding_strategy="max_length",
|
| 324 |
+
return_dataset=False,
|
| 325 |
+
threads=1,
|
| 326 |
+
tqdm_enabled=True,
|
| 327 |
+
):
|
| 328 |
+
"""
|
| 329 |
+
Converts a list of examples into a list of features that can be directly given as input to a model. It is
|
| 330 |
+
model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
examples: list of [`~data.processors.squad.SquadExample`]
|
| 334 |
+
tokenizer: an instance of a child of [`PreTrainedTokenizer`]
|
| 335 |
+
max_seq_length: The maximum sequence length of the inputs.
|
| 336 |
+
doc_stride: The stride used when the context is too large and is split across several features.
|
| 337 |
+
max_query_length: The maximum length of the query.
|
| 338 |
+
is_training: whether to create features for model evaluation or model training.
|
| 339 |
+
padding_strategy: Default to "max_length". Which padding strategy to use
|
| 340 |
+
return_dataset: Default False. Either 'pt' or 'tf'.
|
| 341 |
+
if 'pt': returns a torch.data.TensorDataset, if 'tf': returns a tf.data.Dataset
|
| 342 |
+
threads: multiple processing threads.
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
list of [`~data.processors.squad.SquadFeatures`]
|
| 347 |
+
|
| 348 |
+
Example:
|
| 349 |
+
|
| 350 |
+
```python
|
| 351 |
+
processor = SquadV2Processor()
|
| 352 |
+
examples = processor.get_dev_examples(data_dir)
|
| 353 |
+
|
| 354 |
+
features = squad_convert_examples_to_features(
|
| 355 |
+
examples=examples,
|
| 356 |
+
tokenizer=tokenizer,
|
| 357 |
+
max_seq_length=args.max_seq_length,
|
| 358 |
+
doc_stride=args.doc_stride,
|
| 359 |
+
max_query_length=args.max_query_length,
|
| 360 |
+
is_training=not evaluate,
|
| 361 |
+
)
|
| 362 |
+
```"""
|
| 363 |
+
# Defining helper methods
|
| 364 |
+
features = []
|
| 365 |
+
|
| 366 |
+
threads = min(threads, cpu_count())
|
| 367 |
+
with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
|
| 368 |
+
annotate_ = partial(
|
| 369 |
+
squad_convert_example_to_features,
|
| 370 |
+
max_seq_length=max_seq_length,
|
| 371 |
+
doc_stride=doc_stride,
|
| 372 |
+
max_query_length=max_query_length,
|
| 373 |
+
padding_strategy=padding_strategy,
|
| 374 |
+
is_training=is_training,
|
| 375 |
+
)
|
| 376 |
+
features = list(
|
| 377 |
+
tqdm(
|
| 378 |
+
p.imap(annotate_, examples, chunksize=32),
|
| 379 |
+
total=len(examples),
|
| 380 |
+
desc="convert squad examples to features",
|
| 381 |
+
disable=not tqdm_enabled,
|
| 382 |
+
)
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
new_features = []
|
| 386 |
+
unique_id = 1000000000
|
| 387 |
+
example_index = 0
|
| 388 |
+
for example_features in tqdm(
|
| 389 |
+
features, total=len(features), desc="add example index and unique id", disable=not tqdm_enabled
|
| 390 |
+
):
|
| 391 |
+
if not example_features:
|
| 392 |
+
continue
|
| 393 |
+
for example_feature in example_features:
|
| 394 |
+
example_feature.example_index = example_index
|
| 395 |
+
example_feature.unique_id = unique_id
|
| 396 |
+
new_features.append(example_feature)
|
| 397 |
+
unique_id += 1
|
| 398 |
+
example_index += 1
|
| 399 |
+
features = new_features
|
| 400 |
+
del new_features
|
| 401 |
+
if return_dataset == "pt":
|
| 402 |
+
if not is_torch_available():
|
| 403 |
+
raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
|
| 404 |
+
|
| 405 |
+
# Convert to Tensors and build dataset
|
| 406 |
+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
| 407 |
+
all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
| 408 |
+
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
| 409 |
+
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
| 410 |
+
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
| 411 |
+
all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
|
| 412 |
+
|
| 413 |
+
if not is_training:
|
| 414 |
+
all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
| 415 |
+
dataset = TensorDataset(
|
| 416 |
+
all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask
|
| 417 |
+
)
|
| 418 |
+
else:
|
| 419 |
+
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
| 420 |
+
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
| 421 |
+
dataset = TensorDataset(
|
| 422 |
+
all_input_ids,
|
| 423 |
+
all_attention_masks,
|
| 424 |
+
all_token_type_ids,
|
| 425 |
+
all_start_positions,
|
| 426 |
+
all_end_positions,
|
| 427 |
+
all_cls_index,
|
| 428 |
+
all_p_mask,
|
| 429 |
+
all_is_impossible,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
return features, dataset
|
| 433 |
+
elif return_dataset == "tf":
|
| 434 |
+
if not is_tf_available():
|
| 435 |
+
raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
|
| 436 |
+
|
| 437 |
+
def gen():
|
| 438 |
+
for i, ex in enumerate(features):
|
| 439 |
+
if ex.token_type_ids is None:
|
| 440 |
+
yield (
|
| 441 |
+
{
|
| 442 |
+
"input_ids": ex.input_ids,
|
| 443 |
+
"attention_mask": ex.attention_mask,
|
| 444 |
+
"feature_index": i,
|
| 445 |
+
"qas_id": ex.qas_id,
|
| 446 |
+
},
|
| 447 |
+
{
|
| 448 |
+
"start_positions": ex.start_position,
|
| 449 |
+
"end_positions": ex.end_position,
|
| 450 |
+
"cls_index": ex.cls_index,
|
| 451 |
+
"p_mask": ex.p_mask,
|
| 452 |
+
"is_impossible": ex.is_impossible,
|
| 453 |
+
},
|
| 454 |
+
)
|
| 455 |
+
else:
|
| 456 |
+
yield (
|
| 457 |
+
{
|
| 458 |
+
"input_ids": ex.input_ids,
|
| 459 |
+
"attention_mask": ex.attention_mask,
|
| 460 |
+
"token_type_ids": ex.token_type_ids,
|
| 461 |
+
"feature_index": i,
|
| 462 |
+
"qas_id": ex.qas_id,
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"start_positions": ex.start_position,
|
| 466 |
+
"end_positions": ex.end_position,
|
| 467 |
+
"cls_index": ex.cls_index,
|
| 468 |
+
"p_mask": ex.p_mask,
|
| 469 |
+
"is_impossible": ex.is_impossible,
|
| 470 |
+
},
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# Why have we split the batch into a tuple? PyTorch just has a list of tensors.
|
| 474 |
+
if "token_type_ids" in tokenizer.model_input_names:
|
| 475 |
+
train_types = (
|
| 476 |
+
{
|
| 477 |
+
"input_ids": tf.int32,
|
| 478 |
+
"attention_mask": tf.int32,
|
| 479 |
+
"token_type_ids": tf.int32,
|
| 480 |
+
"feature_index": tf.int64,
|
| 481 |
+
"qas_id": tf.string,
|
| 482 |
+
},
|
| 483 |
+
{
|
| 484 |
+
"start_positions": tf.int64,
|
| 485 |
+
"end_positions": tf.int64,
|
| 486 |
+
"cls_index": tf.int64,
|
| 487 |
+
"p_mask": tf.int32,
|
| 488 |
+
"is_impossible": tf.int32,
|
| 489 |
+
},
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
train_shapes = (
|
| 493 |
+
{
|
| 494 |
+
"input_ids": tf.TensorShape([None]),
|
| 495 |
+
"attention_mask": tf.TensorShape([None]),
|
| 496 |
+
"token_type_ids": tf.TensorShape([None]),
|
| 497 |
+
"feature_index": tf.TensorShape([]),
|
| 498 |
+
"qas_id": tf.TensorShape([]),
|
| 499 |
+
},
|
| 500 |
+
{
|
| 501 |
+
"start_positions": tf.TensorShape([]),
|
| 502 |
+
"end_positions": tf.TensorShape([]),
|
| 503 |
+
"cls_index": tf.TensorShape([]),
|
| 504 |
+
"p_mask": tf.TensorShape([None]),
|
| 505 |
+
"is_impossible": tf.TensorShape([]),
|
| 506 |
+
},
|
| 507 |
+
)
|
| 508 |
+
else:
|
| 509 |
+
train_types = (
|
| 510 |
+
{"input_ids": tf.int32, "attention_mask": tf.int32, "feature_index": tf.int64, "qas_id": tf.string},
|
| 511 |
+
{
|
| 512 |
+
"start_positions": tf.int64,
|
| 513 |
+
"end_positions": tf.int64,
|
| 514 |
+
"cls_index": tf.int64,
|
| 515 |
+
"p_mask": tf.int32,
|
| 516 |
+
"is_impossible": tf.int32,
|
| 517 |
+
},
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
train_shapes = (
|
| 521 |
+
{
|
| 522 |
+
"input_ids": tf.TensorShape([None]),
|
| 523 |
+
"attention_mask": tf.TensorShape([None]),
|
| 524 |
+
"feature_index": tf.TensorShape([]),
|
| 525 |
+
"qas_id": tf.TensorShape([]),
|
| 526 |
+
},
|
| 527 |
+
{
|
| 528 |
+
"start_positions": tf.TensorShape([]),
|
| 529 |
+
"end_positions": tf.TensorShape([]),
|
| 530 |
+
"cls_index": tf.TensorShape([]),
|
| 531 |
+
"p_mask": tf.TensorShape([None]),
|
| 532 |
+
"is_impossible": tf.TensorShape([]),
|
| 533 |
+
},
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
|
| 537 |
+
else:
|
| 538 |
+
return features
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
class SquadProcessor(DataProcessor):
|
| 542 |
+
"""
|
| 543 |
+
Processor for the SQuAD data set. overridden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and
|
| 544 |
+
version 2.0 of SQuAD, respectively.
|
| 545 |
+
"""
|
| 546 |
+
|
| 547 |
+
train_file = None
|
| 548 |
+
dev_file = None
|
| 549 |
+
|
| 550 |
+
def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
|
| 551 |
+
if not evaluate:
|
| 552 |
+
answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8")
|
| 553 |
+
answer_start = tensor_dict["answers"]["answer_start"][0].numpy()
|
| 554 |
+
answers = []
|
| 555 |
+
else:
|
| 556 |
+
answers = [
|
| 557 |
+
{"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")}
|
| 558 |
+
for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"])
|
| 559 |
+
]
|
| 560 |
+
|
| 561 |
+
answer = None
|
| 562 |
+
answer_start = None
|
| 563 |
+
|
| 564 |
+
return SquadExample(
|
| 565 |
+
qas_id=tensor_dict["id"].numpy().decode("utf-8"),
|
| 566 |
+
question_text=tensor_dict["question"].numpy().decode("utf-8"),
|
| 567 |
+
context_text=tensor_dict["context"].numpy().decode("utf-8"),
|
| 568 |
+
answer_text=answer,
|
| 569 |
+
start_position_character=answer_start,
|
| 570 |
+
title=tensor_dict["title"].numpy().decode("utf-8"),
|
| 571 |
+
answers=answers,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
def get_examples_from_dataset(self, dataset, evaluate=False):
|
| 575 |
+
"""
|
| 576 |
+
Creates a list of [`~data.processors.squad.SquadExample`] using a TFDS dataset.
|
| 577 |
+
|
| 578 |
+
Args:
|
| 579 |
+
dataset: The tfds dataset loaded from *tensorflow_datasets.load("squad")*
|
| 580 |
+
evaluate: Boolean specifying if in evaluation mode or in training mode
|
| 581 |
+
|
| 582 |
+
Returns:
|
| 583 |
+
List of SquadExample
|
| 584 |
+
|
| 585 |
+
Examples:
|
| 586 |
+
|
| 587 |
+
```python
|
| 588 |
+
>>> import tensorflow_datasets as tfds
|
| 589 |
+
|
| 590 |
+
>>> dataset = tfds.load("squad")
|
| 591 |
+
|
| 592 |
+
>>> training_examples = get_examples_from_dataset(dataset, evaluate=False)
|
| 593 |
+
>>> evaluation_examples = get_examples_from_dataset(dataset, evaluate=True)
|
| 594 |
+
```"""
|
| 595 |
+
|
| 596 |
+
if evaluate:
|
| 597 |
+
dataset = dataset["validation"]
|
| 598 |
+
else:
|
| 599 |
+
dataset = dataset["train"]
|
| 600 |
+
|
| 601 |
+
examples = []
|
| 602 |
+
for tensor_dict in tqdm(dataset):
|
| 603 |
+
examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
|
| 604 |
+
|
| 605 |
+
return examples
|
| 606 |
+
|
| 607 |
+
def get_train_examples(self, data_dir, filename=None):
|
| 608 |
+
"""
|
| 609 |
+
Returns the training examples from the data directory.
|
| 610 |
+
|
| 611 |
+
Args:
|
| 612 |
+
data_dir: Directory containing the data files used for training and evaluating.
|
| 613 |
+
filename: None by default, specify this if the training file has a different name than the original one
|
| 614 |
+
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
|
| 615 |
+
|
| 616 |
+
"""
|
| 617 |
+
if data_dir is None:
|
| 618 |
+
data_dir = ""
|
| 619 |
+
|
| 620 |
+
if self.train_file is None:
|
| 621 |
+
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
| 622 |
+
|
| 623 |
+
with open(
|
| 624 |
+
os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
|
| 625 |
+
) as reader:
|
| 626 |
+
input_data = json.load(reader)["data"]
|
| 627 |
+
return self._create_examples(input_data, "train")
|
| 628 |
+
|
| 629 |
+
def get_dev_examples(self, data_dir, filename=None):
|
| 630 |
+
"""
|
| 631 |
+
Returns the evaluation example from the data directory.
|
| 632 |
+
|
| 633 |
+
Args:
|
| 634 |
+
data_dir: Directory containing the data files used for training and evaluating.
|
| 635 |
+
filename: None by default, specify this if the evaluation file has a different name than the original one
|
| 636 |
+
which is `dev-v1.1.json` and `dev-v2.0.json` for squad versions 1.1 and 2.0 respectively.
|
| 637 |
+
"""
|
| 638 |
+
if data_dir is None:
|
| 639 |
+
data_dir = ""
|
| 640 |
+
|
| 641 |
+
if self.dev_file is None:
|
| 642 |
+
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
| 643 |
+
|
| 644 |
+
with open(
|
| 645 |
+
os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
|
| 646 |
+
) as reader:
|
| 647 |
+
input_data = json.load(reader)["data"]
|
| 648 |
+
return self._create_examples(input_data, "dev")
|
| 649 |
+
|
| 650 |
+
def _create_examples(self, input_data, set_type):
|
| 651 |
+
is_training = set_type == "train"
|
| 652 |
+
examples = []
|
| 653 |
+
for entry in tqdm(input_data):
|
| 654 |
+
title = entry["title"]
|
| 655 |
+
for paragraph in entry["paragraphs"]:
|
| 656 |
+
context_text = paragraph["context"]
|
| 657 |
+
for qa in paragraph["qas"]:
|
| 658 |
+
qas_id = qa["id"]
|
| 659 |
+
question_text = qa["question"]
|
| 660 |
+
start_position_character = None
|
| 661 |
+
answer_text = None
|
| 662 |
+
answers = []
|
| 663 |
+
|
| 664 |
+
is_impossible = qa.get("is_impossible", False)
|
| 665 |
+
if not is_impossible:
|
| 666 |
+
if is_training:
|
| 667 |
+
answer = qa["answers"][0]
|
| 668 |
+
answer_text = answer["text"]
|
| 669 |
+
start_position_character = answer["answer_start"]
|
| 670 |
+
else:
|
| 671 |
+
answers = qa["answers"]
|
| 672 |
+
|
| 673 |
+
example = SquadExample(
|
| 674 |
+
qas_id=qas_id,
|
| 675 |
+
question_text=question_text,
|
| 676 |
+
context_text=context_text,
|
| 677 |
+
answer_text=answer_text,
|
| 678 |
+
start_position_character=start_position_character,
|
| 679 |
+
title=title,
|
| 680 |
+
is_impossible=is_impossible,
|
| 681 |
+
answers=answers,
|
| 682 |
+
)
|
| 683 |
+
examples.append(example)
|
| 684 |
+
return examples
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
class SquadV1Processor(SquadProcessor):
|
| 688 |
+
train_file = "train-v1.1.json"
|
| 689 |
+
dev_file = "dev-v1.1.json"
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
class SquadV2Processor(SquadProcessor):
|
| 693 |
+
train_file = "train-v2.0.json"
|
| 694 |
+
dev_file = "dev-v2.0.json"
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
class SquadExample:
|
| 698 |
+
"""
|
| 699 |
+
A single training/test example for the Squad dataset, as loaded from disk.
|
| 700 |
+
|
| 701 |
+
Args:
|
| 702 |
+
qas_id: The example's unique identifier
|
| 703 |
+
question_text: The question string
|
| 704 |
+
context_text: The context string
|
| 705 |
+
answer_text: The answer string
|
| 706 |
+
start_position_character: The character position of the start of the answer
|
| 707 |
+
title: The title of the example
|
| 708 |
+
answers: None by default, this is used during evaluation. Holds answers as well as their start positions.
|
| 709 |
+
is_impossible: False by default, set to True if the example has no possible answer.
|
| 710 |
+
"""
|
| 711 |
+
|
| 712 |
+
def __init__(
|
| 713 |
+
self,
|
| 714 |
+
qas_id,
|
| 715 |
+
question_text,
|
| 716 |
+
context_text,
|
| 717 |
+
answer_text,
|
| 718 |
+
start_position_character,
|
| 719 |
+
title,
|
| 720 |
+
answers=[],
|
| 721 |
+
is_impossible=False,
|
| 722 |
+
):
|
| 723 |
+
self.qas_id = qas_id
|
| 724 |
+
self.question_text = question_text
|
| 725 |
+
self.context_text = context_text
|
| 726 |
+
self.answer_text = answer_text
|
| 727 |
+
self.title = title
|
| 728 |
+
self.is_impossible = is_impossible
|
| 729 |
+
self.answers = answers
|
| 730 |
+
|
| 731 |
+
self.start_position, self.end_position = 0, 0
|
| 732 |
+
|
| 733 |
+
doc_tokens = []
|
| 734 |
+
char_to_word_offset = []
|
| 735 |
+
prev_is_whitespace = True
|
| 736 |
+
|
| 737 |
+
# Split on whitespace so that different tokens may be attributed to their original position.
|
| 738 |
+
for c in self.context_text:
|
| 739 |
+
if _is_whitespace(c):
|
| 740 |
+
prev_is_whitespace = True
|
| 741 |
+
else:
|
| 742 |
+
if prev_is_whitespace:
|
| 743 |
+
doc_tokens.append(c)
|
| 744 |
+
else:
|
| 745 |
+
doc_tokens[-1] += c
|
| 746 |
+
prev_is_whitespace = False
|
| 747 |
+
char_to_word_offset.append(len(doc_tokens) - 1)
|
| 748 |
+
|
| 749 |
+
self.doc_tokens = doc_tokens
|
| 750 |
+
self.char_to_word_offset = char_to_word_offset
|
| 751 |
+
|
| 752 |
+
# Start and end positions only has a value during evaluation.
|
| 753 |
+
if start_position_character is not None and not is_impossible:
|
| 754 |
+
self.start_position = char_to_word_offset[start_position_character]
|
| 755 |
+
self.end_position = char_to_word_offset[
|
| 756 |
+
min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1)
|
| 757 |
+
]
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
class SquadFeatures:
|
| 761 |
+
"""
|
| 762 |
+
Single squad example features to be fed to a model. Those features are model-specific and can be crafted from
|
| 763 |
+
[`~data.processors.squad.SquadExample`] using the
|
| 764 |
+
:method:*~transformers.data.processors.squad.squad_convert_examples_to_features* method.
|
| 765 |
+
|
| 766 |
+
Args:
|
| 767 |
+
input_ids: Indices of input sequence tokens in the vocabulary.
|
| 768 |
+
attention_mask: Mask to avoid performing attention on padding token indices.
|
| 769 |
+
token_type_ids: Segment token indices to indicate first and second portions of the inputs.
|
| 770 |
+
cls_index: the index of the CLS token.
|
| 771 |
+
p_mask: Mask identifying tokens that can be answers vs. tokens that cannot.
|
| 772 |
+
Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer
|
| 773 |
+
example_index: the index of the example
|
| 774 |
+
unique_id: The unique Feature identifier
|
| 775 |
+
paragraph_len: The length of the context
|
| 776 |
+
token_is_max_context:
|
| 777 |
+
List of booleans identifying which tokens have their maximum context in this feature object. If a token
|
| 778 |
+
does not have their maximum context in this feature object, it means that another feature object has more
|
| 779 |
+
information related to that token and should be prioritized over this feature for that token.
|
| 780 |
+
tokens: list of tokens corresponding to the input ids
|
| 781 |
+
token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer.
|
| 782 |
+
start_position: start of the answer token index
|
| 783 |
+
end_position: end of the answer token index
|
| 784 |
+
encoding: optionally store the BatchEncoding with the fast-tokenizer alignment methods.
|
| 785 |
+
"""
|
| 786 |
+
|
| 787 |
+
def __init__(
|
| 788 |
+
self,
|
| 789 |
+
input_ids,
|
| 790 |
+
attention_mask,
|
| 791 |
+
token_type_ids,
|
| 792 |
+
cls_index,
|
| 793 |
+
p_mask,
|
| 794 |
+
example_index,
|
| 795 |
+
unique_id,
|
| 796 |
+
paragraph_len,
|
| 797 |
+
token_is_max_context,
|
| 798 |
+
tokens,
|
| 799 |
+
token_to_orig_map,
|
| 800 |
+
start_position,
|
| 801 |
+
end_position,
|
| 802 |
+
is_impossible,
|
| 803 |
+
qas_id: str = None,
|
| 804 |
+
encoding: BatchEncoding = None,
|
| 805 |
+
):
|
| 806 |
+
self.input_ids = input_ids
|
| 807 |
+
self.attention_mask = attention_mask
|
| 808 |
+
self.token_type_ids = token_type_ids
|
| 809 |
+
self.cls_index = cls_index
|
| 810 |
+
self.p_mask = p_mask
|
| 811 |
+
|
| 812 |
+
self.example_index = example_index
|
| 813 |
+
self.unique_id = unique_id
|
| 814 |
+
self.paragraph_len = paragraph_len
|
| 815 |
+
self.token_is_max_context = token_is_max_context
|
| 816 |
+
self.tokens = tokens
|
| 817 |
+
self.token_to_orig_map = token_to_orig_map
|
| 818 |
+
|
| 819 |
+
self.start_position = start_position
|
| 820 |
+
self.end_position = end_position
|
| 821 |
+
self.is_impossible = is_impossible
|
| 822 |
+
self.qas_id = qas_id
|
| 823 |
+
|
| 824 |
+
self.encoding = encoding
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
class SquadResult:
|
| 828 |
+
"""
|
| 829 |
+
Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
|
| 830 |
+
|
| 831 |
+
Args:
|
| 832 |
+
unique_id: The unique identifier corresponding to that example.
|
| 833 |
+
start_logits: The logits corresponding to the start of the answer
|
| 834 |
+
end_logits: The logits corresponding to the end of the answer
|
| 835 |
+
"""
|
| 836 |
+
|
| 837 |
+
def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
|
| 838 |
+
self.start_logits = start_logits
|
| 839 |
+
self.end_logits = end_logits
|
| 840 |
+
self.unique_id = unique_id
|
| 841 |
+
|
| 842 |
+
if start_top_index:
|
| 843 |
+
self.start_top_index = start_top_index
|
| 844 |
+
self.end_top_index = end_top_index
|
| 845 |
+
self.cls_logits = cls_logits
|
.venv/lib/python3.11/site-packages/transformers/data/processors/utils.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import csv
|
| 18 |
+
import dataclasses
|
| 19 |
+
import json
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import List, Optional, Union
|
| 22 |
+
|
| 23 |
+
from ...utils import is_tf_available, is_torch_available, logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class InputExample:
|
| 31 |
+
"""
|
| 32 |
+
A single training/test example for simple sequence classification.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
guid: Unique id for the example.
|
| 36 |
+
text_a: string. The untokenized text of the first sequence. For single
|
| 37 |
+
sequence tasks, only this sequence must be specified.
|
| 38 |
+
text_b: (Optional) string. The untokenized text of the second sequence.
|
| 39 |
+
Only must be specified for sequence pair tasks.
|
| 40 |
+
label: (Optional) string. The label of the example. This should be
|
| 41 |
+
specified for train and dev examples, but not for test examples.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
guid: str
|
| 45 |
+
text_a: str
|
| 46 |
+
text_b: Optional[str] = None
|
| 47 |
+
label: Optional[str] = None
|
| 48 |
+
|
| 49 |
+
def to_json_string(self):
|
| 50 |
+
"""Serializes this instance to a JSON string."""
|
| 51 |
+
return json.dumps(dataclasses.asdict(self), indent=2) + "\n"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass(frozen=True)
|
| 55 |
+
class InputFeatures:
|
| 56 |
+
"""
|
| 57 |
+
A single set of features of data. Property names are the same names as the corresponding inputs to a model.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
input_ids: Indices of input sequence tokens in the vocabulary.
|
| 61 |
+
attention_mask: Mask to avoid performing attention on padding token indices.
|
| 62 |
+
Mask values selected in `[0, 1]`: Usually `1` for tokens that are NOT MASKED, `0` for MASKED (padded)
|
| 63 |
+
tokens.
|
| 64 |
+
token_type_ids: (Optional) Segment token indices to indicate first and second
|
| 65 |
+
portions of the inputs. Only some models use them.
|
| 66 |
+
label: (Optional) Label corresponding to the input. Int for classification problems,
|
| 67 |
+
float for regression problems.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
input_ids: List[int]
|
| 71 |
+
attention_mask: Optional[List[int]] = None
|
| 72 |
+
token_type_ids: Optional[List[int]] = None
|
| 73 |
+
label: Optional[Union[int, float]] = None
|
| 74 |
+
|
| 75 |
+
def to_json_string(self):
|
| 76 |
+
"""Serializes this instance to a JSON string."""
|
| 77 |
+
return json.dumps(dataclasses.asdict(self)) + "\n"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class DataProcessor:
|
| 81 |
+
"""Base class for data converters for sequence classification data sets."""
|
| 82 |
+
|
| 83 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 84 |
+
"""
|
| 85 |
+
Gets an example from a dict with tensorflow tensors.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
tensor_dict: Keys and values should match the corresponding Glue
|
| 89 |
+
tensorflow_dataset examples.
|
| 90 |
+
"""
|
| 91 |
+
raise NotImplementedError()
|
| 92 |
+
|
| 93 |
+
def get_train_examples(self, data_dir):
|
| 94 |
+
"""Gets a collection of [`InputExample`] for the train set."""
|
| 95 |
+
raise NotImplementedError()
|
| 96 |
+
|
| 97 |
+
def get_dev_examples(self, data_dir):
|
| 98 |
+
"""Gets a collection of [`InputExample`] for the dev set."""
|
| 99 |
+
raise NotImplementedError()
|
| 100 |
+
|
| 101 |
+
def get_test_examples(self, data_dir):
|
| 102 |
+
"""Gets a collection of [`InputExample`] for the test set."""
|
| 103 |
+
raise NotImplementedError()
|
| 104 |
+
|
| 105 |
+
def get_labels(self):
|
| 106 |
+
"""Gets the list of labels for this data set."""
|
| 107 |
+
raise NotImplementedError()
|
| 108 |
+
|
| 109 |
+
def tfds_map(self, example):
|
| 110 |
+
"""
|
| 111 |
+
Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. This method converts
|
| 112 |
+
examples to the correct format.
|
| 113 |
+
"""
|
| 114 |
+
if len(self.get_labels()) > 1:
|
| 115 |
+
example.label = self.get_labels()[int(example.label)]
|
| 116 |
+
return example
|
| 117 |
+
|
| 118 |
+
@classmethod
|
| 119 |
+
def _read_tsv(cls, input_file, quotechar=None):
|
| 120 |
+
"""Reads a tab separated value file."""
|
| 121 |
+
with open(input_file, "r", encoding="utf-8-sig") as f:
|
| 122 |
+
return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class SingleSentenceClassificationProcessor(DataProcessor):
|
| 126 |
+
"""Generic processor for a single sentence classification data set."""
|
| 127 |
+
|
| 128 |
+
def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
|
| 129 |
+
self.labels = [] if labels is None else labels
|
| 130 |
+
self.examples = [] if examples is None else examples
|
| 131 |
+
self.mode = mode
|
| 132 |
+
self.verbose = verbose
|
| 133 |
+
|
| 134 |
+
def __len__(self):
|
| 135 |
+
return len(self.examples)
|
| 136 |
+
|
| 137 |
+
def __getitem__(self, idx):
|
| 138 |
+
if isinstance(idx, slice):
|
| 139 |
+
return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
|
| 140 |
+
return self.examples[idx]
|
| 141 |
+
|
| 142 |
+
@classmethod
|
| 143 |
+
def create_from_csv(
|
| 144 |
+
cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
|
| 145 |
+
):
|
| 146 |
+
processor = cls(**kwargs)
|
| 147 |
+
processor.add_examples_from_csv(
|
| 148 |
+
file_name,
|
| 149 |
+
split_name=split_name,
|
| 150 |
+
column_label=column_label,
|
| 151 |
+
column_text=column_text,
|
| 152 |
+
column_id=column_id,
|
| 153 |
+
skip_first_row=skip_first_row,
|
| 154 |
+
overwrite_labels=True,
|
| 155 |
+
overwrite_examples=True,
|
| 156 |
+
)
|
| 157 |
+
return processor
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
|
| 161 |
+
processor = cls(**kwargs)
|
| 162 |
+
processor.add_examples(texts_or_text_and_labels, labels=labels)
|
| 163 |
+
return processor
|
| 164 |
+
|
| 165 |
+
def add_examples_from_csv(
|
| 166 |
+
self,
|
| 167 |
+
file_name,
|
| 168 |
+
split_name="",
|
| 169 |
+
column_label=0,
|
| 170 |
+
column_text=1,
|
| 171 |
+
column_id=None,
|
| 172 |
+
skip_first_row=False,
|
| 173 |
+
overwrite_labels=False,
|
| 174 |
+
overwrite_examples=False,
|
| 175 |
+
):
|
| 176 |
+
lines = self._read_tsv(file_name)
|
| 177 |
+
if skip_first_row:
|
| 178 |
+
lines = lines[1:]
|
| 179 |
+
texts = []
|
| 180 |
+
labels = []
|
| 181 |
+
ids = []
|
| 182 |
+
for i, line in enumerate(lines):
|
| 183 |
+
texts.append(line[column_text])
|
| 184 |
+
labels.append(line[column_label])
|
| 185 |
+
if column_id is not None:
|
| 186 |
+
ids.append(line[column_id])
|
| 187 |
+
else:
|
| 188 |
+
guid = f"{split_name}-{i}" if split_name else str(i)
|
| 189 |
+
ids.append(guid)
|
| 190 |
+
|
| 191 |
+
return self.add_examples(
|
| 192 |
+
texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def add_examples(
|
| 196 |
+
self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
|
| 197 |
+
):
|
| 198 |
+
if labels is not None and len(texts_or_text_and_labels) != len(labels):
|
| 199 |
+
raise ValueError(
|
| 200 |
+
f"Text and labels have mismatched lengths {len(texts_or_text_and_labels)} and {len(labels)}"
|
| 201 |
+
)
|
| 202 |
+
if ids is not None and len(texts_or_text_and_labels) != len(ids):
|
| 203 |
+
raise ValueError(f"Text and ids have mismatched lengths {len(texts_or_text_and_labels)} and {len(ids)}")
|
| 204 |
+
if ids is None:
|
| 205 |
+
ids = [None] * len(texts_or_text_and_labels)
|
| 206 |
+
if labels is None:
|
| 207 |
+
labels = [None] * len(texts_or_text_and_labels)
|
| 208 |
+
examples = []
|
| 209 |
+
added_labels = set()
|
| 210 |
+
for text_or_text_and_label, label, guid in zip(texts_or_text_and_labels, labels, ids):
|
| 211 |
+
if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
|
| 212 |
+
text, label = text_or_text_and_label
|
| 213 |
+
else:
|
| 214 |
+
text = text_or_text_and_label
|
| 215 |
+
added_labels.add(label)
|
| 216 |
+
examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
|
| 217 |
+
|
| 218 |
+
# Update examples
|
| 219 |
+
if overwrite_examples:
|
| 220 |
+
self.examples = examples
|
| 221 |
+
else:
|
| 222 |
+
self.examples.extend(examples)
|
| 223 |
+
|
| 224 |
+
# Update labels
|
| 225 |
+
if overwrite_labels:
|
| 226 |
+
self.labels = list(added_labels)
|
| 227 |
+
else:
|
| 228 |
+
self.labels = list(set(self.labels).union(added_labels))
|
| 229 |
+
|
| 230 |
+
return self.examples
|
| 231 |
+
|
| 232 |
+
def get_features(
|
| 233 |
+
self,
|
| 234 |
+
tokenizer,
|
| 235 |
+
max_length=None,
|
| 236 |
+
pad_on_left=False,
|
| 237 |
+
pad_token=0,
|
| 238 |
+
mask_padding_with_zero=True,
|
| 239 |
+
return_tensors=None,
|
| 240 |
+
):
|
| 241 |
+
"""
|
| 242 |
+
Convert examples in a list of `InputFeatures`
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
tokenizer: Instance of a tokenizer that will tokenize the examples
|
| 246 |
+
max_length: Maximum example length
|
| 247 |
+
pad_on_left: If set to `True`, the examples will be padded on the left rather than on the right (default)
|
| 248 |
+
pad_token: Padding token
|
| 249 |
+
mask_padding_with_zero: If set to `True`, the attention mask will be filled by `1` for actual values
|
| 250 |
+
and by `0` for padded values. If set to `False`, inverts it (`1` for padded values, `0` for actual
|
| 251 |
+
values)
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the
|
| 255 |
+
task-specific features. If the input is a list of `InputExamples`, will return a list of task-specific
|
| 256 |
+
`InputFeatures` which can be fed to the model.
|
| 257 |
+
|
| 258 |
+
"""
|
| 259 |
+
if max_length is None:
|
| 260 |
+
max_length = tokenizer.max_len
|
| 261 |
+
|
| 262 |
+
label_map = {label: i for i, label in enumerate(self.labels)}
|
| 263 |
+
|
| 264 |
+
all_input_ids = []
|
| 265 |
+
for ex_index, example in enumerate(self.examples):
|
| 266 |
+
if ex_index % 10000 == 0:
|
| 267 |
+
logger.info(f"Tokenizing example {ex_index}")
|
| 268 |
+
|
| 269 |
+
input_ids = tokenizer.encode(
|
| 270 |
+
example.text_a,
|
| 271 |
+
add_special_tokens=True,
|
| 272 |
+
max_length=min(max_length, tokenizer.max_len),
|
| 273 |
+
)
|
| 274 |
+
all_input_ids.append(input_ids)
|
| 275 |
+
|
| 276 |
+
batch_length = max(len(input_ids) for input_ids in all_input_ids)
|
| 277 |
+
|
| 278 |
+
features = []
|
| 279 |
+
for ex_index, (input_ids, example) in enumerate(zip(all_input_ids, self.examples)):
|
| 280 |
+
if ex_index % 10000 == 0:
|
| 281 |
+
logger.info(f"Writing example {ex_index}/{len(self.examples)}")
|
| 282 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
| 283 |
+
# tokens are attended to.
|
| 284 |
+
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
| 285 |
+
|
| 286 |
+
# Zero-pad up to the sequence length.
|
| 287 |
+
padding_length = batch_length - len(input_ids)
|
| 288 |
+
if pad_on_left:
|
| 289 |
+
input_ids = ([pad_token] * padding_length) + input_ids
|
| 290 |
+
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
|
| 291 |
+
else:
|
| 292 |
+
input_ids = input_ids + ([pad_token] * padding_length)
|
| 293 |
+
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
| 294 |
+
|
| 295 |
+
if len(input_ids) != batch_length:
|
| 296 |
+
raise ValueError(f"Error with input length {len(input_ids)} vs {batch_length}")
|
| 297 |
+
if len(attention_mask) != batch_length:
|
| 298 |
+
raise ValueError(f"Error with input length {len(attention_mask)} vs {batch_length}")
|
| 299 |
+
|
| 300 |
+
if self.mode == "classification":
|
| 301 |
+
label = label_map[example.label]
|
| 302 |
+
elif self.mode == "regression":
|
| 303 |
+
label = float(example.label)
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError(self.mode)
|
| 306 |
+
|
| 307 |
+
if ex_index < 5 and self.verbose:
|
| 308 |
+
logger.info("*** Example ***")
|
| 309 |
+
logger.info(f"guid: {example.guid}")
|
| 310 |
+
logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}")
|
| 311 |
+
logger.info(f"attention_mask: {' '.join([str(x) for x in attention_mask])}")
|
| 312 |
+
logger.info(f"label: {example.label} (id = {label})")
|
| 313 |
+
|
| 314 |
+
features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
|
| 315 |
+
|
| 316 |
+
if return_tensors is None:
|
| 317 |
+
return features
|
| 318 |
+
elif return_tensors == "tf":
|
| 319 |
+
if not is_tf_available():
|
| 320 |
+
raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
|
| 321 |
+
import tensorflow as tf
|
| 322 |
+
|
| 323 |
+
def gen():
|
| 324 |
+
for ex in features:
|
| 325 |
+
yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
|
| 326 |
+
|
| 327 |
+
dataset = tf.data.Dataset.from_generator(
|
| 328 |
+
gen,
|
| 329 |
+
({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
|
| 330 |
+
({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
|
| 331 |
+
)
|
| 332 |
+
return dataset
|
| 333 |
+
elif return_tensors == "pt":
|
| 334 |
+
if not is_torch_available():
|
| 335 |
+
raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
|
| 336 |
+
import torch
|
| 337 |
+
from torch.utils.data import TensorDataset
|
| 338 |
+
|
| 339 |
+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
| 340 |
+
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
| 341 |
+
if self.mode == "classification":
|
| 342 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
| 343 |
+
elif self.mode == "regression":
|
| 344 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
| 345 |
+
|
| 346 |
+
dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels)
|
| 347 |
+
return dataset
|
| 348 |
+
else:
|
| 349 |
+
raise ValueError("return_tensors should be one of 'tf' or 'pt'")
|
.venv/lib/python3.11/site-packages/transformers/data/processors/xnli.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""XNLI utils (dataset loading and evaluation)"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
from .utils import DataProcessor, InputExample
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class XnliProcessor(DataProcessor):
|
| 28 |
+
"""
|
| 29 |
+
Processor for the XNLI dataset. Adapted from
|
| 30 |
+
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, language, train_language=None):
|
| 34 |
+
self.language = language
|
| 35 |
+
self.train_language = train_language
|
| 36 |
+
|
| 37 |
+
def get_train_examples(self, data_dir):
|
| 38 |
+
"""See base class."""
|
| 39 |
+
lg = self.language if self.train_language is None else self.train_language
|
| 40 |
+
lines = self._read_tsv(os.path.join(data_dir, f"XNLI-MT-1.0/multinli/multinli.train.{lg}.tsv"))
|
| 41 |
+
examples = []
|
| 42 |
+
for i, line in enumerate(lines):
|
| 43 |
+
if i == 0:
|
| 44 |
+
continue
|
| 45 |
+
guid = f"train-{i}"
|
| 46 |
+
text_a = line[0]
|
| 47 |
+
text_b = line[1]
|
| 48 |
+
label = "contradiction" if line[2] == "contradictory" else line[2]
|
| 49 |
+
if not isinstance(text_a, str):
|
| 50 |
+
raise TypeError(f"Training input {text_a} is not a string")
|
| 51 |
+
if not isinstance(text_b, str):
|
| 52 |
+
raise TypeError(f"Training input {text_b} is not a string")
|
| 53 |
+
if not isinstance(label, str):
|
| 54 |
+
raise TypeError(f"Training label {label} is not a string")
|
| 55 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 56 |
+
return examples
|
| 57 |
+
|
| 58 |
+
def get_test_examples(self, data_dir):
|
| 59 |
+
"""See base class."""
|
| 60 |
+
lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv"))
|
| 61 |
+
examples = []
|
| 62 |
+
for i, line in enumerate(lines):
|
| 63 |
+
if i == 0:
|
| 64 |
+
continue
|
| 65 |
+
language = line[0]
|
| 66 |
+
if language != self.language:
|
| 67 |
+
continue
|
| 68 |
+
guid = f"test-{i}"
|
| 69 |
+
text_a = line[6]
|
| 70 |
+
text_b = line[7]
|
| 71 |
+
label = line[1]
|
| 72 |
+
if not isinstance(text_a, str):
|
| 73 |
+
raise TypeError(f"Training input {text_a} is not a string")
|
| 74 |
+
if not isinstance(text_b, str):
|
| 75 |
+
raise TypeError(f"Training input {text_b} is not a string")
|
| 76 |
+
if not isinstance(label, str):
|
| 77 |
+
raise TypeError(f"Training label {label} is not a string")
|
| 78 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 79 |
+
return examples
|
| 80 |
+
|
| 81 |
+
def get_labels(self):
|
| 82 |
+
"""See base class."""
|
| 83 |
+
return ["contradiction", "entailment", "neutral"]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
xnli_processors = {
|
| 87 |
+
"xnli": XnliProcessor,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
xnli_output_modes = {
|
| 91 |
+
"xnli": "classification",
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
xnli_tasks_num_labels = {
|
| 95 |
+
"xnli": 3,
|
| 96 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/generation/__init__.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
_import_structure = {
|
| 21 |
+
"configuration_utils": [
|
| 22 |
+
"BaseWatermarkingConfig",
|
| 23 |
+
"CompileConfig",
|
| 24 |
+
"GenerationConfig",
|
| 25 |
+
"GenerationMode",
|
| 26 |
+
"SynthIDTextWatermarkingConfig",
|
| 27 |
+
"WatermarkingConfig",
|
| 28 |
+
],
|
| 29 |
+
"streamers": ["AsyncTextIteratorStreamer", "TextIteratorStreamer", "TextStreamer"],
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
if not is_torch_available():
|
| 34 |
+
raise OptionalDependencyNotAvailable()
|
| 35 |
+
except OptionalDependencyNotAvailable:
|
| 36 |
+
pass
|
| 37 |
+
else:
|
| 38 |
+
_import_structure["beam_constraints"] = [
|
| 39 |
+
"Constraint",
|
| 40 |
+
"ConstraintListState",
|
| 41 |
+
"DisjunctiveConstraint",
|
| 42 |
+
"PhrasalConstraint",
|
| 43 |
+
]
|
| 44 |
+
_import_structure["beam_search"] = [
|
| 45 |
+
"BeamHypotheses",
|
| 46 |
+
"BeamScorer",
|
| 47 |
+
"BeamSearchScorer",
|
| 48 |
+
"ConstrainedBeamSearchScorer",
|
| 49 |
+
]
|
| 50 |
+
_import_structure["candidate_generator"] = [
|
| 51 |
+
"AssistedCandidateGenerator",
|
| 52 |
+
"CandidateGenerator",
|
| 53 |
+
"EarlyExitCandidateGenerator",
|
| 54 |
+
"PromptLookupCandidateGenerator",
|
| 55 |
+
]
|
| 56 |
+
_import_structure["logits_process"] = [
|
| 57 |
+
"AlternatingCodebooksLogitsProcessor",
|
| 58 |
+
"ClassifierFreeGuidanceLogitsProcessor",
|
| 59 |
+
"EncoderNoRepeatNGramLogitsProcessor",
|
| 60 |
+
"EncoderRepetitionPenaltyLogitsProcessor",
|
| 61 |
+
"EpsilonLogitsWarper",
|
| 62 |
+
"EtaLogitsWarper",
|
| 63 |
+
"ExponentialDecayLengthPenalty",
|
| 64 |
+
"ForcedBOSTokenLogitsProcessor",
|
| 65 |
+
"ForcedEOSTokenLogitsProcessor",
|
| 66 |
+
"HammingDiversityLogitsProcessor",
|
| 67 |
+
"InfNanRemoveLogitsProcessor",
|
| 68 |
+
"LogitNormalization",
|
| 69 |
+
"LogitsProcessor",
|
| 70 |
+
"LogitsProcessorList",
|
| 71 |
+
"LogitsWarper",
|
| 72 |
+
"MinLengthLogitsProcessor",
|
| 73 |
+
"MinNewTokensLengthLogitsProcessor",
|
| 74 |
+
"MinPLogitsWarper",
|
| 75 |
+
"NoBadWordsLogitsProcessor",
|
| 76 |
+
"NoRepeatNGramLogitsProcessor",
|
| 77 |
+
"PrefixConstrainedLogitsProcessor",
|
| 78 |
+
"RepetitionPenaltyLogitsProcessor",
|
| 79 |
+
"SequenceBiasLogitsProcessor",
|
| 80 |
+
"SuppressTokensLogitsProcessor",
|
| 81 |
+
"SuppressTokensAtBeginLogitsProcessor",
|
| 82 |
+
"SynthIDTextWatermarkLogitsProcessor",
|
| 83 |
+
"TemperatureLogitsWarper",
|
| 84 |
+
"TopKLogitsWarper",
|
| 85 |
+
"TopPLogitsWarper",
|
| 86 |
+
"TypicalLogitsWarper",
|
| 87 |
+
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
| 88 |
+
"WhisperTimeStampLogitsProcessor",
|
| 89 |
+
"WatermarkLogitsProcessor",
|
| 90 |
+
]
|
| 91 |
+
_import_structure["stopping_criteria"] = [
|
| 92 |
+
"MaxNewTokensCriteria",
|
| 93 |
+
"MaxLengthCriteria",
|
| 94 |
+
"MaxTimeCriteria",
|
| 95 |
+
"ConfidenceCriteria",
|
| 96 |
+
"EosTokenCriteria",
|
| 97 |
+
"StoppingCriteria",
|
| 98 |
+
"StoppingCriteriaList",
|
| 99 |
+
"validate_stopping_criteria",
|
| 100 |
+
"StopStringCriteria",
|
| 101 |
+
]
|
| 102 |
+
_import_structure["utils"] = [
|
| 103 |
+
"GenerationMixin",
|
| 104 |
+
"GreedySearchEncoderDecoderOutput",
|
| 105 |
+
"GreedySearchDecoderOnlyOutput",
|
| 106 |
+
"SampleEncoderDecoderOutput",
|
| 107 |
+
"SampleDecoderOnlyOutput",
|
| 108 |
+
"BeamSearchEncoderDecoderOutput",
|
| 109 |
+
"BeamSearchDecoderOnlyOutput",
|
| 110 |
+
"BeamSampleEncoderDecoderOutput",
|
| 111 |
+
"BeamSampleDecoderOnlyOutput",
|
| 112 |
+
"ContrastiveSearchEncoderDecoderOutput",
|
| 113 |
+
"ContrastiveSearchDecoderOnlyOutput",
|
| 114 |
+
"GenerateBeamDecoderOnlyOutput",
|
| 115 |
+
"GenerateBeamEncoderDecoderOutput",
|
| 116 |
+
"GenerateDecoderOnlyOutput",
|
| 117 |
+
"GenerateEncoderDecoderOutput",
|
| 118 |
+
]
|
| 119 |
+
_import_structure["watermarking"] = [
|
| 120 |
+
"WatermarkDetector",
|
| 121 |
+
"WatermarkDetectorOutput",
|
| 122 |
+
"BayesianDetectorModel",
|
| 123 |
+
"BayesianDetectorConfig",
|
| 124 |
+
"SynthIDTextWatermarkDetector",
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
if not is_tf_available():
|
| 129 |
+
raise OptionalDependencyNotAvailable()
|
| 130 |
+
except OptionalDependencyNotAvailable:
|
| 131 |
+
pass
|
| 132 |
+
else:
|
| 133 |
+
_import_structure["tf_logits_process"] = [
|
| 134 |
+
"TFForcedBOSTokenLogitsProcessor",
|
| 135 |
+
"TFForcedEOSTokenLogitsProcessor",
|
| 136 |
+
"TFForceTokensLogitsProcessor",
|
| 137 |
+
"TFLogitsProcessor",
|
| 138 |
+
"TFLogitsProcessorList",
|
| 139 |
+
"TFLogitsWarper",
|
| 140 |
+
"TFMinLengthLogitsProcessor",
|
| 141 |
+
"TFNoBadWordsLogitsProcessor",
|
| 142 |
+
"TFNoRepeatNGramLogitsProcessor",
|
| 143 |
+
"TFRepetitionPenaltyLogitsProcessor",
|
| 144 |
+
"TFSuppressTokensAtBeginLogitsProcessor",
|
| 145 |
+
"TFSuppressTokensLogitsProcessor",
|
| 146 |
+
"TFTemperatureLogitsWarper",
|
| 147 |
+
"TFTopKLogitsWarper",
|
| 148 |
+
"TFTopPLogitsWarper",
|
| 149 |
+
]
|
| 150 |
+
_import_structure["tf_utils"] = [
|
| 151 |
+
"TFGenerationMixin",
|
| 152 |
+
"TFGreedySearchDecoderOnlyOutput",
|
| 153 |
+
"TFGreedySearchEncoderDecoderOutput",
|
| 154 |
+
"TFSampleEncoderDecoderOutput",
|
| 155 |
+
"TFSampleDecoderOnlyOutput",
|
| 156 |
+
"TFBeamSearchEncoderDecoderOutput",
|
| 157 |
+
"TFBeamSearchDecoderOnlyOutput",
|
| 158 |
+
"TFBeamSampleEncoderDecoderOutput",
|
| 159 |
+
"TFBeamSampleDecoderOnlyOutput",
|
| 160 |
+
"TFContrastiveSearchEncoderDecoderOutput",
|
| 161 |
+
"TFContrastiveSearchDecoderOnlyOutput",
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
if not is_flax_available():
|
| 166 |
+
raise OptionalDependencyNotAvailable()
|
| 167 |
+
except OptionalDependencyNotAvailable:
|
| 168 |
+
pass
|
| 169 |
+
else:
|
| 170 |
+
_import_structure["flax_logits_process"] = [
|
| 171 |
+
"FlaxForcedBOSTokenLogitsProcessor",
|
| 172 |
+
"FlaxForcedEOSTokenLogitsProcessor",
|
| 173 |
+
"FlaxForceTokensLogitsProcessor",
|
| 174 |
+
"FlaxLogitsProcessor",
|
| 175 |
+
"FlaxLogitsProcessorList",
|
| 176 |
+
"FlaxLogitsWarper",
|
| 177 |
+
"FlaxMinLengthLogitsProcessor",
|
| 178 |
+
"FlaxSuppressTokensAtBeginLogitsProcessor",
|
| 179 |
+
"FlaxSuppressTokensLogitsProcessor",
|
| 180 |
+
"FlaxTemperatureLogitsWarper",
|
| 181 |
+
"FlaxTopKLogitsWarper",
|
| 182 |
+
"FlaxTopPLogitsWarper",
|
| 183 |
+
"FlaxWhisperTimeStampLogitsProcessor",
|
| 184 |
+
"FlaxNoRepeatNGramLogitsProcessor",
|
| 185 |
+
]
|
| 186 |
+
_import_structure["flax_utils"] = [
|
| 187 |
+
"FlaxGenerationMixin",
|
| 188 |
+
"FlaxGreedySearchOutput",
|
| 189 |
+
"FlaxSampleOutput",
|
| 190 |
+
"FlaxBeamSearchOutput",
|
| 191 |
+
]
|
| 192 |
+
|
| 193 |
+
if TYPE_CHECKING:
|
| 194 |
+
from .configuration_utils import (
|
| 195 |
+
BaseWatermarkingConfig,
|
| 196 |
+
CompileConfig,
|
| 197 |
+
GenerationConfig,
|
| 198 |
+
GenerationMode,
|
| 199 |
+
SynthIDTextWatermarkingConfig,
|
| 200 |
+
WatermarkingConfig,
|
| 201 |
+
)
|
| 202 |
+
from .streamers import AsyncTextIteratorStreamer, TextIteratorStreamer, TextStreamer
|
| 203 |
+
|
| 204 |
+
try:
|
| 205 |
+
if not is_torch_available():
|
| 206 |
+
raise OptionalDependencyNotAvailable()
|
| 207 |
+
except OptionalDependencyNotAvailable:
|
| 208 |
+
pass
|
| 209 |
+
else:
|
| 210 |
+
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
|
| 211 |
+
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
| 212 |
+
from .candidate_generator import (
|
| 213 |
+
AssistedCandidateGenerator,
|
| 214 |
+
CandidateGenerator,
|
| 215 |
+
EarlyExitCandidateGenerator,
|
| 216 |
+
PromptLookupCandidateGenerator,
|
| 217 |
+
)
|
| 218 |
+
from .logits_process import (
|
| 219 |
+
AlternatingCodebooksLogitsProcessor,
|
| 220 |
+
ClassifierFreeGuidanceLogitsProcessor,
|
| 221 |
+
EncoderNoRepeatNGramLogitsProcessor,
|
| 222 |
+
EncoderRepetitionPenaltyLogitsProcessor,
|
| 223 |
+
EpsilonLogitsWarper,
|
| 224 |
+
EtaLogitsWarper,
|
| 225 |
+
ExponentialDecayLengthPenalty,
|
| 226 |
+
ForcedBOSTokenLogitsProcessor,
|
| 227 |
+
ForcedEOSTokenLogitsProcessor,
|
| 228 |
+
HammingDiversityLogitsProcessor,
|
| 229 |
+
InfNanRemoveLogitsProcessor,
|
| 230 |
+
LogitNormalization,
|
| 231 |
+
LogitsProcessor,
|
| 232 |
+
LogitsProcessorList,
|
| 233 |
+
LogitsWarper,
|
| 234 |
+
MinLengthLogitsProcessor,
|
| 235 |
+
MinNewTokensLengthLogitsProcessor,
|
| 236 |
+
MinPLogitsWarper,
|
| 237 |
+
NoBadWordsLogitsProcessor,
|
| 238 |
+
NoRepeatNGramLogitsProcessor,
|
| 239 |
+
PrefixConstrainedLogitsProcessor,
|
| 240 |
+
RepetitionPenaltyLogitsProcessor,
|
| 241 |
+
SequenceBiasLogitsProcessor,
|
| 242 |
+
SuppressTokensAtBeginLogitsProcessor,
|
| 243 |
+
SuppressTokensLogitsProcessor,
|
| 244 |
+
SynthIDTextWatermarkLogitsProcessor,
|
| 245 |
+
TemperatureLogitsWarper,
|
| 246 |
+
TopKLogitsWarper,
|
| 247 |
+
TopPLogitsWarper,
|
| 248 |
+
TypicalLogitsWarper,
|
| 249 |
+
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
| 250 |
+
WatermarkLogitsProcessor,
|
| 251 |
+
WhisperTimeStampLogitsProcessor,
|
| 252 |
+
)
|
| 253 |
+
from .stopping_criteria import (
|
| 254 |
+
ConfidenceCriteria,
|
| 255 |
+
EosTokenCriteria,
|
| 256 |
+
MaxLengthCriteria,
|
| 257 |
+
MaxNewTokensCriteria,
|
| 258 |
+
MaxTimeCriteria,
|
| 259 |
+
StoppingCriteria,
|
| 260 |
+
StoppingCriteriaList,
|
| 261 |
+
StopStringCriteria,
|
| 262 |
+
validate_stopping_criteria,
|
| 263 |
+
)
|
| 264 |
+
from .utils import (
|
| 265 |
+
BeamSampleDecoderOnlyOutput,
|
| 266 |
+
BeamSampleEncoderDecoderOutput,
|
| 267 |
+
BeamSearchDecoderOnlyOutput,
|
| 268 |
+
BeamSearchEncoderDecoderOutput,
|
| 269 |
+
ContrastiveSearchDecoderOnlyOutput,
|
| 270 |
+
ContrastiveSearchEncoderDecoderOutput,
|
| 271 |
+
GenerateBeamDecoderOnlyOutput,
|
| 272 |
+
GenerateBeamEncoderDecoderOutput,
|
| 273 |
+
GenerateDecoderOnlyOutput,
|
| 274 |
+
GenerateEncoderDecoderOutput,
|
| 275 |
+
GenerationMixin,
|
| 276 |
+
GreedySearchDecoderOnlyOutput,
|
| 277 |
+
GreedySearchEncoderDecoderOutput,
|
| 278 |
+
SampleDecoderOnlyOutput,
|
| 279 |
+
SampleEncoderDecoderOutput,
|
| 280 |
+
)
|
| 281 |
+
from .watermarking import (
|
| 282 |
+
BayesianDetectorConfig,
|
| 283 |
+
BayesianDetectorModel,
|
| 284 |
+
SynthIDTextWatermarkDetector,
|
| 285 |
+
WatermarkDetector,
|
| 286 |
+
WatermarkDetectorOutput,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
if not is_tf_available():
|
| 291 |
+
raise OptionalDependencyNotAvailable()
|
| 292 |
+
except OptionalDependencyNotAvailable:
|
| 293 |
+
pass
|
| 294 |
+
else:
|
| 295 |
+
from .tf_logits_process import (
|
| 296 |
+
TFForcedBOSTokenLogitsProcessor,
|
| 297 |
+
TFForcedEOSTokenLogitsProcessor,
|
| 298 |
+
TFForceTokensLogitsProcessor,
|
| 299 |
+
TFLogitsProcessor,
|
| 300 |
+
TFLogitsProcessorList,
|
| 301 |
+
TFLogitsWarper,
|
| 302 |
+
TFMinLengthLogitsProcessor,
|
| 303 |
+
TFNoBadWordsLogitsProcessor,
|
| 304 |
+
TFNoRepeatNGramLogitsProcessor,
|
| 305 |
+
TFRepetitionPenaltyLogitsProcessor,
|
| 306 |
+
TFSuppressTokensAtBeginLogitsProcessor,
|
| 307 |
+
TFSuppressTokensLogitsProcessor,
|
| 308 |
+
TFTemperatureLogitsWarper,
|
| 309 |
+
TFTopKLogitsWarper,
|
| 310 |
+
TFTopPLogitsWarper,
|
| 311 |
+
)
|
| 312 |
+
from .tf_utils import (
|
| 313 |
+
TFBeamSampleDecoderOnlyOutput,
|
| 314 |
+
TFBeamSampleEncoderDecoderOutput,
|
| 315 |
+
TFBeamSearchDecoderOnlyOutput,
|
| 316 |
+
TFBeamSearchEncoderDecoderOutput,
|
| 317 |
+
TFContrastiveSearchDecoderOnlyOutput,
|
| 318 |
+
TFContrastiveSearchEncoderDecoderOutput,
|
| 319 |
+
TFGenerationMixin,
|
| 320 |
+
TFGreedySearchDecoderOnlyOutput,
|
| 321 |
+
TFGreedySearchEncoderDecoderOutput,
|
| 322 |
+
TFSampleDecoderOnlyOutput,
|
| 323 |
+
TFSampleEncoderDecoderOutput,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
try:
|
| 327 |
+
if not is_flax_available():
|
| 328 |
+
raise OptionalDependencyNotAvailable()
|
| 329 |
+
except OptionalDependencyNotAvailable:
|
| 330 |
+
pass
|
| 331 |
+
else:
|
| 332 |
+
from .flax_logits_process import (
|
| 333 |
+
FlaxForcedBOSTokenLogitsProcessor,
|
| 334 |
+
FlaxForcedEOSTokenLogitsProcessor,
|
| 335 |
+
FlaxForceTokensLogitsProcessor,
|
| 336 |
+
FlaxLogitsProcessor,
|
| 337 |
+
FlaxLogitsProcessorList,
|
| 338 |
+
FlaxLogitsWarper,
|
| 339 |
+
FlaxMinLengthLogitsProcessor,
|
| 340 |
+
FlaxNoRepeatNGramLogitsProcessor,
|
| 341 |
+
FlaxSuppressTokensAtBeginLogitsProcessor,
|
| 342 |
+
FlaxSuppressTokensLogitsProcessor,
|
| 343 |
+
FlaxTemperatureLogitsWarper,
|
| 344 |
+
FlaxTopKLogitsWarper,
|
| 345 |
+
FlaxTopPLogitsWarper,
|
| 346 |
+
FlaxWhisperTimeStampLogitsProcessor,
|
| 347 |
+
)
|
| 348 |
+
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
|
| 349 |
+
else:
|
| 350 |
+
import sys
|
| 351 |
+
|
| 352 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-311.pyc
ADDED
|
Binary file (43.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-311.pyc
ADDED
|
Binary file (89.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/flax_logits_process.cpython-311.pyc
ADDED
|
Binary file (34.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/flax_utils.cpython-311.pyc
ADDED
|
Binary file (47.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-311.pyc
ADDED
|
Binary file (33.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/streamers.cpython-311.pyc
ADDED
|
Binary file (15.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/tf_logits_process.cpython-311.pyc
ADDED
|
Binary file (44.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/watermarking.cpython-311.pyc
ADDED
|
Binary file (29.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/beam_constraints.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Constraint(ABC):
|
| 6 |
+
r"""Abstract base class for all constraints that can be applied during generation.
|
| 7 |
+
It must define how the constraint can be satisfied.
|
| 8 |
+
|
| 9 |
+
All classes that inherit Constraint must follow the requirement that
|
| 10 |
+
|
| 11 |
+
```py
|
| 12 |
+
completed = False
|
| 13 |
+
while not completed:
|
| 14 |
+
_, completed = constraint.update(constraint.advance())
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
will always terminate (halt).
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
# test for the above condition
|
| 22 |
+
self.test()
|
| 23 |
+
|
| 24 |
+
def test(self):
|
| 25 |
+
"""
|
| 26 |
+
Tests whether this constraint has been properly defined.
|
| 27 |
+
"""
|
| 28 |
+
counter = 0
|
| 29 |
+
completed = False
|
| 30 |
+
while not completed:
|
| 31 |
+
if counter == 1:
|
| 32 |
+
self.reset()
|
| 33 |
+
advance = self.advance()
|
| 34 |
+
if not self.does_advance(advance):
|
| 35 |
+
raise Exception(
|
| 36 |
+
"Custom Constraint is not defined correctly. self.does_advance(self.advance()) must be true."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
stepped, completed, reset = self.update(advance)
|
| 40 |
+
counter += 1
|
| 41 |
+
|
| 42 |
+
if counter > 10000:
|
| 43 |
+
raise Exception("update() does not fulfill the constraint.")
|
| 44 |
+
|
| 45 |
+
if self.remaining() != 0:
|
| 46 |
+
raise Exception("Custom Constraint is not defined correctly.")
|
| 47 |
+
|
| 48 |
+
@abstractmethod
|
| 49 |
+
def advance(self):
|
| 50 |
+
"""
|
| 51 |
+
When called, returns the token(s) that would take this constraint one step closer to being fulfilled.
|
| 52 |
+
|
| 53 |
+
Return:
|
| 54 |
+
token_ids (Union[int, List[int], None]):
|
| 55 |
+
- A single token ID (int) that advances the constraint, or
|
| 56 |
+
- A list of token IDs that could advance the constraint
|
| 57 |
+
- None if the constraint is completed or cannot be advanced
|
| 58 |
+
"""
|
| 59 |
+
raise NotImplementedError(
|
| 60 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
@abstractmethod
|
| 64 |
+
def does_advance(self, token_id: int):
|
| 65 |
+
"""
|
| 66 |
+
Reads in a token and returns whether it creates progress.
|
| 67 |
+
"""
|
| 68 |
+
raise NotImplementedError(
|
| 69 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
@abstractmethod
|
| 73 |
+
def update(self, token_id: int):
|
| 74 |
+
"""
|
| 75 |
+
Reads in a token and returns booleans that indicate the progress made by it. This function will update the
|
| 76 |
+
state of this object unlikes `does_advance(self, token_id: int)`.
|
| 77 |
+
|
| 78 |
+
This isn't to test whether a certain token will advance the progress; it's to update its state as if it has
|
| 79 |
+
been generated. This becomes important if token_id != desired token (refer to else statement in
|
| 80 |
+
PhrasalConstraint)
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
token_id(`int`):
|
| 84 |
+
The id of a newly generated token in the beam search.
|
| 85 |
+
Return:
|
| 86 |
+
stepped(`bool`):
|
| 87 |
+
Whether this constraint has become one step closer to being fulfuilled.
|
| 88 |
+
completed(`bool`):
|
| 89 |
+
Whether this constraint has been completely fulfilled by this token being generated.
|
| 90 |
+
reset (`bool`):
|
| 91 |
+
Whether this constraint has reset its progress by this token being generated.
|
| 92 |
+
"""
|
| 93 |
+
raise NotImplementedError(
|
| 94 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
@abstractmethod
|
| 98 |
+
def reset(self):
|
| 99 |
+
"""
|
| 100 |
+
Resets the state of this constraint to its initialization. We would call this in cases where the fulfillment of
|
| 101 |
+
a constraint is abrupted by an unwanted token.
|
| 102 |
+
"""
|
| 103 |
+
raise NotImplementedError(
|
| 104 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
@abstractmethod
|
| 108 |
+
def remaining(self):
|
| 109 |
+
"""
|
| 110 |
+
Returns the number of remaining steps of `advance()` in order to complete this constraint.
|
| 111 |
+
"""
|
| 112 |
+
raise NotImplementedError(
|
| 113 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
@abstractmethod
|
| 117 |
+
def copy(self, stateful=False):
|
| 118 |
+
"""
|
| 119 |
+
Creates a new instance of this constraint.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
stateful(`bool`): Whether to not only copy the constraint for new instance, but also its state.
|
| 123 |
+
|
| 124 |
+
Return:
|
| 125 |
+
constraint(`Constraint`): The same constraint as the one being called from.
|
| 126 |
+
"""
|
| 127 |
+
raise NotImplementedError(
|
| 128 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class PhrasalConstraint(Constraint):
|
| 133 |
+
r"""
|
| 134 |
+
[`Constraint`] enforcing that an ordered sequence of tokens is included in the output.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
token_ids (`List[int]`):
|
| 138 |
+
The id of the token that must be generated by the output.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(self, token_ids: List[int]):
|
| 142 |
+
super(Constraint, self).__init__()
|
| 143 |
+
|
| 144 |
+
if not isinstance(token_ids, list) or len(token_ids) == 0:
|
| 145 |
+
raise ValueError(f"`token_ids` has to be a non-empty list, but is {token_ids}.")
|
| 146 |
+
if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids):
|
| 147 |
+
raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.")
|
| 148 |
+
|
| 149 |
+
self.token_ids = token_ids
|
| 150 |
+
|
| 151 |
+
self.seqlen = len(self.token_ids)
|
| 152 |
+
self.fulfilled_idx = -1 # the index of the currently fulfilled step
|
| 153 |
+
self.completed = False
|
| 154 |
+
|
| 155 |
+
def advance(self):
|
| 156 |
+
if self.completed:
|
| 157 |
+
return None
|
| 158 |
+
return self.token_ids[self.fulfilled_idx + 1]
|
| 159 |
+
|
| 160 |
+
def does_advance(self, token_id: int):
|
| 161 |
+
if not isinstance(token_id, int):
|
| 162 |
+
raise TypeError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
|
| 163 |
+
|
| 164 |
+
if self.completed:
|
| 165 |
+
return False
|
| 166 |
+
|
| 167 |
+
return token_id == self.token_ids[self.fulfilled_idx + 1]
|
| 168 |
+
|
| 169 |
+
def update(self, token_id: int):
|
| 170 |
+
if not isinstance(token_id, int):
|
| 171 |
+
raise TypeError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
|
| 172 |
+
|
| 173 |
+
stepped = False
|
| 174 |
+
completed = False
|
| 175 |
+
reset = False
|
| 176 |
+
|
| 177 |
+
if self.does_advance(token_id):
|
| 178 |
+
self.fulfilled_idx += 1
|
| 179 |
+
stepped = True
|
| 180 |
+
if self.fulfilled_idx == (self.seqlen - 1):
|
| 181 |
+
completed = True
|
| 182 |
+
self.completed = completed
|
| 183 |
+
else:
|
| 184 |
+
# failed to make progress.
|
| 185 |
+
reset = True
|
| 186 |
+
self.reset()
|
| 187 |
+
return stepped, completed, reset
|
| 188 |
+
|
| 189 |
+
def reset(self):
|
| 190 |
+
self.completed = False
|
| 191 |
+
self.fulfilled_idx = 0
|
| 192 |
+
|
| 193 |
+
def remaining(self):
|
| 194 |
+
return self.seqlen - (self.fulfilled_idx + 1)
|
| 195 |
+
|
| 196 |
+
def copy(self, stateful=False):
|
| 197 |
+
new_constraint = PhrasalConstraint(self.token_ids)
|
| 198 |
+
|
| 199 |
+
if stateful:
|
| 200 |
+
new_constraint.seq_len = self.seqlen
|
| 201 |
+
new_constraint.fulfilled_idx = self.fulfilled_idx
|
| 202 |
+
new_constraint.completed = self.completed
|
| 203 |
+
|
| 204 |
+
return new_constraint
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class DisjunctiveTrie:
|
| 208 |
+
def __init__(self, nested_token_ids: List[List[int]], no_subsets=True):
|
| 209 |
+
r"""
|
| 210 |
+
A helper class that builds a trie with the words represented in `nested_token_ids`.
|
| 211 |
+
"""
|
| 212 |
+
self.max_height = max([len(one) for one in nested_token_ids])
|
| 213 |
+
|
| 214 |
+
root = {}
|
| 215 |
+
for token_ids in nested_token_ids:
|
| 216 |
+
level = root
|
| 217 |
+
for tidx, token_id in enumerate(token_ids):
|
| 218 |
+
if token_id not in level:
|
| 219 |
+
level[token_id] = {}
|
| 220 |
+
|
| 221 |
+
level = level[token_id]
|
| 222 |
+
|
| 223 |
+
if no_subsets and self.has_subsets(root, nested_token_ids):
|
| 224 |
+
raise ValueError(
|
| 225 |
+
"Each list in `nested_token_ids` can't be a complete subset of another list, but is"
|
| 226 |
+
f" {nested_token_ids}."
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
self.trie = root
|
| 230 |
+
|
| 231 |
+
def next_tokens(self, current_seq):
|
| 232 |
+
"""
|
| 233 |
+
The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`.
|
| 234 |
+
"""
|
| 235 |
+
start = self.trie
|
| 236 |
+
|
| 237 |
+
for current_token in current_seq:
|
| 238 |
+
start = start[current_token]
|
| 239 |
+
|
| 240 |
+
next_tokens = list(start.keys())
|
| 241 |
+
|
| 242 |
+
return next_tokens
|
| 243 |
+
|
| 244 |
+
def reached_leaf(self, current_seq):
|
| 245 |
+
next_tokens = self.next_tokens(current_seq)
|
| 246 |
+
|
| 247 |
+
return len(next_tokens) == 0
|
| 248 |
+
|
| 249 |
+
def count_leaves(self, root):
|
| 250 |
+
next_nodes = list(root.values())
|
| 251 |
+
if len(next_nodes) == 0:
|
| 252 |
+
return 1
|
| 253 |
+
else:
|
| 254 |
+
return sum([self.count_leaves(nn) for nn in next_nodes])
|
| 255 |
+
|
| 256 |
+
def has_subsets(self, trie, nested_token_ids):
|
| 257 |
+
"""
|
| 258 |
+
Returns whether # of leaves == # of words. Otherwise some word is a subset of another.
|
| 259 |
+
"""
|
| 260 |
+
leaf_count = self.count_leaves(trie)
|
| 261 |
+
return len(nested_token_ids) != leaf_count
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class DisjunctiveConstraint(Constraint):
|
| 265 |
+
r"""
|
| 266 |
+
A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
nested_token_ids (`List[List[int]]`):
|
| 270 |
+
A list of words, where each word is a list of ids. This constraint is fulfilled by generating just one from
|
| 271 |
+
the list of words.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
def __init__(self, nested_token_ids: List[List[int]]):
|
| 275 |
+
super(Constraint, self).__init__()
|
| 276 |
+
|
| 277 |
+
if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0:
|
| 278 |
+
raise ValueError(f"`nested_token_ids` has to be a non-empty list, but is {nested_token_ids}.")
|
| 279 |
+
if any(not isinstance(token_ids, list) for token_ids in nested_token_ids):
|
| 280 |
+
raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.")
|
| 281 |
+
if any(
|
| 282 |
+
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
|
| 283 |
+
for token_ids in nested_token_ids
|
| 284 |
+
):
|
| 285 |
+
raise ValueError(
|
| 286 |
+
f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}."
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
self.trie = DisjunctiveTrie(nested_token_ids)
|
| 290 |
+
self.token_ids = nested_token_ids
|
| 291 |
+
|
| 292 |
+
self.seqlen = self.trie.max_height
|
| 293 |
+
self.current_seq = []
|
| 294 |
+
self.completed = False
|
| 295 |
+
|
| 296 |
+
def advance(self):
|
| 297 |
+
token_list = self.trie.next_tokens(self.current_seq)
|
| 298 |
+
|
| 299 |
+
if len(token_list) == 0:
|
| 300 |
+
return None
|
| 301 |
+
else:
|
| 302 |
+
return token_list
|
| 303 |
+
|
| 304 |
+
def does_advance(self, token_id: int):
|
| 305 |
+
if not isinstance(token_id, int):
|
| 306 |
+
raise TypeError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
|
| 307 |
+
|
| 308 |
+
next_tokens = self.trie.next_tokens(self.current_seq)
|
| 309 |
+
|
| 310 |
+
return token_id in next_tokens
|
| 311 |
+
|
| 312 |
+
def update(self, token_id: int):
|
| 313 |
+
if not isinstance(token_id, int):
|
| 314 |
+
raise TypeError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
|
| 315 |
+
|
| 316 |
+
stepped = False
|
| 317 |
+
completed = False
|
| 318 |
+
reset = False
|
| 319 |
+
|
| 320 |
+
if self.does_advance(token_id):
|
| 321 |
+
self.current_seq.append(token_id)
|
| 322 |
+
stepped = True
|
| 323 |
+
else:
|
| 324 |
+
reset = True
|
| 325 |
+
self.reset()
|
| 326 |
+
|
| 327 |
+
completed = self.trie.reached_leaf(self.current_seq)
|
| 328 |
+
self.completed = completed
|
| 329 |
+
|
| 330 |
+
return stepped, completed, reset
|
| 331 |
+
|
| 332 |
+
def reset(self):
|
| 333 |
+
self.completed = False
|
| 334 |
+
self.current_seq = []
|
| 335 |
+
|
| 336 |
+
def remaining(self):
|
| 337 |
+
if self.completed:
|
| 338 |
+
# since this can be completed without reaching max height
|
| 339 |
+
return 0
|
| 340 |
+
else:
|
| 341 |
+
return self.seqlen - len(self.current_seq)
|
| 342 |
+
|
| 343 |
+
def copy(self, stateful=False):
|
| 344 |
+
new_constraint = DisjunctiveConstraint(self.token_ids)
|
| 345 |
+
|
| 346 |
+
if stateful:
|
| 347 |
+
new_constraint.seq_len = self.seqlen
|
| 348 |
+
new_constraint.current_seq = self.current_seq
|
| 349 |
+
new_constraint.completed = self.completed
|
| 350 |
+
|
| 351 |
+
return new_constraint
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class ConstraintListState:
|
| 355 |
+
r"""
|
| 356 |
+
A class for beam scorers to track its progress through a list of constraints.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
constraints (`List[Constraint]`):
|
| 360 |
+
A list of [`Constraint`] objects that must be fulfilled by the beam scorer.
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
def __init__(self, constraints: List[Constraint]):
|
| 364 |
+
self.constraints = constraints
|
| 365 |
+
|
| 366 |
+
# max # of steps required to fulfill a given constraint
|
| 367 |
+
self.max_seqlen = max([c.seqlen for c in constraints])
|
| 368 |
+
self.n_constraints = len(constraints)
|
| 369 |
+
self.completed = False
|
| 370 |
+
|
| 371 |
+
self.init_state()
|
| 372 |
+
|
| 373 |
+
def init_state(self):
|
| 374 |
+
self.complete_constraints = []
|
| 375 |
+
self.inprogress_constraint = None
|
| 376 |
+
self.pending_constraints = [constraint.copy(stateful=False) for constraint in self.constraints]
|
| 377 |
+
|
| 378 |
+
def get_bank(self):
|
| 379 |
+
add = 0
|
| 380 |
+
if self.inprogress_constraint:
|
| 381 |
+
# extra points for having a constraint mid-fulfilled
|
| 382 |
+
add += self.max_seqlen - self.inprogress_constraint.remaining()
|
| 383 |
+
|
| 384 |
+
return (len(self.complete_constraints) * self.max_seqlen) + add
|
| 385 |
+
|
| 386 |
+
def advance(self):
|
| 387 |
+
"""The list of tokens to generate such that we can make progress.
|
| 388 |
+
By "list" we don't mean the list of token that will fully fulfill a constraint.
|
| 389 |
+
|
| 390 |
+
Given constraints `c_i = {t_ij | j == # of tokens}`, If we're not in the middle of progressing through a
|
| 391 |
+
specific constraint `c_i`, we return:
|
| 392 |
+
|
| 393 |
+
`[t_k1 for k in indices of unfulfilled constraints]`
|
| 394 |
+
|
| 395 |
+
If we are in the middle of a constraint, then we return:
|
| 396 |
+
`[t_ij]`, where `i` is the index of the inprogress constraint, `j` is the next step for the constraint.
|
| 397 |
+
|
| 398 |
+
Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
|
| 399 |
+
that's the only one we'll return.
|
| 400 |
+
"""
|
| 401 |
+
token_list = []
|
| 402 |
+
if self.inprogress_constraint is None:
|
| 403 |
+
for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
|
| 404 |
+
advance = constraint.advance()
|
| 405 |
+
if isinstance(advance, int):
|
| 406 |
+
token_list.append(advance)
|
| 407 |
+
elif isinstance(advance, list):
|
| 408 |
+
token_list.extend(advance)
|
| 409 |
+
else:
|
| 410 |
+
advance = self.inprogress_constraint.advance()
|
| 411 |
+
if isinstance(advance, int):
|
| 412 |
+
token_list.append(advance)
|
| 413 |
+
elif isinstance(advance, list):
|
| 414 |
+
token_list.extend(advance)
|
| 415 |
+
|
| 416 |
+
if len(token_list) == 0:
|
| 417 |
+
return None
|
| 418 |
+
else:
|
| 419 |
+
return token_list
|
| 420 |
+
|
| 421 |
+
def reset(self, token_ids: Optional[List[int]]):
|
| 422 |
+
"""
|
| 423 |
+
token_ids: the tokens generated thus far to reset the state of the progress through constraints.
|
| 424 |
+
"""
|
| 425 |
+
self.init_state()
|
| 426 |
+
|
| 427 |
+
if token_ids is not None:
|
| 428 |
+
for token in token_ids:
|
| 429 |
+
# completes or steps **one** constraint
|
| 430 |
+
complete, stepped = self.add(token)
|
| 431 |
+
|
| 432 |
+
# the entire list of constraints are fulfilled
|
| 433 |
+
if self.completed:
|
| 434 |
+
break
|
| 435 |
+
|
| 436 |
+
def add(self, token_id: int):
|
| 437 |
+
if not isinstance(token_id, int):
|
| 438 |
+
raise TypeError(f"`token_id` should be an `int`, but is `{token_id}`.")
|
| 439 |
+
|
| 440 |
+
complete, stepped = False, False
|
| 441 |
+
|
| 442 |
+
if self.completed:
|
| 443 |
+
complete = True
|
| 444 |
+
stepped = False
|
| 445 |
+
return complete, stepped
|
| 446 |
+
|
| 447 |
+
if self.inprogress_constraint is not None:
|
| 448 |
+
# In the middle of fulfilling a constraint. If the `token_id` *does* makes an incremental progress to current
|
| 449 |
+
# job, simply update the state
|
| 450 |
+
|
| 451 |
+
stepped, complete, reset = self.inprogress_constraint.update(token_id)
|
| 452 |
+
if reset:
|
| 453 |
+
# 1. If the next token breaks the progress, then we must restart.
|
| 454 |
+
# e.g. constraint = "I love pies" and sequence so far is "I love" but `token_id` == "books".
|
| 455 |
+
|
| 456 |
+
# But that doesn't mean we self.init_state(), since we only reset the state for this particular
|
| 457 |
+
# constraint, not the full list of constraints.
|
| 458 |
+
|
| 459 |
+
self.pending_constraints.append(self.inprogress_constraint.copy(stateful=False))
|
| 460 |
+
self.inprogress_constraint = None
|
| 461 |
+
|
| 462 |
+
if complete:
|
| 463 |
+
# 2. If the next token completes the constraint, move it to completed list, set
|
| 464 |
+
# inprogress to None. If there are no pending constraints either, then this full list of constraints
|
| 465 |
+
# is complete.
|
| 466 |
+
|
| 467 |
+
self.complete_constraints.append(self.inprogress_constraint)
|
| 468 |
+
self.inprogress_constraint = None
|
| 469 |
+
|
| 470 |
+
if len(self.pending_constraints) == 0:
|
| 471 |
+
# we're done!
|
| 472 |
+
self.completed = True
|
| 473 |
+
|
| 474 |
+
else:
|
| 475 |
+
# Not in the middle of fulfilling a constraint. So does this `token_id` helps us step towards any of our list
|
| 476 |
+
# of constraints?
|
| 477 |
+
|
| 478 |
+
for cidx, pending_constraint in enumerate(self.pending_constraints):
|
| 479 |
+
if pending_constraint.does_advance(token_id):
|
| 480 |
+
stepped, complete, reset = pending_constraint.update(token_id)
|
| 481 |
+
|
| 482 |
+
if not stepped:
|
| 483 |
+
raise Exception(
|
| 484 |
+
"`constraint.update(token_id)` is not yielding incremental progress, "
|
| 485 |
+
"even though `constraint.does_advance(token_id)` is true."
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
if complete:
|
| 489 |
+
self.complete_constraints.append(pending_constraint)
|
| 490 |
+
self.inprogress_constraint = None
|
| 491 |
+
|
| 492 |
+
if not complete and stepped:
|
| 493 |
+
self.inprogress_constraint = pending_constraint
|
| 494 |
+
|
| 495 |
+
if complete or stepped:
|
| 496 |
+
# If we made any progress at all, then it's at least not a "pending constraint".
|
| 497 |
+
|
| 498 |
+
self.pending_constraints = (
|
| 499 |
+
self.pending_constraints[:cidx] + self.pending_constraints[cidx + 1 :]
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
if len(self.pending_constraints) == 0 and self.inprogress_constraint is None:
|
| 503 |
+
# If there's no longer any pending after this and no inprogress either, then we must be
|
| 504 |
+
# complete.
|
| 505 |
+
|
| 506 |
+
self.completed = True
|
| 507 |
+
|
| 508 |
+
break # prevent accidentally stepping through multiple constraints with just one token.
|
| 509 |
+
|
| 510 |
+
return complete, stepped
|
| 511 |
+
|
| 512 |
+
def copy(self, stateful=True):
|
| 513 |
+
new_state = ConstraintListState(self.constraints) # we actually never though self.constraints objects
|
| 514 |
+
# throughout this process. So it's at initialization state.
|
| 515 |
+
|
| 516 |
+
if stateful:
|
| 517 |
+
new_state.complete_constraints = [
|
| 518 |
+
constraint.copy(stateful=True) for constraint in self.complete_constraints
|
| 519 |
+
]
|
| 520 |
+
if self.inprogress_constraint is not None:
|
| 521 |
+
new_state.inprogress_constraint = self.inprogress_constraint.copy(stateful=True)
|
| 522 |
+
new_state.pending_constraints = [constraint.copy() for constraint in self.pending_constraints]
|
| 523 |
+
|
| 524 |
+
return new_state
|
.venv/lib/python3.11/site-packages/transformers/generation/beam_search.py
ADDED
|
@@ -0,0 +1,1013 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 The HuggingFace Inc. team
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from abc import ABC, abstractmethod
|
| 17 |
+
from collections import UserDict
|
| 18 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from ..utils import add_start_docstrings
|
| 24 |
+
from .beam_constraints import Constraint, ConstraintListState
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
PROCESS_INPUTS_DOCSTRING = r"""
|
| 28 |
+
Args:
|
| 29 |
+
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
| 30 |
+
Indices of input sequence tokens in the vocabulary.
|
| 31 |
+
|
| 32 |
+
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
|
| 33 |
+
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
| 34 |
+
|
| 35 |
+
[What are input IDs?](../glossary#input-ids)
|
| 36 |
+
next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
|
| 37 |
+
Current scores of the top `2 * num_beams` non-finished beam hypotheses.
|
| 38 |
+
next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
| 39 |
+
`input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
|
| 40 |
+
next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
| 41 |
+
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
|
| 42 |
+
pad_token_id (`int`, *optional*):
|
| 43 |
+
The id of the *padding* token.
|
| 44 |
+
eos_token_id (`Union[int, List[int]]`, *optional*):
|
| 45 |
+
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
| 46 |
+
beam_indices (`torch.LongTensor`, *optional*):
|
| 47 |
+
Beam indices indicating to which beam hypothesis each token correspond.
|
| 48 |
+
group_index (`int`, *optional*):
|
| 49 |
+
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
|
| 50 |
+
|
| 51 |
+
Return:
|
| 52 |
+
`UserDict`: A dictionary composed of the fields as defined above:
|
| 53 |
+
|
| 54 |
+
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all
|
| 55 |
+
non-finished beams.
|
| 56 |
+
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added
|
| 57 |
+
to the non-finished beam_hypotheses.
|
| 58 |
+
- **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
|
| 59 |
+
indicating to which beam the next tokens shall be added.
|
| 60 |
+
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
FINALIZE_INPUTS_DOCSTRING = r"""
|
| 64 |
+
Args:
|
| 65 |
+
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
| 66 |
+
Indices of input sequence tokens in the vocabulary.
|
| 67 |
+
|
| 68 |
+
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
|
| 69 |
+
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
| 70 |
+
|
| 71 |
+
[What are input IDs?](../glossary#input-ids)
|
| 72 |
+
final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
|
| 73 |
+
The final scores of all non-finished beams.
|
| 74 |
+
final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
|
| 75 |
+
The last tokens to be added to the non-finished beam_hypotheses.
|
| 76 |
+
final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
|
| 77 |
+
The beam indices indicating to which beam the `final_beam_tokens` shall be added.
|
| 78 |
+
pad_token_id (`int`, *optional*):
|
| 79 |
+
The id of the *padding* token.
|
| 80 |
+
eos_token_id (`Union[int, List[int]]`, *optional*):
|
| 81 |
+
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
| 82 |
+
|
| 83 |
+
Return:
|
| 84 |
+
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
|
| 85 |
+
The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
|
| 86 |
+
due to the `eos_token_id`.
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class BeamScorer(ABC):
|
| 92 |
+
"""
|
| 93 |
+
Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
|
| 94 |
+
[`~PreTrainedModel.beam_sample`].
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
@abstractmethod
|
| 98 |
+
@add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
|
| 99 |
+
def process(
|
| 100 |
+
self,
|
| 101 |
+
input_ids: torch.LongTensor,
|
| 102 |
+
next_scores: torch.FloatTensor,
|
| 103 |
+
next_tokens: torch.LongTensor,
|
| 104 |
+
next_indices: torch.LongTensor,
|
| 105 |
+
**kwargs,
|
| 106 |
+
) -> Tuple[torch.Tensor]:
|
| 107 |
+
raise NotImplementedError("This is an abstract method.")
|
| 108 |
+
|
| 109 |
+
@abstractmethod
|
| 110 |
+
@add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
|
| 111 |
+
def finalize(
|
| 112 |
+
self,
|
| 113 |
+
input_ids: torch.LongTensor,
|
| 114 |
+
next_scores: torch.FloatTensor,
|
| 115 |
+
next_tokens: torch.LongTensor,
|
| 116 |
+
next_indices: torch.LongTensor,
|
| 117 |
+
max_length: int,
|
| 118 |
+
**kwargs,
|
| 119 |
+
) -> torch.LongTensor:
|
| 120 |
+
raise NotImplementedError("This is an abstract method.")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class BeamSearchScorer(BeamScorer):
|
| 124 |
+
r"""
|
| 125 |
+
[`BeamScorer`] implementing standard beam search decoding.
|
| 126 |
+
|
| 127 |
+
Adapted in part from [Facebook's XLM beam search
|
| 128 |
+
code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
|
| 129 |
+
|
| 130 |
+
Reference for the diverse beam search algorithm and implementation [Ashwin Kalyan's DBS
|
| 131 |
+
implementation](https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua)
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
batch_size (`int`):
|
| 135 |
+
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
|
| 136 |
+
num_beams (`int`):
|
| 137 |
+
Number of beams for beam search.
|
| 138 |
+
device (`torch.device`):
|
| 139 |
+
Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
|
| 140 |
+
allocated.
|
| 141 |
+
length_penalty (`float`, *optional*, defaults to 1.0):
|
| 142 |
+
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
| 143 |
+
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
| 144 |
+
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
| 145 |
+
`length_penalty` < 0.0 encourages shorter sequences.
|
| 146 |
+
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
|
| 147 |
+
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
|
| 148 |
+
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
|
| 149 |
+
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
| 150 |
+
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
|
| 151 |
+
beam search algorithm).
|
| 152 |
+
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
|
| 153 |
+
The number of beam hypotheses that shall be returned upon calling
|
| 154 |
+
[`~transformers.BeamSearchScorer.finalize`].
|
| 155 |
+
num_beam_groups (`int`, *optional*, defaults to 1):
|
| 156 |
+
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
| 157 |
+
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
| 158 |
+
max_length (`int`, *optional*):
|
| 159 |
+
The maximum length of the sequence to be generated.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
batch_size: int,
|
| 165 |
+
num_beams: int,
|
| 166 |
+
device: torch.device,
|
| 167 |
+
length_penalty: Optional[float] = 1.0,
|
| 168 |
+
do_early_stopping: Optional[Union[bool, str]] = False,
|
| 169 |
+
num_beam_hyps_to_keep: Optional[int] = 1,
|
| 170 |
+
num_beam_groups: Optional[int] = 1,
|
| 171 |
+
max_length: Optional[int] = None,
|
| 172 |
+
):
|
| 173 |
+
self.num_beams = num_beams
|
| 174 |
+
self.device = device
|
| 175 |
+
self.length_penalty = length_penalty
|
| 176 |
+
self.do_early_stopping = do_early_stopping
|
| 177 |
+
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
| 178 |
+
self.num_beam_groups = num_beam_groups
|
| 179 |
+
self.group_size = self.num_beams // self.num_beam_groups
|
| 180 |
+
|
| 181 |
+
self._is_init = False
|
| 182 |
+
# self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
|
| 183 |
+
# If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
|
| 184 |
+
self._beam_hyps = [
|
| 185 |
+
BeamHypotheses(
|
| 186 |
+
num_beams=self.group_size,
|
| 187 |
+
length_penalty=self.length_penalty,
|
| 188 |
+
early_stopping=self.do_early_stopping,
|
| 189 |
+
max_length=max_length,
|
| 190 |
+
)
|
| 191 |
+
for _ in range(batch_size * self.num_beam_groups)
|
| 192 |
+
]
|
| 193 |
+
# self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
|
| 194 |
+
# in the i-th mini-batch is complete.
|
| 195 |
+
self._done = torch.tensor(
|
| 196 |
+
[False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
if not isinstance(num_beams, int) or num_beams <= 1:
|
| 200 |
+
raise ValueError(
|
| 201 |
+
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
|
| 202 |
+
" one should make use of `greedy_search` instead."
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
|
| 206 |
+
raise ValueError(
|
| 207 |
+
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
|
| 208 |
+
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
@property
|
| 212 |
+
def is_done(self) -> bool:
|
| 213 |
+
return self._done.all()
|
| 214 |
+
|
| 215 |
+
def process(
|
| 216 |
+
self,
|
| 217 |
+
input_ids: torch.LongTensor,
|
| 218 |
+
next_scores: torch.FloatTensor,
|
| 219 |
+
next_tokens: torch.LongTensor,
|
| 220 |
+
next_indices: torch.LongTensor,
|
| 221 |
+
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
| 222 |
+
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
|
| 223 |
+
beam_indices: Optional[torch.LongTensor] = None,
|
| 224 |
+
group_index: Optional[int] = 0,
|
| 225 |
+
decoder_prompt_len: Optional[int] = 0,
|
| 226 |
+
) -> Dict[str, torch.Tensor]:
|
| 227 |
+
# add up to the length which the next_scores is calculated on (including decoder prompt)
|
| 228 |
+
cur_len = input_ids.shape[-1] + 1
|
| 229 |
+
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
| 230 |
+
|
| 231 |
+
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
| 232 |
+
if self.num_beam_groups > 1:
|
| 233 |
+
raise ValueError(
|
| 234 |
+
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
|
| 235 |
+
f"size of {self.group_size} is expected by the beam scorer."
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError(
|
| 239 |
+
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
|
| 240 |
+
f"{self.group_size} is expected by the beam scorer."
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
device = input_ids.device
|
| 244 |
+
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
|
| 245 |
+
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
| 246 |
+
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
| 247 |
+
|
| 248 |
+
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
| 249 |
+
if isinstance(eos_token_id, int):
|
| 250 |
+
eos_token_id = [eos_token_id]
|
| 251 |
+
eos_token_id = torch.tensor(eos_token_id)
|
| 252 |
+
|
| 253 |
+
for batch_idx in range(batch_size):
|
| 254 |
+
batch_group_idx = batch_idx * self.num_beam_groups + group_index
|
| 255 |
+
if self._done[batch_group_idx]:
|
| 256 |
+
if self.num_beams < len(self._beam_hyps[batch_group_idx]):
|
| 257 |
+
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
|
| 258 |
+
if eos_token_id is None or pad_token_id is None:
|
| 259 |
+
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
|
| 260 |
+
# pad the batch
|
| 261 |
+
next_beam_scores[batch_idx, :] = 0
|
| 262 |
+
next_beam_tokens[batch_idx, :] = pad_token_id
|
| 263 |
+
next_beam_indices[batch_idx, :] = 0
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
# next tokens for this sentence
|
| 267 |
+
beam_idx = 0
|
| 268 |
+
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
|
| 269 |
+
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
|
| 270 |
+
):
|
| 271 |
+
batch_beam_idx = batch_idx * self.group_size + next_index
|
| 272 |
+
# add to generated hypotheses if end of sentence
|
| 273 |
+
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
|
| 274 |
+
# if beam_token does not belong to top num_beams tokens, it should not be added
|
| 275 |
+
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
| 276 |
+
if is_beam_token_worse_than_top_num_beams:
|
| 277 |
+
continue
|
| 278 |
+
if beam_indices is not None:
|
| 279 |
+
beam_index = beam_indices[batch_beam_idx]
|
| 280 |
+
beam_index = beam_index + (batch_beam_idx,)
|
| 281 |
+
else:
|
| 282 |
+
beam_index = None
|
| 283 |
+
|
| 284 |
+
self._beam_hyps[batch_group_idx].add(
|
| 285 |
+
input_ids[batch_beam_idx].clone(),
|
| 286 |
+
next_score.item(),
|
| 287 |
+
beam_indices=beam_index,
|
| 288 |
+
generated_len=cur_len - decoder_prompt_len,
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
# add next predicted token since it is not eos_token
|
| 292 |
+
next_beam_scores[batch_idx, beam_idx] = next_score
|
| 293 |
+
next_beam_tokens[batch_idx, beam_idx] = next_token
|
| 294 |
+
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
|
| 295 |
+
beam_idx += 1
|
| 296 |
+
|
| 297 |
+
# once the beam for next step is full, don't add more tokens to it.
|
| 298 |
+
if beam_idx == self.group_size:
|
| 299 |
+
break
|
| 300 |
+
|
| 301 |
+
if beam_idx < self.group_size:
|
| 302 |
+
raise ValueError(
|
| 303 |
+
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
|
| 304 |
+
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Check if we are done so that we can save a pad step if all(done)
|
| 308 |
+
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
|
| 309 |
+
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return UserDict(
|
| 313 |
+
{
|
| 314 |
+
"next_beam_scores": next_beam_scores.view(-1),
|
| 315 |
+
"next_beam_tokens": next_beam_tokens.view(-1),
|
| 316 |
+
"next_beam_indices": next_beam_indices.view(-1),
|
| 317 |
+
}
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def finalize(
|
| 321 |
+
self,
|
| 322 |
+
input_ids: torch.LongTensor,
|
| 323 |
+
final_beam_scores: torch.FloatTensor,
|
| 324 |
+
final_beam_tokens: torch.LongTensor,
|
| 325 |
+
final_beam_indices: torch.LongTensor,
|
| 326 |
+
max_length: int,
|
| 327 |
+
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
| 328 |
+
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
|
| 329 |
+
beam_indices: Optional[torch.LongTensor] = None,
|
| 330 |
+
decoder_prompt_len: Optional[int] = 0,
|
| 331 |
+
) -> Tuple[torch.LongTensor]:
|
| 332 |
+
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
| 333 |
+
|
| 334 |
+
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
| 335 |
+
if isinstance(eos_token_id, int):
|
| 336 |
+
eos_token_id = [eos_token_id]
|
| 337 |
+
eos_token_id = torch.tensor(eos_token_id)
|
| 338 |
+
|
| 339 |
+
# finalize all open beam hypotheses and add to generated hypotheses
|
| 340 |
+
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
|
| 341 |
+
if self._done[batch_group_idx]:
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
# all open beam hypotheses are added to the beam hypothesis
|
| 345 |
+
# beam hypothesis class automatically keeps the best beams
|
| 346 |
+
for index_per_group in range(self.group_size):
|
| 347 |
+
batch_beam_idx = batch_group_idx * self.group_size + index_per_group
|
| 348 |
+
final_score = final_beam_scores[batch_beam_idx].item()
|
| 349 |
+
final_tokens = input_ids[batch_beam_idx]
|
| 350 |
+
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
| 351 |
+
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
| 352 |
+
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
|
| 353 |
+
|
| 354 |
+
# select the best hypotheses
|
| 355 |
+
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
| 356 |
+
best = []
|
| 357 |
+
best_indices = []
|
| 358 |
+
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
| 359 |
+
|
| 360 |
+
# retrieve best hypotheses
|
| 361 |
+
for i in range(batch_size):
|
| 362 |
+
beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
|
| 363 |
+
candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
|
| 364 |
+
sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
|
| 365 |
+
for j in range(self.num_beam_hyps_to_keep):
|
| 366 |
+
best_hyp_tuple = sorted_hyps.pop()
|
| 367 |
+
best_score = best_hyp_tuple[0]
|
| 368 |
+
best_hyp = best_hyp_tuple[1]
|
| 369 |
+
best_index = best_hyp_tuple[2]
|
| 370 |
+
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
| 371 |
+
|
| 372 |
+
# append hyp to lists
|
| 373 |
+
best.append(best_hyp)
|
| 374 |
+
|
| 375 |
+
# append indices to list
|
| 376 |
+
best_indices.append(best_index)
|
| 377 |
+
|
| 378 |
+
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
| 379 |
+
|
| 380 |
+
# prepare for adding eos
|
| 381 |
+
sent_lengths_max = sent_lengths.max().item() + 1
|
| 382 |
+
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
| 383 |
+
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
| 384 |
+
|
| 385 |
+
if len(best_indices) > 0 and best_indices[0] is not None:
|
| 386 |
+
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
| 387 |
+
else:
|
| 388 |
+
indices = None
|
| 389 |
+
|
| 390 |
+
# shorter batches are padded if needed
|
| 391 |
+
if sent_lengths.min().item() != sent_lengths.max().item():
|
| 392 |
+
if pad_token_id is None:
|
| 393 |
+
raise ValueError("`pad_token_id` has to be defined")
|
| 394 |
+
decoded.fill_(pad_token_id)
|
| 395 |
+
|
| 396 |
+
if indices is not None:
|
| 397 |
+
indices.fill_(-1)
|
| 398 |
+
|
| 399 |
+
# fill with hypotheses and eos_token_id if the latter fits in
|
| 400 |
+
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
|
| 401 |
+
decoded[i, : sent_lengths[i]] = hypo
|
| 402 |
+
|
| 403 |
+
if indices is not None:
|
| 404 |
+
indices[i, : len(best_idx)] = torch.tensor(best_idx)
|
| 405 |
+
|
| 406 |
+
if sent_lengths[i] < sent_max_len:
|
| 407 |
+
# inserting only the first eos_token_id
|
| 408 |
+
decoded[i, sent_lengths[i]] = eos_token_id[0]
|
| 409 |
+
|
| 410 |
+
return UserDict(
|
| 411 |
+
{
|
| 412 |
+
"sequences": decoded,
|
| 413 |
+
"sequence_scores": best_scores,
|
| 414 |
+
"beam_indices": indices,
|
| 415 |
+
}
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class ConstrainedBeamSearchScorer(BeamScorer):
|
| 420 |
+
r"""
|
| 421 |
+
[`BeamScorer`] implementing constrained beam search decoding.
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
batch_size (`int`):
|
| 426 |
+
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
|
| 427 |
+
num_beams (`int`):
|
| 428 |
+
Number of beams for beam search.
|
| 429 |
+
constraints (`List[Constraint]`):
|
| 430 |
+
A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
|
| 431 |
+
output. For more information, the documentation of [`Constraint`] should be read.
|
| 432 |
+
device (`torch.device`):
|
| 433 |
+
Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
|
| 434 |
+
allocated.
|
| 435 |
+
length_penalty (`float`, *optional*, defaults to 1.0):
|
| 436 |
+
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
| 437 |
+
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
| 438 |
+
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
| 439 |
+
`length_penalty` < 0.0 encourages shorter sequences.
|
| 440 |
+
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
|
| 441 |
+
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
|
| 442 |
+
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
|
| 443 |
+
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
| 444 |
+
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
|
| 445 |
+
beam search algorithm).
|
| 446 |
+
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
|
| 447 |
+
The number of beam hypotheses that shall be returned upon calling
|
| 448 |
+
[`~transformers.BeamSearchScorer.finalize`].
|
| 449 |
+
num_beam_groups (`int`, *optional*, defaults to 1):
|
| 450 |
+
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
| 451 |
+
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
| 452 |
+
max_length (`int`, *optional*):
|
| 453 |
+
The maximum length of the sequence to be generated.
|
| 454 |
+
"""
|
| 455 |
+
|
| 456 |
+
def __init__(
|
| 457 |
+
self,
|
| 458 |
+
batch_size: int,
|
| 459 |
+
num_beams: int,
|
| 460 |
+
constraints: List[Constraint],
|
| 461 |
+
device: torch.device,
|
| 462 |
+
length_penalty: Optional[float] = 1.0,
|
| 463 |
+
do_early_stopping: Optional[Union[bool, str]] = False,
|
| 464 |
+
num_beam_hyps_to_keep: Optional[int] = 1,
|
| 465 |
+
num_beam_groups: Optional[int] = 1,
|
| 466 |
+
max_length: Optional[int] = None,
|
| 467 |
+
):
|
| 468 |
+
self.num_beams = num_beams
|
| 469 |
+
self.device = device
|
| 470 |
+
self.length_penalty = length_penalty
|
| 471 |
+
self.do_early_stopping = do_early_stopping
|
| 472 |
+
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
| 473 |
+
self.num_beam_groups = num_beam_groups
|
| 474 |
+
self.group_size = self.num_beams // self.num_beam_groups
|
| 475 |
+
self.constraints = constraints
|
| 476 |
+
|
| 477 |
+
self._is_init = False
|
| 478 |
+
self._beam_hyps = [
|
| 479 |
+
BeamHypotheses(
|
| 480 |
+
num_beams=self.num_beams,
|
| 481 |
+
length_penalty=self.length_penalty,
|
| 482 |
+
early_stopping=self.do_early_stopping,
|
| 483 |
+
max_length=max_length,
|
| 484 |
+
)
|
| 485 |
+
for _ in range(batch_size)
|
| 486 |
+
]
|
| 487 |
+
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
|
| 488 |
+
|
| 489 |
+
if not isinstance(num_beams, int) or num_beams <= 1:
|
| 490 |
+
raise ValueError(
|
| 491 |
+
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
|
| 492 |
+
" one should make use of `greedy_search` instead."
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
|
| 496 |
+
raise ValueError(
|
| 497 |
+
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
|
| 498 |
+
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
@property
|
| 502 |
+
def is_done(self) -> bool:
|
| 503 |
+
return self._done.all()
|
| 504 |
+
|
| 505 |
+
def make_constraint_states(self, n):
|
| 506 |
+
return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)]
|
| 507 |
+
|
| 508 |
+
def check_completes_constraints(self, sequence):
|
| 509 |
+
new_state = self.make_constraint_states(1)[0]
|
| 510 |
+
new_state.reset(sequence)
|
| 511 |
+
return new_state.completed
|
| 512 |
+
|
| 513 |
+
def process(
|
| 514 |
+
self,
|
| 515 |
+
input_ids: torch.LongTensor,
|
| 516 |
+
next_scores: torch.FloatTensor,
|
| 517 |
+
next_tokens: torch.LongTensor,
|
| 518 |
+
next_indices: torch.LongTensor,
|
| 519 |
+
scores_for_all_vocab: torch.FloatTensor,
|
| 520 |
+
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
| 521 |
+
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
|
| 522 |
+
beam_indices: Optional[torch.LongTensor] = None,
|
| 523 |
+
decoder_prompt_len: Optional[int] = 0,
|
| 524 |
+
) -> Tuple[torch.Tensor]:
|
| 525 |
+
r"""
|
| 526 |
+
Args:
|
| 527 |
+
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
| 528 |
+
Indices of input sequence tokens in the vocabulary.
|
| 529 |
+
|
| 530 |
+
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
|
| 531 |
+
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
| 532 |
+
|
| 533 |
+
[What are input IDs?](../glossary#input-ids)
|
| 534 |
+
next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
|
| 535 |
+
Current scores of the top `2 * num_beams` non-finished beam hypotheses.
|
| 536 |
+
next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
| 537 |
+
`input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
|
| 538 |
+
next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
| 539 |
+
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
|
| 540 |
+
scores_for_all_vocab (`torch.FloatTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
| 541 |
+
The scores of all tokens in the vocabulary for each of the beam hypotheses.
|
| 542 |
+
pad_token_id (`int`, *optional*):
|
| 543 |
+
The id of the *padding* token.
|
| 544 |
+
eos_token_id (`Union[int, List[int]]`, *optional*):
|
| 545 |
+
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
| 546 |
+
beam_indices (`torch.LongTensor`, *optional*):
|
| 547 |
+
Beam indices indicating to which beam hypothesis each token correspond.
|
| 548 |
+
decoder_prompt_len (`int`, *optional*):
|
| 549 |
+
The length of prompt that is included in the input to decoder.
|
| 550 |
+
Return:
|
| 551 |
+
`UserDict`: A dictionary composed of the fields as defined above:
|
| 552 |
+
|
| 553 |
+
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
|
| 554 |
+
all
|
| 555 |
+
non-finished beams.
|
| 556 |
+
|
| 557 |
+
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
|
| 558 |
+
added
|
| 559 |
+
to the non-finished beam_hypotheses.
|
| 560 |
+
- **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
|
| 561 |
+
indicating to which beam the next tokens shall be added.
|
| 562 |
+
"""
|
| 563 |
+
|
| 564 |
+
# add up to the length which the next_scores is calculated on (including decoder prompt)
|
| 565 |
+
cur_len = input_ids.shape[-1] + 1
|
| 566 |
+
batch_size = len(self._beam_hyps)
|
| 567 |
+
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
| 568 |
+
if self.num_beam_groups > 1:
|
| 569 |
+
raise ValueError(
|
| 570 |
+
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
|
| 571 |
+
f"size of {self.group_size} is expected by the beam scorer."
|
| 572 |
+
)
|
| 573 |
+
else:
|
| 574 |
+
raise ValueError(
|
| 575 |
+
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
|
| 576 |
+
f"{self.group_size} is expected by the beam scorer."
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
device = input_ids.device
|
| 580 |
+
|
| 581 |
+
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
|
| 582 |
+
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
| 583 |
+
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
| 584 |
+
|
| 585 |
+
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
| 586 |
+
if isinstance(eos_token_id, int):
|
| 587 |
+
eos_token_id = [eos_token_id]
|
| 588 |
+
eos_token_id = torch.tensor(eos_token_id)
|
| 589 |
+
|
| 590 |
+
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
| 591 |
+
if self._done[batch_idx]:
|
| 592 |
+
if self.num_beams < len(beam_hyp):
|
| 593 |
+
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
|
| 594 |
+
if eos_token_id is None or pad_token_id is None:
|
| 595 |
+
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
|
| 596 |
+
# pad the batch
|
| 597 |
+
next_beam_scores[batch_idx, :] = 0
|
| 598 |
+
next_beam_tokens[batch_idx, :] = pad_token_id
|
| 599 |
+
next_beam_indices[batch_idx, :] = 0
|
| 600 |
+
continue
|
| 601 |
+
|
| 602 |
+
# next tokens for this sentence.
|
| 603 |
+
beam_idx = 0
|
| 604 |
+
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
|
| 605 |
+
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
|
| 606 |
+
):
|
| 607 |
+
batch_beam_idx = batch_idx * self.group_size + next_index
|
| 608 |
+
# add to generated hypotheses if end of sentence
|
| 609 |
+
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
|
| 610 |
+
# if beam_token does not belong to top num_beams tokens, it should not be added
|
| 611 |
+
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
| 612 |
+
if is_beam_token_worse_than_top_num_beams:
|
| 613 |
+
continue
|
| 614 |
+
|
| 615 |
+
completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].cpu().tolist())
|
| 616 |
+
if completes_constraint:
|
| 617 |
+
if beam_indices is not None:
|
| 618 |
+
beam_index = beam_indices[batch_beam_idx]
|
| 619 |
+
beam_index = beam_index + (batch_beam_idx,)
|
| 620 |
+
else:
|
| 621 |
+
beam_index = None
|
| 622 |
+
|
| 623 |
+
beam_hyp.add(
|
| 624 |
+
input_ids[batch_beam_idx].clone(),
|
| 625 |
+
next_score.item(),
|
| 626 |
+
beam_indices=beam_index,
|
| 627 |
+
generated_len=cur_len - decoder_prompt_len,
|
| 628 |
+
)
|
| 629 |
+
else:
|
| 630 |
+
# add next predicted token since it is not eos_token
|
| 631 |
+
next_beam_scores[batch_idx, beam_idx] = next_score
|
| 632 |
+
next_beam_tokens[batch_idx, beam_idx] = next_token
|
| 633 |
+
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
|
| 634 |
+
beam_idx += 1
|
| 635 |
+
|
| 636 |
+
# once the beam for next step is full, don't add more tokens to it.
|
| 637 |
+
if beam_idx == self.group_size:
|
| 638 |
+
break
|
| 639 |
+
|
| 640 |
+
new_scores, new_tokens, new_indices = self.step_sentence_constraint(
|
| 641 |
+
batch_idx,
|
| 642 |
+
input_ids,
|
| 643 |
+
scores_for_all_vocab,
|
| 644 |
+
next_beam_scores[batch_idx],
|
| 645 |
+
next_beam_tokens[batch_idx],
|
| 646 |
+
next_beam_indices[batch_idx],
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
next_beam_scores[batch_idx] = new_scores
|
| 650 |
+
next_beam_tokens[batch_idx] = new_tokens
|
| 651 |
+
next_beam_indices[batch_idx] = new_indices
|
| 652 |
+
|
| 653 |
+
if beam_idx < self.group_size:
|
| 654 |
+
raise ValueError(
|
| 655 |
+
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
|
| 656 |
+
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# Check if we are done so that we can save a pad step if all(done)
|
| 660 |
+
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
|
| 661 |
+
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
return UserDict(
|
| 665 |
+
{
|
| 666 |
+
"next_beam_scores": next_beam_scores.view(-1),
|
| 667 |
+
"next_beam_tokens": next_beam_tokens.view(-1),
|
| 668 |
+
"next_beam_indices": next_beam_indices.view(-1),
|
| 669 |
+
}
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
def step_sentence_constraint(
|
| 673 |
+
self,
|
| 674 |
+
batch_idx: int,
|
| 675 |
+
input_ids: torch.LongTensor,
|
| 676 |
+
vocab_scores: torch.FloatTensor,
|
| 677 |
+
sent_beam_scores: torch.FloatTensor,
|
| 678 |
+
sent_beam_tokens: torch.LongTensor,
|
| 679 |
+
sent_beam_indices: torch.LongTensor,
|
| 680 |
+
push_progress: bool = False,
|
| 681 |
+
):
|
| 682 |
+
# sent_beam_tokens are the next {num_beams} number of tokens that are under consideration for this beam
|
| 683 |
+
# (candidate next tokens)
|
| 684 |
+
|
| 685 |
+
# 1. Adding "advance_tokens"
|
| 686 |
+
# using ConstraintStateList.advance(), we propose new tokens to be added into this "candidate list" that will
|
| 687 |
+
# advance us in fulfilling the constraints.
|
| 688 |
+
|
| 689 |
+
# 2. Selecting best candidates such that we end up with highest probable candidates
|
| 690 |
+
# that fulfill our constraints.
|
| 691 |
+
|
| 692 |
+
orig_len = sent_beam_indices.size(0)
|
| 693 |
+
device = sent_beam_indices.device
|
| 694 |
+
|
| 695 |
+
# initialize states
|
| 696 |
+
topk_contraint_states = self.make_constraint_states(orig_len)
|
| 697 |
+
advance_constraint_states = self.make_constraint_states(orig_len)
|
| 698 |
+
|
| 699 |
+
sidx, eidx = batch_idx * orig_len, (batch_idx + 1) * orig_len
|
| 700 |
+
this_batch_input_ids = input_ids[sidx:eidx]
|
| 701 |
+
this_batch_token_scores = vocab_scores[sidx:eidx]
|
| 702 |
+
full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1)
|
| 703 |
+
|
| 704 |
+
# need to make new hypothesis that advance the constraints
|
| 705 |
+
track_new = {
|
| 706 |
+
"new_seqs": full_hypotheses.tolist(),
|
| 707 |
+
"new_states": [],
|
| 708 |
+
"new_indices": [],
|
| 709 |
+
"new_tokens": [],
|
| 710 |
+
"new_scores": [],
|
| 711 |
+
}
|
| 712 |
+
for seq_idx, pre_seq in enumerate(this_batch_input_ids):
|
| 713 |
+
# pre_seq = ith sequence generated before this step.
|
| 714 |
+
|
| 715 |
+
# input_ids -> (topk) generic beam search best model next tokens
|
| 716 |
+
# -> (advance) constraints forcing the next token
|
| 717 |
+
# either way, we need to sort them into "banks" later, so store a "ConstraintListState" for all types of
|
| 718 |
+
# hypotheses.
|
| 719 |
+
|
| 720 |
+
topk_state = topk_contraint_states[seq_idx]
|
| 721 |
+
topk_state.reset(full_hypotheses[seq_idx].cpu().tolist())
|
| 722 |
+
|
| 723 |
+
advance_state = advance_constraint_states[seq_idx]
|
| 724 |
+
advance_state.reset(pre_seq.cpu().tolist())
|
| 725 |
+
|
| 726 |
+
if not advance_state.completed:
|
| 727 |
+
advance_tokens = torch.LongTensor(advance_state.advance()).to(device)
|
| 728 |
+
for advance_token in advance_tokens:
|
| 729 |
+
# since adding each `advance_token` leads to a different hypothesis, create new state instance.
|
| 730 |
+
new_state = advance_state.copy(stateful=True)
|
| 731 |
+
new_state.add(advance_token.cpu().tolist())
|
| 732 |
+
|
| 733 |
+
advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist()
|
| 734 |
+
if advance_seq not in track_new["new_seqs"]:
|
| 735 |
+
# prevent duplicates, which are basically bound to happen in this process.
|
| 736 |
+
track_new["new_seqs"].append(advance_seq)
|
| 737 |
+
track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches
|
| 738 |
+
track_new["new_tokens"].append(advance_token)
|
| 739 |
+
track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
|
| 740 |
+
track_new["new_states"].append(new_state)
|
| 741 |
+
elif push_progress:
|
| 742 |
+
# Basically, `sent_beam_indices` often chooses very little among `input_ids` the generated sequences that
|
| 743 |
+
# actually fulfill our constraints. For example, let constraints == ["loves pies"] and
|
| 744 |
+
|
| 745 |
+
# pre_seq_1 = "The child loves pies and" pre_seq_2 = "The child plays in the playground and"
|
| 746 |
+
|
| 747 |
+
# Without this step, if `sent_beam_indices` is something like [1,1], then
|
| 748 |
+
# 1. `pre_seq_1` won't be added to the list of (topk) hypothesis since it's not in the indices and
|
| 749 |
+
# 2. it won't be added to the list of (advance) hypothesis since it's completed already. (this is
|
| 750 |
+
# the else part of `if constraints_completed[seq_idx]`)
|
| 751 |
+
# 3. it ends up simply getting removed from consideration.
|
| 752 |
+
|
| 753 |
+
# #3 might be fine and actually desired, since it's likely that it's a low-probability output anyways,
|
| 754 |
+
# especially if it's not in the list of `sent_beam_indices`. But this often leads to lengthened beam
|
| 755 |
+
# search times, since completed sequences keep getting removed after all this effort for constrained
|
| 756 |
+
# generation.
|
| 757 |
+
|
| 758 |
+
# Here, we basically take `pre_seq_1` and to "push" it into the considered list of hypotheses, by simply
|
| 759 |
+
# appending the next likely token in the vocabulary and adding it to the list of hypotheses.
|
| 760 |
+
|
| 761 |
+
new_score, new_token = torch.max(this_batch_token_scores[seq_idx], 0) # some next probable token
|
| 762 |
+
advance_seq = torch.cat((pre_seq, new_token.unsqueeze(0)), -1)
|
| 763 |
+
|
| 764 |
+
advance_state = advance_constraint_states[seq_idx]
|
| 765 |
+
|
| 766 |
+
advance_seq = advance_seq.cpu().tolist()
|
| 767 |
+
|
| 768 |
+
advance_state.reset(advance_seq)
|
| 769 |
+
if advance_seq not in track_new["new_seqs"]:
|
| 770 |
+
# but still don't want to have duplicates
|
| 771 |
+
track_new["new_seqs"].append(advance_seq)
|
| 772 |
+
track_new["new_indices"].append(seq_idx)
|
| 773 |
+
track_new["new_tokens"].append(new_token)
|
| 774 |
+
track_new["new_scores"].append(new_score)
|
| 775 |
+
track_new["new_states"].append(advance_state)
|
| 776 |
+
|
| 777 |
+
if len(track_new["new_indices"]) > 0:
|
| 778 |
+
new_indices = torch.tensor(track_new["new_indices"]).to(device)
|
| 779 |
+
new_tokens = torch.stack(track_new["new_tokens"]).to(device)
|
| 780 |
+
new_scores = torch.stack(track_new["new_scores"]).to(device)
|
| 781 |
+
|
| 782 |
+
all_states = topk_contraint_states + track_new["new_states"]
|
| 783 |
+
all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1)
|
| 784 |
+
all_scores = torch.cat((sent_beam_scores, new_scores), -1)
|
| 785 |
+
all_banks = torch.tensor([one.get_bank() for one in all_states]).to(device)
|
| 786 |
+
|
| 787 |
+
zipped = all_banks * 100 + all_scores
|
| 788 |
+
indices = zipped.sort(descending=True).indices
|
| 789 |
+
sorted_banks = all_banks[indices]
|
| 790 |
+
|
| 791 |
+
# Then we end up with {sorted among bank C}, {sorted among bank C-1}, ..., {sorted among bank 0}
|
| 792 |
+
|
| 793 |
+
counter = -1
|
| 794 |
+
cur_bank = sorted_banks[0]
|
| 795 |
+
increments = []
|
| 796 |
+
for bank in sorted_banks:
|
| 797 |
+
if bank == cur_bank:
|
| 798 |
+
counter += 1
|
| 799 |
+
else:
|
| 800 |
+
counter = 0
|
| 801 |
+
cur_bank = bank
|
| 802 |
+
increments.append(counter)
|
| 803 |
+
rearrangers = torch.tensor(np.argsort(increments, kind="mergesort"))
|
| 804 |
+
|
| 805 |
+
indices = indices[rearrangers][:orig_len]
|
| 806 |
+
|
| 807 |
+
sent_beam_scores = all_scores[indices]
|
| 808 |
+
sent_beam_tokens = all_tokens[indices]
|
| 809 |
+
sent_beam_indices = torch.cat((sent_beam_indices, new_indices))[indices]
|
| 810 |
+
|
| 811 |
+
return sent_beam_scores, sent_beam_tokens, sent_beam_indices
|
| 812 |
+
|
| 813 |
+
def finalize(
|
| 814 |
+
self,
|
| 815 |
+
input_ids: torch.LongTensor,
|
| 816 |
+
final_beam_scores: torch.FloatTensor,
|
| 817 |
+
final_beam_tokens: torch.LongTensor,
|
| 818 |
+
final_beam_indices: torch.LongTensor,
|
| 819 |
+
max_length: int,
|
| 820 |
+
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
| 821 |
+
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
|
| 822 |
+
beam_indices: Optional[torch.LongTensor] = None,
|
| 823 |
+
decoder_prompt_len: Optional[int] = 0,
|
| 824 |
+
) -> Tuple[torch.LongTensor]:
|
| 825 |
+
batch_size = len(self._beam_hyps)
|
| 826 |
+
|
| 827 |
+
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
| 828 |
+
if isinstance(eos_token_id, int):
|
| 829 |
+
eos_token_id = [eos_token_id]
|
| 830 |
+
eos_token_id = torch.tensor(eos_token_id)
|
| 831 |
+
|
| 832 |
+
# finalize all open beam hypotheses and add to generated hypotheses
|
| 833 |
+
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
| 834 |
+
if self._done[batch_idx]:
|
| 835 |
+
continue
|
| 836 |
+
|
| 837 |
+
# all open beam hypotheses are added to the beam hypothesis
|
| 838 |
+
# beam hypothesis class automatically keeps the best beams
|
| 839 |
+
|
| 840 |
+
ids_collect = []
|
| 841 |
+
for beam_id in range(self.num_beams):
|
| 842 |
+
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
| 843 |
+
final_score = final_beam_scores[batch_beam_idx].item()
|
| 844 |
+
final_tokens = input_ids[batch_beam_idx]
|
| 845 |
+
|
| 846 |
+
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
|
| 847 |
+
if completes_constraint:
|
| 848 |
+
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
| 849 |
+
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
| 850 |
+
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
|
| 851 |
+
ids_collect.append(beam_id)
|
| 852 |
+
|
| 853 |
+
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
|
| 854 |
+
# generation. In these cases we simply return the highest scoring outputs.
|
| 855 |
+
if len(ids_collect) < self.num_beam_hyps_to_keep:
|
| 856 |
+
for beam_id in range(self.num_beams):
|
| 857 |
+
if beam_id not in ids_collect:
|
| 858 |
+
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
| 859 |
+
final_score = final_beam_scores[batch_beam_idx].item()
|
| 860 |
+
final_tokens = input_ids[batch_beam_idx]
|
| 861 |
+
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
| 862 |
+
beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
|
| 863 |
+
if len(ids_collect) >= self.num_beam_hyps_to_keep:
|
| 864 |
+
break
|
| 865 |
+
|
| 866 |
+
# select the best hypotheses
|
| 867 |
+
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
| 868 |
+
best = []
|
| 869 |
+
best_indices = []
|
| 870 |
+
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
| 871 |
+
|
| 872 |
+
# retrieve best hypotheses
|
| 873 |
+
for i, beam_hyp in enumerate(self._beam_hyps):
|
| 874 |
+
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
|
| 875 |
+
for j in range(self.num_beam_hyps_to_keep):
|
| 876 |
+
best_hyp_tuple = sorted_hyps.pop()
|
| 877 |
+
best_score = best_hyp_tuple[0]
|
| 878 |
+
best_hyp = best_hyp_tuple[1]
|
| 879 |
+
best_index = best_hyp_tuple[2]
|
| 880 |
+
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
| 881 |
+
|
| 882 |
+
# append to lists
|
| 883 |
+
best.append(best_hyp)
|
| 884 |
+
|
| 885 |
+
# append indices to list
|
| 886 |
+
best_indices.append(best_index)
|
| 887 |
+
|
| 888 |
+
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
| 889 |
+
|
| 890 |
+
# prepare for adding eos
|
| 891 |
+
sent_lengths_max = sent_lengths.max().item() + 1
|
| 892 |
+
|
| 893 |
+
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
| 894 |
+
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
| 895 |
+
|
| 896 |
+
if len(best_indices) > 0 and best_indices[0] is not None:
|
| 897 |
+
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
| 898 |
+
else:
|
| 899 |
+
indices = None
|
| 900 |
+
|
| 901 |
+
# shorter batches are padded if needed
|
| 902 |
+
if sent_lengths.min().item() != sent_lengths.max().item():
|
| 903 |
+
if pad_token_id is None:
|
| 904 |
+
raise ValueError("`pad_token_id` has to be defined")
|
| 905 |
+
decoded.fill_(pad_token_id)
|
| 906 |
+
|
| 907 |
+
if indices is not None:
|
| 908 |
+
indices.fill_(-1)
|
| 909 |
+
|
| 910 |
+
# fill with hypotheses and eos_token_id if the latter fits in
|
| 911 |
+
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
|
| 912 |
+
decoded[i, : sent_lengths[i]] = hypo
|
| 913 |
+
|
| 914 |
+
if indices is not None:
|
| 915 |
+
indices[i, : len(best_idx)] = torch.tensor(best_idx)
|
| 916 |
+
|
| 917 |
+
if sent_lengths[i] < sent_max_len:
|
| 918 |
+
# inserting only the first eos_token_id
|
| 919 |
+
decoded[i, sent_lengths[i]] = eos_token_id[0]
|
| 920 |
+
|
| 921 |
+
return UserDict(
|
| 922 |
+
{
|
| 923 |
+
"sequences": decoded,
|
| 924 |
+
"sequence_scores": best_scores,
|
| 925 |
+
"beam_indices": indices,
|
| 926 |
+
}
|
| 927 |
+
)
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
class BeamHypotheses:
|
| 931 |
+
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
|
| 932 |
+
"""
|
| 933 |
+
Initialize n-best list of hypotheses.
|
| 934 |
+
"""
|
| 935 |
+
self.length_penalty = length_penalty
|
| 936 |
+
self.early_stopping = early_stopping
|
| 937 |
+
self.max_length = max_length
|
| 938 |
+
self.num_beams = num_beams
|
| 939 |
+
self.beams = []
|
| 940 |
+
self.worst_score = 1e9
|
| 941 |
+
|
| 942 |
+
if not isinstance(self.early_stopping, bool) and self.max_length is None:
|
| 943 |
+
raise ValueError(
|
| 944 |
+
"When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
|
| 945 |
+
" BeamScorer class instance at initialization time."
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
def __len__(self):
|
| 949 |
+
"""
|
| 950 |
+
Number of hypotheses in the list.
|
| 951 |
+
"""
|
| 952 |
+
return len(self.beams)
|
| 953 |
+
|
| 954 |
+
def add(
|
| 955 |
+
self,
|
| 956 |
+
hyp: torch.LongTensor,
|
| 957 |
+
sum_logprobs: float,
|
| 958 |
+
beam_indices: Optional[torch.LongTensor] = None,
|
| 959 |
+
generated_len: Optional[int] = None,
|
| 960 |
+
):
|
| 961 |
+
"""
|
| 962 |
+
Add a new hypothesis to the list.
|
| 963 |
+
"""
|
| 964 |
+
if generated_len is not None:
|
| 965 |
+
score = sum_logprobs / (generated_len**self.length_penalty)
|
| 966 |
+
# This 'else' case exists for retrocompatibility
|
| 967 |
+
else:
|
| 968 |
+
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
| 969 |
+
|
| 970 |
+
if len(self) < self.num_beams or score > self.worst_score:
|
| 971 |
+
self.beams.append((score, hyp, beam_indices))
|
| 972 |
+
if len(self) > self.num_beams:
|
| 973 |
+
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
|
| 974 |
+
del self.beams[sorted_next_scores[0][1]]
|
| 975 |
+
self.worst_score = sorted_next_scores[1][0]
|
| 976 |
+
else:
|
| 977 |
+
self.worst_score = min(score, self.worst_score)
|
| 978 |
+
|
| 979 |
+
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
|
| 980 |
+
"""
|
| 981 |
+
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
|
| 982 |
+
one in the heap, then we are done with this sentence.
|
| 983 |
+
"""
|
| 984 |
+
|
| 985 |
+
if len(self) < self.num_beams:
|
| 986 |
+
return False
|
| 987 |
+
|
| 988 |
+
# `True`: stop as soon as at least `num_beams` hypotheses are finished
|
| 989 |
+
if self.early_stopping is True:
|
| 990 |
+
return True
|
| 991 |
+
# `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
|
| 992 |
+
# when `length_penalty` is positive. See the discussion below for more details.
|
| 993 |
+
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
| 994 |
+
elif self.early_stopping is False:
|
| 995 |
+
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
|
| 996 |
+
ret = self.worst_score >= highest_attainable_score
|
| 997 |
+
return ret
|
| 998 |
+
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
|
| 999 |
+
else:
|
| 1000 |
+
# `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
|
| 1001 |
+
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
|
| 1002 |
+
# its max this way
|
| 1003 |
+
if self.length_penalty > 0.0:
|
| 1004 |
+
if self.max_length <= decoder_prompt_len:
|
| 1005 |
+
raise ValueError("max_length is not larger than decoder prompt length")
|
| 1006 |
+
highest_attainable_score = (
|
| 1007 |
+
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
|
| 1008 |
+
)
|
| 1009 |
+
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
|
| 1010 |
+
else:
|
| 1011 |
+
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
|
| 1012 |
+
ret = self.worst_score >= highest_attainable_score
|
| 1013 |
+
return ret
|
.venv/lib/python3.11/site-packages/transformers/generation/candidate_generator.py
ADDED
|
@@ -0,0 +1,871 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..utils import is_sklearn_available
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if is_sklearn_available():
|
| 26 |
+
from sklearn.metrics import roc_curve
|
| 27 |
+
|
| 28 |
+
from ..cache_utils import DynamicCache
|
| 29 |
+
from ..pytorch_utils import isin_mps_friendly
|
| 30 |
+
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if TYPE_CHECKING:
|
| 34 |
+
from ..modeling_utils import PreTrainedModel
|
| 35 |
+
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
| 36 |
+
from .configuration_utils import GenerationConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class CandidateGenerator:
|
| 40 |
+
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
|
| 41 |
+
|
| 42 |
+
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
| 43 |
+
"""
|
| 44 |
+
Fetches the candidates to be tried for the current input.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 48 |
+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
| 49 |
+
|
| 50 |
+
Return:
|
| 51 |
+
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
|
| 52 |
+
assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
|
| 53 |
+
vocabulary_size)` containing the logits associated to each candidate.
|
| 54 |
+
"""
|
| 55 |
+
raise NotImplementedError(
|
| 56 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
|
| 60 |
+
"""
|
| 61 |
+
Updates the candidate generation strategy based on the outcomes.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 65 |
+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
| 66 |
+
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
|
| 67 |
+
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
|
| 68 |
+
beam search or log softmax for each vocabulary token when using beam search
|
| 69 |
+
num_matches (`int`):
|
| 70 |
+
The number of matches between the candidate sequences and the model predictions.
|
| 71 |
+
"""
|
| 72 |
+
raise NotImplementedError(
|
| 73 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can call "
|
| 74 |
+
"`update_candidate_strategy`."
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class AssistedCandidateGenerator(CandidateGenerator):
|
| 79 |
+
"""
|
| 80 |
+
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
|
| 81 |
+
candidates through the use of a smaller model. Read the following blog post for more information:
|
| 82 |
+
https://huggingface.co/blog/assisted-generation
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 86 |
+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
| 87 |
+
assistant_model (`PreTrainedModel`):
|
| 88 |
+
The model to be used for generating candidates. This model should be smaller than the main model.
|
| 89 |
+
generation_config (`~generation.GenerationConfig`, *optional*):
|
| 90 |
+
The generation configuration to be used as base parametrization for the generation call.
|
| 91 |
+
logits_processor (`LogitsProcessorList`):
|
| 92 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
| 93 |
+
used to modify the prediction scores of the language modeling head applied at each generation step.
|
| 94 |
+
model_kwargs (`Dict`):
|
| 95 |
+
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
|
| 96 |
+
model as well.
|
| 97 |
+
inputs_tensor (`torch.Tensor`, *optional*):
|
| 98 |
+
The model input tensor. In encoder-decoder models, this is the encoder input.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
input_ids: torch.LongTensor,
|
| 104 |
+
assistant_model: "PreTrainedModel",
|
| 105 |
+
generation_config: "GenerationConfig",
|
| 106 |
+
model_kwargs: Dict,
|
| 107 |
+
inputs_tensor: Optional[torch.Tensor] = None,
|
| 108 |
+
logits_processor: "LogitsProcessorList" = None,
|
| 109 |
+
):
|
| 110 |
+
# Make sure all data at the same device as assistant model
|
| 111 |
+
device = assistant_model.device
|
| 112 |
+
input_ids = input_ids.to(device)
|
| 113 |
+
if inputs_tensor is not None:
|
| 114 |
+
inputs_tensor = inputs_tensor.to(device)
|
| 115 |
+
|
| 116 |
+
# Prepare the assistant and the starting number of candidate tokens
|
| 117 |
+
self.assistant_model = assistant_model
|
| 118 |
+
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
| 119 |
+
self.assistant_confidence_threshold = assistant_model.generation_config.assistant_confidence_threshold
|
| 120 |
+
|
| 121 |
+
# Set eos in assistant same as in target model
|
| 122 |
+
self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id
|
| 123 |
+
|
| 124 |
+
# Prepare the kwargs for the assistant model
|
| 125 |
+
assistant_kwargs = {}
|
| 126 |
+
for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
|
| 127 |
+
if key not in ("encoder_outputs", "assistant_encoder_outputs", "past_key_values"):
|
| 128 |
+
assistant_kwargs[key] = (
|
| 129 |
+
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Remove potential default "num_logits_to_keep" key
|
| 133 |
+
if "num_logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_num_logits_to_keep():
|
| 134 |
+
del assistant_kwargs["num_logits_to_keep"]
|
| 135 |
+
|
| 136 |
+
if "assistant_encoder_outputs" in model_kwargs:
|
| 137 |
+
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
| 138 |
+
elif assistant_model.config.is_encoder_decoder:
|
| 139 |
+
inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs(
|
| 140 |
+
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs
|
| 141 |
+
)
|
| 142 |
+
assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
|
| 143 |
+
inputs_tensor, assistant_kwargs, model_input_name, assistant_model.generation_config
|
| 144 |
+
)
|
| 145 |
+
elif "encoder_outputs" in model_kwargs:
|
| 146 |
+
assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"]
|
| 147 |
+
self.assistant_kwargs = assistant_kwargs
|
| 148 |
+
|
| 149 |
+
# Prepare assistant model's keys of inputs
|
| 150 |
+
if assistant_model.config.is_encoder_decoder:
|
| 151 |
+
# both are encoder-decoder
|
| 152 |
+
self.input_ids_key = "decoder_input_ids"
|
| 153 |
+
elif "encoder_outputs" in assistant_kwargs:
|
| 154 |
+
# special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
|
| 155 |
+
self.input_ids_key = "input_ids"
|
| 156 |
+
self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get(
|
| 157 |
+
"decoder_attention_mask",
|
| 158 |
+
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
|
| 159 |
+
)
|
| 160 |
+
else:
|
| 161 |
+
# both are decoder-only
|
| 162 |
+
self.input_ids_key = "input_ids"
|
| 163 |
+
|
| 164 |
+
# Prepare generation-related options.
|
| 165 |
+
self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
| 166 |
+
self.generation_config = copy.deepcopy(generation_config)
|
| 167 |
+
|
| 168 |
+
self.generation_config.return_dict_in_generate = True
|
| 169 |
+
self.generation_config.output_scores = True
|
| 170 |
+
self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold
|
| 171 |
+
# this flag allow us set the confidence stopping criteria for assistant model generation.
|
| 172 |
+
self.generation_config.is_assistant = True
|
| 173 |
+
|
| 174 |
+
# avoid unnecessary warnings that min_length is larger than max_new_tokens
|
| 175 |
+
# remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
|
| 176 |
+
self.main_model_min_length = self.generation_config.min_length
|
| 177 |
+
self.generation_config.min_length = 0
|
| 178 |
+
self.generation_config.min_new_tokens = None
|
| 179 |
+
for processor in self.logits_processor:
|
| 180 |
+
if isinstance(processor, MinLengthLogitsProcessor):
|
| 181 |
+
raise ValueError(
|
| 182 |
+
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
|
| 183 |
+
"Please pass in `min_length` into `.generate()` instead"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# We need to roll back the cache in assisted generation, only DynamicCache is supported
|
| 187 |
+
self.generation_config.cache_implementation = None
|
| 188 |
+
|
| 189 |
+
if (
|
| 190 |
+
is_sklearn_available()
|
| 191 |
+
and self.assistant_model.generation_config.assistant_confidence_threshold
|
| 192 |
+
and type(self) is AssistedCandidateGenerator
|
| 193 |
+
):
|
| 194 |
+
self.probs = []
|
| 195 |
+
self.matches = []
|
| 196 |
+
|
| 197 |
+
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
| 198 |
+
"""
|
| 199 |
+
Fetches the candidates to be tried for the current input.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 203 |
+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
| 204 |
+
|
| 205 |
+
Return:
|
| 206 |
+
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
|
| 207 |
+
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
|
| 208 |
+
vocabulary_size)` containing the logits associated to each candidate.
|
| 209 |
+
"""
|
| 210 |
+
input_ids = input_ids.to(self.assistant_model.device)
|
| 211 |
+
# Calculate new tokens to generate
|
| 212 |
+
min_new_tokens, max_new_tokens = self._calculate_new_tokens(input_ids)
|
| 213 |
+
if max_new_tokens == 0:
|
| 214 |
+
return input_ids, None
|
| 215 |
+
# Update past key values and masks
|
| 216 |
+
self._update_past_and_masks(input_ids)
|
| 217 |
+
# Generate candidates
|
| 218 |
+
generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens)
|
| 219 |
+
candidate_ids, candidate_logits = self._generate_candidates(generation_args)
|
| 220 |
+
return candidate_ids, candidate_logits
|
| 221 |
+
|
| 222 |
+
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
|
| 223 |
+
"""
|
| 224 |
+
Updates the candidate generation strategy based on the outcomes.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 228 |
+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
| 229 |
+
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
|
| 230 |
+
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
|
| 231 |
+
beam search or log softmax for each vocabulary token when using beam search
|
| 232 |
+
num_matches (`int`):
|
| 233 |
+
The number of matches between the candidate sequences and the model predictions.
|
| 234 |
+
"""
|
| 235 |
+
# Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
| 236 |
+
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
| 237 |
+
# cost of forecasting incorrect assistant tokens.
|
| 238 |
+
if self.assistant_model.generation_config.num_assistant_tokens_schedule in {
|
| 239 |
+
"heuristic",
|
| 240 |
+
"heuristic_transient",
|
| 241 |
+
}:
|
| 242 |
+
# len(scores[0])-1 is the number of candidates according to the target tokenizer.
|
| 243 |
+
if num_matches == len(scores[0]) - 1:
|
| 244 |
+
self.num_assistant_tokens += 2.0
|
| 245 |
+
else:
|
| 246 |
+
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
|
| 247 |
+
|
| 248 |
+
# The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes. The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target. A cost of 25% is assigned to false positives and 75% to false negatives.
|
| 249 |
+
# This adaptation is not compatible with UAG, as it relies on the number of matched tokens based on the draft vocabulary, which is unavailable in UAG.
|
| 250 |
+
if (
|
| 251 |
+
is_sklearn_available()
|
| 252 |
+
and self.assistant_model.generation_config.assistant_confidence_threshold
|
| 253 |
+
and type(self) is AssistedCandidateGenerator
|
| 254 |
+
):
|
| 255 |
+
# update self.matches
|
| 256 |
+
self.matches.extend([1] * num_matches)
|
| 257 |
+
if len(self.probs) > len(self.matches):
|
| 258 |
+
self.matches.append(0)
|
| 259 |
+
|
| 260 |
+
# update self.probs
|
| 261 |
+
excess_length = len(self.probs) - len(self.matches)
|
| 262 |
+
if excess_length > 0:
|
| 263 |
+
del self.probs[-excess_length:]
|
| 264 |
+
|
| 265 |
+
if (
|
| 266 |
+
len(self.probs) > 5 and {0, 1}.issubset(self.matches)
|
| 267 |
+
): # require at least 5 samples to calculate the ROC curve and at least one positive and one negative sample
|
| 268 |
+
fpr, tpr, thresholds = roc_curve(self.matches, self.probs)
|
| 269 |
+
fnr = 1 - tpr
|
| 270 |
+
|
| 271 |
+
# Calculate the cost for each threshold
|
| 272 |
+
costs = fpr + 3 * fnr
|
| 273 |
+
|
| 274 |
+
# Find the threshold that minimizes the cost
|
| 275 |
+
optimal_threshold_index = np.argmin(costs)
|
| 276 |
+
best_threshold = thresholds[optimal_threshold_index]
|
| 277 |
+
|
| 278 |
+
self.assistant_model.generation_config.assistant_confidence_threshold = best_threshold
|
| 279 |
+
|
| 280 |
+
def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]:
|
| 281 |
+
"""Calculate the minimum and maximum number of new tokens to generate."""
|
| 282 |
+
new_cur_len = input_ids.shape[-1]
|
| 283 |
+
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
|
| 284 |
+
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
|
| 285 |
+
return min_new_tokens, max_new_tokens
|
| 286 |
+
|
| 287 |
+
def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool:
|
| 288 |
+
"""Update past key values and attention masks for subsequent generation rounds."""
|
| 289 |
+
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
|
| 290 |
+
if has_past_key_values:
|
| 291 |
+
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
|
| 292 |
+
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
|
| 293 |
+
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
|
| 294 |
+
)
|
| 295 |
+
self.assistant_kwargs = _prepare_attention_mask(
|
| 296 |
+
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
|
| 297 |
+
)
|
| 298 |
+
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
|
| 299 |
+
return has_past_key_values
|
| 300 |
+
|
| 301 |
+
def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
|
| 302 |
+
"""Prepare arguments for the generation call."""
|
| 303 |
+
return {
|
| 304 |
+
self.input_ids_key: input_ids,
|
| 305 |
+
"min_new_tokens": min_new_tokens,
|
| 306 |
+
"max_new_tokens": max_new_tokens,
|
| 307 |
+
"generation_config": self.generation_config,
|
| 308 |
+
"logits_processor": self.logits_processor,
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
def _generate_candidates(self, generation_args: Dict) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
| 312 |
+
"""Generate candidate sequences using the assistant model."""
|
| 313 |
+
assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
|
| 314 |
+
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
| 315 |
+
if (
|
| 316 |
+
is_sklearn_available()
|
| 317 |
+
and self.assistant_model.generation_config.assistant_confidence_threshold
|
| 318 |
+
and type(self) is AssistedCandidateGenerator
|
| 319 |
+
):
|
| 320 |
+
scores_tensor = torch.cat(assistant_output.scores, dim=0)
|
| 321 |
+
scores_softmax = torch.softmax(scores_tensor, dim=-1)
|
| 322 |
+
ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
|
| 323 |
+
p = scores_softmax[range(len(ids)), ids]
|
| 324 |
+
self.probs.extend(p.tolist())
|
| 325 |
+
candidate_logits = torch.stack(assistant_output.scores, dim=1)
|
| 326 |
+
candidate_ids = assistant_output.sequences
|
| 327 |
+
return candidate_ids, candidate_logits
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
| 331 |
+
"""
|
| 332 |
+
`CandidateGenerator` class to be used for Universal Assisted Generation (UAD): assisted generation with different tokenizers
|
| 333 |
+
for the assistant and main models. This class generates candidates through the use of a smaller
|
| 334 |
+
model.
|
| 335 |
+
|
| 336 |
+
The main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
|
| 337 |
+
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
|
| 338 |
+
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
|
| 339 |
+
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
|
| 340 |
+
to ensure the new tokens include the correct prompt suffix.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 344 |
+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
| 345 |
+
assistant_model (`PreTrainedModel`):
|
| 346 |
+
The model to be used for generating candidates. This model should be smaller than the main model.
|
| 347 |
+
target_tokenizer (`PreTrainedTokenizerBase`):
|
| 348 |
+
The tokenizer used for the target model.
|
| 349 |
+
assistant_tokenizer (`PreTrainedTokenizerBase`):
|
| 350 |
+
The tokenizer used for the assistant model.
|
| 351 |
+
generation_config (`~generation.GenerationConfig`, *optional*):
|
| 352 |
+
The generation configuration to be used as base parametrization for the generation call.
|
| 353 |
+
logits_processor (`LogitsProcessorList`):
|
| 354 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
| 355 |
+
used to modify the prediction scores of the language modeling head applied at each generation step.
|
| 356 |
+
model_kwargs (`Dict`):
|
| 357 |
+
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
|
| 358 |
+
model as well.
|
| 359 |
+
inputs_tensor (`torch.Tensor`, *optional*):
|
| 360 |
+
The model input tensor. In encoder-decoder models, this is the encoder input.
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
def __init__(
|
| 364 |
+
self,
|
| 365 |
+
input_ids: torch.LongTensor,
|
| 366 |
+
assistant_model: "PreTrainedModel",
|
| 367 |
+
target_tokenizer: "PreTrainedTokenizerBase",
|
| 368 |
+
assistant_tokenizer: "PreTrainedTokenizerBase",
|
| 369 |
+
generation_config: "GenerationConfig",
|
| 370 |
+
model_kwargs: Dict,
|
| 371 |
+
inputs_tensor: Optional[torch.Tensor] = None,
|
| 372 |
+
logits_processor: "LogitsProcessorList" = None,
|
| 373 |
+
):
|
| 374 |
+
super().__init__(input_ids, assistant_model, generation_config, model_kwargs, inputs_tensor, logits_processor)
|
| 375 |
+
|
| 376 |
+
self.target_tokenizer = target_tokenizer
|
| 377 |
+
self.assistant_tokenizer = assistant_tokenizer
|
| 378 |
+
self.prev_target_ids_len: Optional[int] = None
|
| 379 |
+
self.prev_assistant_ids = None
|
| 380 |
+
self.target_lookbehind = assistant_model.generation_config.target_lookbehind
|
| 381 |
+
self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
|
| 382 |
+
|
| 383 |
+
@staticmethod
|
| 384 |
+
def _get_longest_diag_dict(input_matrix, nonzero_idx):
|
| 385 |
+
"""
|
| 386 |
+
Calculates the length of the longest diagonal sequence in a given matrix.
|
| 387 |
+
Args:
|
| 388 |
+
input_matrix (torch.Tensor): The input matrix.
|
| 389 |
+
nonzero_idx (torch.Tensor): The indices of the non-zero elements in the matrix.
|
| 390 |
+
Returns:
|
| 391 |
+
dict: A dictionary where the keys are the indices of the non-zero elements and the values are the lengths of the longest diagonal sequences starting from those indices.
|
| 392 |
+
"""
|
| 393 |
+
|
| 394 |
+
visited = set()
|
| 395 |
+
diags = {}
|
| 396 |
+
for idx in nonzero_idx:
|
| 397 |
+
start_idx = torch.clone(idx)
|
| 398 |
+
tuple_start_idx = tuple(start_idx.tolist())
|
| 399 |
+
|
| 400 |
+
if tuple_start_idx in visited:
|
| 401 |
+
continue
|
| 402 |
+
|
| 403 |
+
visited.add(tuple_start_idx)
|
| 404 |
+
cur_diag_len = 1
|
| 405 |
+
start_idx += 1
|
| 406 |
+
while start_idx[0] < input_matrix.shape[0] and start_idx[1] < input_matrix.shape[1]:
|
| 407 |
+
tuple_start_idx = tuple(start_idx.tolist())
|
| 408 |
+
visited.add(tuple_start_idx)
|
| 409 |
+
|
| 410 |
+
if input_matrix[start_idx[0], start_idx[1]] == 1:
|
| 411 |
+
cur_diag_len += 1
|
| 412 |
+
start_idx += 1
|
| 413 |
+
else:
|
| 414 |
+
break
|
| 415 |
+
|
| 416 |
+
diags[idx] = cur_diag_len
|
| 417 |
+
return diags
|
| 418 |
+
|
| 419 |
+
@staticmethod
|
| 420 |
+
def _get_longest_diag_index(input_matrix):
|
| 421 |
+
"""
|
| 422 |
+
Returns the start index and length of the longest diagonal in the given input.
|
| 423 |
+
Args:
|
| 424 |
+
input_matrix (numpy.ndarray): The input matrix.
|
| 425 |
+
Returns:
|
| 426 |
+
tuple: A tuple containing the start index and length of the longest diagonal.
|
| 427 |
+
"""
|
| 428 |
+
|
| 429 |
+
diags = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_dict(
|
| 430 |
+
input_matrix, input_matrix.nonzero()
|
| 431 |
+
)
|
| 432 |
+
diags_values = list(diags.values())
|
| 433 |
+
diags_keys = list(diags.keys())
|
| 434 |
+
best_diag = np.argmax(diags_values)
|
| 435 |
+
diag_start_index = diags_keys[best_diag]
|
| 436 |
+
diag_start_length = diags_values[best_diag]
|
| 437 |
+
return diag_start_index, diag_start_length
|
| 438 |
+
|
| 439 |
+
@staticmethod
|
| 440 |
+
def _get_tokens_diag(prompt, prompt_plus_new_tokens):
|
| 441 |
+
"""
|
| 442 |
+
Input:
|
| 443 |
+
prompt: 2D array of shape (batch_size, prompt_length), represents the original prompt tokens
|
| 444 |
+
prompt_plus_new_tokens: 2D array of shape (batch_size, prompt_length), represents the suffix of the original prompt, with additional new tokens.
|
| 445 |
+
Output:
|
| 446 |
+
discrepancy_length: int, represents the number of tokens that need to be replaced from prompt
|
| 447 |
+
new_tokens_only: 2D array of shape (batch_size, new_token_length), represents the new tokens that are not in prompt
|
| 448 |
+
discrepancy_only: 2D array of shape (batch_size, discrepancy_length), represents the new tokens that are in prompt but not in prompt_plus_new_tokens
|
| 449 |
+
"""
|
| 450 |
+
compare_mat = prompt_plus_new_tokens.T == prompt
|
| 451 |
+
if not torch.is_tensor(compare_mat):
|
| 452 |
+
compare_mat = torch.tensor(compare_mat)
|
| 453 |
+
|
| 454 |
+
compare_mat_int = compare_mat.to(int)
|
| 455 |
+
|
| 456 |
+
if not compare_mat_int.any().item():
|
| 457 |
+
# empty intersection between prompt and prompt_plus_new_tokens
|
| 458 |
+
return None, None, None
|
| 459 |
+
|
| 460 |
+
longest_location, longest_diag_length = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_index(
|
| 461 |
+
compare_mat_int
|
| 462 |
+
)
|
| 463 |
+
new_token_start_index = longest_location[0] + longest_diag_length
|
| 464 |
+
discrepancy_with_old = longest_location[1] + longest_diag_length
|
| 465 |
+
discrepancy_length = (prompt.shape[1] - discrepancy_with_old).item()
|
| 466 |
+
new_tokens_only = prompt_plus_new_tokens[:, new_token_start_index + discrepancy_length :]
|
| 467 |
+
discrepancy_only = prompt_plus_new_tokens[
|
| 468 |
+
:, new_token_start_index : new_token_start_index + discrepancy_length
|
| 469 |
+
]
|
| 470 |
+
return discrepancy_length, new_tokens_only, discrepancy_only
|
| 471 |
+
|
| 472 |
+
def convert_source_tokens_to_target_tokens(
|
| 473 |
+
self,
|
| 474 |
+
input_ids,
|
| 475 |
+
source_tokenizer,
|
| 476 |
+
destination_tokenizer,
|
| 477 |
+
):
|
| 478 |
+
"""
|
| 479 |
+
Convert token IDs from one tokenizer to another.
|
| 480 |
+
Args:
|
| 481 |
+
input_ids: The input token IDs.
|
| 482 |
+
source_tokenizer: The source tokenizer.
|
| 483 |
+
destination_tokenizer: The destination tokenizer.
|
| 484 |
+
Returns:
|
| 485 |
+
The converted token IDs.
|
| 486 |
+
"""
|
| 487 |
+
text = source_tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
| 488 |
+
dest_ids = destination_tokenizer(text, add_special_tokens=True, return_tensors="pt")["input_ids"]
|
| 489 |
+
return dest_ids.to(input_ids.device)
|
| 490 |
+
|
| 491 |
+
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
| 492 |
+
"""
|
| 493 |
+
Fetches the candidates to be tried for the current input.
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 497 |
+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
| 498 |
+
|
| 499 |
+
Return:
|
| 500 |
+
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
|
| 501 |
+
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
|
| 502 |
+
vocabulary_size)` containing the logits associated to each candidate.
|
| 503 |
+
"""
|
| 504 |
+
max_new_tokens = int(self.num_assistant_tokens)
|
| 505 |
+
if max_new_tokens == 0:
|
| 506 |
+
return input_ids, None
|
| 507 |
+
|
| 508 |
+
input_ids = input_ids.to(self.assistant_model.device)
|
| 509 |
+
remove_from_pkv = 0
|
| 510 |
+
|
| 511 |
+
assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids)
|
| 512 |
+
self.prev_assistant_ids = assistant_input_ids
|
| 513 |
+
|
| 514 |
+
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0)
|
| 515 |
+
|
| 516 |
+
self._update_past_and_masks(assistant_input_ids, remove_from_pkv)
|
| 517 |
+
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
|
| 518 |
+
self.assistant_kwargs.pop("attention_mask", None)
|
| 519 |
+
|
| 520 |
+
assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
|
| 521 |
+
new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences, assistant_input_ids)
|
| 522 |
+
|
| 523 |
+
# Update state
|
| 524 |
+
self.prev_target_ids_len = input_ids.shape[1]
|
| 525 |
+
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
| 526 |
+
self.prev_assistant_ids = assistant_output.sequences
|
| 527 |
+
|
| 528 |
+
if self.prev_target_ids_len >= new_target_ids.shape[1]:
|
| 529 |
+
return input_ids, None
|
| 530 |
+
|
| 531 |
+
return new_target_ids, None
|
| 532 |
+
|
| 533 |
+
def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, int]:
|
| 534 |
+
"""Converts target input IDs to assistant input IDs, handling discrepancies."""
|
| 535 |
+
convert_kwargs = {
|
| 536 |
+
"source_tokenizer": self.target_tokenizer,
|
| 537 |
+
"destination_tokenizer": self.assistant_tokenizer,
|
| 538 |
+
}
|
| 539 |
+
remove_from_pkv = 0
|
| 540 |
+
|
| 541 |
+
if self.prev_assistant_ids is not None and self.prev_target_ids_len > self.target_lookbehind:
|
| 542 |
+
# input_ids contains all target prompt input ids and some new target input ids
|
| 543 |
+
start_index_in_target_window = self.prev_target_ids_len - self.target_lookbehind
|
| 544 |
+
|
| 545 |
+
new_assistant_ids = self.convert_source_tokens_to_target_tokens(
|
| 546 |
+
input_ids[:, start_index_in_target_window:], **convert_kwargs
|
| 547 |
+
)
|
| 548 |
+
prompt_use_length = new_assistant_ids.shape[1]
|
| 549 |
+
prompt_use = self.prev_assistant_ids[:, -prompt_use_length:]
|
| 550 |
+
|
| 551 |
+
discrepancy_length, new_tokens_only, discrepancy_only = self._get_tokens_diag(
|
| 552 |
+
prompt_use, new_assistant_ids
|
| 553 |
+
)
|
| 554 |
+
assistant_input_ids = self.prev_assistant_ids
|
| 555 |
+
|
| 556 |
+
if new_tokens_only is not None:
|
| 557 |
+
if discrepancy_length > 0 and discrepancy_only.shape[1] > 0:
|
| 558 |
+
if discrepancy_length == discrepancy_only.shape[1]:
|
| 559 |
+
assistant_input_ids[:, -discrepancy_length:] = discrepancy_only
|
| 560 |
+
|
| 561 |
+
elif discrepancy_length > discrepancy_only.shape[1]:
|
| 562 |
+
discrepancy_length_diff = discrepancy_length - discrepancy_only.shape[1]
|
| 563 |
+
assistant_input_ids = assistant_input_ids[:, :-discrepancy_length_diff]
|
| 564 |
+
assistant_input_ids[:, -discrepancy_only.shape[1] :] = discrepancy_only
|
| 565 |
+
|
| 566 |
+
remove_from_pkv = discrepancy_length
|
| 567 |
+
|
| 568 |
+
if new_tokens_only.shape[1] > 0:
|
| 569 |
+
assistant_input_ids = torch.cat([assistant_input_ids, new_tokens_only], dim=-1)
|
| 570 |
+
else:
|
| 571 |
+
# edge case: in case of no intersection between prompt and new_assistant_ids
|
| 572 |
+
assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)
|
| 573 |
+
else:
|
| 574 |
+
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
|
| 575 |
+
self.prev_target_ids_len = input_ids.shape[1]
|
| 576 |
+
|
| 577 |
+
return assistant_input_ids, remove_from_pkv
|
| 578 |
+
|
| 579 |
+
def _process_assistant_outputs(
|
| 580 |
+
self, input_ids: torch.LongTensor, assistant_sequences: torch.LongTensor, assistant_input_ids: torch.LongTensor
|
| 581 |
+
) -> torch.LongTensor:
|
| 582 |
+
"""Processes assistant outputs to obtain target input IDs."""
|
| 583 |
+
num_prev_assistant = self.prev_assistant_ids.shape[1]
|
| 584 |
+
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
|
| 585 |
+
|
| 586 |
+
new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
|
| 587 |
+
assistant_sequences[:, start_assistant_look_index:],
|
| 588 |
+
source_tokenizer=self.assistant_tokenizer,
|
| 589 |
+
destination_tokenizer=self.target_tokenizer,
|
| 590 |
+
)
|
| 591 |
+
target_prompt_use_length = new_target_ids_from_window.shape[1]
|
| 592 |
+
|
| 593 |
+
target_prompt_use = input_ids[:, -target_prompt_use_length:]
|
| 594 |
+
|
| 595 |
+
_, target_new_tokens_only, _ = self._get_tokens_diag(target_prompt_use, new_target_ids_from_window)
|
| 596 |
+
|
| 597 |
+
new_target_ids = input_ids
|
| 598 |
+
|
| 599 |
+
if target_new_tokens_only is not None:
|
| 600 |
+
if target_new_tokens_only.shape[1] > 0:
|
| 601 |
+
new_target_ids = torch.cat([new_target_ids, target_new_tokens_only], dim=-1)
|
| 602 |
+
else:
|
| 603 |
+
# edge case: in case of no intersection between prompt and new_target_ids
|
| 604 |
+
new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1)
|
| 605 |
+
|
| 606 |
+
if hasattr(self.generation_config, "max_length"):
|
| 607 |
+
new_target_ids = new_target_ids[:, : self.generation_config.max_length]
|
| 608 |
+
|
| 609 |
+
return new_target_ids
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class PromptLookupCandidateGenerator(CandidateGenerator):
|
| 613 |
+
"""
|
| 614 |
+
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
|
| 615 |
+
likely continuations in the provided prompt (input_ids) itself.
|
| 616 |
+
Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding
|
| 617 |
+
|
| 618 |
+
Args:
|
| 619 |
+
max_matching_ngram_size (`int`):
|
| 620 |
+
The maximum ngram size to be considered for matching in the prompt
|
| 621 |
+
num_output_tokens (`int`):
|
| 622 |
+
The number of tokens to be output as candidate tokens.
|
| 623 |
+
max_length (`int`):
|
| 624 |
+
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
|
| 625 |
+
Defaults to 20, which is the max length used as default in generation config.
|
| 626 |
+
"""
|
| 627 |
+
|
| 628 |
+
def __init__(
|
| 629 |
+
self,
|
| 630 |
+
eos_token_id: torch.Tensor = None,
|
| 631 |
+
num_output_tokens: int = 10,
|
| 632 |
+
max_matching_ngram_size: int = None,
|
| 633 |
+
max_length: int = 20,
|
| 634 |
+
):
|
| 635 |
+
self.num_output_tokens = num_output_tokens
|
| 636 |
+
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
| 637 |
+
self.max_length = max_length
|
| 638 |
+
self.eos_token_id = eos_token_id
|
| 639 |
+
|
| 640 |
+
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
|
| 641 |
+
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
|
| 642 |
+
|
| 643 |
+
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
| 644 |
+
"""
|
| 645 |
+
Fetches the candidates to be tried for the current input.
|
| 646 |
+
|
| 647 |
+
Args:
|
| 648 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 649 |
+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
| 650 |
+
|
| 651 |
+
Return:
|
| 652 |
+
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
|
| 653 |
+
"""
|
| 654 |
+
input_length = input_ids.size(1)
|
| 655 |
+
|
| 656 |
+
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
|
| 657 |
+
if self.max_length == input_length + 1:
|
| 658 |
+
return input_ids, None
|
| 659 |
+
|
| 660 |
+
chosen_ids = None
|
| 661 |
+
match_found = False
|
| 662 |
+
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
|
| 663 |
+
# Create sliding windows of size ngram_size
|
| 664 |
+
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
|
| 665 |
+
|
| 666 |
+
# Convert ngram to a tensor for comparison
|
| 667 |
+
ngram_tensor = input_ids[0, -ngram_size:]
|
| 668 |
+
|
| 669 |
+
# Find where the windows match the ngram
|
| 670 |
+
matches = (windows == ngram_tensor).all(dim=2)
|
| 671 |
+
|
| 672 |
+
# Get the indices of matches
|
| 673 |
+
match_indices = matches.nonzero(as_tuple=True)[1]
|
| 674 |
+
|
| 675 |
+
# Iterate through match indices to find a valid continuation
|
| 676 |
+
for idx in match_indices:
|
| 677 |
+
start_idx = idx + ngram_size
|
| 678 |
+
end_idx = start_idx + self.num_output_tokens
|
| 679 |
+
end_idx = min(end_idx, input_length, self.max_length)
|
| 680 |
+
|
| 681 |
+
if start_idx < end_idx:
|
| 682 |
+
chosen_ids = input_ids[0, start_idx:end_idx]
|
| 683 |
+
match_found = True
|
| 684 |
+
|
| 685 |
+
# remove remaining candidate ids if an "eos" token is found, otherwise the target model may
|
| 686 |
+
# accept eos and the rest as valid, thus not stopping generation after "eos"
|
| 687 |
+
# NOTE: below code is written based on the fact that assisted decoding supports only bs=1
|
| 688 |
+
mask = isin_mps_friendly(chosen_ids, self.eos_token_id)
|
| 689 |
+
match_indices_eos = torch.nonzero(mask)
|
| 690 |
+
if match_indices_eos.numel() > 0:
|
| 691 |
+
first_eos_index = match_indices_eos[0].item()
|
| 692 |
+
chosen_ids = chosen_ids[:first_eos_index]
|
| 693 |
+
break
|
| 694 |
+
if match_found:
|
| 695 |
+
break
|
| 696 |
+
|
| 697 |
+
if chosen_ids is None or len(chosen_ids) == 0:
|
| 698 |
+
# In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
|
| 699 |
+
return input_ids, None
|
| 700 |
+
|
| 701 |
+
# Now need extend input_ids with chosen_ids
|
| 702 |
+
chosen_ids = chosen_ids.unsqueeze(0)
|
| 703 |
+
candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
|
| 704 |
+
# assisted_generation expects logits as well, but we don't have those here, so returning None
|
| 705 |
+
return candidate_input_ids, None
|
| 706 |
+
|
| 707 |
+
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
|
| 708 |
+
"""
|
| 709 |
+
Updates the candidate generation strategy based on the outcomes.
|
| 710 |
+
|
| 711 |
+
Args:
|
| 712 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 713 |
+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
| 714 |
+
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
|
| 715 |
+
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
|
| 716 |
+
beam search or log softmax for each vocabulary token when using beam search
|
| 717 |
+
num_matches (`int`):
|
| 718 |
+
The number of matches between the candidate sequences and the model predictions.
|
| 719 |
+
"""
|
| 720 |
+
# Currently does nothing
|
| 721 |
+
return
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
class EarlyExitCandidateGenerator(AssistedCandidateGenerator):
|
| 725 |
+
"""
|
| 726 |
+
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
|
| 727 |
+
candidates through the use of **the model itself**, exiting early. Can only be used with models that support early
|
| 728 |
+
exit, e.g., `facebook/layerskip-llama3.2-1B`.
|
| 729 |
+
|
| 730 |
+
Args:
|
| 731 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 732 |
+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
| 733 |
+
assistant_model (`PreTrainedModel`):
|
| 734 |
+
The original model. This model must support early exit (i.e. is trained to compute logits in earlier
|
| 735 |
+
layers).
|
| 736 |
+
generation_config (`~generation.GenerationConfig`, *optional*):
|
| 737 |
+
The generation configuration to be used as base parametrization for the generation call.
|
| 738 |
+
logits_processor (`LogitsProcessorList`):
|
| 739 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
| 740 |
+
used to modify the prediction scores of the language modeling head applied at each generation step.
|
| 741 |
+
model_kwargs (`Dict`):
|
| 742 |
+
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
|
| 743 |
+
model as well.
|
| 744 |
+
inputs_tensor (`torch.Tensor`, *optional*):
|
| 745 |
+
The model input tensor. In encoder-decoder models, this is the encoder input.
|
| 746 |
+
"""
|
| 747 |
+
|
| 748 |
+
def __init__(
|
| 749 |
+
self,
|
| 750 |
+
input_ids: torch.LongTensor,
|
| 751 |
+
assistant_model: "PreTrainedModel",
|
| 752 |
+
generation_config: "GenerationConfig",
|
| 753 |
+
model_kwargs: Dict,
|
| 754 |
+
inputs_tensor: Optional[torch.Tensor] = None,
|
| 755 |
+
logits_processor: "LogitsProcessorList" = None,
|
| 756 |
+
):
|
| 757 |
+
super().__init__(
|
| 758 |
+
input_ids=input_ids,
|
| 759 |
+
assistant_model=assistant_model,
|
| 760 |
+
generation_config=generation_config,
|
| 761 |
+
model_kwargs=model_kwargs,
|
| 762 |
+
inputs_tensor=inputs_tensor,
|
| 763 |
+
logits_processor=logits_processor,
|
| 764 |
+
)
|
| 765 |
+
# We have to move early exit out of the generation config, otherwise the assistant will also call `generate`
|
| 766 |
+
# with early exit
|
| 767 |
+
self.assistant_early_exit = self.generation_config.assistant_early_exit
|
| 768 |
+
self.generation_config.assistant_early_exit = None
|
| 769 |
+
|
| 770 |
+
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
| 771 |
+
# Temporarily sets the number of hidden layers to the early exit value
|
| 772 |
+
base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix)
|
| 773 |
+
original_num_hidden_layers = base_model.config.num_hidden_layers
|
| 774 |
+
base_model.config.num_hidden_layers = self.assistant_early_exit
|
| 775 |
+
candidate_ids, candidate_logits = super().get_candidates(input_ids)
|
| 776 |
+
base_model.config.num_hidden_layers = original_num_hidden_layers
|
| 777 |
+
return candidate_ids, candidate_logits
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
def _crop_past_key_values(model, past_key_values, max_length):
|
| 781 |
+
"""Crops the past key values up to a certain maximum length."""
|
| 782 |
+
new_past = []
|
| 783 |
+
if model.config.is_encoder_decoder:
|
| 784 |
+
for idx in range(len(past_key_values)):
|
| 785 |
+
new_past.append(
|
| 786 |
+
(
|
| 787 |
+
past_key_values[idx][0][:, :, :max_length, :],
|
| 788 |
+
past_key_values[idx][1][:, :, :max_length, :],
|
| 789 |
+
past_key_values[idx][2],
|
| 790 |
+
past_key_values[idx][3],
|
| 791 |
+
)
|
| 792 |
+
)
|
| 793 |
+
past_key_values = tuple(new_past)
|
| 794 |
+
# gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model
|
| 795 |
+
elif "gptbigcode" in model.__class__.__name__.lower() or (
|
| 796 |
+
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
|
| 797 |
+
):
|
| 798 |
+
if model.config.multi_query:
|
| 799 |
+
for idx in range(len(past_key_values)):
|
| 800 |
+
past_key_values[idx] = past_key_values[idx][:, :max_length, :]
|
| 801 |
+
else:
|
| 802 |
+
for idx in range(len(past_key_values)):
|
| 803 |
+
past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
|
| 804 |
+
elif isinstance(past_key_values, DynamicCache):
|
| 805 |
+
past_key_values.crop(max_length)
|
| 806 |
+
elif past_key_values is not None:
|
| 807 |
+
for idx in range(len(past_key_values)):
|
| 808 |
+
if past_key_values[idx] != ([], []):
|
| 809 |
+
new_past.append(
|
| 810 |
+
(
|
| 811 |
+
past_key_values[idx][0][:, :, :max_length, :],
|
| 812 |
+
past_key_values[idx][1][:, :, :max_length, :],
|
| 813 |
+
)
|
| 814 |
+
)
|
| 815 |
+
else:
|
| 816 |
+
new_past.append((past_key_values[idx][0], past_key_values[idx][1]))
|
| 817 |
+
past_key_values = tuple(new_past)
|
| 818 |
+
return past_key_values
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
|
| 822 |
+
"""Expands or crops the model's mask for decoding purposes, to the defined length"""
|
| 823 |
+
|
| 824 |
+
mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
|
| 825 |
+
if mask_key not in model_kwargs:
|
| 826 |
+
return model_kwargs
|
| 827 |
+
|
| 828 |
+
mask = model_kwargs[mask_key]
|
| 829 |
+
mask_length_diff = new_length - mask.shape[1]
|
| 830 |
+
|
| 831 |
+
if mask_length_diff < 0:
|
| 832 |
+
model_kwargs[mask_key] = mask[:, :mask_length_diff]
|
| 833 |
+
elif mask_length_diff > 0:
|
| 834 |
+
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
|
| 835 |
+
|
| 836 |
+
# Handle cross attention models
|
| 837 |
+
if "cross_attention_mask" in model_kwargs:
|
| 838 |
+
# Mllama case
|
| 839 |
+
cross_mask = model_kwargs["cross_attention_mask"]
|
| 840 |
+
if mask_length_diff < 0:
|
| 841 |
+
model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff]
|
| 842 |
+
elif mask_length_diff > 0:
|
| 843 |
+
new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1)
|
| 844 |
+
model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
|
| 845 |
+
elif "image_attention_mask" in model_kwargs:
|
| 846 |
+
# IDEFICS case
|
| 847 |
+
cross_mask = model_kwargs["image_attention_mask"]
|
| 848 |
+
if mask_length_diff < 0:
|
| 849 |
+
model_kwargs["image_attention_mask"] = cross_mask[:, :mask_length_diff]
|
| 850 |
+
elif mask_length_diff > 0:
|
| 851 |
+
new_mask = cross_mask[:, -1:, :].repeat(1, mask_length_diff, 1)
|
| 852 |
+
model_kwargs["image_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
|
| 853 |
+
|
| 854 |
+
return model_kwargs
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
|
| 858 |
+
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
|
| 859 |
+
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
|
| 860 |
+
return model_kwargs
|
| 861 |
+
|
| 862 |
+
token_type_ids = model_kwargs["token_type_ids"]
|
| 863 |
+
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
|
| 864 |
+
type_length_diff = new_length - token_type_ids.shape[1]
|
| 865 |
+
|
| 866 |
+
if type_length_diff < 0:
|
| 867 |
+
token_type_ids = token_type_ids[:, :type_length_diff]
|
| 868 |
+
elif type_length_diff > 0:
|
| 869 |
+
token_type_copies = final_token_type.repeat(1, type_length_diff)
|
| 870 |
+
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
|
| 871 |
+
return model_kwargs
|
.venv/lib/python3.11/site-packages/transformers/generation/configuration_utils.py
ADDED
|
@@ -0,0 +1,1628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Generation configuration class and utilities."""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from abc import ABC, abstractmethod
|
| 22 |
+
from dataclasses import dataclass, is_dataclass
|
| 23 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
| 24 |
+
|
| 25 |
+
from .. import __version__
|
| 26 |
+
from ..configuration_utils import PretrainedConfig
|
| 27 |
+
from ..utils import (
|
| 28 |
+
GENERATION_CONFIG_NAME,
|
| 29 |
+
ExplicitEnum,
|
| 30 |
+
PushToHubMixin,
|
| 31 |
+
cached_file,
|
| 32 |
+
download_url,
|
| 33 |
+
extract_commit_hash,
|
| 34 |
+
is_remote_url,
|
| 35 |
+
is_torch_available,
|
| 36 |
+
logging,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if TYPE_CHECKING:
|
| 41 |
+
from ..modeling_utils import PreTrainedModel
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__)
|
| 45 |
+
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
|
| 46 |
+
NEEDS_CACHE_CONFIG = {}
|
| 47 |
+
NEED_SETUP_CACHE_CLASSES_MAPPING = {}
|
| 48 |
+
QUANT_BACKEND_CLASSES_MAPPING = {}
|
| 49 |
+
ALL_CACHE_IMPLEMENTATIONS = []
|
| 50 |
+
|
| 51 |
+
if is_torch_available():
|
| 52 |
+
from ..cache_utils import (
|
| 53 |
+
HQQQuantizedCache,
|
| 54 |
+
HybridCache,
|
| 55 |
+
MambaCache,
|
| 56 |
+
OffloadedStaticCache,
|
| 57 |
+
QuantizedCacheConfig,
|
| 58 |
+
QuantoQuantizedCache,
|
| 59 |
+
SlidingWindowCache,
|
| 60 |
+
StaticCache,
|
| 61 |
+
StaticCacheConfig,
|
| 62 |
+
)
|
| 63 |
+
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
|
| 64 |
+
|
| 65 |
+
NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
|
| 66 |
+
NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig
|
| 67 |
+
NEED_SETUP_CACHE_CLASSES_MAPPING = {
|
| 68 |
+
"static": StaticCache,
|
| 69 |
+
"offloaded_static": OffloadedStaticCache,
|
| 70 |
+
"sliding_window": SlidingWindowCache,
|
| 71 |
+
"hybrid": HybridCache,
|
| 72 |
+
"mamba": MambaCache,
|
| 73 |
+
}
|
| 74 |
+
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
| 75 |
+
ALL_CACHE_IMPLEMENTATIONS = (
|
| 76 |
+
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class GenerationMode(ExplicitEnum):
|
| 81 |
+
"""
|
| 82 |
+
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
# Non-beam methods
|
| 86 |
+
CONTRASTIVE_SEARCH = "contrastive_search"
|
| 87 |
+
GREEDY_SEARCH = "greedy_search"
|
| 88 |
+
SAMPLE = "sample"
|
| 89 |
+
ASSISTED_GENERATION = "assisted_generation"
|
| 90 |
+
DOLA_GENERATION = "dola_generation"
|
| 91 |
+
# Beam methods
|
| 92 |
+
BEAM_SEARCH = "beam_search"
|
| 93 |
+
BEAM_SAMPLE = "beam_sample"
|
| 94 |
+
CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
|
| 95 |
+
GROUP_BEAM_SEARCH = "group_beam_search"
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class GenerationConfig(PushToHubMixin):
|
| 99 |
+
# no-format
|
| 100 |
+
"""
|
| 101 |
+
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
|
| 102 |
+
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
| 103 |
+
|
| 104 |
+
- *greedy decoding* if `num_beams=1` and `do_sample=False`
|
| 105 |
+
- *contrastive search* if `penalty_alpha>0.` and `top_k>1`
|
| 106 |
+
- *multinomial sampling* if `num_beams=1` and `do_sample=True`
|
| 107 |
+
- *beam-search decoding* if `num_beams>1` and `do_sample=False`
|
| 108 |
+
- *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
|
| 109 |
+
- *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
|
| 110 |
+
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
|
| 111 |
+
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
|
| 112 |
+
- *dola decoding* if `dola_layers` is passed to `.generate()`
|
| 113 |
+
|
| 114 |
+
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
| 115 |
+
|
| 116 |
+
<Tip>
|
| 117 |
+
|
| 118 |
+
A large number of these flags control the logits or the stopping criteria of the generation. Make sure you check
|
| 119 |
+
the [generate-related classes](https://huggingface.co/docs/transformers/internal/generation_utils) for a full
|
| 120 |
+
description of the possible manipulations, as well as examples of their usage.
|
| 121 |
+
|
| 122 |
+
</Tip>
|
| 123 |
+
|
| 124 |
+
Arg:
|
| 125 |
+
> Parameters that control the length of the output
|
| 126 |
+
|
| 127 |
+
max_length (`int`, *optional*, defaults to 20):
|
| 128 |
+
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
|
| 129 |
+
`max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
|
| 130 |
+
max_new_tokens (`int`, *optional*):
|
| 131 |
+
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
| 132 |
+
min_length (`int`, *optional*, defaults to 0):
|
| 133 |
+
The minimum length of the sequence to be generated. Corresponds to the length of the input prompt +
|
| 134 |
+
`min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
|
| 135 |
+
min_new_tokens (`int`, *optional*):
|
| 136 |
+
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
| 137 |
+
early_stopping (`bool` or `str`, *optional*, defaults to `False`):
|
| 138 |
+
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
|
| 139 |
+
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
|
| 140 |
+
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
| 141 |
+
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
|
| 142 |
+
beam search algorithm).
|
| 143 |
+
max_time (`float`, *optional*):
|
| 144 |
+
The maximum amount of time you allow the computation to run for in seconds. generation will still finish
|
| 145 |
+
the current pass after allocated time has been passed.
|
| 146 |
+
stop_strings (`str or List[str]`, *optional*):
|
| 147 |
+
A string or a list of strings that should terminate generation if the model outputs them.
|
| 148 |
+
|
| 149 |
+
> Parameters that control the generation strategy used
|
| 150 |
+
|
| 151 |
+
do_sample (`bool`, *optional*, defaults to `False`):
|
| 152 |
+
Whether or not to use sampling ; use greedy decoding otherwise.
|
| 153 |
+
num_beams (`int`, *optional*, defaults to 1):
|
| 154 |
+
Number of beams for beam search. 1 means no beam search.
|
| 155 |
+
num_beam_groups (`int`, *optional*, defaults to 1):
|
| 156 |
+
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
| 157 |
+
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
| 158 |
+
penalty_alpha (`float`, *optional*):
|
| 159 |
+
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
|
| 160 |
+
dola_layers (`str` or `List[int]`, *optional*):
|
| 161 |
+
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
|
| 162 |
+
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
|
| 163 |
+
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
|
| 164 |
+
layers up to the last 20 layers.
|
| 165 |
+
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
|
| 166 |
+
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
|
| 167 |
+
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
|
| 168 |
+
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
|
| 169 |
+
|
| 170 |
+
> Parameters that control the cache
|
| 171 |
+
|
| 172 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 173 |
+
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
| 174 |
+
speed up decoding.
|
| 175 |
+
cache_implementation (`str`, *optional*, default to `None`):
|
| 176 |
+
Name of the cache class that will be instantiated in `generate`, for faster decoding. Possible values are:
|
| 177 |
+
|
| 178 |
+
- `"static"`: [`StaticCache`]
|
| 179 |
+
- `"offloaded_static"`: [`OffloadedStaticCache`]
|
| 180 |
+
- `"sliding_window"`: [`SlidingWindowCache`]
|
| 181 |
+
- `"hybrid"`: [`HybridCache`]
|
| 182 |
+
- `"mamba"`: [`MambaCache`]
|
| 183 |
+
- `"quantized"`: [`QuantizedCache`]
|
| 184 |
+
|
| 185 |
+
We support other cache types, but they must be manually instantiated and
|
| 186 |
+
passed to `generate` through the `past_key_values` argument. See our
|
| 187 |
+
[cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
|
| 188 |
+
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
|
| 189 |
+
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
|
| 190 |
+
it will be converted to its repsective `CacheConfig` internally.
|
| 191 |
+
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
|
| 192 |
+
return_legacy_cache (`bool`, *optional*, default to `True`):
|
| 193 |
+
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
|
| 194 |
+
|
| 195 |
+
> Parameters for manipulation of the model output logits
|
| 196 |
+
|
| 197 |
+
temperature (`float`, *optional*, defaults to 1.0):
|
| 198 |
+
The value used to module the next token probabilities. This value is set in a model's `generation_config.json` file. If it isn't set, the default value is 1.0
|
| 199 |
+
top_k (`int`, *optional*, defaults to 50):
|
| 200 |
+
The number of highest probability vocabulary tokens to keep for top-k-filtering. This value is set in a model's `generation_config.json` file. If it isn't set, the default value is 50.
|
| 201 |
+
top_p (`float`, *optional*, defaults to 1.0):
|
| 202 |
+
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to
|
| 203 |
+
`top_p` or higher are kept for generation. This value is set in a model's `generation_config.json` file. If it isn't set, the default value is 1.0
|
| 204 |
+
min_p (`float`, *optional*):
|
| 205 |
+
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
|
| 206 |
+
value between 0 and 1. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in
|
| 207 |
+
the 0.99-0.8 range (use the opposite of normal `top_p` values).
|
| 208 |
+
typical_p (`float`, *optional*, defaults to 1.0):
|
| 209 |
+
Local typicality measures how similar the conditional probability of predicting a target token next is to
|
| 210 |
+
the expected conditional probability of predicting a random token next, given the partial text already
|
| 211 |
+
generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that
|
| 212 |
+
add up to `typical_p` or higher are kept for generation. See [this
|
| 213 |
+
paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
|
| 214 |
+
epsilon_cutoff (`float`, *optional*, defaults to 0.0):
|
| 215 |
+
If set to float strictly between 0 and 1, only tokens with a conditional probability greater than
|
| 216 |
+
`epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the
|
| 217 |
+
size of the model. See [Truncation Sampling as Language Model
|
| 218 |
+
Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
|
| 219 |
+
eta_cutoff (`float`, *optional*, defaults to 0.0):
|
| 220 |
+
Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between
|
| 221 |
+
0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) *
|
| 222 |
+
exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token
|
| 223 |
+
probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3,
|
| 224 |
+
depending on the size of the model. See [Truncation Sampling as Language Model
|
| 225 |
+
Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
|
| 226 |
+
diversity_penalty (`float`, *optional*, defaults to 0.0):
|
| 227 |
+
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
|
| 228 |
+
particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.
|
| 229 |
+
repetition_penalty (`float`, *optional*, defaults to 1.0):
|
| 230 |
+
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
| 231 |
+
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
| 232 |
+
encoder_repetition_penalty (`float`, *optional*, defaults to 1.0):
|
| 233 |
+
The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are not in the
|
| 234 |
+
original input. 1.0 means no penalty.
|
| 235 |
+
length_penalty (`float`, *optional*, defaults to 1.0):
|
| 236 |
+
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
| 237 |
+
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
| 238 |
+
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
| 239 |
+
`length_penalty` < 0.0 encourages shorter sequences.
|
| 240 |
+
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
|
| 241 |
+
If set to int > 0, all ngrams of that size can only occur once.
|
| 242 |
+
bad_words_ids (`List[List[int]]`, *optional*):
|
| 243 |
+
List of list of token ids that are not allowed to be generated. Check
|
| 244 |
+
[`~generation.NoBadWordsLogitsProcessor`] for further documentation and examples.
|
| 245 |
+
force_words_ids (`List[List[int]]` or `List[List[List[int]]]`, *optional*):
|
| 246 |
+
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of
|
| 247 |
+
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
|
| 248 |
+
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
|
| 249 |
+
can allow different forms of each word.
|
| 250 |
+
renormalize_logits (`bool`, *optional*, defaults to `False`):
|
| 251 |
+
Whether to renormalize the logits after applying all the logits processors (including the custom
|
| 252 |
+
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
|
| 253 |
+
are normalized but some logit processors break the normalization.
|
| 254 |
+
constraints (`List[Constraint]`, *optional*):
|
| 255 |
+
Custom constraints that can be added to the generation to ensure that the output will contain the use of
|
| 256 |
+
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
|
| 257 |
+
forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
|
| 258 |
+
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
|
| 259 |
+
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
|
| 260 |
+
language token.
|
| 261 |
+
forced_eos_token_id (`int` or List[int]`, *optional*, defaults to `model.config.forced_eos_token_id`):
|
| 262 |
+
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
|
| 263 |
+
list to set multiple *end-of-sequence* tokens.
|
| 264 |
+
remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
|
| 265 |
+
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash.
|
| 266 |
+
Note that using `remove_invalid_values` can slow down generation.
|
| 267 |
+
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
|
| 268 |
+
This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
|
| 269 |
+
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where
|
| 270 |
+
penalty starts and `decay_factor` represents the factor of exponential decay
|
| 271 |
+
suppress_tokens (`List[int]`, *optional*):
|
| 272 |
+
A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their
|
| 273 |
+
log probs to `-inf` so that they are not sampled.
|
| 274 |
+
begin_suppress_tokens (`List[int]`, *optional*):
|
| 275 |
+
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
|
| 276 |
+
processor will set their log probs to `-inf` so that they are not sampled.
|
| 277 |
+
forced_decoder_ids (`List[List[int]]`, *optional*):
|
| 278 |
+
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
|
| 279 |
+
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
|
| 280 |
+
of index 123.
|
| 281 |
+
sequence_bias (`Dict[Tuple[int], float]`, *optional*)):
|
| 282 |
+
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
| 283 |
+
sequence being selected, while negative biases do the opposite. Check
|
| 284 |
+
[`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
|
| 285 |
+
token_healing (`bool`, *optional*, defaults to `False`):
|
| 286 |
+
Heal tail tokens of prompts by replacing them with their appropriate extensions.
|
| 287 |
+
This enhances the quality of completions for prompts affected by greedy tokenization bias.
|
| 288 |
+
guidance_scale (`float`, *optional*):
|
| 289 |
+
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
| 290 |
+
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
| 291 |
+
prompt, usually at the expense of poorer quality.
|
| 292 |
+
low_memory (`bool`, *optional*):
|
| 293 |
+
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
|
| 294 |
+
Used with beam search and contrastive search.
|
| 295 |
+
watermarking_config (`BaseWatermarkingConfig` or `dict`, *optional*):
|
| 296 |
+
Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green"
|
| 297 |
+
tokens. See the docs of [`SynthIDTextWatermarkingConfig`] and [`WatermarkingConfig`] for more
|
| 298 |
+
details. If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
|
| 299 |
+
|
| 300 |
+
> Parameters that define the output variables of generate
|
| 301 |
+
|
| 302 |
+
num_return_sequences (`int`, *optional*, defaults to 1):
|
| 303 |
+
The number of independently computed returned sequences for each element in the batch.
|
| 304 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
| 305 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 306 |
+
tensors for more details.
|
| 307 |
+
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
| 308 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 309 |
+
more details.
|
| 310 |
+
output_scores (`bool`, *optional*, defaults to `False`):
|
| 311 |
+
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
| 312 |
+
output_logits (`bool`, *optional*):
|
| 313 |
+
Whether or not to return the unprocessed prediction logit scores. See `logits` under returned tensors for
|
| 314 |
+
more details.
|
| 315 |
+
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
| 316 |
+
Whether or not to return a [`~utils.ModelOutput`], as opposed to returning exclusively the generated
|
| 317 |
+
sequence. This flag must be set to `True` to return the generation cache (when `use_cache` is `True`)
|
| 318 |
+
or optional outputs (see flags starting with `output_`)
|
| 319 |
+
|
| 320 |
+
> Special tokens that can be used at generation time
|
| 321 |
+
|
| 322 |
+
pad_token_id (`int`, *optional*):
|
| 323 |
+
The id of the *padding* token.
|
| 324 |
+
bos_token_id (`int`, *optional*):
|
| 325 |
+
The id of the *beginning-of-sequence* token.
|
| 326 |
+
eos_token_id (`Union[int, List[int]]`, *optional*):
|
| 327 |
+
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
| 328 |
+
|
| 329 |
+
> Generation parameters exclusive to encoder-decoder models
|
| 330 |
+
|
| 331 |
+
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
|
| 332 |
+
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
|
| 333 |
+
`decoder_input_ids`.
|
| 334 |
+
decoder_start_token_id (`int` or `List[int]`, *optional*):
|
| 335 |
+
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token or a list of length
|
| 336 |
+
`batch_size`. Indicating a list enables different start ids for each element in the batch
|
| 337 |
+
(e.g. multilingual models with different target languages in one batch)
|
| 338 |
+
|
| 339 |
+
> Generation parameters exclusive to assistant generation
|
| 340 |
+
is_assistant (`bool`, *optional*, defaults to `False`):
|
| 341 |
+
Whether the model is an assistant (draft) model.
|
| 342 |
+
num_assistant_tokens (`int`, *optional*, defaults to 20):
|
| 343 |
+
Defines the number of _speculative tokens_ that shall be generated by the assistant model before being
|
| 344 |
+
checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation
|
| 345 |
+
more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant
|
| 346 |
+
model requires lots of corrections, lower speed-ups are reached.
|
| 347 |
+
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"constant"`):
|
| 348 |
+
Defines the schedule at which max assistant tokens shall be changed during inference.
|
| 349 |
+
- `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else
|
| 350 |
+
reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
|
| 351 |
+
- `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
|
| 352 |
+
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
|
| 353 |
+
assistant_confidence_threshold (`float`, *optional*, defaults to 0.4):
|
| 354 |
+
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
|
| 355 |
+
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
|
| 356 |
+
(defined by `num_assistant_tokens`) is not yet reached. The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes, biased towards avoiding false negatives.
|
| 357 |
+
`assistant_confidence_threshold` value is persistent over multiple generation calls with the same assistant model.
|
| 358 |
+
It is an unsupervised version of the dynamic speculation lookahead
|
| 359 |
+
from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://arxiv.org/abs/2405.04304>.
|
| 360 |
+
prompt_lookup_num_tokens (`int`, *optional*):
|
| 361 |
+
The number of tokens to be output as candidate tokens.
|
| 362 |
+
max_matching_ngram_size (`int`, *optional*):
|
| 363 |
+
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
|
| 364 |
+
assistant_early_exit(`int`, *optional*):
|
| 365 |
+
If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
|
| 366 |
+
models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head).
|
| 367 |
+
assistant_lookbehind(`int`, *optional*, defaults to 10):
|
| 368 |
+
If set to a positive integer, the re-encodeing process will additionally consider the last `assistant_lookbehind` assistant tokens
|
| 369 |
+
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
|
| 370 |
+
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
|
| 371 |
+
target_lookbehind(`int`, *optional*, defaults to 10):
|
| 372 |
+
If set to a positive integer, the re-encodeing process will additionally consider the last `target_lookbehind` target tokens
|
| 373 |
+
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
|
| 374 |
+
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
|
| 375 |
+
|
| 376 |
+
> Parameters related to performances and compilation
|
| 377 |
+
|
| 378 |
+
compile_config (CompileConfig, *optional*):
|
| 379 |
+
If using a static cache, this controls how `generate` will `compile` the forward pass for performance
|
| 380 |
+
gains.
|
| 381 |
+
|
| 382 |
+
> Wild card
|
| 383 |
+
|
| 384 |
+
generation_kwargs:
|
| 385 |
+
Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not
|
| 386 |
+
present in `generate`'s signature will be used in the model forward pass.
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits")
|
| 390 |
+
|
| 391 |
+
def __init__(self, **kwargs):
|
| 392 |
+
# Parameters that control the length of the output
|
| 393 |
+
self.max_length = kwargs.pop("max_length", 20)
|
| 394 |
+
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
|
| 395 |
+
self.min_length = kwargs.pop("min_length", 0)
|
| 396 |
+
self.min_new_tokens = kwargs.pop("min_new_tokens", None)
|
| 397 |
+
self.early_stopping = kwargs.pop("early_stopping", False)
|
| 398 |
+
self.max_time = kwargs.pop("max_time", None)
|
| 399 |
+
self.stop_strings = kwargs.pop("stop_strings", None)
|
| 400 |
+
|
| 401 |
+
# Parameters that control the generation strategy used
|
| 402 |
+
self.do_sample = kwargs.pop("do_sample", False)
|
| 403 |
+
self.num_beams = kwargs.pop("num_beams", 1)
|
| 404 |
+
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
| 405 |
+
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
|
| 406 |
+
self.dola_layers = kwargs.pop("dola_layers", None)
|
| 407 |
+
|
| 408 |
+
# Parameters that control the cache
|
| 409 |
+
self.use_cache = kwargs.pop("use_cache", True)
|
| 410 |
+
self.cache_implementation = kwargs.pop("cache_implementation", None)
|
| 411 |
+
self.cache_config = kwargs.pop("cache_config", None)
|
| 412 |
+
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
|
| 413 |
+
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
|
| 414 |
+
if self.cache_config is None:
|
| 415 |
+
self.cache_config = cache_config_class()
|
| 416 |
+
elif isinstance(self.cache_config, dict):
|
| 417 |
+
self.cache_config = cache_config_class.from_dict(self.cache_config)
|
| 418 |
+
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
|
| 419 |
+
|
| 420 |
+
# Parameters for manipulation of the model output logits
|
| 421 |
+
self.temperature = kwargs.pop("temperature", 1.0)
|
| 422 |
+
self.top_k = kwargs.pop("top_k", 50)
|
| 423 |
+
self.top_p = kwargs.pop("top_p", 1.0)
|
| 424 |
+
self.min_p = kwargs.pop("min_p", None)
|
| 425 |
+
self.typical_p = kwargs.pop("typical_p", 1.0)
|
| 426 |
+
self.epsilon_cutoff = kwargs.pop("epsilon_cutoff", 0.0)
|
| 427 |
+
self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0)
|
| 428 |
+
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
| 429 |
+
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
| 430 |
+
self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty", 1.0)
|
| 431 |
+
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
| 432 |
+
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
| 433 |
+
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
| 434 |
+
self.force_words_ids = kwargs.pop("force_words_ids", None)
|
| 435 |
+
self.renormalize_logits = kwargs.pop("renormalize_logits", False)
|
| 436 |
+
self.constraints = kwargs.pop("constraints", None)
|
| 437 |
+
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
| 438 |
+
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
| 439 |
+
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
|
| 440 |
+
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
|
| 441 |
+
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
|
| 442 |
+
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
|
| 443 |
+
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
|
| 444 |
+
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
| 445 |
+
self.token_healing = kwargs.pop("token_healing", False)
|
| 446 |
+
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
| 447 |
+
self.low_memory = kwargs.pop("low_memory", None)
|
| 448 |
+
watermarking_config = kwargs.pop("watermarking_config", None)
|
| 449 |
+
if watermarking_config is None:
|
| 450 |
+
self.watermarking_config = None
|
| 451 |
+
elif isinstance(watermarking_config, BaseWatermarkingConfig):
|
| 452 |
+
self.watermarking_config = watermarking_config
|
| 453 |
+
else:
|
| 454 |
+
self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config)
|
| 455 |
+
|
| 456 |
+
# Parameters that define the output variables of `generate`
|
| 457 |
+
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
| 458 |
+
self.output_attentions = kwargs.pop("output_attentions", False)
|
| 459 |
+
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
| 460 |
+
self.output_scores = kwargs.pop("output_scores", False)
|
| 461 |
+
self.output_logits = kwargs.pop("output_logits", None)
|
| 462 |
+
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
| 463 |
+
|
| 464 |
+
# Special tokens that can be used at generation time
|
| 465 |
+
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
| 466 |
+
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
| 467 |
+
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
| 468 |
+
|
| 469 |
+
# Generation parameters exclusive to encoder-decoder models
|
| 470 |
+
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
|
| 471 |
+
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
| 472 |
+
|
| 473 |
+
# Assistant generation
|
| 474 |
+
self.is_assistant = False
|
| 475 |
+
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20)
|
| 476 |
+
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant")
|
| 477 |
+
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4)
|
| 478 |
+
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
| 479 |
+
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
|
| 480 |
+
self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)
|
| 481 |
+
## assistant generation for different tokenizers, the windows size for assistant/target model
|
| 482 |
+
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
|
| 483 |
+
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
|
| 484 |
+
|
| 485 |
+
# Performances
|
| 486 |
+
self.compile_config = kwargs.pop("compile_config", CompileConfig())
|
| 487 |
+
|
| 488 |
+
# Wild card
|
| 489 |
+
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
| 490 |
+
|
| 491 |
+
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
|
| 492 |
+
# interface.
|
| 493 |
+
self._from_model_config = kwargs.pop("_from_model_config", False)
|
| 494 |
+
self._commit_hash = kwargs.pop("_commit_hash", None)
|
| 495 |
+
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
| 496 |
+
|
| 497 |
+
# Additional attributes without default values
|
| 498 |
+
if not self._from_model_config:
|
| 499 |
+
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
|
| 500 |
+
# model's default configuration file
|
| 501 |
+
for key, value in kwargs.items():
|
| 502 |
+
try:
|
| 503 |
+
setattr(self, key, value)
|
| 504 |
+
except AttributeError as err:
|
| 505 |
+
logger.error(f"Can't set {key} with value {value} for {self}")
|
| 506 |
+
raise err
|
| 507 |
+
|
| 508 |
+
# Validate the values of the attributes
|
| 509 |
+
self.validate(is_init=True)
|
| 510 |
+
|
| 511 |
+
def __hash__(self):
|
| 512 |
+
return hash(self.to_json_string(ignore_metadata=True))
|
| 513 |
+
|
| 514 |
+
def __eq__(self, other):
|
| 515 |
+
if not isinstance(other, GenerationConfig):
|
| 516 |
+
return False
|
| 517 |
+
|
| 518 |
+
self_without_metadata = self.to_json_string(use_diff=False, ignore_metadata=True)
|
| 519 |
+
other_without_metadata = other.to_json_string(use_diff=False, ignore_metadata=True)
|
| 520 |
+
return self_without_metadata == other_without_metadata
|
| 521 |
+
|
| 522 |
+
def __repr__(self):
|
| 523 |
+
return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}"
|
| 524 |
+
|
| 525 |
+
def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = None) -> GenerationMode:
|
| 526 |
+
"""
|
| 527 |
+
Returns the generation mode triggered by the [`GenerationConfig`] instance.
|
| 528 |
+
|
| 529 |
+
Arg:
|
| 530 |
+
assistant_model (`PreTrainedModel`, *optional*):
|
| 531 |
+
The assistant model to be used for assisted generation. If set, the generation mode will be
|
| 532 |
+
assisted generation.
|
| 533 |
+
|
| 534 |
+
Returns:
|
| 535 |
+
`GenerationMode`: The generation mode triggered by the instance.
|
| 536 |
+
"""
|
| 537 |
+
# TODO joao: find out a way of not depending on external fields (e.g. `assistant_model`), then make this a
|
| 538 |
+
# property and part of the `__repr__`
|
| 539 |
+
if self.constraints is not None or self.force_words_ids is not None:
|
| 540 |
+
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
|
| 541 |
+
elif self.num_beams == 1:
|
| 542 |
+
if self.do_sample is False:
|
| 543 |
+
if (
|
| 544 |
+
self.top_k is not None
|
| 545 |
+
and self.top_k > 1
|
| 546 |
+
and self.penalty_alpha is not None
|
| 547 |
+
and self.penalty_alpha > 0
|
| 548 |
+
):
|
| 549 |
+
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
|
| 550 |
+
else:
|
| 551 |
+
generation_mode = GenerationMode.GREEDY_SEARCH
|
| 552 |
+
else:
|
| 553 |
+
generation_mode = GenerationMode.SAMPLE
|
| 554 |
+
else:
|
| 555 |
+
if self.num_beam_groups > 1:
|
| 556 |
+
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
|
| 557 |
+
elif self.do_sample is True:
|
| 558 |
+
generation_mode = GenerationMode.BEAM_SAMPLE
|
| 559 |
+
else:
|
| 560 |
+
generation_mode = GenerationMode.BEAM_SEARCH
|
| 561 |
+
|
| 562 |
+
# Assisted generation may extend some generation modes
|
| 563 |
+
if (
|
| 564 |
+
assistant_model is not None
|
| 565 |
+
or self.prompt_lookup_num_tokens is not None
|
| 566 |
+
or self.assistant_early_exit is not None
|
| 567 |
+
):
|
| 568 |
+
if generation_mode in ("greedy_search", "sample"):
|
| 569 |
+
generation_mode = GenerationMode.ASSISTED_GENERATION
|
| 570 |
+
else:
|
| 571 |
+
raise ValueError(
|
| 572 |
+
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
| 573 |
+
"is only supported with Greedy Search and Sample."
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# DoLa generation may extend some generation modes
|
| 577 |
+
if self.dola_layers is not None:
|
| 578 |
+
if generation_mode in ("greedy_search", "sample"):
|
| 579 |
+
generation_mode = GenerationMode.DOLA_GENERATION
|
| 580 |
+
else:
|
| 581 |
+
raise ValueError(
|
| 582 |
+
"You've set `dola_layers`, which triggers DoLa generate. Currently, DoLa generate "
|
| 583 |
+
"is only supported with Greedy Search and Sample."
|
| 584 |
+
)
|
| 585 |
+
return generation_mode
|
| 586 |
+
|
| 587 |
+
def validate(self, is_init=False):
|
| 588 |
+
"""
|
| 589 |
+
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
|
| 590 |
+
of parameterization that can be detected as incorrect from the configuration instance alone.
|
| 591 |
+
|
| 592 |
+
Note that some parameters not validated here are best validated at generate runtime, as they may depend on
|
| 593 |
+
other inputs and/or the model, such as parameters related to the generation length.
|
| 594 |
+
|
| 595 |
+
Arg:
|
| 596 |
+
is_init (`bool`, *optional*, defaults to `False`):
|
| 597 |
+
Whether the validation is performed during the initialization of the instance.
|
| 598 |
+
"""
|
| 599 |
+
|
| 600 |
+
# Validation of individual attributes
|
| 601 |
+
if self.early_stopping not in {True, False, "never"}:
|
| 602 |
+
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
| 603 |
+
if self.max_new_tokens is not None and self.max_new_tokens <= 0:
|
| 604 |
+
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
|
| 605 |
+
if self.pad_token_id is not None and self.pad_token_id < 0:
|
| 606 |
+
warnings.warn(
|
| 607 |
+
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
|
| 608 |
+
"generating, if there is padding. Please set `pad_token_id` explicitly as "
|
| 609 |
+
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
# Validation of attribute relations:
|
| 613 |
+
fix_location = ""
|
| 614 |
+
if is_init:
|
| 615 |
+
fix_location = (
|
| 616 |
+
" This was detected when initializing the generation config instance, which means the corresponding "
|
| 617 |
+
"file may hold incorrect parameterization and should be fixed."
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
# 1. detect sampling-only parameterization when not in sampling mode
|
| 621 |
+
if self.do_sample is False:
|
| 622 |
+
greedy_wrong_parameter_msg = (
|
| 623 |
+
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
|
| 624 |
+
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
|
| 625 |
+
+ fix_location
|
| 626 |
+
)
|
| 627 |
+
if self.temperature is not None and self.temperature != 1.0:
|
| 628 |
+
warnings.warn(
|
| 629 |
+
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
|
| 630 |
+
UserWarning,
|
| 631 |
+
)
|
| 632 |
+
if self.top_p is not None and self.top_p != 1.0:
|
| 633 |
+
warnings.warn(
|
| 634 |
+
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
|
| 635 |
+
UserWarning,
|
| 636 |
+
)
|
| 637 |
+
if self.min_p is not None:
|
| 638 |
+
warnings.warn(
|
| 639 |
+
greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p),
|
| 640 |
+
UserWarning,
|
| 641 |
+
)
|
| 642 |
+
if self.typical_p is not None and self.typical_p != 1.0:
|
| 643 |
+
warnings.warn(
|
| 644 |
+
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
|
| 645 |
+
UserWarning,
|
| 646 |
+
)
|
| 647 |
+
if (
|
| 648 |
+
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
|
| 649 |
+
): # contrastive search uses top_k
|
| 650 |
+
warnings.warn(
|
| 651 |
+
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
|
| 652 |
+
UserWarning,
|
| 653 |
+
)
|
| 654 |
+
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
|
| 655 |
+
warnings.warn(
|
| 656 |
+
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
|
| 657 |
+
UserWarning,
|
| 658 |
+
)
|
| 659 |
+
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
|
| 660 |
+
warnings.warn(
|
| 661 |
+
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
|
| 662 |
+
UserWarning,
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
# 2. detect beam-only parameterization when not in beam mode
|
| 666 |
+
if self.num_beams is None:
|
| 667 |
+
warnings.warn("`num_beams` is set to None - defaulting to 1.", UserWarning)
|
| 668 |
+
self.num_beams = 1
|
| 669 |
+
|
| 670 |
+
if self.num_beams == 1:
|
| 671 |
+
single_beam_wrong_parameter_msg = (
|
| 672 |
+
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
|
| 673 |
+
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location
|
| 674 |
+
)
|
| 675 |
+
if self.early_stopping is not False:
|
| 676 |
+
warnings.warn(
|
| 677 |
+
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
|
| 678 |
+
UserWarning,
|
| 679 |
+
)
|
| 680 |
+
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
| 681 |
+
warnings.warn(
|
| 682 |
+
single_beam_wrong_parameter_msg.format(
|
| 683 |
+
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
| 684 |
+
),
|
| 685 |
+
UserWarning,
|
| 686 |
+
)
|
| 687 |
+
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
|
| 688 |
+
warnings.warn(
|
| 689 |
+
single_beam_wrong_parameter_msg.format(
|
| 690 |
+
flag_name="diversity_penalty", flag_value=self.diversity_penalty
|
| 691 |
+
),
|
| 692 |
+
UserWarning,
|
| 693 |
+
)
|
| 694 |
+
if self.length_penalty is not None and self.length_penalty != 1.0:
|
| 695 |
+
warnings.warn(
|
| 696 |
+
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
|
| 697 |
+
UserWarning,
|
| 698 |
+
)
|
| 699 |
+
if self.constraints is not None:
|
| 700 |
+
warnings.warn(
|
| 701 |
+
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints),
|
| 702 |
+
UserWarning,
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
# 3. detect incorrect paramaterization specific to advanced beam modes
|
| 706 |
+
else:
|
| 707 |
+
# constrained beam search
|
| 708 |
+
if self.constraints is not None or self.force_words_ids is not None:
|
| 709 |
+
constrained_wrong_parameter_msg = (
|
| 710 |
+
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. However, "
|
| 711 |
+
"`{flag_name}` is set to `{flag_value}`, which is incompatible with this generation mode. Set "
|
| 712 |
+
"`constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + fix_location
|
| 713 |
+
)
|
| 714 |
+
if self.do_sample is True:
|
| 715 |
+
raise ValueError(
|
| 716 |
+
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
|
| 717 |
+
)
|
| 718 |
+
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
| 719 |
+
raise ValueError(
|
| 720 |
+
constrained_wrong_parameter_msg.format(
|
| 721 |
+
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
| 722 |
+
)
|
| 723 |
+
)
|
| 724 |
+
# group beam search
|
| 725 |
+
if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
|
| 726 |
+
group_error_prefix = (
|
| 727 |
+
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
|
| 728 |
+
"this generation mode, "
|
| 729 |
+
)
|
| 730 |
+
if self.do_sample is True:
|
| 731 |
+
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
|
| 732 |
+
if self.num_beams % self.num_beam_groups != 0:
|
| 733 |
+
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
|
| 734 |
+
if self.diversity_penalty == 0.0:
|
| 735 |
+
raise ValueError(
|
| 736 |
+
group_error_prefix
|
| 737 |
+
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
| 738 |
+
)
|
| 739 |
+
# DoLa generation
|
| 740 |
+
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
| 741 |
+
warnings.warn(
|
| 742 |
+
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
|
| 743 |
+
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
|
| 744 |
+
"DoLa decoding is `repetition_penalty>=1.2`.",
|
| 745 |
+
UserWarning,
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# 4. check `num_return_sequences`
|
| 749 |
+
if self.num_return_sequences != 1:
|
| 750 |
+
if self.num_beams == 1:
|
| 751 |
+
if self.do_sample is False:
|
| 752 |
+
raise ValueError(
|
| 753 |
+
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
|
| 754 |
+
f"(got {self.num_return_sequences})."
|
| 755 |
+
)
|
| 756 |
+
elif self.num_return_sequences > self.num_beams:
|
| 757 |
+
raise ValueError(
|
| 758 |
+
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
|
| 759 |
+
f"({self.num_beams})."
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
# 5. check cache-related arguments
|
| 763 |
+
if self.cache_implementation is not None and self.cache_implementation not in ALL_CACHE_IMPLEMENTATIONS:
|
| 764 |
+
raise ValueError(
|
| 765 |
+
f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: "
|
| 766 |
+
f"{ALL_CACHE_IMPLEMENTATIONS}"
|
| 767 |
+
)
|
| 768 |
+
if self.cache_config is not None:
|
| 769 |
+
cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
|
| 770 |
+
if cache_class is None:
|
| 771 |
+
raise ValueError(
|
| 772 |
+
"You provided a `cache_config` but the cache implementation you are using "
|
| 773 |
+
f"({self.cache_implementation}) does not require any config. Make sure to use the "
|
| 774 |
+
"correct cache implementation matching your cache config."
|
| 775 |
+
)
|
| 776 |
+
if not isinstance(self.cache_config, cache_class):
|
| 777 |
+
self.cache_config = cache_class.from_dict(self.cache_config)
|
| 778 |
+
self.cache_config.validate()
|
| 779 |
+
if self.use_cache is False:
|
| 780 |
+
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
|
| 781 |
+
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
|
| 782 |
+
# (otherwise a user might need to overwrite several parameters).
|
| 783 |
+
no_cache_warning = (
|
| 784 |
+
"You have set `use_cache` to `False`, but {cache_arg} is set to {cache_arg_value}. {cache_arg} will "
|
| 785 |
+
"have no effect."
|
| 786 |
+
)
|
| 787 |
+
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
|
| 788 |
+
if getattr(self, arg_name) is not None:
|
| 789 |
+
logger.warning_once(
|
| 790 |
+
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)),
|
| 791 |
+
UserWarning,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
# 6. check watermarking arguments
|
| 795 |
+
if self.watermarking_config is not None:
|
| 796 |
+
if not (
|
| 797 |
+
isinstance(self.watermarking_config, WatermarkingConfig)
|
| 798 |
+
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
|
| 799 |
+
):
|
| 800 |
+
warnings.warn(
|
| 801 |
+
"`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with "
|
| 802 |
+
"`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.",
|
| 803 |
+
FutureWarning,
|
| 804 |
+
)
|
| 805 |
+
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
| 806 |
+
self.watermarking_config.validate()
|
| 807 |
+
|
| 808 |
+
# 7. performances arguments
|
| 809 |
+
if not isinstance(self.compile_config, CompileConfig):
|
| 810 |
+
raise ValueError(
|
| 811 |
+
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an instance of `CompileConfig`."
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
# 8. other incorrect combinations
|
| 815 |
+
if self.return_dict_in_generate is not True:
|
| 816 |
+
for extra_output_flag in self.extra_output_flags:
|
| 817 |
+
if getattr(self, extra_output_flag) is True:
|
| 818 |
+
warnings.warn(
|
| 819 |
+
f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When "
|
| 820 |
+
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored.",
|
| 821 |
+
UserWarning,
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
# 8. check common issue: passing `generate` arguments inside the generation config
|
| 825 |
+
generate_arguments = (
|
| 826 |
+
"logits_processor",
|
| 827 |
+
"stopping_criteria",
|
| 828 |
+
"prefix_allowed_tokens_fn",
|
| 829 |
+
"synced_gpus",
|
| 830 |
+
"assistant_model",
|
| 831 |
+
"streamer",
|
| 832 |
+
"negative_prompt_ids",
|
| 833 |
+
"negative_prompt_attention_mask",
|
| 834 |
+
)
|
| 835 |
+
for arg in generate_arguments:
|
| 836 |
+
if hasattr(self, arg):
|
| 837 |
+
raise ValueError(
|
| 838 |
+
f"Argument `{arg}` is not a valid argument of `GenerationConfig`. It should be passed to "
|
| 839 |
+
"`generate()` (or a pipeline) directly."
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
def save_pretrained(
|
| 843 |
+
self,
|
| 844 |
+
save_directory: Union[str, os.PathLike],
|
| 845 |
+
config_file_name: Optional[Union[str, os.PathLike]] = None,
|
| 846 |
+
push_to_hub: bool = False,
|
| 847 |
+
**kwargs,
|
| 848 |
+
):
|
| 849 |
+
r"""
|
| 850 |
+
Save a generation configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
| 851 |
+
[`~GenerationConfig.from_pretrained`] class method.
|
| 852 |
+
|
| 853 |
+
Args:
|
| 854 |
+
save_directory (`str` or `os.PathLike`):
|
| 855 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
| 856 |
+
config_file_name (`str` or `os.PathLike`, *optional*, defaults to `"generation_config.json"`):
|
| 857 |
+
Name of the generation configuration JSON file to be saved in `save_directory`.
|
| 858 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 859 |
+
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
| 860 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
| 861 |
+
namespace).
|
| 862 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 863 |
+
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
| 864 |
+
"""
|
| 865 |
+
|
| 866 |
+
# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance.
|
| 867 |
+
# This strictness is enforced to prevent bad configurations from being saved and re-used.
|
| 868 |
+
try:
|
| 869 |
+
with warnings.catch_warnings(record=True) as caught_warnings:
|
| 870 |
+
self.validate()
|
| 871 |
+
if len(caught_warnings) > 0:
|
| 872 |
+
raise ValueError(str([w.message for w in caught_warnings]))
|
| 873 |
+
except ValueError as exc:
|
| 874 |
+
raise ValueError(
|
| 875 |
+
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
|
| 876 |
+
"Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc)
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 880 |
+
|
| 881 |
+
if use_auth_token is not None:
|
| 882 |
+
warnings.warn(
|
| 883 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. "
|
| 884 |
+
"Please use `token` instead.",
|
| 885 |
+
FutureWarning,
|
| 886 |
+
)
|
| 887 |
+
if kwargs.get("token", None) is not None:
|
| 888 |
+
raise ValueError(
|
| 889 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 890 |
+
)
|
| 891 |
+
kwargs["token"] = use_auth_token
|
| 892 |
+
|
| 893 |
+
config_file_name = config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME
|
| 894 |
+
|
| 895 |
+
if os.path.isfile(save_directory):
|
| 896 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 897 |
+
|
| 898 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 899 |
+
|
| 900 |
+
if push_to_hub:
|
| 901 |
+
commit_message = kwargs.pop("commit_message", None)
|
| 902 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
| 903 |
+
repo_id = self._create_repo(repo_id, **kwargs)
|
| 904 |
+
files_timestamps = self._get_files_timestamps(save_directory)
|
| 905 |
+
|
| 906 |
+
output_config_file = os.path.join(save_directory, config_file_name)
|
| 907 |
+
|
| 908 |
+
self.to_json_file(output_config_file, use_diff=True)
|
| 909 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
| 910 |
+
|
| 911 |
+
if push_to_hub:
|
| 912 |
+
self._upload_modified_files(
|
| 913 |
+
save_directory,
|
| 914 |
+
repo_id,
|
| 915 |
+
files_timestamps,
|
| 916 |
+
commit_message=commit_message,
|
| 917 |
+
token=kwargs.get("token"),
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
@classmethod
|
| 921 |
+
def from_pretrained(
|
| 922 |
+
cls,
|
| 923 |
+
pretrained_model_name: Union[str, os.PathLike],
|
| 924 |
+
config_file_name: Optional[Union[str, os.PathLike]] = None,
|
| 925 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 926 |
+
force_download: bool = False,
|
| 927 |
+
local_files_only: bool = False,
|
| 928 |
+
token: Optional[Union[str, bool]] = None,
|
| 929 |
+
revision: str = "main",
|
| 930 |
+
**kwargs,
|
| 931 |
+
) -> "GenerationConfig":
|
| 932 |
+
r"""
|
| 933 |
+
Instantiate a [`GenerationConfig`] from a generation configuration file.
|
| 934 |
+
|
| 935 |
+
Args:
|
| 936 |
+
pretrained_model_name (`str` or `os.PathLike`):
|
| 937 |
+
This can be either:
|
| 938 |
+
|
| 939 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 940 |
+
huggingface.co.
|
| 941 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 942 |
+
[`~GenerationConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 943 |
+
config_file_name (`str` or `os.PathLike`, *optional*, defaults to `"generation_config.json"`):
|
| 944 |
+
Name of the generation configuration JSON file to be loaded from `pretrained_model_name`.
|
| 945 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 946 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 947 |
+
standard cache should not be used.
|
| 948 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 949 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if
|
| 950 |
+
they exist.
|
| 951 |
+
resume_download:
|
| 952 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 953 |
+
Will be removed in v5 of Transformers.
|
| 954 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 955 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 956 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 957 |
+
token (`str` or `bool`, *optional*):
|
| 958 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
| 959 |
+
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
|
| 960 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 961 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 962 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 963 |
+
identifier allowed by git.
|
| 964 |
+
|
| 965 |
+
<Tip>
|
| 966 |
+
|
| 967 |
+
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
|
| 968 |
+
|
| 969 |
+
</Tip>
|
| 970 |
+
|
| 971 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 972 |
+
If `False`, then this function returns just the final configuration object.
|
| 973 |
+
|
| 974 |
+
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
|
| 975 |
+
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
|
| 976 |
+
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
|
| 977 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 978 |
+
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
| 979 |
+
specify the folder name here.
|
| 980 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 981 |
+
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
| 982 |
+
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
| 983 |
+
by the `return_unused_kwargs` keyword parameter.
|
| 984 |
+
|
| 985 |
+
Returns:
|
| 986 |
+
[`GenerationConfig`]: The configuration object instantiated from this pretrained model.
|
| 987 |
+
|
| 988 |
+
Examples:
|
| 989 |
+
|
| 990 |
+
```python
|
| 991 |
+
>>> from transformers import GenerationConfig
|
| 992 |
+
|
| 993 |
+
>>> # Download configuration from huggingface.co and cache.
|
| 994 |
+
>>> generation_config = GenerationConfig.from_pretrained("openai-community/gpt2")
|
| 995 |
+
|
| 996 |
+
>>> # E.g. config was saved using *save_pretrained('./test/saved_model/')*
|
| 997 |
+
>>> generation_config.save_pretrained("./test/saved_model/")
|
| 998 |
+
>>> generation_config = GenerationConfig.from_pretrained("./test/saved_model/")
|
| 999 |
+
|
| 1000 |
+
>>> # You can also specify configuration names to your generation configuration file
|
| 1001 |
+
>>> generation_config.save_pretrained("./test/saved_model/", config_file_name="my_configuration.json")
|
| 1002 |
+
>>> generation_config = GenerationConfig.from_pretrained("./test/saved_model/", "my_configuration.json")
|
| 1003 |
+
|
| 1004 |
+
>>> # If you'd like to try a minor variation to an existing configuration, you can also pass generation
|
| 1005 |
+
>>> # arguments to `.from_pretrained()`. Be mindful that typos and unused arguments will be ignored
|
| 1006 |
+
>>> generation_config, unused_kwargs = GenerationConfig.from_pretrained(
|
| 1007 |
+
... "openai-community/gpt2", top_k=1, foo=False, do_sample=True, return_unused_kwargs=True
|
| 1008 |
+
... )
|
| 1009 |
+
>>> generation_config.top_k
|
| 1010 |
+
1
|
| 1011 |
+
|
| 1012 |
+
>>> unused_kwargs
|
| 1013 |
+
{'foo': False}
|
| 1014 |
+
```"""
|
| 1015 |
+
config_file_name = config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME
|
| 1016 |
+
|
| 1017 |
+
resume_download = kwargs.pop("resume_download", None)
|
| 1018 |
+
proxies = kwargs.pop("proxies", None)
|
| 1019 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 1020 |
+
subfolder = kwargs.pop("subfolder", "")
|
| 1021 |
+
from_pipeline = kwargs.pop("_from_pipeline", None)
|
| 1022 |
+
from_auto_class = kwargs.pop("_from_auto", False)
|
| 1023 |
+
commit_hash = kwargs.pop("_commit_hash", None)
|
| 1024 |
+
|
| 1025 |
+
if use_auth_token is not None:
|
| 1026 |
+
warnings.warn(
|
| 1027 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 1028 |
+
FutureWarning,
|
| 1029 |
+
)
|
| 1030 |
+
if token is not None:
|
| 1031 |
+
raise ValueError(
|
| 1032 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 1033 |
+
)
|
| 1034 |
+
token = use_auth_token
|
| 1035 |
+
|
| 1036 |
+
user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
|
| 1037 |
+
if from_pipeline is not None:
|
| 1038 |
+
user_agent["using_pipeline"] = from_pipeline
|
| 1039 |
+
|
| 1040 |
+
config_path = os.path.join(pretrained_model_name, config_file_name)
|
| 1041 |
+
config_path = str(config_path)
|
| 1042 |
+
|
| 1043 |
+
is_local = os.path.exists(config_path)
|
| 1044 |
+
if os.path.isfile(os.path.join(subfolder, config_path)):
|
| 1045 |
+
# Special case when config_path is a local file
|
| 1046 |
+
resolved_config_file = config_path
|
| 1047 |
+
is_local = True
|
| 1048 |
+
elif is_remote_url(config_path):
|
| 1049 |
+
configuration_file = config_path
|
| 1050 |
+
resolved_config_file = download_url(config_path)
|
| 1051 |
+
else:
|
| 1052 |
+
configuration_file = config_file_name
|
| 1053 |
+
try:
|
| 1054 |
+
# Load from local folder or from cache or download from model Hub and cache
|
| 1055 |
+
resolved_config_file = cached_file(
|
| 1056 |
+
pretrained_model_name,
|
| 1057 |
+
configuration_file,
|
| 1058 |
+
cache_dir=cache_dir,
|
| 1059 |
+
force_download=force_download,
|
| 1060 |
+
proxies=proxies,
|
| 1061 |
+
resume_download=resume_download,
|
| 1062 |
+
local_files_only=local_files_only,
|
| 1063 |
+
token=token,
|
| 1064 |
+
user_agent=user_agent,
|
| 1065 |
+
revision=revision,
|
| 1066 |
+
subfolder=subfolder,
|
| 1067 |
+
_commit_hash=commit_hash,
|
| 1068 |
+
)
|
| 1069 |
+
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
| 1070 |
+
except EnvironmentError:
|
| 1071 |
+
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
| 1072 |
+
# the original exception.
|
| 1073 |
+
raise
|
| 1074 |
+
except Exception:
|
| 1075 |
+
# For any other exception, we throw a generic error.
|
| 1076 |
+
raise EnvironmentError(
|
| 1077 |
+
f"Can't load the configuration of '{pretrained_model_name}'. If you were trying to load it"
|
| 1078 |
+
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
|
| 1079 |
+
f" name. Otherwise, make sure '{pretrained_model_name}' is the correct path to a directory"
|
| 1080 |
+
f" containing a {configuration_file} file"
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
try:
|
| 1084 |
+
# Load config dict
|
| 1085 |
+
config_dict = cls._dict_from_json_file(resolved_config_file)
|
| 1086 |
+
config_dict["_commit_hash"] = commit_hash
|
| 1087 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
| 1088 |
+
raise EnvironmentError(
|
| 1089 |
+
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
if is_local:
|
| 1093 |
+
logger.info(f"loading configuration file {resolved_config_file}")
|
| 1094 |
+
else:
|
| 1095 |
+
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
| 1096 |
+
|
| 1097 |
+
if kwargs.get("return_unused_kwargs") is True:
|
| 1098 |
+
config, unused_kwargs = cls.from_dict(config_dict, **kwargs)
|
| 1099 |
+
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
| 1100 |
+
return config, unused_kwargs
|
| 1101 |
+
else:
|
| 1102 |
+
config = cls.from_dict(config_dict, **kwargs)
|
| 1103 |
+
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
| 1104 |
+
return config
|
| 1105 |
+
|
| 1106 |
+
@classmethod
|
| 1107 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
| 1108 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
| 1109 |
+
text = reader.read()
|
| 1110 |
+
return json.loads(text)
|
| 1111 |
+
|
| 1112 |
+
@classmethod
|
| 1113 |
+
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
|
| 1114 |
+
"""
|
| 1115 |
+
Instantiates a [`GenerationConfig`] from a Python dictionary of parameters.
|
| 1116 |
+
|
| 1117 |
+
Args:
|
| 1118 |
+
config_dict (`Dict[str, Any]`):
|
| 1119 |
+
Dictionary that will be used to instantiate the configuration object.
|
| 1120 |
+
kwargs (`Dict[str, Any]`):
|
| 1121 |
+
Additional parameters from which to initialize the configuration object.
|
| 1122 |
+
|
| 1123 |
+
Returns:
|
| 1124 |
+
[`GenerationConfig`]: The configuration object instantiated from those parameters.
|
| 1125 |
+
"""
|
| 1126 |
+
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
| 1127 |
+
# Those arguments may be passed along for our internal telemetry.
|
| 1128 |
+
# We remove them so they don't appear in `return_unused_kwargs`.
|
| 1129 |
+
kwargs.pop("_from_auto", None)
|
| 1130 |
+
kwargs.pop("_from_pipeline", None)
|
| 1131 |
+
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
|
| 1132 |
+
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
|
| 1133 |
+
kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
| 1134 |
+
|
| 1135 |
+
# The line below allows model-specific config to be loaded as well through kwargs, with safety checks.
|
| 1136 |
+
# See https://github.com/huggingface/transformers/pull/21269
|
| 1137 |
+
config = cls(**{**config_dict, **kwargs})
|
| 1138 |
+
unused_kwargs = config.update(**kwargs)
|
| 1139 |
+
|
| 1140 |
+
logger.info(f"Generate config {config}")
|
| 1141 |
+
if return_unused_kwargs:
|
| 1142 |
+
return config, unused_kwargs
|
| 1143 |
+
else:
|
| 1144 |
+
return config
|
| 1145 |
+
|
| 1146 |
+
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
|
| 1147 |
+
"""
|
| 1148 |
+
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
|
| 1149 |
+
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
|
| 1150 |
+
string, which can then be stored in the json format.
|
| 1151 |
+
"""
|
| 1152 |
+
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
|
| 1153 |
+
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
| 1154 |
+
for value in d.values():
|
| 1155 |
+
if isinstance(value, dict):
|
| 1156 |
+
self.dict_torch_dtype_to_str(value)
|
| 1157 |
+
|
| 1158 |
+
def to_diff_dict(self) -> Dict[str, Any]:
|
| 1159 |
+
"""
|
| 1160 |
+
Removes all attributes from config which correspond to the default config attributes for better readability and
|
| 1161 |
+
serializes to a Python dictionary.
|
| 1162 |
+
|
| 1163 |
+
Returns:
|
| 1164 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
|
| 1165 |
+
"""
|
| 1166 |
+
config_dict = self.to_dict()
|
| 1167 |
+
|
| 1168 |
+
# get the default config dict
|
| 1169 |
+
default_config_dict = GenerationConfig().to_dict()
|
| 1170 |
+
|
| 1171 |
+
serializable_config_dict = {}
|
| 1172 |
+
|
| 1173 |
+
# only serialize values that differ from the default config
|
| 1174 |
+
for key, value in config_dict.items():
|
| 1175 |
+
if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]:
|
| 1176 |
+
serializable_config_dict[key] = value
|
| 1177 |
+
|
| 1178 |
+
self.dict_torch_dtype_to_str(serializable_config_dict)
|
| 1179 |
+
return serializable_config_dict
|
| 1180 |
+
|
| 1181 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 1182 |
+
"""
|
| 1183 |
+
Serializes this instance to a Python dictionary.
|
| 1184 |
+
|
| 1185 |
+
Returns:
|
| 1186 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
| 1187 |
+
"""
|
| 1188 |
+
output = copy.deepcopy(self.__dict__)
|
| 1189 |
+
|
| 1190 |
+
# Fields to ignore at serialization time
|
| 1191 |
+
if "_commit_hash" in output:
|
| 1192 |
+
del output["_commit_hash"]
|
| 1193 |
+
if "_original_object_hash" in output:
|
| 1194 |
+
del output["_original_object_hash"]
|
| 1195 |
+
if "compile_config" in output:
|
| 1196 |
+
del output["compile_config"]
|
| 1197 |
+
|
| 1198 |
+
# Transformers version when serializing this file
|
| 1199 |
+
output["transformers_version"] = __version__
|
| 1200 |
+
|
| 1201 |
+
self.dict_torch_dtype_to_str(output)
|
| 1202 |
+
return output
|
| 1203 |
+
|
| 1204 |
+
def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str:
|
| 1205 |
+
"""
|
| 1206 |
+
Serializes this instance to a JSON string.
|
| 1207 |
+
|
| 1208 |
+
Args:
|
| 1209 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
| 1210 |
+
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
|
| 1211 |
+
is serialized to JSON string.
|
| 1212 |
+
ignore_metadata (`bool`, *optional*, defaults to `False`):
|
| 1213 |
+
Whether to ignore the metadata fields present in the instance
|
| 1214 |
+
|
| 1215 |
+
Returns:
|
| 1216 |
+
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
| 1217 |
+
"""
|
| 1218 |
+
if use_diff is True:
|
| 1219 |
+
config_dict = self.to_diff_dict()
|
| 1220 |
+
else:
|
| 1221 |
+
config_dict = self.to_dict()
|
| 1222 |
+
|
| 1223 |
+
if ignore_metadata:
|
| 1224 |
+
for metadata_field in METADATA_FIELDS:
|
| 1225 |
+
config_dict.pop(metadata_field, None)
|
| 1226 |
+
|
| 1227 |
+
def convert_keys_to_string(obj):
|
| 1228 |
+
if isinstance(obj, dict):
|
| 1229 |
+
return {str(key): convert_keys_to_string(value) for key, value in obj.items()}
|
| 1230 |
+
elif isinstance(obj, list):
|
| 1231 |
+
return [convert_keys_to_string(item) for item in obj]
|
| 1232 |
+
else:
|
| 1233 |
+
return obj
|
| 1234 |
+
|
| 1235 |
+
def convert_dataclass_to_dict(obj):
|
| 1236 |
+
if isinstance(obj, dict):
|
| 1237 |
+
return {key: convert_dataclass_to_dict(value) for key, value in obj.items()}
|
| 1238 |
+
elif is_dataclass(obj):
|
| 1239 |
+
return obj.to_dict()
|
| 1240 |
+
else:
|
| 1241 |
+
return obj
|
| 1242 |
+
|
| 1243 |
+
config_dict = convert_keys_to_string(config_dict)
|
| 1244 |
+
config_dict = convert_dataclass_to_dict(config_dict)
|
| 1245 |
+
|
| 1246 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
| 1247 |
+
|
| 1248 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
|
| 1249 |
+
"""
|
| 1250 |
+
Save this instance to a JSON file.
|
| 1251 |
+
|
| 1252 |
+
Args:
|
| 1253 |
+
json_file_path (`str` or `os.PathLike`):
|
| 1254 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
| 1255 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
| 1256 |
+
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
|
| 1257 |
+
is serialized to JSON file.
|
| 1258 |
+
"""
|
| 1259 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
| 1260 |
+
writer.write(self.to_json_string(use_diff=use_diff))
|
| 1261 |
+
|
| 1262 |
+
@classmethod
|
| 1263 |
+
def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
|
| 1264 |
+
"""
|
| 1265 |
+
Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
|
| 1266 |
+
[`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
|
| 1267 |
+
|
| 1268 |
+
Args:
|
| 1269 |
+
model_config (`PretrainedConfig`):
|
| 1270 |
+
The model config that will be used to instantiate the generation config.
|
| 1271 |
+
|
| 1272 |
+
Returns:
|
| 1273 |
+
[`GenerationConfig`]: The configuration object instantiated from those parameters.
|
| 1274 |
+
"""
|
| 1275 |
+
config_dict = model_config.to_dict()
|
| 1276 |
+
config_dict.pop("_from_model_config", None)
|
| 1277 |
+
|
| 1278 |
+
# Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
|
| 1279 |
+
config_dict = {key: value for key, value in config_dict.items() if value is not None}
|
| 1280 |
+
|
| 1281 |
+
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
|
| 1282 |
+
|
| 1283 |
+
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
| 1284 |
+
# generation config (which in turn is defined from the outer attributes of model config).
|
| 1285 |
+
decoder_config = model_config.get_text_config(decoder=True)
|
| 1286 |
+
if decoder_config is not model_config:
|
| 1287 |
+
default_generation_config = GenerationConfig()
|
| 1288 |
+
decoder_config_dict = decoder_config.to_dict()
|
| 1289 |
+
for attr in generation_config.to_dict().keys():
|
| 1290 |
+
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
|
| 1291 |
+
if attr in decoder_config_dict and is_unset:
|
| 1292 |
+
setattr(generation_config, attr, decoder_config_dict[attr])
|
| 1293 |
+
|
| 1294 |
+
# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
|
| 1295 |
+
if generation_config.return_dict_in_generate is False:
|
| 1296 |
+
if any(
|
| 1297 |
+
getattr(generation_config, extra_output_flag, False)
|
| 1298 |
+
for extra_output_flag in generation_config.extra_output_flags
|
| 1299 |
+
):
|
| 1300 |
+
generation_config.return_dict_in_generate = True
|
| 1301 |
+
|
| 1302 |
+
# Hash to detect whether the instance was modified
|
| 1303 |
+
generation_config._original_object_hash = hash(generation_config)
|
| 1304 |
+
return generation_config
|
| 1305 |
+
|
| 1306 |
+
def update(self, **kwargs):
|
| 1307 |
+
"""
|
| 1308 |
+
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
|
| 1309 |
+
returning all the unused kwargs.
|
| 1310 |
+
|
| 1311 |
+
Args:
|
| 1312 |
+
kwargs (`Dict[str, Any]`):
|
| 1313 |
+
Dictionary of attributes to tentatively update this class.
|
| 1314 |
+
|
| 1315 |
+
Returns:
|
| 1316 |
+
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
|
| 1317 |
+
"""
|
| 1318 |
+
to_remove = []
|
| 1319 |
+
for key, value in kwargs.items():
|
| 1320 |
+
if hasattr(self, key):
|
| 1321 |
+
setattr(self, key, value)
|
| 1322 |
+
to_remove.append(key)
|
| 1323 |
+
|
| 1324 |
+
# Confirm that the updated instance is still valid
|
| 1325 |
+
self.validate()
|
| 1326 |
+
|
| 1327 |
+
# Remove all the attributes that were updated, without modifying the input dict
|
| 1328 |
+
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
| 1329 |
+
return unused_kwargs
|
| 1330 |
+
|
| 1331 |
+
|
| 1332 |
+
@dataclass
|
| 1333 |
+
class BaseWatermarkingConfig(ABC):
|
| 1334 |
+
"""Generic watermarking config"""
|
| 1335 |
+
|
| 1336 |
+
@classmethod
|
| 1337 |
+
def from_dict(cls, config_dict, **kwargs):
|
| 1338 |
+
"""
|
| 1339 |
+
Constructs a BaseWatermarkingConfig instance from a dictionary of parameters.
|
| 1340 |
+
|
| 1341 |
+
Args:
|
| 1342 |
+
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
|
| 1343 |
+
**kwargs: Additional keyword arguments to override dictionary values.
|
| 1344 |
+
|
| 1345 |
+
Returns:
|
| 1346 |
+
BaseWatermarkingConfig: Instance of BaseWatermarkingConfig constructed from the dictionary.
|
| 1347 |
+
"""
|
| 1348 |
+
config = cls(**config_dict)
|
| 1349 |
+
to_remove = []
|
| 1350 |
+
for key, value in kwargs.items():
|
| 1351 |
+
if hasattr(config, key):
|
| 1352 |
+
setattr(config, key, value)
|
| 1353 |
+
to_remove.append(key)
|
| 1354 |
+
for key in to_remove:
|
| 1355 |
+
kwargs.pop(key, None)
|
| 1356 |
+
return config
|
| 1357 |
+
|
| 1358 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
| 1359 |
+
"""
|
| 1360 |
+
Save this instance to a JSON file.
|
| 1361 |
+
|
| 1362 |
+
Args:
|
| 1363 |
+
json_file_path (Union[str, os.PathLike]): Path to the JSON file in which this configuration instance's parameters will be saved.
|
| 1364 |
+
"""
|
| 1365 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
| 1366 |
+
config_dict = self.to_dict()
|
| 1367 |
+
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
| 1368 |
+
|
| 1369 |
+
writer.write(json_string)
|
| 1370 |
+
|
| 1371 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 1372 |
+
"""
|
| 1373 |
+
Serializes this instance to a Python dictionary.
|
| 1374 |
+
|
| 1375 |
+
Returns:
|
| 1376 |
+
Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
|
| 1377 |
+
"""
|
| 1378 |
+
output = copy.deepcopy(self.__dict__)
|
| 1379 |
+
return output
|
| 1380 |
+
|
| 1381 |
+
def __iter__(self):
|
| 1382 |
+
for attr, value in copy.deepcopy(self.__dict__).items():
|
| 1383 |
+
yield attr, value
|
| 1384 |
+
|
| 1385 |
+
def __repr__(self):
|
| 1386 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
| 1387 |
+
|
| 1388 |
+
def to_json_string(self):
|
| 1389 |
+
"""
|
| 1390 |
+
Serializes this instance to a JSON formatted string.
|
| 1391 |
+
|
| 1392 |
+
Returns:
|
| 1393 |
+
str: JSON formatted string representing the configuration instance.
|
| 1394 |
+
"""
|
| 1395 |
+
return json.dumps(self.__dict__, indent=2) + "\n"
|
| 1396 |
+
|
| 1397 |
+
def update(self, **kwargs):
|
| 1398 |
+
"""
|
| 1399 |
+
Update the configuration attributes with new values.
|
| 1400 |
+
|
| 1401 |
+
Args:
|
| 1402 |
+
**kwargs: Keyword arguments representing configuration attributes and their new values.
|
| 1403 |
+
"""
|
| 1404 |
+
for key, value in kwargs.items():
|
| 1405 |
+
if hasattr(self, key):
|
| 1406 |
+
setattr(self, key, value)
|
| 1407 |
+
|
| 1408 |
+
@abstractmethod
|
| 1409 |
+
def validate(self): ...
|
| 1410 |
+
|
| 1411 |
+
@abstractmethod
|
| 1412 |
+
def construct_processor(self, vocab_size): ...
|
| 1413 |
+
|
| 1414 |
+
|
| 1415 |
+
@dataclass
|
| 1416 |
+
class WatermarkingConfig(BaseWatermarkingConfig):
|
| 1417 |
+
"""
|
| 1418 |
+
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
|
| 1419 |
+
See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments.
|
| 1420 |
+
|
| 1421 |
+
Accepts the following keys:
|
| 1422 |
+
- greenlist_ratio (`float`):
|
| 1423 |
+
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
| 1424 |
+
- bias (`float`):
|
| 1425 |
+
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
|
| 1426 |
+
- hashing_key (`int`):
|
| 1427 |
+
Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
|
| 1428 |
+
- seeding_scheme (`str`):
|
| 1429 |
+
Algorithm to use for watermarking. Accepts values:
|
| 1430 |
+
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
|
| 1431 |
+
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
|
| 1432 |
+
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
| 1433 |
+
- context_width(`int`):
|
| 1434 |
+
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
| 1435 |
+
"""
|
| 1436 |
+
|
| 1437 |
+
def __init__(
|
| 1438 |
+
self,
|
| 1439 |
+
greenlist_ratio: Optional[float] = 0.25,
|
| 1440 |
+
bias: Optional[float] = 2.0,
|
| 1441 |
+
hashing_key: Optional[int] = 15485863,
|
| 1442 |
+
seeding_scheme: Optional[str] = "lefthash",
|
| 1443 |
+
context_width: Optional[int] = 1,
|
| 1444 |
+
):
|
| 1445 |
+
self.greenlist_ratio = greenlist_ratio
|
| 1446 |
+
self.bias = bias
|
| 1447 |
+
self.hashing_key = hashing_key
|
| 1448 |
+
self.seeding_scheme = seeding_scheme
|
| 1449 |
+
self.context_width = context_width
|
| 1450 |
+
|
| 1451 |
+
def validate(self):
|
| 1452 |
+
watermark_missing_arg_msg = (
|
| 1453 |
+
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
| 1454 |
+
"but found {found_value}"
|
| 1455 |
+
)
|
| 1456 |
+
if self.seeding_scheme not in ["selfhash", "lefthash"]:
|
| 1457 |
+
raise ValueError(
|
| 1458 |
+
watermark_missing_arg_msg.format(
|
| 1459 |
+
key="seeding_scheme",
|
| 1460 |
+
correct_value="[`selfhash`, `lefthash`]",
|
| 1461 |
+
found_value=self.seeding_scheme,
|
| 1462 |
+
),
|
| 1463 |
+
)
|
| 1464 |
+
if not 0.0 <= self.greenlist_ratio <= 1.0:
|
| 1465 |
+
raise ValueError(
|
| 1466 |
+
watermark_missing_arg_msg.format(
|
| 1467 |
+
key="greenlist_ratio",
|
| 1468 |
+
correct_value="in range between 0.0 and 1.0",
|
| 1469 |
+
found_value=self.seeding_scheme,
|
| 1470 |
+
),
|
| 1471 |
+
)
|
| 1472 |
+
if not self.context_width >= 1:
|
| 1473 |
+
raise ValueError(
|
| 1474 |
+
watermark_missing_arg_msg.format(
|
| 1475 |
+
key="context_width",
|
| 1476 |
+
correct_value="a positive integer",
|
| 1477 |
+
found_value=self.context_width,
|
| 1478 |
+
),
|
| 1479 |
+
)
|
| 1480 |
+
|
| 1481 |
+
def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor":
|
| 1482 |
+
return WatermarkLogitsProcessor(
|
| 1483 |
+
vocab_size=vocab_size,
|
| 1484 |
+
device=device,
|
| 1485 |
+
greenlist_ratio=self.greenlist_ratio,
|
| 1486 |
+
bias=self.bias,
|
| 1487 |
+
hashing_key=self.hashing_key,
|
| 1488 |
+
seeding_scheme=self.seeding_scheme,
|
| 1489 |
+
context_width=self.context_width,
|
| 1490 |
+
)
|
| 1491 |
+
|
| 1492 |
+
|
| 1493 |
+
@dataclass
|
| 1494 |
+
class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig):
|
| 1495 |
+
"""
|
| 1496 |
+
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
|
| 1497 |
+
See [this paper](https://www.nature.com/articles/s41586-024-08025-4) for more details on the arguments.
|
| 1498 |
+
|
| 1499 |
+
Args:
|
| 1500 |
+
ngram_len (`int`):
|
| 1501 |
+
Ngram length.
|
| 1502 |
+
keys (`List[int]`):
|
| 1503 |
+
A sequence of watermarking keys, one for each depth.
|
| 1504 |
+
context_history_size (`int`, *optional*, defaults to 1024):
|
| 1505 |
+
Size of the tensor to keep track of seen contexts.
|
| 1506 |
+
sampling_table_seed (`int`, *optional*, defaults to 0):
|
| 1507 |
+
Random seed to generate the sampling table.
|
| 1508 |
+
sampling_table_size (`int`, *optional*, defaults to 65536):
|
| 1509 |
+
Size of the sampling table.
|
| 1510 |
+
skip_first_ngram_calls (`bool`, *optional*, defaults to `False`):
|
| 1511 |
+
Whether to skip first ngram calls.
|
| 1512 |
+
debug_mode (`bool`, optional, *optional*, defaults to `False`):
|
| 1513 |
+
Logits are modified to uniform one got before watermarking modification is applied. This is to test the
|
| 1514 |
+
implementation.
|
| 1515 |
+
|
| 1516 |
+
Examples:
|
| 1517 |
+
```python
|
| 1518 |
+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig
|
| 1519 |
+
|
| 1520 |
+
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b', padding_side="left")
|
| 1521 |
+
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b')
|
| 1522 |
+
|
| 1523 |
+
>>> # SynthID Text configuration
|
| 1524 |
+
>>> watermarking_config = SynthIDTextWatermarkingConfig(
|
| 1525 |
+
... keys=[654, 400, 836, 123, 340, 443, 597, 160, 57],
|
| 1526 |
+
... ngram_len=5,
|
| 1527 |
+
... )
|
| 1528 |
+
|
| 1529 |
+
>>> # Generation with watermarking
|
| 1530 |
+
>>> tokenized_prompts = tokenizer(["Once upon a time, "], return_tensors="pt", padding=True)
|
| 1531 |
+
>>> output_sequences = model.generate(
|
| 1532 |
+
... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, max_new_tokens=10
|
| 1533 |
+
... )
|
| 1534 |
+
>>> watermarked_text = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
|
| 1535 |
+
```
|
| 1536 |
+
"""
|
| 1537 |
+
|
| 1538 |
+
def __init__(
|
| 1539 |
+
self,
|
| 1540 |
+
ngram_len: int,
|
| 1541 |
+
keys: List[int],
|
| 1542 |
+
context_history_size: int = 1024,
|
| 1543 |
+
sampling_table_seed: int = 0,
|
| 1544 |
+
sampling_table_size: int = 2**16,
|
| 1545 |
+
skip_first_ngram_calls: bool = False,
|
| 1546 |
+
debug_mode: bool = False,
|
| 1547 |
+
):
|
| 1548 |
+
self.ngram_len = ngram_len
|
| 1549 |
+
self.keys = keys
|
| 1550 |
+
self.sampling_table_size = sampling_table_size
|
| 1551 |
+
self.sampling_table_seed = sampling_table_seed
|
| 1552 |
+
self.context_history_size = context_history_size
|
| 1553 |
+
self.skip_first_ngram_calls = skip_first_ngram_calls
|
| 1554 |
+
self.debug_mode = debug_mode
|
| 1555 |
+
|
| 1556 |
+
def validate(self):
|
| 1557 |
+
watermark_missing_arg_msg = (
|
| 1558 |
+
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
| 1559 |
+
"but found {found_value}"
|
| 1560 |
+
)
|
| 1561 |
+
if self.sampling_table_size > 2**24:
|
| 1562 |
+
raise ValueError(
|
| 1563 |
+
watermark_missing_arg_msg.format(
|
| 1564 |
+
key="sampling_table_size",
|
| 1565 |
+
correct_value="< 2**24",
|
| 1566 |
+
found_value=self.sampling_table_size,
|
| 1567 |
+
),
|
| 1568 |
+
)
|
| 1569 |
+
|
| 1570 |
+
def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor":
|
| 1571 |
+
return SynthIDTextWatermarkLogitsProcessor(
|
| 1572 |
+
ngram_len=self.ngram_len,
|
| 1573 |
+
keys=self.keys,
|
| 1574 |
+
sampling_table_size=self.sampling_table_size,
|
| 1575 |
+
sampling_table_seed=self.sampling_table_seed,
|
| 1576 |
+
context_history_size=self.context_history_size,
|
| 1577 |
+
device=device,
|
| 1578 |
+
skip_first_ngram_calls=self.skip_first_ngram_calls,
|
| 1579 |
+
debug_mode=self.debug_mode,
|
| 1580 |
+
)
|
| 1581 |
+
|
| 1582 |
+
|
| 1583 |
+
@dataclass
|
| 1584 |
+
class CompileConfig(object):
|
| 1585 |
+
"""
|
| 1586 |
+
Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`.
|
| 1587 |
+
See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.
|
| 1588 |
+
|
| 1589 |
+
Args:
|
| 1590 |
+
fullgraph (`bool`, *optional*, defaults to `True`):
|
| 1591 |
+
If `True`, requires that the whole forward be capturable in a single graph.
|
| 1592 |
+
dynamic (`bool` or `None`, *optional*):
|
| 1593 |
+
Whether to try to use dynamic shape graphs.
|
| 1594 |
+
backend (`str` or `Callable`, *optional*, defaults to `"inductor"`):
|
| 1595 |
+
Backend to be used.
|
| 1596 |
+
mode (`str`, *optional*, defaults to `"reduce-overhead"`):
|
| 1597 |
+
Controls balance between performance and overhead.
|
| 1598 |
+
options (`dict`, *optional*):
|
| 1599 |
+
A dictionary of options to pass to the backend.
|
| 1600 |
+
|
| 1601 |
+
Examples:
|
| 1602 |
+
```python
|
| 1603 |
+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig
|
| 1604 |
+
|
| 1605 |
+
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b')
|
| 1606 |
+
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b').cuda()
|
| 1607 |
+
|
| 1608 |
+
>>> # Automatic compile configuration, used with static cache
|
| 1609 |
+
>>> compile_config = CompileConfig(dynamic=True)
|
| 1610 |
+
|
| 1611 |
+
>>> # Generation with static cache and compile config
|
| 1612 |
+
>>> input = tokenizer.encode("Hello there, how", return_tensors="pt").cuda()
|
| 1613 |
+
>>> output = model.generate(
|
| 1614 |
+
... input, do_sample=False, max_new_tokens=300, cache_implementation="static", compile_config=compile_config
|
| 1615 |
+
... )
|
| 1616 |
+
>>> output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
|
| 1617 |
+
```
|
| 1618 |
+
"""
|
| 1619 |
+
|
| 1620 |
+
fullgraph: bool = True
|
| 1621 |
+
dynamic: Optional[bool] = None
|
| 1622 |
+
backend: Union[str, Callable] = "inductor"
|
| 1623 |
+
mode: str = "reduce-overhead"
|
| 1624 |
+
options: Optional[dict] = None
|
| 1625 |
+
|
| 1626 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 1627 |
+
"""Serializes this instance to a Python dictionary."""
|
| 1628 |
+
return copy.deepcopy(self.__dict__)
|
.venv/lib/python3.11/site-packages/transformers/generation/flax_logits_process.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
|
| 18 |
+
import jax
|
| 19 |
+
import jax.lax as lax
|
| 20 |
+
import jax.numpy as jnp
|
| 21 |
+
from jax.experimental import sparse
|
| 22 |
+
|
| 23 |
+
from ..utils import add_start_docstrings
|
| 24 |
+
from ..utils.logging import get_logger
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
| 31 |
+
Args:
|
| 32 |
+
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
|
| 33 |
+
Indices of input sequence tokens in the vocabulary.
|
| 34 |
+
|
| 35 |
+
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 36 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 37 |
+
|
| 38 |
+
[What are input IDs?](../glossary#input-ids)
|
| 39 |
+
scores (`jnp.ndarray` of shape `(batch_size, config.vocab_size)`):
|
| 40 |
+
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
|
| 41 |
+
search or log softmax for each vocabulary token when using beam search
|
| 42 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 43 |
+
Additional logits processor specific kwargs.
|
| 44 |
+
|
| 45 |
+
Return:
|
| 46 |
+
`jnp.ndarray` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class FlaxLogitsProcessor:
|
| 52 |
+
"""Abstract base class for all logit processors that can be applied during generation."""
|
| 53 |
+
|
| 54 |
+
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
| 55 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
|
| 56 |
+
"""Flax method for processing logits."""
|
| 57 |
+
raise NotImplementedError(
|
| 58 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class FlaxLogitsWarper:
|
| 63 |
+
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
| 64 |
+
|
| 65 |
+
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
| 66 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
|
| 67 |
+
"""Flax method for warping logits."""
|
| 68 |
+
raise NotImplementedError(
|
| 69 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class FlaxLogitsProcessorList(list):
|
| 74 |
+
"""
|
| 75 |
+
This class can be used to create a list of [`FlaxLogitsProcessor`] or [`FlaxLogitsWarper`] to subsequently process
|
| 76 |
+
a `scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
|
| 77 |
+
[`FlaxLogitsProcessor`] or [`FlaxLogitsWarper`] to the inputs.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
| 81 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int, **kwargs) -> jnp.ndarray:
|
| 82 |
+
for processor in self:
|
| 83 |
+
function_args = inspect.signature(processor.__call__).parameters
|
| 84 |
+
if len(function_args) > 3:
|
| 85 |
+
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"Make sure that all the required parameters: {list(function_args.keys())} for "
|
| 88 |
+
f"{processor.__class__} are passed to the logits processor."
|
| 89 |
+
)
|
| 90 |
+
scores = processor(input_ids, scores, cur_len, **kwargs)
|
| 91 |
+
else:
|
| 92 |
+
scores = processor(input_ids, scores, cur_len)
|
| 93 |
+
return scores
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
| 97 |
+
r"""
|
| 98 |
+
[`FlaxLogitsWarper`] for temperature (exponential scaling output probability distribution).
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
temperature (`float`):
|
| 102 |
+
The value used to module the logits distribution.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, temperature: float):
|
| 106 |
+
if not isinstance(temperature, float) or not (temperature > 0):
|
| 107 |
+
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
|
| 108 |
+
|
| 109 |
+
self.temperature = temperature
|
| 110 |
+
|
| 111 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
| 112 |
+
scores = scores / self.temperature
|
| 113 |
+
return scores
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
| 117 |
+
"""
|
| 118 |
+
[`FlaxLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
top_p (`float`):
|
| 122 |
+
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
| 123 |
+
higher are kept for generation.
|
| 124 |
+
filter_value (`float`, *optional*, defaults to -inf):
|
| 125 |
+
All filtered values will be set to this float value.
|
| 126 |
+
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
| 127 |
+
Minimum number of tokens that cannot be filtered.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
| 131 |
+
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
|
| 132 |
+
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
| 133 |
+
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
|
| 134 |
+
raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
|
| 135 |
+
|
| 136 |
+
self.top_p = top_p
|
| 137 |
+
self.filter_value = filter_value
|
| 138 |
+
self.min_tokens_to_keep = min_tokens_to_keep
|
| 139 |
+
|
| 140 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
| 141 |
+
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
|
| 142 |
+
|
| 143 |
+
mask_scores = jnp.full_like(scores, self.filter_value)
|
| 144 |
+
cumulative_probs = jax.nn.softmax(topk_scores, axis=-1).cumsum(axis=-1)
|
| 145 |
+
score_mask = cumulative_probs < self.top_p
|
| 146 |
+
|
| 147 |
+
# include the token that is higher than top_p as well
|
| 148 |
+
score_mask = jnp.roll(score_mask, 1)
|
| 149 |
+
score_mask |= score_mask.at[:, 0].set(True)
|
| 150 |
+
|
| 151 |
+
# min tokens to keep
|
| 152 |
+
score_mask = score_mask.at[:, : self.min_tokens_to_keep].set(True)
|
| 153 |
+
|
| 154 |
+
topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
|
| 155 |
+
next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]
|
| 156 |
+
|
| 157 |
+
return next_scores
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
| 161 |
+
r"""
|
| 162 |
+
[`FlaxLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
top_k (`int`):
|
| 166 |
+
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
| 167 |
+
filter_value (`float`, *optional*, defaults to -inf):
|
| 168 |
+
All filtered values will be set to this float value.
|
| 169 |
+
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
| 170 |
+
Minimum number of tokens that cannot be filtered.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
| 174 |
+
if not isinstance(top_k, int) or top_k <= 0:
|
| 175 |
+
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
|
| 176 |
+
|
| 177 |
+
self.top_k = max(top_k, min_tokens_to_keep)
|
| 178 |
+
self.filter_value = filter_value
|
| 179 |
+
|
| 180 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
| 181 |
+
batch_size, vocab_size = scores.shape
|
| 182 |
+
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
|
| 183 |
+
|
| 184 |
+
topk = min(self.top_k, scores.shape[-1]) # Safety check
|
| 185 |
+
topk_scores, topk_indices = lax.top_k(scores, topk)
|
| 186 |
+
shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten()
|
| 187 |
+
topk_scores_flat = topk_scores.flatten()
|
| 188 |
+
topk_indices_flat = topk_indices.flatten() + shift
|
| 189 |
+
|
| 190 |
+
next_scores_flat = next_scores_flat.at[topk_indices_flat].set(topk_scores_flat)
|
| 191 |
+
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
|
| 192 |
+
return next_scores
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
|
| 196 |
+
r"""
|
| 197 |
+
[`FlaxLogitsProcessor`] that enforces the specified token as the first generated token.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
bos_token_id (`int`):
|
| 201 |
+
The id of the token to force as the first generated token.
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
def __init__(self, bos_token_id: int):
|
| 205 |
+
self.bos_token_id = bos_token_id
|
| 206 |
+
|
| 207 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
| 208 |
+
new_scores = jnp.full(scores.shape, -float("inf"))
|
| 209 |
+
|
| 210 |
+
apply_penalty = 1 - jnp.bool_(cur_len - 1)
|
| 211 |
+
|
| 212 |
+
scores = jnp.where(apply_penalty, new_scores.at[:, self.bos_token_id].set(0), scores)
|
| 213 |
+
|
| 214 |
+
return scores
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
|
| 218 |
+
r"""
|
| 219 |
+
[`FlaxLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
max_length (`int`):
|
| 223 |
+
The maximum length of the sequence to be generated.
|
| 224 |
+
eos_token_id (`int`):
|
| 225 |
+
The id of the token to force as the last generated token when `max_length` is reached.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def __init__(self, max_length: int, eos_token_id: int):
|
| 229 |
+
self.max_length = max_length
|
| 230 |
+
self.eos_token_id = eos_token_id
|
| 231 |
+
|
| 232 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
| 233 |
+
new_scores = jnp.full(scores.shape, -float("inf"))
|
| 234 |
+
|
| 235 |
+
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
|
| 236 |
+
|
| 237 |
+
scores = jnp.where(apply_penalty, new_scores.at[:, self.eos_token_id].set(0), scores)
|
| 238 |
+
|
| 239 |
+
return scores
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
|
| 243 |
+
r"""
|
| 244 |
+
[`FlaxLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
min_length (`int`):
|
| 248 |
+
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
| 249 |
+
eos_token_id (`int`):
|
| 250 |
+
The id of the *end-of-sequence* token.
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
def __init__(self, min_length: int, eos_token_id: int):
|
| 254 |
+
if not isinstance(min_length, int) or min_length < 0:
|
| 255 |
+
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
|
| 256 |
+
|
| 257 |
+
if not isinstance(eos_token_id, int) or eos_token_id < 0:
|
| 258 |
+
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
|
| 259 |
+
|
| 260 |
+
self.min_length = min_length
|
| 261 |
+
self.eos_token_id = eos_token_id
|
| 262 |
+
|
| 263 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
| 264 |
+
# create boolean flag to decide if min length penalty should be applied
|
| 265 |
+
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
|
| 266 |
+
|
| 267 |
+
scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores)
|
| 268 |
+
|
| 269 |
+
return scores
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor):
|
| 273 |
+
r"""
|
| 274 |
+
[`FlaxLogitsProcessor`] supressing a list of tokens as soon as the `generate` function starts generating using
|
| 275 |
+
`begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not sampled at the
|
| 276 |
+
beginning of the generation.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
begin_suppress_tokens (`List[int]`):
|
| 280 |
+
Tokens to not sample.
|
| 281 |
+
begin_index (`int`):
|
| 282 |
+
Index where the tokens are suppressed.
|
| 283 |
+
"""
|
| 284 |
+
|
| 285 |
+
def __init__(self, begin_suppress_tokens, begin_index):
|
| 286 |
+
self.begin_suppress_tokens = list(begin_suppress_tokens)
|
| 287 |
+
self.begin_index = begin_index
|
| 288 |
+
|
| 289 |
+
def __call__(self, input_ids, scores, cur_len: int):
|
| 290 |
+
apply_penalty = 1 - jnp.bool_(cur_len - self.begin_index)
|
| 291 |
+
|
| 292 |
+
scores = jnp.where(apply_penalty, scores.at[:, self.begin_suppress_tokens].set(-float("inf")), scores)
|
| 293 |
+
|
| 294 |
+
return scores
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor):
|
| 298 |
+
r"""
|
| 299 |
+
[`FlaxLogitsProcessor`] suppressing a list of tokens at each decoding step. The processor will set their log probs
|
| 300 |
+
to be `-inf` so they are not sampled.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
suppress_tokens (`list`):
|
| 304 |
+
Tokens to not sample.
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
def __init__(self, suppress_tokens: list):
|
| 308 |
+
self.suppress_tokens = list(suppress_tokens)
|
| 309 |
+
|
| 310 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
| 311 |
+
scores = scores.at[..., self.suppress_tokens].set(-float("inf"))
|
| 312 |
+
|
| 313 |
+
return scores
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor):
|
| 317 |
+
r"""
|
| 318 |
+
[`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
|
| 319 |
+
token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
|
| 320 |
+
to `-inf` so that they are sampled at their corresponding index.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
force_token_map (`list`):
|
| 324 |
+
Map giving token ids and indices where they will be forced to be sampled.
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
def __init__(self, force_token_map):
|
| 328 |
+
force_token_map = dict(force_token_map)
|
| 329 |
+
# Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
|
| 330 |
+
# index of the array corresponds to the index of the token to be forced, for XLA compatibility.
|
| 331 |
+
# Indexes without forced tokens will have a negative value.
|
| 332 |
+
force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1
|
| 333 |
+
for index, token in force_token_map.items():
|
| 334 |
+
if token is not None:
|
| 335 |
+
force_token_array = force_token_array.at[index].set(token)
|
| 336 |
+
self.force_token_array = jnp.int32(force_token_array)
|
| 337 |
+
|
| 338 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
| 339 |
+
def _force_token(generation_idx):
|
| 340 |
+
batch_size = scores.shape[0]
|
| 341 |
+
current_token = self.force_token_array[generation_idx]
|
| 342 |
+
|
| 343 |
+
new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
|
| 344 |
+
updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
|
| 345 |
+
new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
|
| 346 |
+
return new_scores
|
| 347 |
+
|
| 348 |
+
scores = lax.cond(
|
| 349 |
+
cur_len >= self.force_token_array.shape[0],
|
| 350 |
+
# If the current length is geq than the length of force_token_array, the processor does nothing.
|
| 351 |
+
lambda: scores,
|
| 352 |
+
# Otherwise, it may force a certain token.
|
| 353 |
+
lambda: lax.cond(
|
| 354 |
+
self.force_token_array[cur_len] >= 0,
|
| 355 |
+
# Only valid (positive) tokens are forced
|
| 356 |
+
lambda: _force_token(cur_len),
|
| 357 |
+
# Otherwise, the processor does nothing.
|
| 358 |
+
lambda: scores,
|
| 359 |
+
),
|
| 360 |
+
)
|
| 361 |
+
return scores
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor):
|
| 365 |
+
r"""
|
| 366 |
+
Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
|
| 367 |
+
probs to `inf` so that they are sampled at their corresponding index.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
generate_config (`GenerateConfig`):
|
| 371 |
+
The generate config used to generate the output. The following parameters are required:
|
| 372 |
+
eos_token_id (`int`, *optional*, defaults to 50257):
|
| 373 |
+
The id of the *end-of-sequence* token.
|
| 374 |
+
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
|
| 375 |
+
The id of the `"<|notimestamps|>"` token.
|
| 376 |
+
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
|
| 377 |
+
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
|
| 378 |
+
predicting timestamps that are too far in the future.
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
def __init__(self, generate_config, model_config, decoder_input_length):
|
| 382 |
+
self.eos_token_id = generate_config.eos_token_id
|
| 383 |
+
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
|
| 384 |
+
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
|
| 385 |
+
|
| 386 |
+
self.begin_index = decoder_input_length + 1
|
| 387 |
+
|
| 388 |
+
if generate_config.is_multilingual:
|
| 389 |
+
# room for language token and task token
|
| 390 |
+
self.begin_index += 2
|
| 391 |
+
if hasattr(generate_config, "max_initial_timestamp_index"):
|
| 392 |
+
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
|
| 393 |
+
else:
|
| 394 |
+
self.max_initial_timestamp_index = model_config.vocab_size
|
| 395 |
+
if self.max_initial_timestamp_index is None:
|
| 396 |
+
self.max_initial_timestamp_index = model_config.vocab_size
|
| 397 |
+
|
| 398 |
+
def __call__(self, input_ids, scores, cur_len):
|
| 399 |
+
# suppress <|notimestamps|> which is handled by without_timestamps
|
| 400 |
+
scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf"))
|
| 401 |
+
|
| 402 |
+
def handle_pairs(input_ids_k, scores_k):
|
| 403 |
+
last_was_timestamp = jnp.where((cur_len - self.begin_index) >= 1, True, False)
|
| 404 |
+
last_was_timestamp = jnp.where(
|
| 405 |
+
input_ids_k[cur_len - 1] >= self.timestamp_begin,
|
| 406 |
+
True and last_was_timestamp,
|
| 407 |
+
False,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
penultimate_was_timestamp = jnp.where((cur_len - self.begin_index) < 2, True, False)
|
| 411 |
+
penultimate_was_timestamp = jnp.where(
|
| 412 |
+
input_ids_k[cur_len - 2] >= self.timestamp_begin,
|
| 413 |
+
True,
|
| 414 |
+
penultimate_was_timestamp,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
return jnp.where(
|
| 418 |
+
last_was_timestamp,
|
| 419 |
+
jnp.where(
|
| 420 |
+
penultimate_was_timestamp > 0,
|
| 421 |
+
scores_k.at[self.timestamp_begin :].set(-float("inf")),
|
| 422 |
+
scores_k.at[: self.eos_token_id].set(-float("inf")),
|
| 423 |
+
),
|
| 424 |
+
scores_k,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
scores = jax.vmap(handle_pairs)(input_ids, scores)
|
| 428 |
+
|
| 429 |
+
apply_max_initial_timestamp = jnp.where(cur_len == self.begin_index, True, False)
|
| 430 |
+
apply_max_initial_timestamp = jnp.where(
|
| 431 |
+
self.max_initial_timestamp_index is not None,
|
| 432 |
+
True and apply_max_initial_timestamp,
|
| 433 |
+
False,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
|
| 437 |
+
|
| 438 |
+
scores = jnp.where(
|
| 439 |
+
apply_max_initial_timestamp,
|
| 440 |
+
scores.at[:, last_allowed + 1 :].set(-float("inf")),
|
| 441 |
+
scores,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# if sum of probability over timestamps is above any other token, sample timestamp
|
| 445 |
+
logprobs = jax.nn.log_softmax(scores, axis=-1)
|
| 446 |
+
|
| 447 |
+
def handle_cumulative_probs(logprobs_k, scores_k):
|
| 448 |
+
timestamp_logprob = jax.nn.logsumexp(logprobs_k[self.timestamp_begin :], axis=-1)
|
| 449 |
+
max_text_token_logprob = jnp.max(logprobs_k[: self.timestamp_begin])
|
| 450 |
+
return jnp.where(
|
| 451 |
+
timestamp_logprob > max_text_token_logprob,
|
| 452 |
+
scores_k.at[: self.timestamp_begin].set(-float("inf")),
|
| 453 |
+
scores_k,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)
|
| 457 |
+
|
| 458 |
+
return scores
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class FlaxNoRepeatNGramLogitsProcessor(FlaxLogitsProcessor):
|
| 462 |
+
r"""
|
| 463 |
+
[`FlaxLogitsProcessor`] that enforces no repetition of n-grams. See
|
| 464 |
+
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
ngram_size (`int`):
|
| 468 |
+
All ngrams of size `ngram_size` can only occur once.
|
| 469 |
+
"""
|
| 470 |
+
|
| 471 |
+
def __init__(self, ngram_size: int):
|
| 472 |
+
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
| 473 |
+
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
|
| 474 |
+
self.ngram_size = ngram_size
|
| 475 |
+
|
| 476 |
+
def get_previous_ngrams(self, input_ids: jnp.ndarray, vocab_size: int, cur_len: int):
|
| 477 |
+
"""
|
| 478 |
+
get a matrix of size (batch_size,) + (vocab_size,)*n (for n-grams) that
|
| 479 |
+
represent the n-grams that occurred previously.
|
| 480 |
+
The BCOO representation allow to store only the few non-zero entries, instead of the full (huge) matrix
|
| 481 |
+
"""
|
| 482 |
+
batch_size, seq_len = input_ids.shape
|
| 483 |
+
# number of n-grams in the whole sequence
|
| 484 |
+
seq_ngrams = seq_len - (self.ngram_size - 1)
|
| 485 |
+
# number of n-grams in the currently generated sequence
|
| 486 |
+
cur_ngrams = cur_len - (self.ngram_size - 1)
|
| 487 |
+
|
| 488 |
+
def body_fun(i, val):
|
| 489 |
+
b = i % batch_size
|
| 490 |
+
pos = i // batch_size
|
| 491 |
+
return val.at[i].set(
|
| 492 |
+
jnp.array(
|
| 493 |
+
[
|
| 494 |
+
b,
|
| 495 |
+
]
|
| 496 |
+
+ [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)]
|
| 497 |
+
)
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
shape = (batch_size * seq_ngrams, self.ngram_size + 1)
|
| 501 |
+
all_update_indices = jax.lax.fori_loop(
|
| 502 |
+
0, batch_size * cur_ngrams, body_fun, jnp.zeros(shape, dtype=input_ids.dtype)
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
# ignore the n-grams not yet generated
|
| 506 |
+
data = (jnp.arange(batch_size * seq_ngrams) < batch_size * cur_ngrams).astype("float32")
|
| 507 |
+
|
| 508 |
+
return sparse.BCOO((data, all_update_indices), shape=(batch_size,) + (vocab_size,) * self.ngram_size)
|
| 509 |
+
|
| 510 |
+
def get_banned_tokens_mask(self, latest_tokens: jnp.ndarray, previous_ngrams) -> jnp.ndarray:
|
| 511 |
+
"""
|
| 512 |
+
Determines which tokens must be banned given latest tokens and the previously seen
|
| 513 |
+
ngrams.
|
| 514 |
+
"""
|
| 515 |
+
|
| 516 |
+
@sparse.sparsify
|
| 517 |
+
@jax.vmap
|
| 518 |
+
def inner_fn(latest_tokens, previous_ngrams):
|
| 519 |
+
return previous_ngrams[tuple(latest_tokens)]
|
| 520 |
+
|
| 521 |
+
return sparse.bcoo_todense(inner_fn(latest_tokens, previous_ngrams))
|
| 522 |
+
|
| 523 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
| 524 |
+
def true_fn():
|
| 525 |
+
_, vocab_size = scores.shape
|
| 526 |
+
# store the previously seen n-grams
|
| 527 |
+
previous_ngrams = self.get_previous_ngrams(input_ids, vocab_size, cur_len)
|
| 528 |
+
|
| 529 |
+
# get the n-1 last tokens that prefix the n-gram being generated
|
| 530 |
+
latest_tokens = jnp.zeros((input_ids.shape[0], self.ngram_size - 1), dtype=input_ids.dtype)
|
| 531 |
+
latest_tokens = jax.lax.dynamic_update_slice(
|
| 532 |
+
latest_tokens,
|
| 533 |
+
jax.lax.dynamic_slice(
|
| 534 |
+
input_ids, (0, cur_len - (self.ngram_size - 1)), (input_ids.shape[0], (self.ngram_size - 1))
|
| 535 |
+
),
|
| 536 |
+
(0, 0),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# compute the banned tokens, ie all the tokens that when added to the latest tokens lead to a n-gram that was previously generated
|
| 540 |
+
banned_tokens_indices_mask = self.get_banned_tokens_mask(latest_tokens, previous_ngrams).astype("bool")
|
| 541 |
+
return jnp.where(banned_tokens_indices_mask, -float("inf"), scores)
|
| 542 |
+
|
| 543 |
+
output = jax.lax.cond((cur_len >= self.ngram_size - 1), true_fn, lambda: scores)
|
| 544 |
+
return output
|
.venv/lib/python3.11/site-packages/transformers/generation/flax_utils.py
ADDED
|
@@ -0,0 +1,1027 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The Google AI Flax Team Authors, and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import copy
|
| 19 |
+
import inspect
|
| 20 |
+
import warnings
|
| 21 |
+
from functools import partial
|
| 22 |
+
from typing import Any, Dict, Optional, Union
|
| 23 |
+
|
| 24 |
+
import flax
|
| 25 |
+
import jax
|
| 26 |
+
import jax.numpy as jnp
|
| 27 |
+
import numpy as np
|
| 28 |
+
from jax import lax
|
| 29 |
+
|
| 30 |
+
from ..models.auto import (
|
| 31 |
+
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
| 32 |
+
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
| 33 |
+
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
| 34 |
+
)
|
| 35 |
+
from ..utils import ModelOutput, logging
|
| 36 |
+
from .configuration_utils import GenerationConfig
|
| 37 |
+
from .flax_logits_process import (
|
| 38 |
+
FlaxForcedBOSTokenLogitsProcessor,
|
| 39 |
+
FlaxForcedEOSTokenLogitsProcessor,
|
| 40 |
+
FlaxForceTokensLogitsProcessor,
|
| 41 |
+
FlaxLogitsProcessorList,
|
| 42 |
+
FlaxMinLengthLogitsProcessor,
|
| 43 |
+
FlaxNoRepeatNGramLogitsProcessor,
|
| 44 |
+
FlaxSuppressTokensAtBeginLogitsProcessor,
|
| 45 |
+
FlaxSuppressTokensLogitsProcessor,
|
| 46 |
+
FlaxTemperatureLogitsWarper,
|
| 47 |
+
FlaxTopKLogitsWarper,
|
| 48 |
+
FlaxTopPLogitsWarper,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
logger = logging.get_logger(__name__)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@flax.struct.dataclass
|
| 56 |
+
class FlaxGreedySearchOutput(ModelOutput):
|
| 57 |
+
"""
|
| 58 |
+
Flax Base class for outputs of decoder-only generation models using greedy search.
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
| 63 |
+
The generated sequences.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
sequences: jnp.ndarray = None
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@flax.struct.dataclass
|
| 70 |
+
class FlaxSampleOutput(ModelOutput):
|
| 71 |
+
"""
|
| 72 |
+
Flax Base class for outputs of decoder-only generation models using sampling.
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
| 77 |
+
The generated sequences.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
sequences: jnp.ndarray = None
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@flax.struct.dataclass
|
| 84 |
+
class FlaxBeamSearchOutput(ModelOutput):
|
| 85 |
+
"""
|
| 86 |
+
Flax Base class for outputs of decoder-only generation models using greedy search.
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
| 91 |
+
The generated sequences.
|
| 92 |
+
scores (`jnp.ndarray` of shape `(batch_size,)`):
|
| 93 |
+
The scores (log probabilities) of the generated sequences.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
sequences: jnp.ndarray = None
|
| 97 |
+
scores: jnp.ndarray = None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@flax.struct.dataclass
|
| 101 |
+
class GreedyState:
|
| 102 |
+
cur_len: jnp.ndarray
|
| 103 |
+
sequences: jnp.ndarray
|
| 104 |
+
running_token: jnp.ndarray
|
| 105 |
+
is_sent_finished: jnp.ndarray
|
| 106 |
+
model_kwargs: Dict[str, jnp.ndarray]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@flax.struct.dataclass
|
| 110 |
+
class SampleState:
|
| 111 |
+
cur_len: jnp.ndarray
|
| 112 |
+
sequences: jnp.ndarray
|
| 113 |
+
running_token: jnp.ndarray
|
| 114 |
+
is_sent_finished: jnp.ndarray
|
| 115 |
+
prng_key: jnp.ndarray
|
| 116 |
+
model_kwargs: Dict[str, jnp.ndarray]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@flax.struct.dataclass
|
| 120 |
+
class BeamSearchState:
|
| 121 |
+
cur_len: jnp.ndarray
|
| 122 |
+
running_sequences: jnp.ndarray
|
| 123 |
+
running_scores: jnp.ndarray
|
| 124 |
+
sequences: jnp.ndarray
|
| 125 |
+
scores: jnp.ndarray
|
| 126 |
+
is_sent_finished: jnp.ndarray
|
| 127 |
+
model_kwargs: Dict[str, jnp.ndarray]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class FlaxGenerationMixin:
|
| 131 |
+
"""
|
| 132 |
+
A class containing all functions for auto-regressive text generation, to be used as a mixin in
|
| 133 |
+
[`FlaxPreTrainedModel`].
|
| 134 |
+
|
| 135 |
+
The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for:
|
| 136 |
+
- *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and
|
| 137 |
+
`do_sample=False`
|
| 138 |
+
- *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and
|
| 139 |
+
`do_sample=True`
|
| 140 |
+
- *beam-search decoding* by calling [`~generation.FlaxGenerationMixin._beam_search`] if `num_beams>1` and
|
| 141 |
+
`do_sample=False`
|
| 142 |
+
|
| 143 |
+
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
|
| 144 |
+
learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def prepare_inputs_for_generation(self, *args, **kwargs):
|
| 148 |
+
raise NotImplementedError(
|
| 149 |
+
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def _run_loop_in_debug(cond_fn, body_fn, init_state):
|
| 154 |
+
"""
|
| 155 |
+
Run generation in untraced mode. This should only be used for debugging purposes.
|
| 156 |
+
"""
|
| 157 |
+
state = init_state
|
| 158 |
+
while cond_fn(state):
|
| 159 |
+
state = body_fn(state)
|
| 160 |
+
return state
|
| 161 |
+
|
| 162 |
+
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs):
|
| 163 |
+
encoder_kwargs = {
|
| 164 |
+
argument: value
|
| 165 |
+
for argument, value in model_kwargs.items()
|
| 166 |
+
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
|
| 167 |
+
}
|
| 168 |
+
model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
|
| 169 |
+
return model_kwargs
|
| 170 |
+
|
| 171 |
+
def _prepare_decoder_input_ids_for_generation(
|
| 172 |
+
self,
|
| 173 |
+
batch_size: int,
|
| 174 |
+
decoder_start_token_id: int = None,
|
| 175 |
+
bos_token_id: int = None,
|
| 176 |
+
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
| 177 |
+
) -> jnp.ndarray:
|
| 178 |
+
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
| 179 |
+
# Only use this arg if not None, otherwise just remove from model_kwargs
|
| 180 |
+
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
|
| 181 |
+
if decoder_input_ids is not None:
|
| 182 |
+
return decoder_input_ids
|
| 183 |
+
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
| 184 |
+
return jnp.array(decoder_start_token_id, dtype="i4").reshape(1, -1).repeat(batch_size, axis=0)
|
| 185 |
+
|
| 186 |
+
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
|
| 187 |
+
# retrieve decoder_start_token_id for encoder-decoder models
|
| 188 |
+
# fall back to bos_token_id if necessary
|
| 189 |
+
decoder_start_token_id = (
|
| 190 |
+
decoder_start_token_id
|
| 191 |
+
if decoder_start_token_id is not None
|
| 192 |
+
else self.generation_config.decoder_start_token_id
|
| 193 |
+
)
|
| 194 |
+
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
|
| 195 |
+
if decoder_start_token_id is not None:
|
| 196 |
+
return decoder_start_token_id
|
| 197 |
+
elif (
|
| 198 |
+
hasattr(self.config, "decoder")
|
| 199 |
+
and hasattr(self.config.decoder, "decoder_start_token_id")
|
| 200 |
+
and self.config.decoder.decoder_start_token_id is not None
|
| 201 |
+
):
|
| 202 |
+
return self.config.decoder.decoder_start_token_id
|
| 203 |
+
elif bos_token_id is not None:
|
| 204 |
+
return bos_token_id
|
| 205 |
+
elif (
|
| 206 |
+
hasattr(self.config, "decoder")
|
| 207 |
+
and hasattr(self.config.decoder, "bos_token_id")
|
| 208 |
+
and self.config.decoder.bos_token_id is not None
|
| 209 |
+
):
|
| 210 |
+
return self.config.decoder.bos_token_id
|
| 211 |
+
raise ValueError(
|
| 212 |
+
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
@staticmethod
|
| 216 |
+
def _expand_to_num_beams(tensor, num_beams):
|
| 217 |
+
return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
|
| 218 |
+
|
| 219 |
+
def _adapt_logits_for_beam_search(self, logits):
|
| 220 |
+
"""
|
| 221 |
+
This function can be overwritten in the specific modeling_flax_<model-name>.py classes to allow for custom beam
|
| 222 |
+
search behavior. Note that the only model that overwrites this method is [`~transformes.FlaxMarianMTModel`].
|
| 223 |
+
"""
|
| 224 |
+
return logits
|
| 225 |
+
|
| 226 |
+
def _validate_model_class(self):
|
| 227 |
+
"""
|
| 228 |
+
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
|
| 229 |
+
right class to use.
|
| 230 |
+
"""
|
| 231 |
+
if not self.can_generate():
|
| 232 |
+
generate_compatible_mappings = [
|
| 233 |
+
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
| 234 |
+
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
| 235 |
+
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
| 236 |
+
]
|
| 237 |
+
generate_compatible_classes = set()
|
| 238 |
+
for model_mapping in generate_compatible_mappings:
|
| 239 |
+
supported_models = model_mapping.get(type(self.config), default=None)
|
| 240 |
+
if supported_models is not None:
|
| 241 |
+
generate_compatible_classes.add(supported_models.__name__)
|
| 242 |
+
exception_message = (
|
| 243 |
+
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
|
| 244 |
+
"it doesn't have a language model head."
|
| 245 |
+
)
|
| 246 |
+
if generate_compatible_classes:
|
| 247 |
+
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
|
| 248 |
+
raise TypeError(exception_message)
|
| 249 |
+
|
| 250 |
+
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
| 251 |
+
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
| 252 |
+
unused_model_args = []
|
| 253 |
+
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
| 254 |
+
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
|
| 255 |
+
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
|
| 256 |
+
if "kwargs" in model_args or "model_kwargs" in model_args:
|
| 257 |
+
model_args |= set(inspect.signature(self.__call__).parameters)
|
| 258 |
+
for key, value in model_kwargs.items():
|
| 259 |
+
if value is not None and key not in model_args:
|
| 260 |
+
unused_model_args.append(key)
|
| 261 |
+
|
| 262 |
+
if unused_model_args:
|
| 263 |
+
raise ValueError(
|
| 264 |
+
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
|
| 265 |
+
" generate arguments will also show up in this list)"
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
def generate(
|
| 269 |
+
self,
|
| 270 |
+
input_ids: jnp.ndarray,
|
| 271 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 272 |
+
prng_key: Optional[jnp.ndarray] = None,
|
| 273 |
+
trace: bool = True,
|
| 274 |
+
params: Optional[Dict[str, jnp.ndarray]] = None,
|
| 275 |
+
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
| 276 |
+
**kwargs,
|
| 277 |
+
):
|
| 278 |
+
r"""
|
| 279 |
+
Generates sequences of token ids for models with a language modeling head.
|
| 280 |
+
|
| 281 |
+
Parameters:
|
| 282 |
+
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
|
| 283 |
+
The sequence used as a prompt for the generation.
|
| 284 |
+
generation_config (`~generation.GenerationConfig`, *optional*):
|
| 285 |
+
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
| 286 |
+
passed to generate matching the attributes of `generation_config` will override them. If
|
| 287 |
+
`generation_config` is not provided, the default will be used, which had the following loading
|
| 288 |
+
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
| 289 |
+
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
| 290 |
+
default values, whose documentation should be checked to parameterize generation.
|
| 291 |
+
trace (`bool`, *optional*, defaults to `True`):
|
| 292 |
+
Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a
|
| 293 |
+
considerably slower runtime.
|
| 294 |
+
params (`Dict[str, jnp.ndarray]`, *optional*):
|
| 295 |
+
Optionally the model parameters can be passed. Can be useful for parallelized generation.
|
| 296 |
+
logits_processor (`FlaxLogitsProcessorList `, *optional*):
|
| 297 |
+
Custom logits processors that complement the default logits processors built from arguments and
|
| 298 |
+
generation config. If a logit processor is passed that is already created with the arguments or a
|
| 299 |
+
generation config an error is thrown. This feature is intended for advanced users.
|
| 300 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 301 |
+
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
| 302 |
+
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
| 303 |
+
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
|
| 304 |
+
|
| 305 |
+
Return:
|
| 306 |
+
[`~utils.ModelOutput`].
|
| 307 |
+
|
| 308 |
+
"""
|
| 309 |
+
# Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
| 310 |
+
self._validate_model_class()
|
| 311 |
+
|
| 312 |
+
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
| 313 |
+
if generation_config is None:
|
| 314 |
+
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
|
| 315 |
+
# two conditions must be met
|
| 316 |
+
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
| 317 |
+
# 2) the generation config must have seen no modification since its creation (the hash is the same).
|
| 318 |
+
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
|
| 319 |
+
self.generation_config
|
| 320 |
+
):
|
| 321 |
+
new_generation_config = GenerationConfig.from_model_config(self.config)
|
| 322 |
+
if new_generation_config != self.generation_config:
|
| 323 |
+
warnings.warn(
|
| 324 |
+
"You have modified the pretrained model configuration to control generation. This is a"
|
| 325 |
+
" deprecated strategy to control generation and will be removed soon, in a future version."
|
| 326 |
+
" Please use and modify the model generation configuration (see"
|
| 327 |
+
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
|
| 328 |
+
)
|
| 329 |
+
self.generation_config = new_generation_config
|
| 330 |
+
generation_config = self.generation_config
|
| 331 |
+
|
| 332 |
+
generation_config = copy.deepcopy(generation_config)
|
| 333 |
+
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
| 334 |
+
self._validate_model_kwargs(model_kwargs.copy())
|
| 335 |
+
|
| 336 |
+
logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()
|
| 337 |
+
|
| 338 |
+
# set init values
|
| 339 |
+
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
| 340 |
+
|
| 341 |
+
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
| 342 |
+
if model_kwargs.get("attention_mask") is None:
|
| 343 |
+
logger.warning(
|
| 344 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
| 345 |
+
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
| 346 |
+
)
|
| 347 |
+
eos_token_id = generation_config.eos_token_id
|
| 348 |
+
if isinstance(eos_token_id, list):
|
| 349 |
+
eos_token_id = eos_token_id[0]
|
| 350 |
+
generation_config.pad_token_id = eos_token_id
|
| 351 |
+
|
| 352 |
+
if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:
|
| 353 |
+
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
|
| 354 |
+
|
| 355 |
+
# decoder-only models should use left-padding for generation (can't be checked with `trace=True`)
|
| 356 |
+
if not self.config.is_encoder_decoder and not trace:
|
| 357 |
+
if (
|
| 358 |
+
generation_config.pad_token_id is not None
|
| 359 |
+
and jnp.sum(input_ids[:, -1] == generation_config.pad_token_id) > 0
|
| 360 |
+
):
|
| 361 |
+
logger.warning(
|
| 362 |
+
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
| 363 |
+
"generation results, please set `padding_side='left'` when initializing the tokenizer."
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
batch_size = input_ids.shape[0]
|
| 367 |
+
|
| 368 |
+
if self.config.is_encoder_decoder:
|
| 369 |
+
# add encoder_outputs to model_kwargs
|
| 370 |
+
if model_kwargs.get("encoder_outputs") is None:
|
| 371 |
+
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
|
| 372 |
+
# prepare decoder_input_ids for generation
|
| 373 |
+
input_ids = self._prepare_decoder_input_ids_for_generation(
|
| 374 |
+
batch_size,
|
| 375 |
+
decoder_start_token_id=generation_config.decoder_start_token_id,
|
| 376 |
+
bos_token_id=generation_config.bos_token_id,
|
| 377 |
+
model_kwargs=model_kwargs,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Prepare `max_length` depending on other stopping criteria.
|
| 381 |
+
input_ids_seq_length = input_ids.shape[-1]
|
| 382 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
| 383 |
+
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
|
| 384 |
+
# 20 is the default max_length of the generation config
|
| 385 |
+
warnings.warn(
|
| 386 |
+
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
|
| 387 |
+
"to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.",
|
| 388 |
+
UserWarning,
|
| 389 |
+
)
|
| 390 |
+
elif generation_config.max_new_tokens is not None:
|
| 391 |
+
if not has_default_max_length and generation_config.max_length is not None:
|
| 392 |
+
logger.warning(
|
| 393 |
+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
| 394 |
+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
| 395 |
+
"Please refer to the documentation for more information. "
|
| 396 |
+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
| 397 |
+
)
|
| 398 |
+
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
| 399 |
+
else: # by default let's always generate 10 new tokens
|
| 400 |
+
if generation_config.max_length == GenerationConfig().max_length:
|
| 401 |
+
generation_config.max_length = generation_config.max_length + input_ids_seq_length
|
| 402 |
+
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
| 403 |
+
if max_position_embeddings is not None:
|
| 404 |
+
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
|
| 405 |
+
|
| 406 |
+
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
| 407 |
+
raise ValueError(
|
| 408 |
+
f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than"
|
| 409 |
+
f" the maximum length ({generation_config.max_length})"
|
| 410 |
+
)
|
| 411 |
+
if input_ids_seq_length >= generation_config.max_length:
|
| 412 |
+
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
| 413 |
+
logger.warning(
|
| 414 |
+
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
| 415 |
+
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
| 416 |
+
" increasing`max_new_tokens`."
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
logits_processor = self._get_logits_processor(
|
| 420 |
+
generation_config=generation_config,
|
| 421 |
+
input_ids_seq_length=input_ids_seq_length,
|
| 422 |
+
logits_processor=logits_processor,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
if not generation_config.do_sample and generation_config.num_beams == 1:
|
| 426 |
+
return self._greedy_search(
|
| 427 |
+
input_ids,
|
| 428 |
+
generation_config.max_length,
|
| 429 |
+
generation_config.pad_token_id,
|
| 430 |
+
generation_config.eos_token_id,
|
| 431 |
+
logits_processor=logits_processor,
|
| 432 |
+
trace=trace,
|
| 433 |
+
params=params,
|
| 434 |
+
model_kwargs=model_kwargs,
|
| 435 |
+
)
|
| 436 |
+
elif generation_config.do_sample and generation_config.num_beams == 1:
|
| 437 |
+
logits_warper = self._get_logits_warper(generation_config=generation_config)
|
| 438 |
+
return self._sample(
|
| 439 |
+
input_ids,
|
| 440 |
+
generation_config.max_length,
|
| 441 |
+
generation_config.pad_token_id,
|
| 442 |
+
generation_config.eos_token_id,
|
| 443 |
+
prng_key,
|
| 444 |
+
logits_warper=logits_warper,
|
| 445 |
+
logits_processor=logits_processor,
|
| 446 |
+
trace=trace,
|
| 447 |
+
params=params,
|
| 448 |
+
model_kwargs=model_kwargs,
|
| 449 |
+
)
|
| 450 |
+
elif not generation_config.do_sample and generation_config.num_beams > 1:
|
| 451 |
+
# broadcast input_ids & encoder_outputs
|
| 452 |
+
input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams)
|
| 453 |
+
|
| 454 |
+
if "encoder_outputs" in model_kwargs:
|
| 455 |
+
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
|
| 456 |
+
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
for kwarg in ["attention_mask", "decoder_attention_mask"]:
|
| 460 |
+
if kwarg in model_kwargs:
|
| 461 |
+
model_kwargs[kwarg] = self._expand_to_num_beams(
|
| 462 |
+
model_kwargs[kwarg], num_beams=generation_config.num_beams
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
return self._beam_search(
|
| 466 |
+
input_ids,
|
| 467 |
+
generation_config.max_length,
|
| 468 |
+
generation_config.pad_token_id,
|
| 469 |
+
generation_config.eos_token_id,
|
| 470 |
+
length_penalty=generation_config.length_penalty,
|
| 471 |
+
early_stopping=generation_config.early_stopping,
|
| 472 |
+
logits_processor=logits_processor,
|
| 473 |
+
trace=trace,
|
| 474 |
+
params=params,
|
| 475 |
+
num_return_sequences=generation_config.num_return_sequences,
|
| 476 |
+
model_kwargs=model_kwargs,
|
| 477 |
+
)
|
| 478 |
+
else:
|
| 479 |
+
raise NotImplementedError("`Beam sampling is currently not implemented.")
|
| 480 |
+
|
| 481 |
+
def _get_logits_warper(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList:
|
| 482 |
+
"""
|
| 483 |
+
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
|
| 484 |
+
instances used for multinomial sampling.
|
| 485 |
+
"""
|
| 486 |
+
warpers = FlaxLogitsProcessorList()
|
| 487 |
+
|
| 488 |
+
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
| 489 |
+
warpers.append(FlaxTemperatureLogitsWarper(generation_config.temperature))
|
| 490 |
+
if generation_config.top_k is not None and generation_config.top_k != 0:
|
| 491 |
+
warpers.append(FlaxTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))
|
| 492 |
+
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
| 493 |
+
warpers.append(FlaxTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1))
|
| 494 |
+
|
| 495 |
+
return warpers
|
| 496 |
+
|
| 497 |
+
def _get_logits_processor(
|
| 498 |
+
self,
|
| 499 |
+
generation_config: GenerationConfig,
|
| 500 |
+
input_ids_seq_length: int,
|
| 501 |
+
logits_processor: Optional[FlaxLogitsProcessorList],
|
| 502 |
+
) -> FlaxLogitsProcessorList:
|
| 503 |
+
"""
|
| 504 |
+
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
|
| 505 |
+
instances used to modify the scores of the language model head.
|
| 506 |
+
"""
|
| 507 |
+
processors = FlaxLogitsProcessorList()
|
| 508 |
+
|
| 509 |
+
if (
|
| 510 |
+
generation_config.min_length is not None
|
| 511 |
+
and generation_config.eos_token_id is not None
|
| 512 |
+
and generation_config.min_length > -1
|
| 513 |
+
):
|
| 514 |
+
processors.append(
|
| 515 |
+
FlaxMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)
|
| 516 |
+
)
|
| 517 |
+
if generation_config.forced_bos_token_id is not None:
|
| 518 |
+
processors.append(FlaxForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
|
| 519 |
+
if generation_config.forced_eos_token_id is not None:
|
| 520 |
+
processors.append(
|
| 521 |
+
FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
|
| 522 |
+
)
|
| 523 |
+
if generation_config.suppress_tokens is not None:
|
| 524 |
+
processors.append(FlaxSuppressTokensLogitsProcessor(generation_config.suppress_tokens))
|
| 525 |
+
if generation_config.begin_suppress_tokens is not None:
|
| 526 |
+
begin_index = input_ids_seq_length
|
| 527 |
+
begin_index = (
|
| 528 |
+
begin_index
|
| 529 |
+
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
|
| 530 |
+
else begin_index + 1
|
| 531 |
+
)
|
| 532 |
+
if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0:
|
| 533 |
+
# generation starts after the last token that is forced
|
| 534 |
+
begin_index += generation_config.forced_decoder_ids[-1][0]
|
| 535 |
+
processors.append(
|
| 536 |
+
FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
|
| 537 |
+
)
|
| 538 |
+
if generation_config.forced_decoder_ids is not None:
|
| 539 |
+
forced_decoder_ids = [
|
| 540 |
+
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
|
| 541 |
+
]
|
| 542 |
+
processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
|
| 543 |
+
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
|
| 544 |
+
processors.append(FlaxNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
|
| 545 |
+
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
| 546 |
+
|
| 547 |
+
return processors
|
| 548 |
+
|
| 549 |
+
def _merge_criteria_processor_list(
|
| 550 |
+
self,
|
| 551 |
+
default_list: FlaxLogitsProcessorList,
|
| 552 |
+
custom_list: FlaxLogitsProcessorList,
|
| 553 |
+
) -> FlaxLogitsProcessorList:
|
| 554 |
+
if len(custom_list) == 0:
|
| 555 |
+
return default_list
|
| 556 |
+
for default in default_list:
|
| 557 |
+
for custom in custom_list:
|
| 558 |
+
if type(custom) is type(default):
|
| 559 |
+
object_type = "logits processor"
|
| 560 |
+
raise ValueError(
|
| 561 |
+
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
|
| 562 |
+
f" `generate`, but it has already been created with the values {default}. {default} has been"
|
| 563 |
+
" created by passing the corresponding arguments to generate or by the model's config default"
|
| 564 |
+
f" values. If you just want to change the default values of {object_type} consider passing"
|
| 565 |
+
f" them as arguments to `generate` instead of using a custom {object_type}."
|
| 566 |
+
)
|
| 567 |
+
default_list.extend(custom_list)
|
| 568 |
+
return default_list
|
| 569 |
+
|
| 570 |
+
def _greedy_search(
|
| 571 |
+
self,
|
| 572 |
+
input_ids: None,
|
| 573 |
+
max_length: Optional[int] = None,
|
| 574 |
+
pad_token_id: Optional[int] = None,
|
| 575 |
+
eos_token_id: Optional[int] = None,
|
| 576 |
+
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
| 577 |
+
trace: bool = True,
|
| 578 |
+
params: Optional[Dict[str, jnp.ndarray]] = None,
|
| 579 |
+
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
| 580 |
+
):
|
| 581 |
+
# init values
|
| 582 |
+
max_length = max_length if max_length is not None else self.generation_config.max_length
|
| 583 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
| 584 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
| 585 |
+
|
| 586 |
+
batch_size, cur_len = input_ids.shape
|
| 587 |
+
|
| 588 |
+
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
|
| 589 |
+
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
|
| 590 |
+
cur_len = jnp.array(cur_len)
|
| 591 |
+
|
| 592 |
+
# per batch-item holding current token in loop.
|
| 593 |
+
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
|
| 594 |
+
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
|
| 595 |
+
|
| 596 |
+
# per batch-item state bit indicating if sentence has finished.
|
| 597 |
+
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
| 598 |
+
|
| 599 |
+
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
| 600 |
+
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
| 601 |
+
model = self.decode if self.config.is_encoder_decoder else self
|
| 602 |
+
# initialize model specific kwargs
|
| 603 |
+
model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
| 604 |
+
|
| 605 |
+
# initialize state
|
| 606 |
+
state = GreedyState(
|
| 607 |
+
cur_len=cur_len,
|
| 608 |
+
sequences=sequences,
|
| 609 |
+
running_token=input_ids,
|
| 610 |
+
is_sent_finished=is_sent_finished,
|
| 611 |
+
model_kwargs=model_kwargs,
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
def greedy_search_cond_fn(state):
|
| 615 |
+
"""state termination condition fn."""
|
| 616 |
+
has_reached_max_length = state.cur_len == max_length
|
| 617 |
+
all_sequence_finished = jnp.all(state.is_sent_finished)
|
| 618 |
+
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
|
| 619 |
+
return ~finish_generation
|
| 620 |
+
|
| 621 |
+
def greedy_search_body_fn(state):
|
| 622 |
+
"""state update fn."""
|
| 623 |
+
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
|
| 624 |
+
logits = model_outputs.logits[:, -1]
|
| 625 |
+
|
| 626 |
+
# apply min_length, ...
|
| 627 |
+
logits = logits_processor(state.sequences, logits, state.cur_len)
|
| 628 |
+
|
| 629 |
+
next_token = jnp.argmax(logits, axis=-1)
|
| 630 |
+
|
| 631 |
+
next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
|
| 632 |
+
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
| 633 |
+
next_token = next_token[:, None]
|
| 634 |
+
|
| 635 |
+
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
| 636 |
+
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
| 637 |
+
return GreedyState(
|
| 638 |
+
cur_len=state.cur_len + 1,
|
| 639 |
+
sequences=next_sequences,
|
| 640 |
+
running_token=next_token,
|
| 641 |
+
is_sent_finished=next_is_sent_finished,
|
| 642 |
+
model_kwargs=next_model_kwargs,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
| 646 |
+
if input_ids.shape[1] > 1:
|
| 647 |
+
state = greedy_search_body_fn(state)
|
| 648 |
+
|
| 649 |
+
if not trace:
|
| 650 |
+
state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state)
|
| 651 |
+
else:
|
| 652 |
+
state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state)
|
| 653 |
+
|
| 654 |
+
return FlaxGreedySearchOutput(sequences=state.sequences)
|
| 655 |
+
|
| 656 |
+
def _sample(
|
| 657 |
+
self,
|
| 658 |
+
input_ids: None,
|
| 659 |
+
max_length: Optional[int] = None,
|
| 660 |
+
pad_token_id: Optional[int] = None,
|
| 661 |
+
eos_token_id: Optional[int] = None,
|
| 662 |
+
prng_key: Optional[jnp.ndarray] = None,
|
| 663 |
+
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
| 664 |
+
logits_warper: Optional[FlaxLogitsProcessorList] = None,
|
| 665 |
+
trace: bool = True,
|
| 666 |
+
params: Optional[Dict[str, jnp.ndarray]] = None,
|
| 667 |
+
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
| 668 |
+
):
|
| 669 |
+
# init values
|
| 670 |
+
max_length = max_length if max_length is not None else self.generation_config.max_length
|
| 671 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
| 672 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
| 673 |
+
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
| 674 |
+
|
| 675 |
+
batch_size, cur_len = input_ids.shape
|
| 676 |
+
|
| 677 |
+
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
|
| 678 |
+
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
|
| 679 |
+
cur_len = jnp.array(cur_len)
|
| 680 |
+
|
| 681 |
+
# per batch-item holding current token in loop.
|
| 682 |
+
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
|
| 683 |
+
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
|
| 684 |
+
|
| 685 |
+
# per batch-item state bit indicating if sentence has finished.
|
| 686 |
+
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
| 687 |
+
|
| 688 |
+
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
| 689 |
+
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
| 690 |
+
model = self.decode if self.config.is_encoder_decoder else self
|
| 691 |
+
|
| 692 |
+
# initialize model specific kwargs
|
| 693 |
+
model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
| 694 |
+
|
| 695 |
+
# initialize state
|
| 696 |
+
state = SampleState(
|
| 697 |
+
cur_len=cur_len,
|
| 698 |
+
sequences=sequences,
|
| 699 |
+
running_token=input_ids,
|
| 700 |
+
is_sent_finished=is_sent_finished,
|
| 701 |
+
prng_key=prng_key,
|
| 702 |
+
model_kwargs=model_kwargs,
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
def sample_search_cond_fn(state):
|
| 706 |
+
"""state termination condition fn."""
|
| 707 |
+
has_reached_max_length = state.cur_len == max_length
|
| 708 |
+
all_sequence_finished = jnp.all(state.is_sent_finished)
|
| 709 |
+
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
|
| 710 |
+
return ~finish_generation
|
| 711 |
+
|
| 712 |
+
def sample_search_body_fn(state):
|
| 713 |
+
"""state update fn."""
|
| 714 |
+
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
| 715 |
+
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
|
| 716 |
+
|
| 717 |
+
logits = model_outputs.logits[:, -1]
|
| 718 |
+
|
| 719 |
+
# apply min_length, ...
|
| 720 |
+
logits = logits_processor(state.sequences, logits, state.cur_len)
|
| 721 |
+
# apply top_p, top_k, temperature
|
| 722 |
+
logits = logits_warper(logits, logits, state.cur_len)
|
| 723 |
+
|
| 724 |
+
next_token = jax.random.categorical(prng_key, logits, axis=-1)
|
| 725 |
+
|
| 726 |
+
next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
|
| 727 |
+
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
| 728 |
+
next_token = next_token[:, None]
|
| 729 |
+
|
| 730 |
+
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
| 731 |
+
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
| 732 |
+
|
| 733 |
+
return SampleState(
|
| 734 |
+
cur_len=state.cur_len + 1,
|
| 735 |
+
sequences=next_sequences,
|
| 736 |
+
running_token=next_token,
|
| 737 |
+
is_sent_finished=next_is_sent_finished,
|
| 738 |
+
model_kwargs=next_model_kwargs,
|
| 739 |
+
prng_key=prng_key_next,
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
| 743 |
+
if input_ids.shape[1] > 1:
|
| 744 |
+
state = sample_search_body_fn(state)
|
| 745 |
+
|
| 746 |
+
if not trace:
|
| 747 |
+
state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
|
| 748 |
+
else:
|
| 749 |
+
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
|
| 750 |
+
|
| 751 |
+
return FlaxSampleOutput(sequences=state.sequences)
|
| 752 |
+
|
| 753 |
+
def _beam_search(
|
| 754 |
+
self,
|
| 755 |
+
input_ids: None,
|
| 756 |
+
max_length: Optional[int] = None,
|
| 757 |
+
pad_token_id: Optional[int] = None,
|
| 758 |
+
eos_token_id: Optional[int] = None,
|
| 759 |
+
length_penalty: Optional[float] = None,
|
| 760 |
+
early_stopping: Optional[Union[bool, str]] = None,
|
| 761 |
+
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
| 762 |
+
trace: bool = True,
|
| 763 |
+
params: Optional[Dict[str, jnp.ndarray]] = None,
|
| 764 |
+
num_return_sequences: Optional[int] = None,
|
| 765 |
+
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
| 766 |
+
):
|
| 767 |
+
"""
|
| 768 |
+
This beam search function is heavily inspired by Flax's official example:
|
| 769 |
+
https://github.com/google/flax/blob/main/examples/wmt/decode.py
|
| 770 |
+
"""
|
| 771 |
+
|
| 772 |
+
def flatten_beam_dim(tensor):
|
| 773 |
+
"""Flattens the first two dimensions of a non-scalar array."""
|
| 774 |
+
# ignore scalars (e.g. cache index)
|
| 775 |
+
if tensor.ndim == 0:
|
| 776 |
+
return tensor
|
| 777 |
+
return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
| 778 |
+
|
| 779 |
+
def unflatten_beam_dim(tensor, batch_size, num_beams):
|
| 780 |
+
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
| 781 |
+
# ignore scalars (e.g. cache index)
|
| 782 |
+
if tensor.ndim == 0:
|
| 783 |
+
return tensor
|
| 784 |
+
return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
|
| 785 |
+
|
| 786 |
+
def gather_beams(nested, beam_indices, batch_size, new_num_beams):
|
| 787 |
+
"""
|
| 788 |
+
Gathers the beam slices indexed by beam_indices into new beam array.
|
| 789 |
+
"""
|
| 790 |
+
batch_indices = jnp.reshape(
|
| 791 |
+
jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
def gather_fn(tensor):
|
| 795 |
+
# ignore scalars (e.g. cache index)
|
| 796 |
+
if tensor.ndim == 0:
|
| 797 |
+
return tensor
|
| 798 |
+
else:
|
| 799 |
+
return tensor[batch_indices, beam_indices]
|
| 800 |
+
|
| 801 |
+
return jax.tree_util.tree_map(gather_fn, nested)
|
| 802 |
+
|
| 803 |
+
# init values
|
| 804 |
+
max_length = max_length if max_length is not None else self.generation_config.max_length
|
| 805 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
| 806 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
| 807 |
+
length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty
|
| 808 |
+
early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping
|
| 809 |
+
num_return_sequences = (
|
| 810 |
+
num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
batch_size, num_beams, cur_len = input_ids.shape
|
| 814 |
+
|
| 815 |
+
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
|
| 816 |
+
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
|
| 817 |
+
cur_len = jnp.array(cur_len)
|
| 818 |
+
|
| 819 |
+
# record the prompt length of decoder
|
| 820 |
+
decoder_prompt_len = input_ids.shape[-1]
|
| 821 |
+
|
| 822 |
+
# per batch,beam-item holding current token in loop.
|
| 823 |
+
sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
| 824 |
+
running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
| 825 |
+
running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
|
| 826 |
+
|
| 827 |
+
# per batch,beam-item state bit indicating if sentence has finished.
|
| 828 |
+
is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
|
| 829 |
+
|
| 830 |
+
# per batch,beam-item score, logprobs
|
| 831 |
+
running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
|
| 832 |
+
scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
|
| 833 |
+
|
| 834 |
+
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
| 835 |
+
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
| 836 |
+
model = self.decode if self.config.is_encoder_decoder else self
|
| 837 |
+
|
| 838 |
+
# flatten beam dim
|
| 839 |
+
if "encoder_outputs" in model_kwargs:
|
| 840 |
+
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
|
| 841 |
+
model_kwargs["encoder_outputs"]["last_hidden_state"]
|
| 842 |
+
)
|
| 843 |
+
for kwarg in ["attention_mask", "decoder_attention_mask"]:
|
| 844 |
+
if kwarg in model_kwargs:
|
| 845 |
+
model_kwargs[kwarg] = flatten_beam_dim(model_kwargs[kwarg])
|
| 846 |
+
|
| 847 |
+
# initialize model specific kwargs
|
| 848 |
+
model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
|
| 849 |
+
|
| 850 |
+
# initialize state
|
| 851 |
+
state = BeamSearchState(
|
| 852 |
+
cur_len=cur_len,
|
| 853 |
+
running_sequences=running_sequences,
|
| 854 |
+
running_scores=running_scores,
|
| 855 |
+
sequences=sequences,
|
| 856 |
+
scores=scores,
|
| 857 |
+
is_sent_finished=is_sent_finished,
|
| 858 |
+
model_kwargs=model_kwargs,
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
def beam_search_cond_fn(state):
|
| 862 |
+
"""beam search state termination condition fn."""
|
| 863 |
+
|
| 864 |
+
# 1. is less than max length?
|
| 865 |
+
not_max_length_yet = state.cur_len < max_length
|
| 866 |
+
|
| 867 |
+
# 2. can the new beams still improve?
|
| 868 |
+
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
|
| 869 |
+
# below for more details.
|
| 870 |
+
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
| 871 |
+
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
|
| 872 |
+
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
|
| 873 |
+
if early_stopping == "never" and length_penalty > 0.0:
|
| 874 |
+
best_running_score = state.running_scores[:, :1] / (
|
| 875 |
+
(max_length - decoder_prompt_len) ** length_penalty
|
| 876 |
+
)
|
| 877 |
+
else:
|
| 878 |
+
best_running_score = state.running_scores[:, :1] / (
|
| 879 |
+
(state.cur_len - decoder_prompt_len) ** length_penalty
|
| 880 |
+
)
|
| 881 |
+
worst_finished_score = jnp.where(
|
| 882 |
+
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
|
| 883 |
+
)
|
| 884 |
+
improvement_still_possible = jnp.any(best_running_score > worst_finished_score)
|
| 885 |
+
|
| 886 |
+
# 3. is there still a beam that has not finished?
|
| 887 |
+
still_open_beam = ~(jnp.all(state.is_sent_finished) & (early_stopping is True))
|
| 888 |
+
|
| 889 |
+
return not_max_length_yet & still_open_beam & improvement_still_possible
|
| 890 |
+
|
| 891 |
+
def beam_search_body_fn(state, input_ids_length=1):
|
| 892 |
+
"""beam search state update fn."""
|
| 893 |
+
# 1. Forward current tokens
|
| 894 |
+
# Collect the current position slice along length to feed the fast
|
| 895 |
+
# autoregressive decoder model. Flatten the beam dimension into batch
|
| 896 |
+
# dimension for feeding into the model.
|
| 897 |
+
# unflatten beam dimension
|
| 898 |
+
# Unflatten beam dimension in attention cache arrays
|
| 899 |
+
input_token = flatten_beam_dim(
|
| 900 |
+
lax.dynamic_slice(
|
| 901 |
+
state.running_sequences,
|
| 902 |
+
(0, 0, state.cur_len - input_ids_length),
|
| 903 |
+
(batch_size, num_beams, input_ids_length),
|
| 904 |
+
)
|
| 905 |
+
)
|
| 906 |
+
model_outputs = model(input_token, params=params, **state.model_kwargs)
|
| 907 |
+
|
| 908 |
+
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
|
| 909 |
+
cache = jax.tree_util.tree_map(
|
| 910 |
+
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
# adapt logits for FlaxMarianMTModel
|
| 914 |
+
logits = self._adapt_logits_for_beam_search(logits)
|
| 915 |
+
|
| 916 |
+
# 2. Compute log probs
|
| 917 |
+
# get log probabilities from logits,
|
| 918 |
+
# process logits with processors (*e.g.* min_length, ...), and
|
| 919 |
+
# add new logprobs to existing running logprobs scores.
|
| 920 |
+
log_probs = jax.nn.log_softmax(logits)
|
| 921 |
+
log_probs = logits_processor(
|
| 922 |
+
flatten_beam_dim(state.running_sequences), flatten_beam_dim(log_probs), state.cur_len
|
| 923 |
+
)
|
| 924 |
+
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
| 925 |
+
log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
|
| 926 |
+
vocab_size = log_probs.shape[2]
|
| 927 |
+
log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
|
| 928 |
+
|
| 929 |
+
# 3. Retrieve top-K
|
| 930 |
+
# Each item in batch has num_beams * vocab_size candidate sequences.
|
| 931 |
+
# For each item, get the top 2*k candidates with the highest log-
|
| 932 |
+
# probabilities. We gather the top 2*K beams here so that even if the best
|
| 933 |
+
# K sequences reach EOS simultaneously, we have another K sequences
|
| 934 |
+
# remaining to continue the live beam search.
|
| 935 |
+
# Gather the top 2*K scores from _all_ beams.
|
| 936 |
+
# Gather 2*k top beams.
|
| 937 |
+
# Recover the beam index by floor division.
|
| 938 |
+
# Recover token id by modulo division and expand Id array for broadcasting.
|
| 939 |
+
# Update sequences for the 2*K top-k new sequences.
|
| 940 |
+
beams_to_keep = 2 * num_beams
|
| 941 |
+
topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
|
| 942 |
+
topk_beam_indices = topk_indices // vocab_size
|
| 943 |
+
topk_running_sequences = gather_beams(
|
| 944 |
+
state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
|
| 945 |
+
)
|
| 946 |
+
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
|
| 947 |
+
topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
|
| 948 |
+
|
| 949 |
+
# 4. Check which sequences have ended
|
| 950 |
+
# Update current sequences:
|
| 951 |
+
# Did any of these sequences reach an end marker?
|
| 952 |
+
# To prevent these just finished sequences from being added to the current sequences
|
| 953 |
+
# set of active beam search sequences, set their log probs to a very large
|
| 954 |
+
# negative value.
|
| 955 |
+
did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
|
| 956 |
+
running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
|
| 957 |
+
# 5. Get running sequences scores for next
|
| 958 |
+
# Determine the top k beam indices (from top 2*k beams) from log probs
|
| 959 |
+
# and gather top k beams (from top 2*k beams).
|
| 960 |
+
next_topk_indices = lax.top_k(running_topk_log_probs, k=num_beams)[1]
|
| 961 |
+
next_running_sequences, next_running_scores = gather_beams(
|
| 962 |
+
[topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
# 6. Process topk logits
|
| 966 |
+
# Further process log probs:
|
| 967 |
+
# - add length penalty
|
| 968 |
+
# - make sure no scores can be added anymore if beam is full
|
| 969 |
+
# - make sure still running sequences cannot be chosen as finalized beam
|
| 970 |
+
topk_log_probs = topk_log_probs / ((state.cur_len + 1 - decoder_prompt_len) ** length_penalty)
|
| 971 |
+
beams_in_batch_are_full = jnp.broadcast_to(
|
| 972 |
+
state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape
|
| 973 |
+
) & (early_stopping is True)
|
| 974 |
+
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
|
| 975 |
+
topk_log_probs += add_penalty * np.array(-1.0e7)
|
| 976 |
+
|
| 977 |
+
# 7. Get scores, sequences, is sentence finished for next.
|
| 978 |
+
# Combine sequences, scores, and flags along the beam dimension and compare
|
| 979 |
+
# new finished sequence scores to existing finished scores and select the
|
| 980 |
+
# best from the new set of beams
|
| 981 |
+
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
|
| 982 |
+
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
|
| 983 |
+
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
|
| 984 |
+
topk_merged_indices = lax.top_k(merged_scores, k=num_beams)[1]
|
| 985 |
+
next_sequences, next_scores, next_is_sent_finished = gather_beams(
|
| 986 |
+
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
# 8. Update model kwargs.
|
| 990 |
+
# Determine the top k beam indices from the original set of all beams.
|
| 991 |
+
# With these, gather the top k beam-associated caches.
|
| 992 |
+
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
|
| 993 |
+
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
|
| 994 |
+
model_outputs["past_key_values"] = jax.tree_util.tree_map(lambda x: flatten_beam_dim(x), next_cache)
|
| 995 |
+
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
| 996 |
+
|
| 997 |
+
return BeamSearchState(
|
| 998 |
+
cur_len=state.cur_len + 1,
|
| 999 |
+
running_scores=next_running_scores,
|
| 1000 |
+
running_sequences=next_running_sequences,
|
| 1001 |
+
scores=next_scores,
|
| 1002 |
+
sequences=next_sequences,
|
| 1003 |
+
is_sent_finished=next_is_sent_finished,
|
| 1004 |
+
model_kwargs=next_model_kwargs,
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
# Always run first iteration outside of `lax.while_loop` to avoid calling `beam_search_cond_fn`
|
| 1008 |
+
# when `state.cur_len` equals `decoder_prompt_len`. This also helps to comply with TPU when
|
| 1009 |
+
# the very first prompt has sequence length > 1.
|
| 1010 |
+
state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
|
| 1011 |
+
|
| 1012 |
+
if not trace:
|
| 1013 |
+
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
|
| 1014 |
+
else:
|
| 1015 |
+
state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
|
| 1016 |
+
|
| 1017 |
+
# Account for the edge-case where there are no finished sequences for a
|
| 1018 |
+
# particular batch item. If so, return running sequences for that batch item.
|
| 1019 |
+
none_finished = jnp.any(state.is_sent_finished, axis=1)
|
| 1020 |
+
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
|
| 1021 |
+
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
|
| 1022 |
+
|
| 1023 |
+
# Take best beams for each batch (the score is sorted in descending order)
|
| 1024 |
+
sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
|
| 1025 |
+
scores = flatten_beam_dim(scores[:, :num_return_sequences])
|
| 1026 |
+
|
| 1027 |
+
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
|
.venv/lib/python3.11/site-packages/transformers/generation/logits_process.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/stopping_criteria.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import warnings
|
| 3 |
+
from abc import ABC
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
|
| 12 |
+
from ..pytorch_utils import isin_mps_friendly
|
| 13 |
+
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
| 14 |
+
from ..utils import add_start_docstrings, logging
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.get_logger(__name__)
|
| 18 |
+
# We maintain a module-level cache of the embedding vectors for the stop string criterion
|
| 19 |
+
# because they are slow to compute
|
| 20 |
+
STOP_STRING_EMBEDDING_CACHE = OrderedDict()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
|
| 24 |
+
Args:
|
| 25 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 26 |
+
Indices of input sequence tokens in the vocabulary.
|
| 27 |
+
|
| 28 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 29 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 30 |
+
|
| 31 |
+
[What are input IDs?](../glossary#input-ids)
|
| 32 |
+
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
|
| 33 |
+
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
|
| 34 |
+
or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input,
|
| 35 |
+
make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`.
|
| 36 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 37 |
+
Additional stopping criteria specific kwargs.
|
| 38 |
+
|
| 39 |
+
Return:
|
| 40 |
+
`torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation
|
| 41 |
+
for a particular row, `True` indicates we should continue.
|
| 42 |
+
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class StoppingCriteria(ABC):
|
| 47 |
+
"""Abstract base class for all stopping criteria that can be applied during generation.
|
| 48 |
+
|
| 49 |
+
If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True,
|
| 50 |
+
output_scores=True` to `generate`.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
| 54 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
| 55 |
+
raise NotImplementedError("StoppingCriteria needs to be subclassed")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MaxLengthCriteria(StoppingCriteria):
|
| 59 |
+
"""
|
| 60 |
+
This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep
|
| 61 |
+
in mind for decoder-only type of transformers, this will include the initial prompted tokens.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
max_length (`int`):
|
| 65 |
+
The maximum length that the output sequence can have in number of tokens.
|
| 66 |
+
max_position_embeddings (`int`, *optional*):
|
| 67 |
+
The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None):
|
| 71 |
+
self.max_length = max_length
|
| 72 |
+
self.max_position_embeddings = max_position_embeddings
|
| 73 |
+
|
| 74 |
+
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
| 75 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
| 76 |
+
cur_len = input_ids.shape[-1]
|
| 77 |
+
is_done = cur_len >= self.max_length
|
| 78 |
+
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
|
| 79 |
+
logger.warning_once(
|
| 80 |
+
"This is a friendly reminder - the current text generation call will exceed the model's predefined "
|
| 81 |
+
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
|
| 82 |
+
"exceptions, performance degradation, or nothing at all."
|
| 83 |
+
)
|
| 84 |
+
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class MaxTimeCriteria(StoppingCriteria):
|
| 88 |
+
"""
|
| 89 |
+
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
|
| 90 |
+
time will start being counted when you initialize this function. You can override this by passing an
|
| 91 |
+
`initial_time`.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
max_time (`float`):
|
| 95 |
+
The maximum allowed time in seconds for the generation.
|
| 96 |
+
initial_time (`float`, *optional*, defaults to `time.time()`):
|
| 97 |
+
The start of the generation allowed time.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
|
| 101 |
+
self.max_time = max_time
|
| 102 |
+
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
|
| 103 |
+
|
| 104 |
+
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
| 105 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
| 106 |
+
is_done = time.time() - self.initial_timestamp > self.max_time
|
| 107 |
+
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class StopStringCriteria(StoppingCriteria):
|
| 111 |
+
"""
|
| 112 |
+
This class can be used to stop generation whenever specific string sequences are generated. It preprocesses
|
| 113 |
+
the strings together with the tokenizer vocab to find positions where tokens can validly complete the stop strings.
|
| 114 |
+
|
| 115 |
+
Generation is stopped as soon as a token is generated that completes any of the stop strings.
|
| 116 |
+
We want to catch any instance in which the stop string would be present in the decoded output, which means
|
| 117 |
+
we must also catch cases with "overhangs" off one or both ends. To make this more concrete, for the stop string
|
| 118 |
+
"stop", any of the following token sequences would trigger the match:
|
| 119 |
+
|
| 120 |
+
- ["st", "op"]
|
| 121 |
+
- ["stop"]
|
| 122 |
+
- ["st", "opera"]
|
| 123 |
+
- ["sto", "pper"]
|
| 124 |
+
- ["las", "topper"]
|
| 125 |
+
- ["s", "to", "pped"]
|
| 126 |
+
|
| 127 |
+
Note that a match will only be triggered if the stop string is at the end of the generated sequence. In other
|
| 128 |
+
words, these sequences will not trigger a match:
|
| 129 |
+
|
| 130 |
+
- ["stop", "at"]
|
| 131 |
+
- ["st", "op", "at"]
|
| 132 |
+
- ["st", "opera", "tion"]
|
| 133 |
+
|
| 134 |
+
The reason these are not a match is that the stop string does not overlap with the final token. If you can remove
|
| 135 |
+
one or more tokens from the end of the sequence without destroying the stop string, then this criterion will not
|
| 136 |
+
match that stop string. This is by design; because this check is run after each token is generated, we can't miss a
|
| 137 |
+
valid stop string if one is generated, but we don't want to halt generation just because the stop string exists
|
| 138 |
+
somewhere in the past input_ids.
|
| 139 |
+
|
| 140 |
+
How is the match actually performed, though? We do it in quite a confusing way, because we want the entire match
|
| 141 |
+
process to be compilable with Torch or XLA, which means we cannot use standard string methods. However, it is possible,
|
| 142 |
+
with some work, to do string matching with pure tensor operations. We'll begin by describing the algorithm we use
|
| 143 |
+
with standard string operations, and then at the end we'll explain how this is converted to pure tensor operations.
|
| 144 |
+
|
| 145 |
+
The key to the algorithm is an observation: Because the stop string must overlap with the end of the token sequence, we can start at
|
| 146 |
+
the end of the sequence and work backwards. Specifically, we check that there is an overlap between the start of
|
| 147 |
+
the final token and the end of the stop_string, or to put it another way, stop_string[-i:] == token[:i] for
|
| 148 |
+
some i > 0. If you look at the positive examples above, you'll see the last token in all of them fulfills this
|
| 149 |
+
property:
|
| 150 |
+
|
| 151 |
+
- ["st", "op"] (overlap is "op", overlap length == 2)
|
| 152 |
+
- ["stop"] (overlap is "stop", overlap length == 4)
|
| 153 |
+
- ["st", "opera"] (overlap is "op", overlap length == 2)
|
| 154 |
+
- ["sto", "pper"] (overlap is "p", overlap length == 1)
|
| 155 |
+
- ["las", "topper"] (overlap is "top", overlap length == 3)
|
| 156 |
+
- ["s", "to", "pped"] (overlap is "p", overlap length == 1)
|
| 157 |
+
|
| 158 |
+
It's impossible to construct a matching sequence that does not have this property (feel free to verify this
|
| 159 |
+
yourself). However, although this overlap between the start of the final token and the end of the stop string is
|
| 160 |
+
necessary for a match, it is not sufficient. We also need to check that the rest of the token sequence is
|
| 161 |
+
consistent with the stop string.
|
| 162 |
+
|
| 163 |
+
How do we do that? Let's use ["s", "to", "pped"] as an example. We know that the final token, "pped", has an
|
| 164 |
+
overlap of 1 with the stop string, "stop". We then go back to the previous token, "to". Since we have already
|
| 165 |
+
matched 1 character from the stop string, the remainder to check is "sto". We check that the next token "to"
|
| 166 |
+
matches the end of the remainder, which it does. We have now matched 3 characters from the stop string, and the
|
| 167 |
+
remainder to match is "s". We go back to the previous token again, which is also "s". This is a match, and so
|
| 168 |
+
we have matched the entire stop string.
|
| 169 |
+
|
| 170 |
+
How does it work when the tokens run off the start of the stop string, though? Let's consider the example of
|
| 171 |
+
["las", "topper"]. The final token, "topper", has an overlap of 3 with the stop string, "stop". Therefore,
|
| 172 |
+
the remaining stop string to match is "s". We go back to the previous token, "las". Because the remainder to
|
| 173 |
+
match is just "s", with length 1, we consider only the final 1 character from the token, which is "s". This
|
| 174 |
+
matches the stop string, and so the entire string is matched.
|
| 175 |
+
|
| 176 |
+
How do we compute these matches with tensor operations, though? Simply: we efficiently precompute the necessary
|
| 177 |
+
information for all tokens! For every token, we compute:
|
| 178 |
+
- Its overlap with the end of the stop string, if any
|
| 179 |
+
- The positions inside the stop string where the token matches, including matches that run off the start.
|
| 180 |
+
- The total length of the token
|
| 181 |
+
|
| 182 |
+
For example, for the token "pped", we would compute an end overlap of 1, no internal matching positions,
|
| 183 |
+
and a length of 4. For the token "to", we would compute no end overlap, a single internal matching position
|
| 184 |
+
of 1 (counting from the end), and a length of 2. For the token "s", we would compute no end overlap,
|
| 185 |
+
a single internal matching position of 3 (again counting from the end) and a length of 1.
|
| 186 |
+
|
| 187 |
+
As long as we have this information, we can execute the algorithm above without any string comparison
|
| 188 |
+
operations. We simply perform the following steps:
|
| 189 |
+
- Check if the final token has an end-overlap with the start string
|
| 190 |
+
- Continue backwards, keeping track of how much of the stop string we've matched so far
|
| 191 |
+
- At each point, check if the next token has the current position as one of its valid positions
|
| 192 |
+
- Continue until either a match fails, or we completely match the whole stop string
|
| 193 |
+
|
| 194 |
+
Again, consider ["s", "to", "pped"] as an example. "pped" has an end overlap of 1, so we can begin a match.
|
| 195 |
+
We have matched 1 character so far, so we check that the next token "to", has 1 as a valid position (again,
|
| 196 |
+
counting from the end). It does, so we add the length of "to" to our position tracker. We have now matched
|
| 197 |
+
3 characters, so we check that the next token "s" has 3 as a valid position. It does, so we add its length
|
| 198 |
+
to the position tracker. The position tracker is now 4, which is the length of the stop string. We have matched the
|
| 199 |
+
entire stop string.
|
| 200 |
+
|
| 201 |
+
In the second case, ["las", "topper"], "topper" has an end overlap of 3, so we can begin a match. We have
|
| 202 |
+
matched 3 characters so far, so we check that the next token "las" has 3 as a valid position. It does, because we
|
| 203 |
+
allow tokens to match positions that run off the start of the stop string. We add its length to the position
|
| 204 |
+
tracker. The position tracker is now 6, which is greater than the length of the stop string! Don't panic, though -
|
| 205 |
+
this also counts as a match of the stop string. We have matched the entire stop string.
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
tokenizer (`PreTrainedTokenizer`):
|
| 210 |
+
The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences)
|
| 211 |
+
stop_strings (`Union[str, List[str]]`):
|
| 212 |
+
A list of strings that should end generation. If a string is passed, it will be treated like a
|
| 213 |
+
list with a single element.
|
| 214 |
+
|
| 215 |
+
Examples:
|
| 216 |
+
|
| 217 |
+
```python
|
| 218 |
+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 219 |
+
|
| 220 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
|
| 221 |
+
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
|
| 222 |
+
>>> inputs = tokenizer("The biggest states in the USA by land area:", return_tensors="pt")
|
| 223 |
+
|
| 224 |
+
>>> gen_out = model.generate(**inputs)
|
| 225 |
+
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
|
| 226 |
+
The biggest states in the USA by land area:
|
| 227 |
+
- Alaska
|
| 228 |
+
- Texas
|
| 229 |
+
- California
|
| 230 |
+
|
| 231 |
+
>>> # Passing one or more stop strings will halt generation after those strings are emitted
|
| 232 |
+
>>> # Note that generating with stop strings requires you to pass the tokenizer too
|
| 233 |
+
>>> gen_out = model.generate(**inputs, stop_strings=["Texas"], tokenizer=tokenizer)
|
| 234 |
+
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
|
| 235 |
+
The biggest states in the USA by land area:
|
| 236 |
+
- Alaska
|
| 237 |
+
- Texas
|
| 238 |
+
```
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, List[str]]):
|
| 242 |
+
if isinstance(stop_strings, str):
|
| 243 |
+
stop_strings = [stop_strings]
|
| 244 |
+
self.stop_strings: Tuple[str, ...] = tuple(stop_strings)
|
| 245 |
+
vocab = tokenizer.get_vocab()
|
| 246 |
+
token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values())
|
| 247 |
+
self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache(
|
| 248 |
+
token_list, token_indices, self.stop_strings, tokenizer
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings])
|
| 252 |
+
self.num_stop_strings = len(self.stop_strings)
|
| 253 |
+
self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32)
|
| 254 |
+
|
| 255 |
+
def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_strings, tokenizer):
|
| 256 |
+
# We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality
|
| 257 |
+
if (token_list, token_indices, stop_strings) in STOP_STRING_EMBEDDING_CACHE:
|
| 258 |
+
embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[
|
| 259 |
+
(token_list, token_indices, self.stop_strings)
|
| 260 |
+
]
|
| 261 |
+
STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, stop_strings))
|
| 262 |
+
else:
|
| 263 |
+
clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer)
|
| 264 |
+
embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec(
|
| 265 |
+
clean_token_list, clean_token_indices, stop_strings
|
| 266 |
+
)
|
| 267 |
+
STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, stop_strings)] = (
|
| 268 |
+
embedding_vec,
|
| 269 |
+
max_valid_positions,
|
| 270 |
+
max_valid_end_lens,
|
| 271 |
+
)
|
| 272 |
+
if len(STOP_STRING_EMBEDDING_CACHE) > 8:
|
| 273 |
+
STOP_STRING_EMBEDDING_CACHE.popitem(last=False) # Pop from the start, the least recently used item
|
| 274 |
+
return embedding_vec, max_valid_positions, max_valid_end_lens
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"):
|
| 278 |
+
"""
|
| 279 |
+
This method turns a tokenizer vocab into a "clean" vocab where each token represents the actual string
|
| 280 |
+
it will yield, without any special prefixes like "##" or "Ġ". This is trickier than it looks - the method
|
| 281 |
+
tokenizer.convert_tokens_to_string() does not always return the correct string because of issues with prefix
|
| 282 |
+
space addition/removal. To work around this, we add a static prefix to the start of the token, then remove
|
| 283 |
+
it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string().
|
| 284 |
+
"""
|
| 285 |
+
vocab = tokenizer.get_vocab()
|
| 286 |
+
clean_token_list = []
|
| 287 |
+
clean_token_indices = []
|
| 288 |
+
sentence_base = tokenizer(static_prefix, add_special_tokens=False)["input_ids"]
|
| 289 |
+
tokens_base = [tokenizer._convert_id_to_token(tok) for tok in sentence_base]
|
| 290 |
+
for token, token_idx in vocab.items():
|
| 291 |
+
token_string = tokenizer.convert_tokens_to_string(tokens_base + [token])
|
| 292 |
+
token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :]
|
| 293 |
+
clean_token_list.append(token_string)
|
| 294 |
+
clean_token_indices.append(token_idx)
|
| 295 |
+
return tuple(clean_token_list), tuple(clean_token_indices)
|
| 296 |
+
|
| 297 |
+
@staticmethod
|
| 298 |
+
def _stop_string_get_matching_positions(
|
| 299 |
+
token_list, token_indices, stop_strings
|
| 300 |
+
) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]:
|
| 301 |
+
"""This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can
|
| 302 |
+
validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the
|
| 303 |
+
token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters
|
| 304 |
+
from the end of the stop string that overlap with the start of the token, which can have more than one value.
|
| 305 |
+
|
| 306 |
+
The reason for computing these may seem a bit cryptic - please see the docstring for StopStringCriteria for a full
|
| 307 |
+
explanation of what these values are for!"""
|
| 308 |
+
|
| 309 |
+
token_valid_positions = {}
|
| 310 |
+
token_end_overlaps = {}
|
| 311 |
+
for stop_string in stop_strings:
|
| 312 |
+
reversed_stop_string = stop_string[::-1]
|
| 313 |
+
token_valid_positions[stop_string] = {}
|
| 314 |
+
token_end_overlaps[stop_string] = {}
|
| 315 |
+
for token, tok_idx in zip(token_list, token_indices):
|
| 316 |
+
reversed_token = token[::-1]
|
| 317 |
+
matching_positions = []
|
| 318 |
+
possible_end_lengths = []
|
| 319 |
+
for i in range(1 - len(token), len(stop_string)):
|
| 320 |
+
if i < 0:
|
| 321 |
+
tok = reversed_token[-i:]
|
| 322 |
+
i = 0
|
| 323 |
+
else:
|
| 324 |
+
tok = reversed_token
|
| 325 |
+
stop = reversed_stop_string[i : i + len(tok)]
|
| 326 |
+
if tok.startswith(stop):
|
| 327 |
+
if i == 0:
|
| 328 |
+
possible_end_lengths.append(min(len(tok), len(stop)))
|
| 329 |
+
else:
|
| 330 |
+
matching_positions.append(i)
|
| 331 |
+
|
| 332 |
+
if matching_positions:
|
| 333 |
+
token_valid_positions[stop_string][tok_idx] = matching_positions
|
| 334 |
+
if possible_end_lengths:
|
| 335 |
+
token_end_overlaps[stop_string][tok_idx] = possible_end_lengths
|
| 336 |
+
return token_valid_positions, token_end_overlaps
|
| 337 |
+
|
| 338 |
+
@staticmethod
|
| 339 |
+
def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) -> Dict[str, torch.tensor]:
|
| 340 |
+
"""This function precomputes everything needed for the run-time checks in StopStringCriteria, and packs
|
| 341 |
+
them into an embedding tensor that can be accessed with pure tensor operations. For the specifics of the values
|
| 342 |
+
that are precomputed and what they are used for, please refer to the StopStringCriteria docstring!"""
|
| 343 |
+
token_valid_positions, token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions(
|
| 344 |
+
token_list, token_indices, stop_strings
|
| 345 |
+
)
|
| 346 |
+
all_valid_positions = [len(val) for positions in token_valid_positions.values() for val in positions.values()]
|
| 347 |
+
# In some cases, tokens may have no valid internal positions (such as single-character stop strings), so
|
| 348 |
+
# we need a fallback to handle this case
|
| 349 |
+
max_valid_positions = max(all_valid_positions) if all_valid_positions else 1
|
| 350 |
+
# There should always be at least one valid end_len, however, so no fallback needed here
|
| 351 |
+
valid_end_lens = [len(val) for positions in token_end_overlaps.values() for val in positions.values()]
|
| 352 |
+
if not valid_end_lens:
|
| 353 |
+
raise ValueError(
|
| 354 |
+
"Stop string preprocessing was unable to identify tokens matching one or more of the "
|
| 355 |
+
"supplied stop string(s). This is most often caused by the stop "
|
| 356 |
+
"strings containing unusual characters that are not in the tokenizer vocabulary."
|
| 357 |
+
)
|
| 358 |
+
max_valid_end_lens = max(valid_end_lens)
|
| 359 |
+
vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1
|
| 360 |
+
gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1)
|
| 361 |
+
|
| 362 |
+
for i, stop_string in enumerate(stop_strings):
|
| 363 |
+
positions = token_valid_positions[stop_string]
|
| 364 |
+
end_lens = token_end_overlaps[stop_string]
|
| 365 |
+
|
| 366 |
+
# Since this is lots of very small assignments of lists, we build it with numpy rather
|
| 367 |
+
# than torch for speed + simplicity, then convert to torch at the end
|
| 368 |
+
for token_idx, valid_positions in positions.items():
|
| 369 |
+
gather_vec[token_idx, max_valid_positions * i : max_valid_positions * i + len(valid_positions)] = (
|
| 370 |
+
valid_positions
|
| 371 |
+
)
|
| 372 |
+
for token_idx, possible_end_lens in end_lens.items():
|
| 373 |
+
gather_vec[
|
| 374 |
+
token_idx,
|
| 375 |
+
max_valid_positions * len(stop_strings) + max_valid_end_lens * i : max_valid_positions
|
| 376 |
+
* len(stop_strings)
|
| 377 |
+
+ max_valid_end_lens * i
|
| 378 |
+
+ len(possible_end_lens),
|
| 379 |
+
] = possible_end_lens
|
| 380 |
+
for token, token_idx in zip(token_list, token_indices):
|
| 381 |
+
gather_vec[token_idx, -1] = len(token)
|
| 382 |
+
|
| 383 |
+
gather_vec = torch.tensor(gather_vec, dtype=torch.int32)
|
| 384 |
+
|
| 385 |
+
return gather_vec, max_valid_positions, max_valid_end_lens
|
| 386 |
+
|
| 387 |
+
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
| 388 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.Tensor:
|
| 389 |
+
self.embedding_vec = self.embedding_vec.to(input_ids.device)
|
| 390 |
+
self.target_lens = self.target_lens.to(input_ids.device)
|
| 391 |
+
# The maximum length we need to consider is 1 token per character. Note that input_ids can also be
|
| 392 |
+
# *shorter* than the global max, and the code below should be ready for that
|
| 393 |
+
input_ids = input_ids[:, -self.maximum_token_len :]
|
| 394 |
+
|
| 395 |
+
# Flip input_ids because we're only matching strings at the end of the generated sequence
|
| 396 |
+
flipped_ids = torch.flip(input_ids, (1,))
|
| 397 |
+
|
| 398 |
+
# Size of the vector of positions a single token can match
|
| 399 |
+
max_valid_positions = self.max_valid_positions
|
| 400 |
+
|
| 401 |
+
# The embedding vec contains the valid positions, end_lengths and total lengths for each token
|
| 402 |
+
embedded = F.embedding(flipped_ids, self.embedding_vec)
|
| 403 |
+
|
| 404 |
+
# Now we split the embedding vector. valid_positions is the positions in the stop string the token can fit
|
| 405 |
+
valid_positions = embedded[:, 1:, : max_valid_positions * self.num_stop_strings].unflatten(
|
| 406 |
+
-1, (self.num_stop_strings, -1)
|
| 407 |
+
)
|
| 408 |
+
# end_lengths is the number of characters from the string, counting from the end, that the token
|
| 409 |
+
# contains. It can have multiple values if the same token can overlap different end lengths
|
| 410 |
+
end_lengths = embedded[:, :1, max_valid_positions * self.num_stop_strings : -1].unflatten(
|
| 411 |
+
-1, (self.num_stop_strings, -1)
|
| 412 |
+
)
|
| 413 |
+
# Lengths is the total length of each token. Unlike the others, it always has a single value
|
| 414 |
+
lengths = embedded[:, 1:, None, -1:] # Insert a dummy dimension for stop_strings even though lengths are const
|
| 415 |
+
|
| 416 |
+
# Concatenate lengths onto each possible end_lengths value
|
| 417 |
+
lengths = lengths.expand((-1, -1, end_lengths.shape[-2], end_lengths.shape[-1]))
|
| 418 |
+
lengths_with_ends = torch.cat([end_lengths, lengths], dim=1)
|
| 419 |
+
|
| 420 |
+
# cumsum() to get the number of matched characters in the stop string after each token
|
| 421 |
+
cumsum = lengths_with_ends.cumsum(dim=1) # B x maximum_token_len x num_stop_strings x max_valid_end_lens
|
| 422 |
+
|
| 423 |
+
# The calculation above assumes that all tokens are in valid positions. Now we mask the ones that are not.
|
| 424 |
+
# First, tokens match the start of the string if they have a positive value in the end_lengths vector
|
| 425 |
+
initial_match = end_lengths > 0
|
| 426 |
+
|
| 427 |
+
# Tokens continue the string if the cumsum() so far is one of the valid positions for that token
|
| 428 |
+
# Note that we're actually tracking one cumsum() for for each possible end_length
|
| 429 |
+
later_match = torch.any(cumsum[:, :-1, :, None] == valid_positions[:, :, :, :, None], axis=-2)
|
| 430 |
+
|
| 431 |
+
# The match vector is a boolean vector that indicates which positions have valid tokens
|
| 432 |
+
match = torch.cat([initial_match, later_match], dim=1)
|
| 433 |
+
|
| 434 |
+
# Once a single position does not match, all positions following that position are masked
|
| 435 |
+
mask = (~match).cumsum(dim=1, dtype=torch.int32)
|
| 436 |
+
mask = mask == 0
|
| 437 |
+
|
| 438 |
+
# The string is matched if we reached a cumsum equal to or greater than the length of the string
|
| 439 |
+
# before hitting the mask
|
| 440 |
+
string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :]
|
| 441 |
+
|
| 442 |
+
# We return a per-sample vector that is True if any stop string is matched for that sample
|
| 443 |
+
return torch.any(string_matches, dim=-1)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class EosTokenCriteria(StoppingCriteria):
|
| 447 |
+
"""
|
| 448 |
+
This class can be used to stop generation whenever the "end-of-sequence" token is generated.
|
| 449 |
+
By default, it uses the `model.generation_config.eos_token_id`.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
eos_token_id (`Union[int, List[int], torch.Tensor]`):
|
| 453 |
+
The id(s) of the *end-of-sequence* token.
|
| 454 |
+
"""
|
| 455 |
+
|
| 456 |
+
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor]):
|
| 457 |
+
if not isinstance(eos_token_id, torch.Tensor):
|
| 458 |
+
if isinstance(eos_token_id, int):
|
| 459 |
+
eos_token_id = [eos_token_id]
|
| 460 |
+
eos_token_id = torch.tensor(eos_token_id)
|
| 461 |
+
self.eos_token_id = eos_token_id
|
| 462 |
+
|
| 463 |
+
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
| 464 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
| 465 |
+
self.eos_token_id = self.eos_token_id.to(input_ids.device)
|
| 466 |
+
is_done = isin_mps_friendly(input_ids[:, -1], self.eos_token_id)
|
| 467 |
+
return is_done
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class ConfidenceCriteria(StoppingCriteria):
|
| 471 |
+
"""
|
| 472 |
+
This class can be used to stop generation whenever assistant model's confidence in its prediction for the current token is lower than the threshold
|
| 473 |
+
`model.generation_config.assistant_confidence_threshold` even if the number of speculative tokens (defined by `num_assistant_tokens`) is not yet reached.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
assistant_confidence_threshold (`float`):
|
| 477 |
+
The value of the threshold.
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
def __init__(self, assistant_confidence_threshold):
|
| 481 |
+
self.assistant_confidence_threshold = assistant_confidence_threshold
|
| 482 |
+
|
| 483 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
| 484 |
+
probs = scores[-1].softmax(-1)
|
| 485 |
+
p = probs[0, input_ids[0, -1]].item()
|
| 486 |
+
if p < self.assistant_confidence_threshold:
|
| 487 |
+
return True
|
| 488 |
+
return False
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class StoppingCriteriaList(list):
|
| 492 |
+
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
| 493 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
| 494 |
+
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool)
|
| 495 |
+
for criteria in self:
|
| 496 |
+
is_done = is_done | criteria(input_ids, scores, **kwargs)
|
| 497 |
+
return is_done
|
| 498 |
+
|
| 499 |
+
@property
|
| 500 |
+
def max_length(self) -> Optional[int]:
|
| 501 |
+
for stopping_criterium in self:
|
| 502 |
+
if isinstance(stopping_criterium, MaxLengthCriteria):
|
| 503 |
+
return stopping_criterium.max_length
|
| 504 |
+
return None
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
|
| 508 |
+
stopping_max_length = stopping_criteria.max_length
|
| 509 |
+
new_stopping_criteria = deepcopy(stopping_criteria)
|
| 510 |
+
if stopping_max_length is not None and stopping_max_length != max_length:
|
| 511 |
+
warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
|
| 512 |
+
elif stopping_max_length is None:
|
| 513 |
+
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
|
| 514 |
+
return new_stopping_criteria
|
.venv/lib/python3.11/site-packages/transformers/generation/streamers.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import asyncio
|
| 19 |
+
from queue import Queue
|
| 20 |
+
from typing import TYPE_CHECKING, Optional
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
if TYPE_CHECKING:
|
| 24 |
+
from ..models.auto import AutoTokenizer
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BaseStreamer:
|
| 28 |
+
"""
|
| 29 |
+
Base class from which `.generate()` streamers should inherit.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def put(self, value):
|
| 33 |
+
"""Function that is called by `.generate()` to push new tokens"""
|
| 34 |
+
raise NotImplementedError()
|
| 35 |
+
|
| 36 |
+
def end(self):
|
| 37 |
+
"""Function that is called by `.generate()` to signal the end of generation"""
|
| 38 |
+
raise NotImplementedError()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TextStreamer(BaseStreamer):
|
| 42 |
+
"""
|
| 43 |
+
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
|
| 44 |
+
|
| 45 |
+
<Tip warning={true}>
|
| 46 |
+
|
| 47 |
+
The API for the streamer classes is still under development and may change in the future.
|
| 48 |
+
|
| 49 |
+
</Tip>
|
| 50 |
+
|
| 51 |
+
Parameters:
|
| 52 |
+
tokenizer (`AutoTokenizer`):
|
| 53 |
+
The tokenized used to decode the tokens.
|
| 54 |
+
skip_prompt (`bool`, *optional*, defaults to `False`):
|
| 55 |
+
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
| 56 |
+
decode_kwargs (`dict`, *optional*):
|
| 57 |
+
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
| 58 |
+
|
| 59 |
+
Examples:
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
| 63 |
+
|
| 64 |
+
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
| 65 |
+
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
| 66 |
+
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
| 67 |
+
>>> streamer = TextStreamer(tok)
|
| 68 |
+
|
| 69 |
+
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
|
| 70 |
+
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
|
| 71 |
+
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
| 72 |
+
```
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
| 76 |
+
self.tokenizer = tokenizer
|
| 77 |
+
self.skip_prompt = skip_prompt
|
| 78 |
+
self.decode_kwargs = decode_kwargs
|
| 79 |
+
|
| 80 |
+
# variables used in the streaming process
|
| 81 |
+
self.token_cache = []
|
| 82 |
+
self.print_len = 0
|
| 83 |
+
self.next_tokens_are_prompt = True
|
| 84 |
+
|
| 85 |
+
def put(self, value):
|
| 86 |
+
"""
|
| 87 |
+
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
| 88 |
+
"""
|
| 89 |
+
if len(value.shape) > 1 and value.shape[0] > 1:
|
| 90 |
+
raise ValueError("TextStreamer only supports batch size 1")
|
| 91 |
+
elif len(value.shape) > 1:
|
| 92 |
+
value = value[0]
|
| 93 |
+
|
| 94 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
| 95 |
+
self.next_tokens_are_prompt = False
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
# Add the new token to the cache and decodes the entire thing.
|
| 99 |
+
self.token_cache.extend(value.tolist())
|
| 100 |
+
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
| 101 |
+
|
| 102 |
+
# After the symbol for a new line, we flush the cache.
|
| 103 |
+
if text.endswith("\n"):
|
| 104 |
+
printable_text = text[self.print_len :]
|
| 105 |
+
self.token_cache = []
|
| 106 |
+
self.print_len = 0
|
| 107 |
+
# If the last token is a CJK character, we print the characters.
|
| 108 |
+
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
|
| 109 |
+
printable_text = text[self.print_len :]
|
| 110 |
+
self.print_len += len(printable_text)
|
| 111 |
+
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
| 112 |
+
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
| 113 |
+
else:
|
| 114 |
+
printable_text = text[self.print_len : text.rfind(" ") + 1]
|
| 115 |
+
self.print_len += len(printable_text)
|
| 116 |
+
|
| 117 |
+
self.on_finalized_text(printable_text)
|
| 118 |
+
|
| 119 |
+
def end(self):
|
| 120 |
+
"""Flushes any remaining cache and prints a newline to stdout."""
|
| 121 |
+
# Flush the cache, if it exists
|
| 122 |
+
if len(self.token_cache) > 0:
|
| 123 |
+
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
| 124 |
+
printable_text = text[self.print_len :]
|
| 125 |
+
self.token_cache = []
|
| 126 |
+
self.print_len = 0
|
| 127 |
+
else:
|
| 128 |
+
printable_text = ""
|
| 129 |
+
|
| 130 |
+
self.next_tokens_are_prompt = True
|
| 131 |
+
self.on_finalized_text(printable_text, stream_end=True)
|
| 132 |
+
|
| 133 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 134 |
+
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
|
| 135 |
+
print(text, flush=True, end="" if not stream_end else None)
|
| 136 |
+
|
| 137 |
+
def _is_chinese_char(self, cp):
|
| 138 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
| 139 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 140 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 141 |
+
#
|
| 142 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
| 143 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
| 144 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
| 145 |
+
# space-separated words, so they are not treated specially and handled
|
| 146 |
+
# like the all of the other languages.
|
| 147 |
+
if (
|
| 148 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
| 149 |
+
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
| 150 |
+
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
| 151 |
+
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
| 152 |
+
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
| 153 |
+
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
| 154 |
+
or (cp >= 0xF900 and cp <= 0xFAFF)
|
| 155 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
| 156 |
+
): #
|
| 157 |
+
return True
|
| 158 |
+
|
| 159 |
+
return False
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class TextIteratorStreamer(TextStreamer):
|
| 163 |
+
"""
|
| 164 |
+
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
|
| 165 |
+
useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive
|
| 166 |
+
Gradio demo).
|
| 167 |
+
|
| 168 |
+
<Tip warning={true}>
|
| 169 |
+
|
| 170 |
+
The API for the streamer classes is still under development and may change in the future.
|
| 171 |
+
|
| 172 |
+
</Tip>
|
| 173 |
+
|
| 174 |
+
Parameters:
|
| 175 |
+
tokenizer (`AutoTokenizer`):
|
| 176 |
+
The tokenized used to decode the tokens.
|
| 177 |
+
skip_prompt (`bool`, *optional*, defaults to `False`):
|
| 178 |
+
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
| 179 |
+
timeout (`float`, *optional*):
|
| 180 |
+
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
|
| 181 |
+
in `.generate()`, when it is called in a separate thread.
|
| 182 |
+
decode_kwargs (`dict`, *optional*):
|
| 183 |
+
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
| 184 |
+
|
| 185 |
+
Examples:
|
| 186 |
+
|
| 187 |
+
```python
|
| 188 |
+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 189 |
+
>>> from threading import Thread
|
| 190 |
+
|
| 191 |
+
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
| 192 |
+
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
| 193 |
+
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
| 194 |
+
>>> streamer = TextIteratorStreamer(tok)
|
| 195 |
+
|
| 196 |
+
>>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
|
| 197 |
+
>>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
|
| 198 |
+
>>> thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 199 |
+
>>> thread.start()
|
| 200 |
+
>>> generated_text = ""
|
| 201 |
+
>>> for new_text in streamer:
|
| 202 |
+
... generated_text += new_text
|
| 203 |
+
>>> generated_text
|
| 204 |
+
'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,'
|
| 205 |
+
```
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
def __init__(
|
| 209 |
+
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs
|
| 210 |
+
):
|
| 211 |
+
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
| 212 |
+
self.text_queue = Queue()
|
| 213 |
+
self.stop_signal = None
|
| 214 |
+
self.timeout = timeout
|
| 215 |
+
|
| 216 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 217 |
+
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
|
| 218 |
+
self.text_queue.put(text, timeout=self.timeout)
|
| 219 |
+
if stream_end:
|
| 220 |
+
self.text_queue.put(self.stop_signal, timeout=self.timeout)
|
| 221 |
+
|
| 222 |
+
def __iter__(self):
|
| 223 |
+
return self
|
| 224 |
+
|
| 225 |
+
def __next__(self):
|
| 226 |
+
value = self.text_queue.get(timeout=self.timeout)
|
| 227 |
+
if value == self.stop_signal:
|
| 228 |
+
raise StopIteration()
|
| 229 |
+
else:
|
| 230 |
+
return value
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class AsyncTextIteratorStreamer(TextStreamer):
|
| 234 |
+
"""
|
| 235 |
+
Streamer that stores print-ready text in a queue, to be used by a downstream application as an async iterator.
|
| 236 |
+
This is useful for applications that benefit from acessing the generated text asynchronously (e.g. in an
|
| 237 |
+
interactive Gradio demo).
|
| 238 |
+
|
| 239 |
+
<Tip warning={true}>
|
| 240 |
+
|
| 241 |
+
The API for the streamer classes is still under development and may change in the future.
|
| 242 |
+
|
| 243 |
+
</Tip>
|
| 244 |
+
|
| 245 |
+
Parameters:
|
| 246 |
+
tokenizer (`AutoTokenizer`):
|
| 247 |
+
The tokenized used to decode the tokens.
|
| 248 |
+
skip_prompt (`bool`, *optional*, defaults to `False`):
|
| 249 |
+
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
| 250 |
+
timeout (`float`, *optional*):
|
| 251 |
+
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
|
| 252 |
+
in `.generate()`, when it is called in a separate thread.
|
| 253 |
+
decode_kwargs (`dict`, *optional*):
|
| 254 |
+
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
| 255 |
+
|
| 256 |
+
Raises:
|
| 257 |
+
TimeoutError: If token generation time exceeds timeout value.
|
| 258 |
+
|
| 259 |
+
Examples:
|
| 260 |
+
|
| 261 |
+
```python
|
| 262 |
+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, AsyncTextIteratorStreamer
|
| 263 |
+
>>> from threading import Thread
|
| 264 |
+
>>> import asyncio
|
| 265 |
+
|
| 266 |
+
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
| 267 |
+
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
| 268 |
+
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
| 269 |
+
|
| 270 |
+
>>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
|
| 271 |
+
>>> async def main():
|
| 272 |
+
... # Important: AsyncTextIteratorStreamer must be initialized inside a coroutine!
|
| 273 |
+
... streamer = AsyncTextIteratorStreamer(tok)
|
| 274 |
+
... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
|
| 275 |
+
... thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 276 |
+
... thread.start()
|
| 277 |
+
... generated_text = ""
|
| 278 |
+
... async for new_text in streamer:
|
| 279 |
+
... generated_text += new_text
|
| 280 |
+
>>> print(generated_text)
|
| 281 |
+
>>> asyncio.run(main())
|
| 282 |
+
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
| 283 |
+
```
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(
|
| 287 |
+
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: float | None = None, **decode_kwargs
|
| 288 |
+
):
|
| 289 |
+
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
| 290 |
+
self.text_queue = asyncio.Queue()
|
| 291 |
+
self.stop_signal = None
|
| 292 |
+
self.timeout = timeout
|
| 293 |
+
self.loop = asyncio.get_running_loop()
|
| 294 |
+
self.has_asyncio_timeout = hasattr(asyncio, "timeout")
|
| 295 |
+
|
| 296 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 297 |
+
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
|
| 298 |
+
self.loop.call_soon_threadsafe(self.text_queue.put_nowait, text)
|
| 299 |
+
if stream_end:
|
| 300 |
+
self.loop.call_soon_threadsafe(self.text_queue.put_nowait, self.stop_signal)
|
| 301 |
+
|
| 302 |
+
def __aiter__(self):
|
| 303 |
+
return self
|
| 304 |
+
|
| 305 |
+
async def __anext__(self):
|
| 306 |
+
try:
|
| 307 |
+
if self.has_asyncio_timeout:
|
| 308 |
+
async with asyncio.timeout(self.timeout):
|
| 309 |
+
value = await self.text_queue.get()
|
| 310 |
+
else:
|
| 311 |
+
value = await asyncio.wait_for(self.text_queue.get(), timeout=self.timeout)
|
| 312 |
+
except asyncio.TimeoutError:
|
| 313 |
+
raise TimeoutError()
|
| 314 |
+
else:
|
| 315 |
+
if value == self.stop_signal:
|
| 316 |
+
raise StopAsyncIteration()
|
| 317 |
+
else:
|
| 318 |
+
return value
|
.venv/lib/python3.11/site-packages/transformers/generation/tf_logits_process.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import List, Tuple
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import tensorflow as tf
|
| 21 |
+
|
| 22 |
+
from ..tf_utils import stable_softmax
|
| 23 |
+
from ..utils import add_start_docstrings
|
| 24 |
+
from ..utils.logging import get_logger
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
| 31 |
+
Args:
|
| 32 |
+
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
| 33 |
+
Indices of input sequence tokens in the vocabulary.
|
| 34 |
+
|
| 35 |
+
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 36 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 37 |
+
|
| 38 |
+
[What are input IDs?](../glossary#input-ids)
|
| 39 |
+
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
|
| 40 |
+
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
|
| 41 |
+
search or log softmax for each vocabulary token when using beam search.
|
| 42 |
+
cur_len (`int`):
|
| 43 |
+
The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length
|
| 44 |
+
is the maximum length generate can produce, and we need to know which of its tokens are valid.
|
| 45 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 46 |
+
Additional logits processor specific kwargs.
|
| 47 |
+
|
| 48 |
+
Return:
|
| 49 |
+
`tf.Tensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TFLogitsProcessor:
|
| 54 |
+
"""Abstract base class for all logit processors that can be applied during generation."""
|
| 55 |
+
|
| 56 |
+
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
| 57 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 58 |
+
"""TF method for processing logits."""
|
| 59 |
+
raise NotImplementedError(
|
| 60 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class TFLogitsWarper:
|
| 65 |
+
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
| 66 |
+
|
| 67 |
+
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
| 68 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 69 |
+
"""TF method for warping logits."""
|
| 70 |
+
raise NotImplementedError(
|
| 71 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TFLogitsProcessorList(list):
|
| 76 |
+
"""
|
| 77 |
+
This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor.
|
| 78 |
+
This class inherits from list and adds a specific *__call__* method to apply each [`TFLogitsProcessor`] to the
|
| 79 |
+
inputs.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
| 83 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor:
|
| 84 |
+
for processor in self:
|
| 85 |
+
function_args = inspect.signature(processor.__call__).parameters
|
| 86 |
+
if len(function_args) > 3:
|
| 87 |
+
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
|
| 88 |
+
raise ValueError(
|
| 89 |
+
f"Make sure that all the required parameters: {list(function_args.keys())} for "
|
| 90 |
+
f"{processor.__class__} are passed to the logits processor."
|
| 91 |
+
)
|
| 92 |
+
scores = processor(input_ids, scores, cur_len, **kwargs)
|
| 93 |
+
else:
|
| 94 |
+
scores = processor(input_ids, scores, cur_len)
|
| 95 |
+
return scores
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class TFTemperatureLogitsWarper(TFLogitsWarper):
|
| 99 |
+
r"""
|
| 100 |
+
[`TFLogitsWarper`] for temperature (exponential scaling output probability distribution).
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
temperature (`float`):
|
| 104 |
+
The value used to module the logits distribution.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def __init__(self, temperature: float):
|
| 108 |
+
if not isinstance(temperature, float) or not (temperature > 0):
|
| 109 |
+
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
|
| 110 |
+
|
| 111 |
+
self.temperature = temperature
|
| 112 |
+
|
| 113 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 114 |
+
scores = scores / self.temperature
|
| 115 |
+
return scores
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class TFTopKLogitsWarper(TFLogitsWarper):
|
| 119 |
+
r"""
|
| 120 |
+
[`TFLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
top_k (`int`):
|
| 124 |
+
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
| 125 |
+
filter_value (`float`, *optional*, defaults to -inf):
|
| 126 |
+
All filtered values will be set to this float value.
|
| 127 |
+
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
| 128 |
+
Minimum number of tokens that cannot be filtered.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
| 132 |
+
if not isinstance(top_k, int) or top_k <= 0:
|
| 133 |
+
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
|
| 134 |
+
|
| 135 |
+
self.top_k = max(top_k, min_tokens_to_keep)
|
| 136 |
+
self.filter_value = filter_value
|
| 137 |
+
|
| 138 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 139 |
+
top_k = min(self.top_k, scores.shape[-1]) # Safety check
|
| 140 |
+
# Boolean mask containing all tokens with a probability less than the last token of the top-k
|
| 141 |
+
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
|
| 142 |
+
next_scores = tf.where(indices_to_remove, self.filter_value, scores)
|
| 143 |
+
return next_scores
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class TFTopPLogitsWarper(TFLogitsWarper):
|
| 147 |
+
"""
|
| 148 |
+
[`TFLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to <= prob_cut_off.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
top_p (`float`):
|
| 152 |
+
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
| 153 |
+
higher are kept for generation.
|
| 154 |
+
filter_value (`float`, *optional*, defaults to -inf):
|
| 155 |
+
All filtered values will be set to this float value.
|
| 156 |
+
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
| 157 |
+
Minimum number of tokens that cannot be filtered.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
| 161 |
+
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
|
| 162 |
+
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
| 163 |
+
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
|
| 164 |
+
raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
|
| 165 |
+
|
| 166 |
+
self.top_p = top_p
|
| 167 |
+
self.filter_value = filter_value
|
| 168 |
+
self.min_tokens_to_keep = min_tokens_to_keep
|
| 169 |
+
|
| 170 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 171 |
+
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
|
| 172 |
+
|
| 173 |
+
mask_scores = tf.fill(scores.shape, self.filter_value)
|
| 174 |
+
cumulative_probs = tf.math.cumsum(stable_softmax(topk_scores, axis=-1), axis=-1)
|
| 175 |
+
score_mask = cumulative_probs < self.top_p
|
| 176 |
+
|
| 177 |
+
# Also include the token that is higher than top_p (the first false = shift and insert a True on the left)
|
| 178 |
+
score_mask = tf.concat((tf.ones([score_mask.shape[0], 1], dtype=tf.bool), score_mask[:, :-1]), axis=-1)
|
| 179 |
+
|
| 180 |
+
# Ensure min tokens to keep
|
| 181 |
+
score_mask = tf.concat(
|
| 182 |
+
(
|
| 183 |
+
tf.ones([score_mask.shape[0], self.min_tokens_to_keep], dtype=tf.bool),
|
| 184 |
+
score_mask[:, self.min_tokens_to_keep :],
|
| 185 |
+
),
|
| 186 |
+
axis=-1,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Mask the values that do not fit the criteria
|
| 190 |
+
topk_next_scores = tf.where(score_mask, topk_scores, mask_scores)
|
| 191 |
+
|
| 192 |
+
# Undo the topk sorting: converts the 2D matrix of per-row original indices of shape (batch_size, vocab_size)
|
| 193 |
+
# to a 3D tensor of shape (batch_size, vocab_size, 2) containing the original score coordinate, from which we
|
| 194 |
+
# can scatter (i.e. `scatter_indices[row, col, :]` is a tensor containing `[row, topk_indices[row, col]]`)
|
| 195 |
+
scatter_rows = tf.tile(tf.expand_dims(tf.range(topk_indices.shape[0]), axis=-1), [1, topk_indices.shape[-1]])
|
| 196 |
+
scatter_indices = tf.stack((scatter_rows, topk_indices), axis=-1)
|
| 197 |
+
next_scores = tf.scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape)
|
| 198 |
+
|
| 199 |
+
return next_scores
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class TFMinLengthLogitsProcessor(TFLogitsProcessor):
|
| 203 |
+
r"""
|
| 204 |
+
[`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
min_length (`int`):
|
| 208 |
+
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
| 209 |
+
eos_token_id (`int`):
|
| 210 |
+
The id of the *end-of-sequence* token.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(self, min_length: int, eos_token_id: int):
|
| 214 |
+
if not isinstance(min_length, int) or min_length < 0:
|
| 215 |
+
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
|
| 216 |
+
|
| 217 |
+
if not isinstance(eos_token_id, int) or eos_token_id < 0:
|
| 218 |
+
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
|
| 219 |
+
|
| 220 |
+
self.min_length = min_length
|
| 221 |
+
self.eos_token_id = eos_token_id
|
| 222 |
+
|
| 223 |
+
def _apply_eos_token_mask(self, scores: tf.Tensor) -> tf.Tensor:
|
| 224 |
+
eos_token_id_mask = tf.range(scores.shape[-1]) == self.eos_token_id
|
| 225 |
+
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
|
| 226 |
+
return scores
|
| 227 |
+
|
| 228 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 229 |
+
# applies eos token masking if the first argument is true
|
| 230 |
+
scores = tf.cond(
|
| 231 |
+
tf.less(cur_len, self.min_length),
|
| 232 |
+
lambda: self._apply_eos_token_mask(scores),
|
| 233 |
+
lambda: tf.identity(scores),
|
| 234 |
+
)
|
| 235 |
+
return scores
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
|
| 239 |
+
r"""
|
| 240 |
+
[`TFLogitsProcessor`] enforcing an exponential penalty on repeated sequences.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
repetition_penalty (`float`):
|
| 244 |
+
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
| 245 |
+
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def __init__(self, penalty: float):
|
| 249 |
+
if not isinstance(penalty, float) or not (penalty > 0):
|
| 250 |
+
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
| 251 |
+
|
| 252 |
+
self.penalty = penalty
|
| 253 |
+
|
| 254 |
+
def _create_score_penalties(self, input_ids: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
|
| 255 |
+
# We want to populate the penalties in the positions of `input_ids`. Since XLA can't handle shapes unknown
|
| 256 |
+
# before runtime, `tf.unique` can't be used. Therefore, we may have redundant updates, when a given row has
|
| 257 |
+
# the same token multiple times.
|
| 258 |
+
|
| 259 |
+
# Gathers the penalties to apply
|
| 260 |
+
logit_penalties = tf.gather(logits, input_ids, axis=1, batch_dims=1)
|
| 261 |
+
logit_penalties = tf.where(logit_penalties > 0, 1 / self.penalty, logit_penalties)
|
| 262 |
+
logit_penalties = tf.where(logit_penalties < 0, self.penalty, logit_penalties)
|
| 263 |
+
|
| 264 |
+
# Scatters the penalties
|
| 265 |
+
token_penalties = tf.ones(logits.shape)
|
| 266 |
+
batch_size = input_ids.shape[0]
|
| 267 |
+
seq_len = tf.shape(input_ids)[1] # the sequence length has dynamic size, hence the dynamic shape
|
| 268 |
+
indexable_prev_input_ids = tf.concat(
|
| 269 |
+
(
|
| 270 |
+
tf.expand_dims(tf.repeat(tf.range(batch_size), seq_len), axis=-1),
|
| 271 |
+
tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
|
| 272 |
+
),
|
| 273 |
+
axis=1,
|
| 274 |
+
)
|
| 275 |
+
token_penalties = tf.tensor_scatter_nd_update(
|
| 276 |
+
token_penalties, indices=indexable_prev_input_ids, updates=tf.reshape(logit_penalties, [-1])
|
| 277 |
+
)
|
| 278 |
+
return token_penalties
|
| 279 |
+
|
| 280 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 281 |
+
score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores)
|
| 282 |
+
|
| 283 |
+
scores = tf.math.multiply(scores, score_penalties)
|
| 284 |
+
|
| 285 |
+
return scores
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
|
| 289 |
+
"""
|
| 290 |
+
[`TFLogitsProcessor`] that enforces that specified sequences will never be sampled.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
bad_words_ids (`List[List[int]]`):
|
| 294 |
+
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
|
| 295 |
+
that should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing
|
| 296 |
+
the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space`
|
| 297 |
+
argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from
|
| 298 |
+
`pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
|
| 299 |
+
eos_token_id (`int`):
|
| 300 |
+
The id of the *end-of-sequence* token.
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
|
| 304 |
+
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
|
| 305 |
+
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
|
| 306 |
+
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
|
| 307 |
+
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
|
| 308 |
+
if any(
|
| 309 |
+
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
|
| 310 |
+
for bad_word_ids in bad_words_ids
|
| 311 |
+
):
|
| 312 |
+
raise ValueError(
|
| 313 |
+
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# stores the information about bad words in three tensors:
|
| 317 |
+
# 1. a rectangular tensor with the forbidden sequences (padded with `-1`), for full data comparisons
|
| 318 |
+
self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1)
|
| 319 |
+
# 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons
|
| 320 |
+
bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids]
|
| 321 |
+
if any(word_len == 0 for word_len in bad_word_seqs_len):
|
| 322 |
+
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
|
| 323 |
+
self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32)
|
| 324 |
+
# 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned
|
| 325 |
+
self.seq_forbidden_tokens = tf.convert_to_tensor([bad_words[-1] for bad_words in bad_words_ids])
|
| 326 |
+
|
| 327 |
+
def _calc_row_banned_bad_tokens(self, row_input_ids: tf.Tensor) -> tf.Tensor:
|
| 328 |
+
def _tokens_match(bad_word_seq_number):
|
| 329 |
+
def _len_one():
|
| 330 |
+
# If the bad sequence only has one token, always mask it
|
| 331 |
+
return tf.cond(
|
| 332 |
+
tf.math.equal(self.bad_word_seqs_len[bad_word_seq_number], 1),
|
| 333 |
+
lambda: tf.ones((), dtype=tf.bool),
|
| 334 |
+
_len_greater_than_cur_len,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def _len_greater_than_cur_len():
|
| 338 |
+
# Otherwise, if the bad sequence is longer than the current length they can't ever match
|
| 339 |
+
return tf.cond(
|
| 340 |
+
tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], tf.shape(row_input_ids)[0]),
|
| 341 |
+
lambda: tf.zeros((), dtype=tf.bool),
|
| 342 |
+
_match_found,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def _match_found():
|
| 346 |
+
# Finaly, runs the actual comparison. Can only be called if the previous comparisons do not yield
|
| 347 |
+
# an answer (otherwise we get indexing exceptions)
|
| 348 |
+
compare_len = self.bad_word_seqs_len[bad_word_seq_number] - 1
|
| 349 |
+
return tf.cond(
|
| 350 |
+
tf.math.reduce_all(
|
| 351 |
+
tf.math.equal(
|
| 352 |
+
row_input_ids[-compare_len:], self.bad_word_seqs_ids[bad_word_seq_number, :compare_len]
|
| 353 |
+
)
|
| 354 |
+
),
|
| 355 |
+
lambda: tf.ones((), dtype=tf.bool),
|
| 356 |
+
lambda: tf.zeros((), dtype=tf.bool),
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
match = _len_one()
|
| 360 |
+
return match
|
| 361 |
+
|
| 362 |
+
# Compares the current row against all bad word sequences, obtaining a mask with the matches.
|
| 363 |
+
match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool)
|
| 364 |
+
row_banned_tokens = self.seq_forbidden_tokens[match_mask]
|
| 365 |
+
return row_banned_tokens
|
| 366 |
+
|
| 367 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 368 |
+
# We want to mask some banned tokens, at a score level. Since the banned tokens depend on the previous
|
| 369 |
+
# `input_ids`, they may have a different length for each row, and they may even be empty for some rows.
|
| 370 |
+
# To remain simple and XLA-compatible, we work on a per-row fashion.
|
| 371 |
+
# TODO (Joao): this function might trigger XLA retracing as `cur_len` increases. Fix it if it becomes
|
| 372 |
+
# a frequent choke point. (make `cur_len` a tensor?)
|
| 373 |
+
def _get_row_updated_score(row_inputs: Tuple[tf.Tensor]) -> tf.Tensor:
|
| 374 |
+
row_input_ids, row_score = row_inputs
|
| 375 |
+
banned_tokens = self._calc_row_banned_bad_tokens(row_input_ids[:cur_len])
|
| 376 |
+
banned_tokens_mask = tf.scatter_nd(
|
| 377 |
+
indices=tf.expand_dims(banned_tokens, axis=-1),
|
| 378 |
+
updates=tf.ones_like(banned_tokens, dtype=tf.bool),
|
| 379 |
+
shape=row_score.shape,
|
| 380 |
+
)
|
| 381 |
+
row_score = tf.where(banned_tokens_mask, -float("inf"), row_score)
|
| 382 |
+
return row_score
|
| 383 |
+
|
| 384 |
+
scores = tf.map_fn(_get_row_updated_score, (input_ids, scores), fn_output_signature=tf.float32)
|
| 385 |
+
return scores
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
|
| 389 |
+
r"""
|
| 390 |
+
[`TFLogitsProcessor`] that enforces no repetition of n-grams. See
|
| 391 |
+
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
ngram_size (`int`):
|
| 395 |
+
All ngrams of size `ngram_size` can only occur once.
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
def __init__(self, ngram_size: int):
|
| 399 |
+
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
| 400 |
+
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
|
| 401 |
+
self.ngram_size = ngram_size
|
| 402 |
+
|
| 403 |
+
def calc_banned_ngram_tokens(self, input_ids, num_hypos, cur_len):
|
| 404 |
+
# Copied from fairseq for no_repeat_ngram in beam_search
|
| 405 |
+
if cur_len + 1 < self.ngram_size:
|
| 406 |
+
# return no banned tokens if we haven't generated ngram_size tokens yet
|
| 407 |
+
return [[] for _ in range(num_hypos)]
|
| 408 |
+
generated_ngrams = [{} for _ in range(num_hypos)]
|
| 409 |
+
prev_input_ids = input_ids[:, :cur_len]
|
| 410 |
+
for idx in range(num_hypos):
|
| 411 |
+
gen_tokens = prev_input_ids[idx].numpy().tolist()
|
| 412 |
+
generated_ngram = generated_ngrams[idx]
|
| 413 |
+
for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
|
| 414 |
+
prev_ngram_tuple = tuple(ngram[:-1])
|
| 415 |
+
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
| 416 |
+
|
| 417 |
+
def _get_generated_ngrams(hypo_idx):
|
| 418 |
+
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
| 419 |
+
start_idx = cur_len + 1 - self.ngram_size
|
| 420 |
+
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
|
| 421 |
+
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
| 422 |
+
|
| 423 |
+
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
| 424 |
+
|
| 425 |
+
return banned_tokens
|
| 426 |
+
|
| 427 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 428 |
+
# TODO (joao): enable XLA on this logits processor. See discussion and attempts in
|
| 429 |
+
# https://github.com/huggingface/transformers/pull/16974
|
| 430 |
+
if not tf.executing_eagerly():
|
| 431 |
+
raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")
|
| 432 |
+
|
| 433 |
+
batch_size, vocab_size = scores.shape
|
| 434 |
+
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)
|
| 435 |
+
|
| 436 |
+
# create banned_tokens boolean mask
|
| 437 |
+
banned_tokens_indices_mask = []
|
| 438 |
+
for banned_tokens_slice in banned_tokens:
|
| 439 |
+
banned_tokens_indices_mask.append(
|
| 440 |
+
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
|
| 444 |
+
|
| 445 |
+
return scores
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class TFForcedBOSTokenLogitsProcessor(TFLogitsProcessor):
|
| 449 |
+
r"""
|
| 450 |
+
[`TFLogitsProcessor`] that enforces the specified token as the first generated token.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
bos_token_id (`int`):
|
| 454 |
+
The id of the token to force as the first generated token.
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
def __init__(self, bos_token_id: int):
|
| 458 |
+
if bos_token_id < 0:
|
| 459 |
+
raise ValueError(f"The forced bos token id must be a non-negative integer, got {bos_token_id}")
|
| 460 |
+
self.bos_token_id = bos_token_id
|
| 461 |
+
|
| 462 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 463 |
+
if cur_len == 1:
|
| 464 |
+
batch_size, num_tokens = scores.shape
|
| 465 |
+
# sets the score to 0 in the bos_token_id column
|
| 466 |
+
scores = tf.zeros((batch_size, 1))
|
| 467 |
+
# sets the score to -inf everywhere else
|
| 468 |
+
if self.bos_token_id > 0:
|
| 469 |
+
scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.bos_token_id)), scores), axis=-1)
|
| 470 |
+
if self.bos_token_id < (num_tokens - 1):
|
| 471 |
+
scores = tf.concat(
|
| 472 |
+
(scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.bos_token_id))),
|
| 473 |
+
axis=-1,
|
| 474 |
+
)
|
| 475 |
+
return scores
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor):
|
| 479 |
+
r"""
|
| 480 |
+
[`TFLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
max_length (`int`):
|
| 484 |
+
The maximum length of the sequence to be generated.
|
| 485 |
+
eos_token_id (`int`):
|
| 486 |
+
The id of the token to force as the last generated token when `max_length` is reached.
|
| 487 |
+
"""
|
| 488 |
+
|
| 489 |
+
def __init__(self, max_length: int, eos_token_id: int):
|
| 490 |
+
self.max_length = max_length
|
| 491 |
+
if eos_token_id < 0:
|
| 492 |
+
raise ValueError(f"The forced eos token id must be a non-negative integer, got {eos_token_id}")
|
| 493 |
+
self.eos_token_id = eos_token_id
|
| 494 |
+
|
| 495 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 496 |
+
if cur_len == self.max_length - 1:
|
| 497 |
+
batch_size, num_tokens = scores.shape
|
| 498 |
+
# sets the score to 0 in the eos_token_id column
|
| 499 |
+
scores = tf.zeros((batch_size, 1))
|
| 500 |
+
# sets the score to -inf everywhere else
|
| 501 |
+
if self.eos_token_id > 0:
|
| 502 |
+
scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.eos_token_id)), scores), axis=-1)
|
| 503 |
+
if self.eos_token_id < (num_tokens - 1):
|
| 504 |
+
scores = tf.concat(
|
| 505 |
+
(scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.eos_token_id))),
|
| 506 |
+
axis=-1,
|
| 507 |
+
)
|
| 508 |
+
return scores
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class TFSuppressTokensAtBeginLogitsProcessor(TFLogitsProcessor):
|
| 512 |
+
r"""
|
| 513 |
+
[`TFSuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts
|
| 514 |
+
generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not
|
| 515 |
+
sampled at the beginning of the generation.
|
| 516 |
+
"""
|
| 517 |
+
|
| 518 |
+
def __init__(self, begin_suppress_tokens, begin_index):
|
| 519 |
+
self.begin_suppress_tokens = list(begin_suppress_tokens)
|
| 520 |
+
self.begin_index = begin_index
|
| 521 |
+
|
| 522 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 523 |
+
suppressed_indices = []
|
| 524 |
+
for token in self.begin_suppress_tokens:
|
| 525 |
+
if token < scores.shape[-1]: # to ensure we don't go beyond the vocab size
|
| 526 |
+
suppressed_indices.extend([[i, token] for i in range(scores.shape[0])])
|
| 527 |
+
|
| 528 |
+
if len(suppressed_indices) > 0:
|
| 529 |
+
scores = tf.cond(
|
| 530 |
+
tf.equal(cur_len, self.begin_index),
|
| 531 |
+
lambda: tf.tensor_scatter_nd_update(
|
| 532 |
+
scores,
|
| 533 |
+
indices=suppressed_indices,
|
| 534 |
+
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.begin_suppress_tokens))],
|
| 535 |
+
),
|
| 536 |
+
lambda: scores,
|
| 537 |
+
)
|
| 538 |
+
return scores
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
class TFSuppressTokensLogitsProcessor(TFLogitsProcessor):
|
| 542 |
+
r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
|
| 543 |
+
are not sampled."""
|
| 544 |
+
|
| 545 |
+
def __init__(self, suppress_tokens):
|
| 546 |
+
self.suppress_tokens = list(suppress_tokens)
|
| 547 |
+
|
| 548 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 549 |
+
suppressed_indices = []
|
| 550 |
+
for token in self.suppress_tokens:
|
| 551 |
+
if token < scores.shape[-1]: # to ensure we don't go beyond the vocab size
|
| 552 |
+
suppressed_indices.extend([[i, token] for i in range(scores.shape[0])])
|
| 553 |
+
|
| 554 |
+
if len(suppressed_indices) > 0:
|
| 555 |
+
scores = tf.tensor_scatter_nd_update(
|
| 556 |
+
scores,
|
| 557 |
+
indices=[[i, token] for i in range(scores.shape[0]) for token in self.suppress_tokens],
|
| 558 |
+
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.suppress_tokens))],
|
| 559 |
+
)
|
| 560 |
+
return scores
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class TFForceTokensLogitsProcessor(TFLogitsProcessor):
|
| 564 |
+
r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
|
| 565 |
+
indices that will be forced before sampling. The processor will set their log probs to `0` and all other tokens to
|
| 566 |
+
`-inf` so that they are sampled at their corresponding index."""
|
| 567 |
+
|
| 568 |
+
def __init__(self, force_token_map: List[List[int]]):
|
| 569 |
+
force_token_map = dict(force_token_map)
|
| 570 |
+
# Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
|
| 571 |
+
# index of the array corresponds to the index of the token to be forced, for XLA compatibility.
|
| 572 |
+
# Indexes without forced tokens will have an negative value.
|
| 573 |
+
force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1
|
| 574 |
+
for index, token in force_token_map.items():
|
| 575 |
+
if token is not None:
|
| 576 |
+
force_token_array[index] = token
|
| 577 |
+
self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32)
|
| 578 |
+
|
| 579 |
+
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
| 580 |
+
def _force_token(generation_idx):
|
| 581 |
+
batch_size = scores.shape[0]
|
| 582 |
+
current_token = self.force_token_array[generation_idx]
|
| 583 |
+
|
| 584 |
+
new_scores = tf.zeros_like(scores, dtype=scores.dtype) + tf.constant([scores.dtype.min])
|
| 585 |
+
indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)
|
| 586 |
+
updates = tf.zeros((batch_size,), dtype=scores.dtype)
|
| 587 |
+
new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)
|
| 588 |
+
return new_scores
|
| 589 |
+
|
| 590 |
+
scores = tf.cond(
|
| 591 |
+
tf.greater_equal(cur_len, tf.shape(self.force_token_array)[0]),
|
| 592 |
+
# If the current length is geq than the length of force_token_array, the processor does nothing.
|
| 593 |
+
lambda: tf.identity(scores),
|
| 594 |
+
# Otherwise, it may force a certain token.
|
| 595 |
+
lambda: tf.cond(
|
| 596 |
+
tf.greater_equal(self.force_token_array[cur_len], 0),
|
| 597 |
+
# Only valid (positive) tokens are forced
|
| 598 |
+
lambda: _force_token(cur_len),
|
| 599 |
+
# Otherwise, the processor does nothing.
|
| 600 |
+
lambda: scores,
|
| 601 |
+
),
|
| 602 |
+
)
|
| 603 |
+
return scores
|
.venv/lib/python3.11/site-packages/transformers/generation/tf_utils.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/utils.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/watermarking.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team and Google DeepMind.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import collections
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from functools import lru_cache
|
| 19 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import BCELoss
|
| 25 |
+
|
| 26 |
+
from ..modeling_utils import PreTrainedModel
|
| 27 |
+
from ..utils import ModelOutput, is_torch_available, logging
|
| 28 |
+
from .configuration_utils import PretrainedConfig, WatermarkingConfig
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if is_torch_available():
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class WatermarkDetectorOutput:
|
| 42 |
+
"""
|
| 43 |
+
Outputs of a watermark detector.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
num_tokens_scored (np.array of shape (batch_size)):
|
| 47 |
+
Array containing the number of tokens scored for each element in the batch.
|
| 48 |
+
num_green_tokens (np.array of shape (batch_size)):
|
| 49 |
+
Array containing the number of green tokens for each element in the batch.
|
| 50 |
+
green_fraction (np.array of shape (batch_size)):
|
| 51 |
+
Array containing the fraction of green tokens for each element in the batch.
|
| 52 |
+
z_score (np.array of shape (batch_size)):
|
| 53 |
+
Array containing the z-score for each element in the batch. Z-score here shows
|
| 54 |
+
how many standard deviations away is the green token count in the input text
|
| 55 |
+
from the expected green token count for machine-generated text.
|
| 56 |
+
p_value (np.array of shape (batch_size)):
|
| 57 |
+
Array containing the p-value for each batch obtained from z-scores.
|
| 58 |
+
prediction (np.array of shape (batch_size)), *optional*:
|
| 59 |
+
Array containing boolean predictions whether a text is machine-generated for each element in the batch.
|
| 60 |
+
confidence (np.array of shape (batch_size)), *optional*:
|
| 61 |
+
Array containing confidence scores of a text being machine-generated for each element in the batch.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
num_tokens_scored: np.array = None
|
| 65 |
+
num_green_tokens: np.array = None
|
| 66 |
+
green_fraction: np.array = None
|
| 67 |
+
z_score: np.array = None
|
| 68 |
+
p_value: np.array = None
|
| 69 |
+
prediction: Optional[np.array] = None
|
| 70 |
+
confidence: Optional[np.array] = None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class WatermarkDetector:
|
| 74 |
+
r"""
|
| 75 |
+
Detector for detection of watermark generated text. The detector needs to be given the exact same settings that were
|
| 76 |
+
given during text generation to replicate the watermark greenlist generation and so detect the watermark. This includes
|
| 77 |
+
the correct device that was used during text generation, the correct watermarking arguments and the correct tokenizer vocab size.
|
| 78 |
+
The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main).
|
| 79 |
+
|
| 80 |
+
See [the paper](https://arxiv.org/abs/2306.04634) for more information.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
model_config (`PretrainedConfig`):
|
| 84 |
+
The model config that will be used to get model specific arguments used when generating.
|
| 85 |
+
device (`str`):
|
| 86 |
+
The device which was used during watermarked text generation.
|
| 87 |
+
watermarking_config (Union[`WatermarkingConfig`, `Dict`]):
|
| 88 |
+
The exact same watermarking config and arguments used when generating text.
|
| 89 |
+
ignore_repeated_ngrams (`bool`, *optional*, defaults to `False`):
|
| 90 |
+
Whether to count every unique ngram only once or not.
|
| 91 |
+
max_cache_size (`int`, *optional*, defaults to 128):
|
| 92 |
+
The max size to be used for LRU caching of seeding/sampling algorithms called for every token.
|
| 93 |
+
|
| 94 |
+
Examples:
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig
|
| 98 |
+
|
| 99 |
+
>>> model_id = "openai-community/gpt2"
|
| 100 |
+
>>> model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 101 |
+
>>> tok = AutoTokenizer.from_pretrained(model_id)
|
| 102 |
+
>>> tok.pad_token_id = tok.eos_token_id
|
| 103 |
+
>>> tok.padding_side = "left"
|
| 104 |
+
|
| 105 |
+
>>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
|
| 106 |
+
>>> input_len = inputs["input_ids"].shape[-1]
|
| 107 |
+
|
| 108 |
+
>>> # first generate text with watermark and without
|
| 109 |
+
>>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
|
| 110 |
+
>>> out_watermarked = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20)
|
| 111 |
+
>>> out = model.generate(**inputs, do_sample=False, max_length=20)
|
| 112 |
+
|
| 113 |
+
>>> # now we can instantiate the detector and check the generated text
|
| 114 |
+
>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config)
|
| 115 |
+
>>> detection_out_watermarked = detector(out_watermarked, return_dict=True)
|
| 116 |
+
>>> detection_out = detector(out, return_dict=True)
|
| 117 |
+
>>> detection_out_watermarked.prediction
|
| 118 |
+
array([ True, True])
|
| 119 |
+
|
| 120 |
+
>>> detection_out.prediction
|
| 121 |
+
array([False, False])
|
| 122 |
+
```
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
model_config: PretrainedConfig,
|
| 128 |
+
device: str,
|
| 129 |
+
watermarking_config: Union[WatermarkingConfig, Dict],
|
| 130 |
+
ignore_repeated_ngrams: bool = False,
|
| 131 |
+
max_cache_size: int = 128,
|
| 132 |
+
):
|
| 133 |
+
if isinstance(watermarking_config, WatermarkingConfig):
|
| 134 |
+
watermarking_config = watermarking_config.to_dict()
|
| 135 |
+
|
| 136 |
+
self.bos_token_id = (
|
| 137 |
+
model_config.bos_token_id if not model_config.is_encoder_decoder else model_config.decoder_start_token_id
|
| 138 |
+
)
|
| 139 |
+
self.greenlist_ratio = watermarking_config["greenlist_ratio"]
|
| 140 |
+
self.ignore_repeated_ngrams = ignore_repeated_ngrams
|
| 141 |
+
self.processor = WatermarkLogitsProcessor(
|
| 142 |
+
vocab_size=model_config.vocab_size, device=device, **watermarking_config
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Expensive re-seeding and sampling is cached.
|
| 146 |
+
self._get_ngram_score_cached = lru_cache(maxsize=max_cache_size)(self._get_ngram_score)
|
| 147 |
+
|
| 148 |
+
def _get_ngram_score(self, prefix: torch.LongTensor, target: int):
|
| 149 |
+
greenlist_ids = self.processor._get_greenlist_ids(prefix)
|
| 150 |
+
return target in greenlist_ids
|
| 151 |
+
|
| 152 |
+
def _score_ngrams_in_passage(self, input_ids: torch.LongTensor):
|
| 153 |
+
batch_size, seq_length = input_ids.shape
|
| 154 |
+
selfhash = int(self.processor.seeding_scheme == "selfhash")
|
| 155 |
+
n = self.processor.context_width + 1 - selfhash
|
| 156 |
+
indices = torch.arange(n).unsqueeze(0) + torch.arange(seq_length - n + 1).unsqueeze(1)
|
| 157 |
+
ngram_tensors = input_ids[:, indices]
|
| 158 |
+
|
| 159 |
+
num_tokens_scored_batch = np.zeros(batch_size)
|
| 160 |
+
green_token_count_batch = np.zeros(batch_size)
|
| 161 |
+
for batch_idx in range(ngram_tensors.shape[0]):
|
| 162 |
+
frequencies_table = collections.Counter(ngram_tensors[batch_idx])
|
| 163 |
+
ngram_to_watermark_lookup = {}
|
| 164 |
+
for ngram_example in frequencies_table.keys():
|
| 165 |
+
prefix = ngram_example if selfhash else ngram_example[:-1]
|
| 166 |
+
target = ngram_example[-1]
|
| 167 |
+
ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target)
|
| 168 |
+
|
| 169 |
+
if self.ignore_repeated_ngrams:
|
| 170 |
+
# counts a green/red hit once per unique ngram.
|
| 171 |
+
# num total tokens scored becomes the number unique ngrams.
|
| 172 |
+
num_tokens_scored_batch[batch_idx] = len(frequencies_table.keys())
|
| 173 |
+
green_token_count_batch[batch_idx] = sum(ngram_to_watermark_lookup.values())
|
| 174 |
+
else:
|
| 175 |
+
num_tokens_scored_batch[batch_idx] = sum(frequencies_table.values())
|
| 176 |
+
green_token_count_batch[batch_idx] = sum(
|
| 177 |
+
freq * outcome
|
| 178 |
+
for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values())
|
| 179 |
+
)
|
| 180 |
+
return num_tokens_scored_batch, green_token_count_batch
|
| 181 |
+
|
| 182 |
+
def _compute_z_score(self, green_token_count: np.array, total_num_tokens: np.array) -> np.array:
|
| 183 |
+
expected_count = self.greenlist_ratio
|
| 184 |
+
numer = green_token_count - expected_count * total_num_tokens
|
| 185 |
+
denom = np.sqrt(total_num_tokens * expected_count * (1 - expected_count))
|
| 186 |
+
z = numer / denom
|
| 187 |
+
return z
|
| 188 |
+
|
| 189 |
+
def _compute_pval(self, x, loc=0, scale=1):
|
| 190 |
+
z = (x - loc) / scale
|
| 191 |
+
return 1 - (0.5 * (1 + np.sign(z) * (1 - np.exp(-2 * z**2 / np.pi))))
|
| 192 |
+
|
| 193 |
+
def __call__(
|
| 194 |
+
self,
|
| 195 |
+
input_ids: torch.LongTensor,
|
| 196 |
+
z_threshold: float = 3.0,
|
| 197 |
+
return_dict: bool = False,
|
| 198 |
+
) -> Union[WatermarkDetectorOutput, np.array]:
|
| 199 |
+
"""
|
| 200 |
+
Args:
|
| 201 |
+
input_ids (`torch.LongTensor`):
|
| 202 |
+
The watermark generated text. It is advised to remove the prompt, which can affect the detection.
|
| 203 |
+
z_threshold (`Dict`, *optional*, defaults to `3.0`):
|
| 204 |
+
Changing this threshold will change the sensitivity of the detector. Higher z threshold gives less
|
| 205 |
+
sensitivity and vice versa for lower z threshold.
|
| 206 |
+
return_dict (`bool`, *optional*, defaults to `False`):
|
| 207 |
+
Whether to return `~generation.WatermarkDetectorOutput` or not. If not it will return boolean predictions,
|
| 208 |
+
ma
|
| 209 |
+
Return:
|
| 210 |
+
[`~generation.WatermarkDetectorOutput`] or `np.array`: A [`~generation.WatermarkDetectorOutput`]
|
| 211 |
+
if `return_dict=True` otherwise a `np.array`.
|
| 212 |
+
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
# Let's assume that if one batch start with `bos`, all batched also do
|
| 216 |
+
if input_ids[0, 0] == self.bos_token_id:
|
| 217 |
+
input_ids = input_ids[:, 1:]
|
| 218 |
+
|
| 219 |
+
if input_ids.shape[-1] - self.processor.context_width < 1:
|
| 220 |
+
raise ValueError(
|
| 221 |
+
f"Must have at least `1` token to score after the first "
|
| 222 |
+
f"min_prefix_len={self.processor.context_width} tokens required by the seeding scheme."
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
num_tokens_scored, green_token_count = self._score_ngrams_in_passage(input_ids)
|
| 226 |
+
z_score = self._compute_z_score(green_token_count, num_tokens_scored)
|
| 227 |
+
prediction = z_score > z_threshold
|
| 228 |
+
|
| 229 |
+
if return_dict:
|
| 230 |
+
p_value = self._compute_pval(z_score)
|
| 231 |
+
confidence = 1 - p_value
|
| 232 |
+
|
| 233 |
+
return WatermarkDetectorOutput(
|
| 234 |
+
num_tokens_scored=num_tokens_scored,
|
| 235 |
+
num_green_tokens=green_token_count,
|
| 236 |
+
green_fraction=green_token_count / num_tokens_scored,
|
| 237 |
+
z_score=z_score,
|
| 238 |
+
p_value=p_value,
|
| 239 |
+
prediction=prediction,
|
| 240 |
+
confidence=confidence,
|
| 241 |
+
)
|
| 242 |
+
return prediction
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class BayesianDetectorConfig(PretrainedConfig):
|
| 246 |
+
"""
|
| 247 |
+
This is the configuration class to store the configuration of a [`BayesianDetectorModel`]. It is used to
|
| 248 |
+
instantiate a Bayesian Detector model according to the specified arguments.
|
| 249 |
+
|
| 250 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 251 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
watermarking_depth (`int`, *optional*):
|
| 255 |
+
The number of tournament layers.
|
| 256 |
+
base_rate (`float1`, *optional*, defaults to 0.5):
|
| 257 |
+
Prior probability P(w) that a text is watermarked.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(self, watermarking_depth: int = None, base_rate: float = 0.5, **kwargs):
|
| 261 |
+
self.watermarking_depth = watermarking_depth
|
| 262 |
+
self.base_rate = base_rate
|
| 263 |
+
# These can be set later to store information about this detector.
|
| 264 |
+
self.model_name = None
|
| 265 |
+
self.watermarking_config = None
|
| 266 |
+
|
| 267 |
+
super().__init__(**kwargs)
|
| 268 |
+
|
| 269 |
+
def set_detector_information(self, model_name, watermarking_config):
|
| 270 |
+
self.model_name = model_name
|
| 271 |
+
self.watermarking_config = watermarking_config
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
@dataclass
|
| 275 |
+
class BayesianWatermarkDetectorModelOutput(ModelOutput):
|
| 276 |
+
"""
|
| 277 |
+
Base class for outputs of models predicting if the text is watermarked.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 281 |
+
Language modeling loss.
|
| 282 |
+
posterior_probabilities (`torch.FloatTensor` of shape `(1,)`):
|
| 283 |
+
Multiple choice classification loss.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
loss: Optional[torch.FloatTensor] = None
|
| 287 |
+
posterior_probabilities: Optional[torch.FloatTensor] = None
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class BayesianDetectorWatermarkedLikelihood(nn.Module):
|
| 291 |
+
"""Watermarked likelihood model for binary-valued g-values.
|
| 292 |
+
|
| 293 |
+
This takes in g-values and returns p(g_values|watermarked).
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(self, watermarking_depth: int):
|
| 297 |
+
"""Initializes the model parameters."""
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.watermarking_depth = watermarking_depth
|
| 300 |
+
self.beta = torch.nn.Parameter(-2.5 + 0.001 * torch.randn(1, 1, watermarking_depth))
|
| 301 |
+
self.delta = torch.nn.Parameter(0.001 * torch.randn(1, 1, self.watermarking_depth, watermarking_depth))
|
| 302 |
+
|
| 303 |
+
def _compute_latents(self, g_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 304 |
+
"""Computes the unique token probability distribution given g-values.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`):
|
| 308 |
+
PRF values.
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
p_one_unique_token and p_two_unique_tokens, both of shape
|
| 312 |
+
[batch_size, seq_len, watermarking_depth]. p_one_unique_token[i,t,l]
|
| 313 |
+
gives the probability of there being one unique token in a tournament
|
| 314 |
+
match on layer l, on timestep t, for batch item i.
|
| 315 |
+
p_one_unique_token[i,t,l] + p_two_unique_token[i,t,l] = 1.
|
| 316 |
+
"""
|
| 317 |
+
# Tile g-values to produce feature vectors for predicting the latents
|
| 318 |
+
# for each layer in the tournament; our model for the latents psi is a
|
| 319 |
+
# logistic regression model psi = sigmoid(delta * x + beta).
|
| 320 |
+
|
| 321 |
+
# [batch_size, seq_len, watermarking_depth, watermarking_depth]
|
| 322 |
+
x = torch.repeat_interleave(torch.unsqueeze(g_values, dim=-2), self.watermarking_depth, axis=-2)
|
| 323 |
+
|
| 324 |
+
# mask all elements above -1 diagonal for autoregressive factorization
|
| 325 |
+
x = torch.tril(x, diagonal=-1)
|
| 326 |
+
|
| 327 |
+
# [batch_size, seq_len, watermarking_depth]
|
| 328 |
+
# (i, j, k, l) x (i, j, k, l) -> (i, j, k) einsum equivalent
|
| 329 |
+
logits = (self.delta[..., None, :] @ x.type(self.delta.dtype)[..., None]).squeeze() + self.beta
|
| 330 |
+
|
| 331 |
+
p_two_unique_tokens = torch.sigmoid(logits)
|
| 332 |
+
p_one_unique_token = 1 - p_two_unique_tokens
|
| 333 |
+
return p_one_unique_token, p_two_unique_tokens
|
| 334 |
+
|
| 335 |
+
def forward(self, g_values: torch.Tensor) -> torch.Tensor:
|
| 336 |
+
"""Computes the likelihoods P(g_values|watermarked).
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`):
|
| 340 |
+
g-values (values 0 or 1)
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
p(g_values|watermarked) of shape [batch_size, seq_len, watermarking_depth].
|
| 344 |
+
"""
|
| 345 |
+
p_one_unique_token, p_two_unique_tokens = self._compute_latents(g_values)
|
| 346 |
+
|
| 347 |
+
# P(g_tl | watermarked) is equal to
|
| 348 |
+
# 0.5 * [ (g_tl+0.5) * p_two_unique_tokens + p_one_unique_token].
|
| 349 |
+
return 0.5 * ((g_values + 0.5) * p_two_unique_tokens + p_one_unique_token)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class BayesianDetectorModel(PreTrainedModel):
|
| 353 |
+
r"""
|
| 354 |
+
Bayesian classifier for watermark detection.
|
| 355 |
+
|
| 356 |
+
This detector uses Bayes' rule to compute a watermarking score, which is the sigmoid of the log of ratio of the
|
| 357 |
+
posterior probabilities P(watermarked|g_values) and P(unwatermarked|g_values). Please see the section on
|
| 358 |
+
BayesianScore in the paper for further details.
|
| 359 |
+
Paper URL: https://www.nature.com/articles/s41586-024-08025-4
|
| 360 |
+
|
| 361 |
+
Note that this detector only works with non-distortionary Tournament-based watermarking using the Bernoulli(0.5)
|
| 362 |
+
g-value distribution.
|
| 363 |
+
|
| 364 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 365 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 366 |
+
etc.)
|
| 367 |
+
|
| 368 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 369 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 370 |
+
and behavior.
|
| 371 |
+
|
| 372 |
+
Parameters:
|
| 373 |
+
config ([`BayesianDetectorConfig`]): Model configuration class with all the parameters of the model.
|
| 374 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 375 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
config_class = BayesianDetectorConfig
|
| 379 |
+
base_model_prefix = "model"
|
| 380 |
+
|
| 381 |
+
def __init__(self, config):
|
| 382 |
+
super().__init__(config)
|
| 383 |
+
|
| 384 |
+
self.watermarking_depth = config.watermarking_depth
|
| 385 |
+
self.base_rate = config.base_rate
|
| 386 |
+
self.likelihood_model_watermarked = BayesianDetectorWatermarkedLikelihood(
|
| 387 |
+
watermarking_depth=self.watermarking_depth
|
| 388 |
+
)
|
| 389 |
+
self.prior = torch.nn.Parameter(torch.tensor([self.base_rate]))
|
| 390 |
+
|
| 391 |
+
def _init_weights(self, module):
|
| 392 |
+
"""Initialize the weights."""
|
| 393 |
+
if isinstance(module, nn.Parameter):
|
| 394 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 395 |
+
|
| 396 |
+
def _compute_posterior(
|
| 397 |
+
self,
|
| 398 |
+
likelihoods_watermarked: torch.Tensor,
|
| 399 |
+
likelihoods_unwatermarked: torch.Tensor,
|
| 400 |
+
mask: torch.Tensor,
|
| 401 |
+
prior: float,
|
| 402 |
+
) -> torch.Tensor:
|
| 403 |
+
"""
|
| 404 |
+
Compute posterior P(w|g) given likelihoods, mask and prior.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
likelihoods_watermarked (`torch.Tensor` of shape `(batch, length, depth)`):
|
| 408 |
+
Likelihoods P(g_values|watermarked) of g-values under watermarked model.
|
| 409 |
+
likelihoods_unwatermarked (`torch.Tensor` of shape `(batch, length, depth)`):
|
| 410 |
+
Likelihoods P(g_values|unwatermarked) of g-values under unwatermarked model.
|
| 411 |
+
mask (`torch.Tensor` of shape `(batch, length)`):
|
| 412 |
+
A binary array indicating which g-values should be used. g-values with mask value 0 are discarded.
|
| 413 |
+
prior (`float`):
|
| 414 |
+
the prior probability P(w) that the text is watermarked.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
Posterior probability P(watermarked|g_values), shape [batch].
|
| 418 |
+
"""
|
| 419 |
+
mask = torch.unsqueeze(mask, dim=-1)
|
| 420 |
+
prior = torch.clamp(prior, min=1e-5, max=1 - 1e-5)
|
| 421 |
+
log_likelihoods_watermarked = torch.log(torch.clamp(likelihoods_watermarked, min=1e-30, max=float("inf")))
|
| 422 |
+
log_likelihoods_unwatermarked = torch.log(torch.clamp(likelihoods_unwatermarked, min=1e-30, max=float("inf")))
|
| 423 |
+
log_odds = log_likelihoods_watermarked - log_likelihoods_unwatermarked
|
| 424 |
+
|
| 425 |
+
# Sum relative surprisals (log odds) across all token positions and layers.
|
| 426 |
+
relative_surprisal_likelihood = torch.einsum("i...->i", log_odds * mask)
|
| 427 |
+
|
| 428 |
+
# Compute the relative surprisal prior
|
| 429 |
+
relative_surprisal_prior = torch.log(prior) - torch.log(1 - prior)
|
| 430 |
+
|
| 431 |
+
# Combine prior and likelihood.
|
| 432 |
+
# [batch_size]
|
| 433 |
+
relative_surprisal = relative_surprisal_prior + relative_surprisal_likelihood
|
| 434 |
+
|
| 435 |
+
# Compute the posterior probability P(w|g) = sigmoid(relative_surprisal).
|
| 436 |
+
return torch.sigmoid(relative_surprisal)
|
| 437 |
+
|
| 438 |
+
def forward(
|
| 439 |
+
self,
|
| 440 |
+
g_values: torch.Tensor,
|
| 441 |
+
mask: torch.Tensor,
|
| 442 |
+
labels: Optional[torch.Tensor] = None,
|
| 443 |
+
loss_batch_weight=1,
|
| 444 |
+
return_dict=False,
|
| 445 |
+
) -> BayesianWatermarkDetectorModelOutput:
|
| 446 |
+
"""
|
| 447 |
+
Computes the watermarked posterior P(watermarked|g_values).
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth, ...)`):
|
| 451 |
+
g-values (with values 0 or 1)
|
| 452 |
+
mask:
|
| 453 |
+
A binary array shape [batch_size, seq_len] indicating which g-values should be used. g-values with mask
|
| 454 |
+
value 0 are discarded.
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
p(watermarked | g_values), of shape [batch_size].
|
| 458 |
+
"""
|
| 459 |
+
|
| 460 |
+
likelihoods_watermarked = self.likelihood_model_watermarked(g_values)
|
| 461 |
+
likelihoods_unwatermarked = 0.5 * torch.ones_like(g_values)
|
| 462 |
+
out = self._compute_posterior(
|
| 463 |
+
likelihoods_watermarked=likelihoods_watermarked,
|
| 464 |
+
likelihoods_unwatermarked=likelihoods_unwatermarked,
|
| 465 |
+
mask=mask,
|
| 466 |
+
prior=self.prior,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
loss = None
|
| 470 |
+
if labels is not None:
|
| 471 |
+
loss_fct = BCELoss()
|
| 472 |
+
loss_unwweight = torch.sum(self.likelihood_model_watermarked.delta**2)
|
| 473 |
+
loss_weight = loss_unwweight * loss_batch_weight
|
| 474 |
+
loss = loss_fct(torch.clamp(out, 1e-5, 1 - 1e-5), labels) + loss_weight
|
| 475 |
+
|
| 476 |
+
if not return_dict:
|
| 477 |
+
return (out,) if loss is None else (out, loss)
|
| 478 |
+
|
| 479 |
+
return BayesianWatermarkDetectorModelOutput(loss=loss, posterior_probabilities=out)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
class SynthIDTextWatermarkDetector:
|
| 483 |
+
r"""
|
| 484 |
+
SynthID text watermark detector class.
|
| 485 |
+
|
| 486 |
+
This class has to be initialized with the trained bayesian detector module check script
|
| 487 |
+
in examples/synthid_text/detector_training.py for example in training/saving/loading this
|
| 488 |
+
detector module. The folder also showcases example use case of this detector.
|
| 489 |
+
|
| 490 |
+
Parameters:
|
| 491 |
+
detector_module ([`BayesianDetectorModel`]):
|
| 492 |
+
Bayesian detector module object initialized with parameters.
|
| 493 |
+
Check examples/research_projects/synthid_text/detector_training.py for usage.
|
| 494 |
+
logits_processor (`SynthIDTextWatermarkLogitsProcessor`):
|
| 495 |
+
The logits processor used for watermarking.
|
| 496 |
+
tokenizer (`Any`):
|
| 497 |
+
The tokenizer used for the model.
|
| 498 |
+
|
| 499 |
+
Examples:
|
| 500 |
+
```python
|
| 501 |
+
>>> from transformers import (
|
| 502 |
+
... AutoTokenizer, BayesianDetectorModel, SynthIDTextWatermarkLogitsProcessor, SynthIDTextWatermarkDetector
|
| 503 |
+
... )
|
| 504 |
+
|
| 505 |
+
>>> # Load the detector. See examples/research_projects/synthid_text for training a detector.
|
| 506 |
+
>>> detector_model = BayesianDetectorModel.from_pretrained("joaogante/dummy_synthid_detector")
|
| 507 |
+
>>> logits_processor = SynthIDTextWatermarkLogitsProcessor(
|
| 508 |
+
... **detector_model.config.watermarking_config, device="cpu"
|
| 509 |
+
... )
|
| 510 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(detector_model.config.model_name)
|
| 511 |
+
>>> detector = SynthIDTextWatermarkDetector(detector_model, logits_processor, tokenizer)
|
| 512 |
+
|
| 513 |
+
>>> # Test whether a certain string is watermarked
|
| 514 |
+
>>> test_input = tokenizer(["This is a test input"], return_tensors="pt")
|
| 515 |
+
>>> is_watermarked = detector(test_input.input_ids)
|
| 516 |
+
```
|
| 517 |
+
"""
|
| 518 |
+
|
| 519 |
+
def __init__(
|
| 520 |
+
self,
|
| 521 |
+
detector_module: BayesianDetectorModel,
|
| 522 |
+
logits_processor: SynthIDTextWatermarkLogitsProcessor,
|
| 523 |
+
tokenizer: Any,
|
| 524 |
+
):
|
| 525 |
+
self.detector_module = detector_module
|
| 526 |
+
self.logits_processor = logits_processor
|
| 527 |
+
self.tokenizer = tokenizer
|
| 528 |
+
|
| 529 |
+
def __call__(self, tokenized_outputs: torch.Tensor):
|
| 530 |
+
# eos mask is computed, skip first ngram_len - 1 tokens
|
| 531 |
+
# eos_mask will be of shape [batch_size, output_len]
|
| 532 |
+
eos_token_mask = self.logits_processor.compute_eos_token_mask(
|
| 533 |
+
input_ids=tokenized_outputs,
|
| 534 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 535 |
+
)[:, self.logits_processor.ngram_len - 1 :]
|
| 536 |
+
|
| 537 |
+
# context repetition mask is computed
|
| 538 |
+
context_repetition_mask = self.logits_processor.compute_context_repetition_mask(
|
| 539 |
+
input_ids=tokenized_outputs,
|
| 540 |
+
)
|
| 541 |
+
# context repitition mask shape [batch_size, output_len - (ngram_len - 1)]
|
| 542 |
+
|
| 543 |
+
combined_mask = context_repetition_mask * eos_token_mask
|
| 544 |
+
|
| 545 |
+
g_values = self.logits_processor.compute_g_values(
|
| 546 |
+
input_ids=tokenized_outputs,
|
| 547 |
+
)
|
| 548 |
+
# g values shape [batch_size, output_len - (ngram_len - 1), depth]
|
| 549 |
+
return self.detector_module(g_values, combined_mask)
|