Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitignore +151 -0
- .pre-commit-config.yaml +52 -0
- .pre-commit-config_local.yaml +52 -0
- 4JOB_train.tar +3 -0
- CODE_OF_CONDUCT.md +132 -0
- CONTRIBUTING.md +63 -0
- CONTRIBUTING_CN.md +81 -0
- GRPO_TEST.jsonl +0 -0
- GRPOtrain.sh +38 -0
- HH_TEST.jsonl +53 -0
- HM_TEST.jsonl +44 -0
- LICENSE +201 -0
- VLLM.sh +7 -0
- add_errorType.py +40 -0
- analyze_dialogue_lengths.py +112 -0
- compare_scores.py +96 -0
- docs/resources/kto_data.png +3 -0
- docs/resources/web-ui-en.jpg +3 -0
- docs/transformers/build/lib/transformers/models/clip/configuration_clip.py +422 -0
- docs/transformers/build/lib/transformers/models/clip/feature_extraction_clip.py +38 -0
- docs/transformers/build/lib/transformers/models/clip/image_processing_clip.py +350 -0
- docs/transformers/build/lib/transformers/models/clip/image_processing_clip_fast.py +42 -0
- docs/transformers/build/lib/transformers/models/clip/modeling_clip.py +1473 -0
- docs/transformers/build/lib/transformers/models/clip/modeling_flax_clip.py +1306 -0
- docs/transformers/build/lib/transformers/models/clip/modeling_tf_clip.py +1460 -0
- docs/transformers/build/lib/transformers/models/clip/processing_clip.py +156 -0
- docs/transformers/build/lib/transformers/models/clip/tokenization_clip.py +519 -0
- docs/transformers/build/lib/transformers/models/clip/tokenization_clip_fast.py +164 -0
- docs/transformers/build/lib/transformers/models/clipseg/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/clipseg/configuration_clipseg.py +396 -0
- docs/transformers/build/lib/transformers/models/clipseg/convert_clipseg_original_pytorch_to_hf.py +264 -0
- docs/transformers/build/lib/transformers/models/clipseg/modeling_clipseg.py +1520 -0
- docs/transformers/build/lib/transformers/models/clipseg/processing_clipseg.py +164 -0
- docs/transformers/build/lib/transformers/models/clvp/__init__.py +30 -0
- docs/transformers/build/lib/transformers/models/clvp/configuration_clvp.py +443 -0
- docs/transformers/build/lib/transformers/models/clvp/convert_clvp_to_hf.py +234 -0
- docs/transformers/build/lib/transformers/models/clvp/feature_extraction_clvp.py +241 -0
- docs/transformers/build/lib/transformers/models/clvp/modeling_clvp.py +2131 -0
- docs/transformers/build/lib/transformers/models/clvp/number_normalizer.py +237 -0
- docs/transformers/build/lib/transformers/models/clvp/processing_clvp.py +93 -0
- docs/transformers/build/lib/transformers/models/clvp/tokenization_clvp.py +367 -0
- docs/transformers/build/lib/transformers/models/code_llama/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/code_llama/tokenization_code_llama.py +454 -0
- docs/transformers/build/lib/transformers/models/code_llama/tokenization_code_llama_fast.py +378 -0
- docs/transformers/build/lib/transformers/models/codegen/__init__.py +29 -0
- docs/transformers/build/lib/transformers/models/codegen/configuration_codegen.py +230 -0
- docs/transformers/build/lib/transformers/models/codegen/modeling_codegen.py +834 -0
- docs/transformers/build/lib/transformers/models/codegen/tokenization_codegen.py +419 -0
- docs/transformers/build/lib/transformers/models/codegen/tokenization_codegen_fast.py +265 -0
.gitattributes
CHANGED
|
@@ -51,3 +51,5 @@ seamless_interaction/assets/banner.gif filter=lfs diff=lfs merge=lfs -text
|
|
| 51 |
docs/resources/grpo_countdown.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
docs/resources/grpo_geoqa.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
docs/resources/grpo_openr1_multimodal.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 51 |
docs/resources/grpo_countdown.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
docs/resources/grpo_geoqa.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
docs/resources/grpo_openr1_multimodal.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
docs/resources/web-ui-en.jpg filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
tmp
|
| 3 |
+
*.ttf
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.py[cod]
|
| 6 |
+
*$py.class
|
| 7 |
+
test.py
|
| 8 |
+
# C extensions
|
| 9 |
+
*.so
|
| 10 |
+
|
| 11 |
+
# Distribution / packaging
|
| 12 |
+
.Python
|
| 13 |
+
build/
|
| 14 |
+
develop-eggs/
|
| 15 |
+
dist/
|
| 16 |
+
downloads/
|
| 17 |
+
eggs/
|
| 18 |
+
.eggs/
|
| 19 |
+
lib/
|
| 20 |
+
lib64/
|
| 21 |
+
parts/
|
| 22 |
+
sdist/
|
| 23 |
+
var/
|
| 24 |
+
wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
/package
|
| 29 |
+
/temp
|
| 30 |
+
MANIFEST
|
| 31 |
+
|
| 32 |
+
# PyInstaller
|
| 33 |
+
# Usually these files are written by a python script from a template
|
| 34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 35 |
+
*.manifest
|
| 36 |
+
*.spec
|
| 37 |
+
|
| 38 |
+
# Installer logs
|
| 39 |
+
pip-log.txt
|
| 40 |
+
pip-delete-this-directory.txt
|
| 41 |
+
|
| 42 |
+
# Unit test / coverage reports
|
| 43 |
+
htmlcov/
|
| 44 |
+
.tox/
|
| 45 |
+
.coverage
|
| 46 |
+
.coverage.*
|
| 47 |
+
.cache
|
| 48 |
+
nosetests.xml
|
| 49 |
+
coverage.xml
|
| 50 |
+
*.cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
|
| 63 |
+
# Flask stuff:
|
| 64 |
+
instance/
|
| 65 |
+
.webassets-cache
|
| 66 |
+
|
| 67 |
+
# Scrapy stuff:
|
| 68 |
+
.scrapy
|
| 69 |
+
|
| 70 |
+
# Sphinx documentation
|
| 71 |
+
docs/_build/
|
| 72 |
+
|
| 73 |
+
# PyBuilder
|
| 74 |
+
target/
|
| 75 |
+
|
| 76 |
+
# Jupyter Notebook
|
| 77 |
+
.ipynb_checkpoints
|
| 78 |
+
|
| 79 |
+
# pyenv
|
| 80 |
+
.python-version
|
| 81 |
+
|
| 82 |
+
# celery beat schedule file
|
| 83 |
+
celerybeat-schedule
|
| 84 |
+
|
| 85 |
+
# SageMath parsed files
|
| 86 |
+
*.sage.py
|
| 87 |
+
|
| 88 |
+
# Environments
|
| 89 |
+
.env
|
| 90 |
+
.venv
|
| 91 |
+
env/
|
| 92 |
+
venv/
|
| 93 |
+
ENV/
|
| 94 |
+
env.bak/
|
| 95 |
+
venv.bak/
|
| 96 |
+
|
| 97 |
+
# Spyder project settings
|
| 98 |
+
.spyderproject
|
| 99 |
+
.spyproject
|
| 100 |
+
|
| 101 |
+
# Rope project settings
|
| 102 |
+
.ropeproject
|
| 103 |
+
|
| 104 |
+
# mkdocs documentation
|
| 105 |
+
/site
|
| 106 |
+
|
| 107 |
+
# mypy
|
| 108 |
+
.mypy_cache/
|
| 109 |
+
|
| 110 |
+
.vscode
|
| 111 |
+
.idea
|
| 112 |
+
.run
|
| 113 |
+
|
| 114 |
+
# custom
|
| 115 |
+
*.pkl
|
| 116 |
+
*.pkl.json
|
| 117 |
+
*.log.json
|
| 118 |
+
*.whl
|
| 119 |
+
*.tar.gz
|
| 120 |
+
*.swp
|
| 121 |
+
*.log
|
| 122 |
+
*.tar.gz
|
| 123 |
+
source.sh
|
| 124 |
+
tensorboard.sh
|
| 125 |
+
.DS_Store
|
| 126 |
+
replace.sh
|
| 127 |
+
result.png
|
| 128 |
+
result.jpg
|
| 129 |
+
result.mp4
|
| 130 |
+
output/
|
| 131 |
+
outputs/
|
| 132 |
+
wandb/
|
| 133 |
+
*.out
|
| 134 |
+
benchmarks/
|
| 135 |
+
eval_output/
|
| 136 |
+
eval_outputs/
|
| 137 |
+
transformers/
|
| 138 |
+
vlmeval/
|
| 139 |
+
my_model/
|
| 140 |
+
/data
|
| 141 |
+
result/
|
| 142 |
+
images
|
| 143 |
+
/custom/
|
| 144 |
+
megatron_output/
|
| 145 |
+
|
| 146 |
+
# Pytorch
|
| 147 |
+
*.pth
|
| 148 |
+
*.pt
|
| 149 |
+
|
| 150 |
+
# ast template
|
| 151 |
+
ast_index_file.py
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pycqa/flake8.git
|
| 3 |
+
rev: 4.0.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: flake8
|
| 6 |
+
exclude: |
|
| 7 |
+
(?x)^(
|
| 8 |
+
thirdparty/|
|
| 9 |
+
examples/|
|
| 10 |
+
tests/run.py
|
| 11 |
+
)$
|
| 12 |
+
- repo: https://github.com/PyCQA/isort.git
|
| 13 |
+
rev: 4.3.21
|
| 14 |
+
hooks:
|
| 15 |
+
- id: isort
|
| 16 |
+
exclude: |
|
| 17 |
+
(?x)^(
|
| 18 |
+
examples/|
|
| 19 |
+
tests/run.py|
|
| 20 |
+
swift/cli/sft.py
|
| 21 |
+
)$
|
| 22 |
+
- repo: https://github.com/pre-commit/mirrors-yapf.git
|
| 23 |
+
rev: v0.30.0
|
| 24 |
+
hooks:
|
| 25 |
+
- id: yapf
|
| 26 |
+
exclude: |
|
| 27 |
+
(?x)^(
|
| 28 |
+
thirdparty/|
|
| 29 |
+
examples/|
|
| 30 |
+
tests/run.py
|
| 31 |
+
)$
|
| 32 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks.git
|
| 33 |
+
rev: v3.1.0
|
| 34 |
+
hooks:
|
| 35 |
+
- id: trailing-whitespace
|
| 36 |
+
exclude: thirdparty/|tests/run.py
|
| 37 |
+
- id: check-yaml
|
| 38 |
+
exclude: thirdparty/|tests/run.py
|
| 39 |
+
- id: end-of-file-fixer
|
| 40 |
+
exclude: thirdparty/|tests/run.py
|
| 41 |
+
- id: requirements-txt-fixer
|
| 42 |
+
exclude: thirdparty/|tests/run.py
|
| 43 |
+
- id: double-quote-string-fixer
|
| 44 |
+
exclude: thirdparty/|tests/run.py
|
| 45 |
+
- id: check-merge-conflict
|
| 46 |
+
exclude: thirdparty/|tests/run.py
|
| 47 |
+
- id: fix-encoding-pragma
|
| 48 |
+
exclude: thirdparty/|tests/run.py
|
| 49 |
+
args: ["--remove"]
|
| 50 |
+
- id: mixed-line-ending
|
| 51 |
+
exclude: thirdparty/|tests/run.py
|
| 52 |
+
args: ["--fix=lf"]
|
.pre-commit-config_local.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: /home/admin/pre-commit/flake8
|
| 3 |
+
rev: 4.0.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: flake8
|
| 6 |
+
exclude: |
|
| 7 |
+
(?x)^(
|
| 8 |
+
thirdparty/|
|
| 9 |
+
examples/|
|
| 10 |
+
tests/run.py
|
| 11 |
+
)$
|
| 12 |
+
- repo: /home/admin/pre-commit/isort
|
| 13 |
+
rev: 4.3.21
|
| 14 |
+
hooks:
|
| 15 |
+
- id: isort
|
| 16 |
+
exclude: |
|
| 17 |
+
(?x)^(
|
| 18 |
+
examples/|
|
| 19 |
+
tests/run.py|
|
| 20 |
+
swift/cli/sft.py
|
| 21 |
+
)$
|
| 22 |
+
- repo: /home/admin/pre-commit/mirrors-yapf
|
| 23 |
+
rev: v0.30.0
|
| 24 |
+
hooks:
|
| 25 |
+
- id: yapf
|
| 26 |
+
exclude: |
|
| 27 |
+
(?x)^(
|
| 28 |
+
thirdparty/|
|
| 29 |
+
examples/|
|
| 30 |
+
tests/run.py
|
| 31 |
+
)$
|
| 32 |
+
- repo: /home/admin/pre-commit/pre-commit-hooks
|
| 33 |
+
rev: v3.1.0
|
| 34 |
+
hooks:
|
| 35 |
+
- id: trailing-whitespace
|
| 36 |
+
exclude: thirdparty/|tests/run.py
|
| 37 |
+
- id: check-yaml
|
| 38 |
+
exclude: thirdparty/|tests/run.py
|
| 39 |
+
- id: end-of-file-fixer
|
| 40 |
+
exclude: thirdparty/
|
| 41 |
+
- id: requirements-txt-fixer
|
| 42 |
+
exclude: thirdparty/|tests/run.py
|
| 43 |
+
- id: double-quote-string-fixer
|
| 44 |
+
exclude: thirdparty/|tests/run.py
|
| 45 |
+
- id: check-merge-conflict
|
| 46 |
+
exclude: thirdparty/|tests/run.py
|
| 47 |
+
- id: fix-encoding-pragma
|
| 48 |
+
exclude: thirdparty/|tests/run.py
|
| 49 |
+
args: ["--remove"]
|
| 50 |
+
- id: mixed-line-ending
|
| 51 |
+
exclude: thirdparty/|tests/run.py
|
| 52 |
+
args: ["--fix=lf"]
|
4JOB_train.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:69f7465b4776100721f926c3f4221d72752dfe6f124d6f45586e2c2eadc55b7e
|
| 3 |
+
size 6600263680
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributor Covenant Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
We as members, contributors, and leaders pledge to make participation in our
|
| 6 |
+
community a harassment-free experience for everyone, regardless of age, body
|
| 7 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
| 8 |
+
identity and expression, level of experience, education, socio-economic status,
|
| 9 |
+
nationality, personal appearance, race, caste, color, religion, or sexual
|
| 10 |
+
identity and orientation.
|
| 11 |
+
|
| 12 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
| 13 |
+
diverse, inclusive, and healthy community.
|
| 14 |
+
|
| 15 |
+
## Our Standards
|
| 16 |
+
|
| 17 |
+
Examples of behavior that contributes to a positive environment for our
|
| 18 |
+
community include:
|
| 19 |
+
|
| 20 |
+
* Demonstrating empathy and kindness toward other people
|
| 21 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
| 22 |
+
* Giving and gracefully accepting constructive feedback
|
| 23 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
| 24 |
+
and learning from the experience
|
| 25 |
+
* Focusing on what is best not just for us as individuals, but for the overall
|
| 26 |
+
community
|
| 27 |
+
|
| 28 |
+
Examples of unacceptable behavior include:
|
| 29 |
+
|
| 30 |
+
* The use of sexualized language or imagery, and sexual attention or advances of
|
| 31 |
+
any kind
|
| 32 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
| 33 |
+
* Public or private harassment
|
| 34 |
+
* Publishing others' private information, such as a physical or email address,
|
| 35 |
+
without their explicit permission
|
| 36 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 37 |
+
professional setting
|
| 38 |
+
|
| 39 |
+
## Enforcement Responsibilities
|
| 40 |
+
|
| 41 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
| 42 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
| 43 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
| 44 |
+
or harmful.
|
| 45 |
+
|
| 46 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
| 47 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
| 48 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
| 49 |
+
decisions when appropriate.
|
| 50 |
+
|
| 51 |
+
## Scope
|
| 52 |
+
|
| 53 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
| 54 |
+
an individual is officially representing the community in public spaces.
|
| 55 |
+
Examples of representing our community include using an official e-mail address,
|
| 56 |
+
posting via an official social media account, or acting as an appointed
|
| 57 |
+
representative at an online or offline event.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported to the community leaders responsible for enforcement at
|
| 63 |
+
contact@modelscope.cn.
|
| 64 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
| 65 |
+
|
| 66 |
+
All community leaders are obligated to respect the privacy and security of the
|
| 67 |
+
reporter of any incident.
|
| 68 |
+
|
| 69 |
+
## Enforcement Guidelines
|
| 70 |
+
|
| 71 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
| 72 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
| 73 |
+
|
| 74 |
+
### 1. Correction
|
| 75 |
+
|
| 76 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
| 77 |
+
unprofessional or unwelcome in the community.
|
| 78 |
+
|
| 79 |
+
**Consequence**: A private, written warning from community leaders, providing
|
| 80 |
+
clarity around the nature of the violation and an explanation of why the
|
| 81 |
+
behavior was inappropriate. A public apology may be requested.
|
| 82 |
+
|
| 83 |
+
### 2. Warning
|
| 84 |
+
|
| 85 |
+
**Community Impact**: A violation through a single incident or series of
|
| 86 |
+
actions.
|
| 87 |
+
|
| 88 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
| 89 |
+
interaction with the people involved, including unsolicited interaction with
|
| 90 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
| 91 |
+
includes avoiding interactions in community spaces as well as external channels
|
| 92 |
+
like social media. Violating these terms may lead to a temporary or permanent
|
| 93 |
+
ban.
|
| 94 |
+
|
| 95 |
+
### 3. Temporary Ban
|
| 96 |
+
|
| 97 |
+
**Community Impact**: A serious violation of community standards, including
|
| 98 |
+
sustained inappropriate behavior.
|
| 99 |
+
|
| 100 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
| 101 |
+
communication with the community for a specified period of time. No public or
|
| 102 |
+
private interaction with the people involved, including unsolicited interaction
|
| 103 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
| 104 |
+
Violating these terms may lead to a permanent ban.
|
| 105 |
+
|
| 106 |
+
### 4. Permanent Ban
|
| 107 |
+
|
| 108 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
| 109 |
+
standards, including sustained inappropriate behavior, harassment of an
|
| 110 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
| 111 |
+
|
| 112 |
+
**Consequence**: A permanent ban from any sort of public interaction within the
|
| 113 |
+
community.
|
| 114 |
+
|
| 115 |
+
## Attribution
|
| 116 |
+
|
| 117 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
| 118 |
+
version 2.1, available at
|
| 119 |
+
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
| 120 |
+
|
| 121 |
+
Community Impact Guidelines were inspired by
|
| 122 |
+
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
| 123 |
+
|
| 124 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
| 125 |
+
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
| 126 |
+
[https://www.contributor-covenant.org/translations][translations].
|
| 127 |
+
|
| 128 |
+
[homepage]: https://www.contributor-covenant.org
|
| 129 |
+
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
| 130 |
+
[Mozilla CoC]: https://github.com/mozilla/diversity
|
| 131 |
+
[FAQ]: https://www.contributor-covenant.org/faq
|
| 132 |
+
[translations]: https://www.contributor-covenant.org/translations
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributor Guide
|
| 2 |
+
|
| 3 |
+
_Welcome to offer PRs, bug reports, documentation supplements or other types of contributions to SWIFT!_
|
| 4 |
+
|
| 5 |
+
## Table of Contents
|
| 6 |
+
- [Code of Conduct](#-code-of-conduct)
|
| 7 |
+
- [Contribution Process](#-contribution-process)
|
| 8 |
+
- [Hardware support](#-Hardware-support)
|
| 9 |
+
|
| 10 |
+
## 📖 Code of Conduct
|
| 11 |
+
Please refer to our [Code of Conduct documentation](./CODE_OF_CONDUCT.md).
|
| 12 |
+
|
| 13 |
+
## 🔁 Contribution Process
|
| 14 |
+
### What We Need
|
| 15 |
+
- New Technologies and New Models: SWIFT needs to support more open-source models and datasets, or new technologies that we have not paid attention to. If you are interested please submit a PR to us.
|
| 16 |
+
- Technical Propagation: If you are interested in technical propagation, you are welcome to help us write tutorials, documents or videos on any website, and send us the link.
|
| 17 |
+
- Community Contribution: You can write technical articles related to SWIFT, and submit them to us. After review and approval, we will publish them on the official ModelScope accounts (Zhihu, WeChat, etc.), with your name assigned.
|
| 18 |
+
|
| 19 |
+
### Incentives
|
| 20 |
+
- we will issue electronic certificates to contributors on behalf of the ModelScope community, to encourage your selfless contributions.
|
| 21 |
+
- We will offer small souvenirs related to the ModelScope Community.
|
| 22 |
+
- We will provide free A10 computing power during the development period. For more details, please refer to [Hardware-support](#-Hardware-support) section.
|
| 23 |
+
|
| 24 |
+
### Submitting PR (Pull Requests)
|
| 25 |
+
|
| 26 |
+
Any feature development is carried out in the form of Fork and then PR on GitHub.
|
| 27 |
+
1. Fork: Go to the [SWIFT](https://github.com/modelscope/swift) page and click the **Fork button**. After completion, a SWIFT code repository will be cloned under your personal organization.
|
| 28 |
+
2. Clone: Clone the code repository generated in the first step to your local machine and **create a new branch** for development. During development, please click the **Sync Fork button** in time to synchronize with the `main` branch to prevent code expiration and conflicts.
|
| 29 |
+
3. Submit PR: After development and testing, push the code to the remote branch. On GitHub, go to the **Pull Requests page**, create a new PR, select your code branch as the source branch, and the `modelscope/swift:main` branch as the target branch.
|
| 30 |
+
|
| 31 |
+
4. Write Description: It is necessary to provide a good feature description in the PR, so that the reviewers know the content of your modification.
|
| 32 |
+
5. Review: We hope that the code to be merged is concise and efficient, so we may raise some questions and discuss them. Please note that any issues raised in the review are aimed at the code itself, not at you personally. Once all issues are discussed and resolved, your code will be approved.
|
| 33 |
+
|
| 34 |
+
### Code Standards and Development Approach
|
| 35 |
+
SWIFT has conventional variable naming conventions and development approaches. Please follow these approaches as much as possible during development.
|
| 36 |
+
1. Variable names are separated by underscores, and class names are named with the first letter of each word capitalized.
|
| 37 |
+
2. All Python indentation uses four spaces instead of a tab.
|
| 38 |
+
3. Choose well-known open-source libraries, avoid using closed-source libraries or unstable open-source libraries, and avoid repeating the existing code.
|
| 39 |
+
|
| 40 |
+
After the PR is submitted, SWIFT will perform two types of tests:
|
| 41 |
+
- Code Lint Test: A static code compliance check test. please make sure that you have performed code lint locally in advance.
|
| 42 |
+
```shell
|
| 43 |
+
pip install pre-commit # In the swift folder
|
| 44 |
+
pre-commit run --all-files # Fix the errors reported by pre-commit until all checks are successful
|
| 45 |
+
```
|
| 46 |
+
- CI Tests: Smoke tests and unit tests, please refer to the next section.
|
| 47 |
+
|
| 48 |
+
### Running CI Tests
|
| 49 |
+
Before submitting the PR, please ensure that your development code is protected by test cases, such as smoke tests for new features, or unit tests for various edge cases. Reviewers will also pay attention to this during code review. At the same time, there will be dedicated services running CI Tests, running all test cases, and the code can only be merged after the test cases pass.
|
| 50 |
+
|
| 51 |
+
Additionally, since some important tests have been skipped due to long running time, to ensure that your logic is correct, you can run the test locally:
|
| 52 |
+
```shell
|
| 53 |
+
python tests/llm/test_run.py
|
| 54 |
+
```
|
| 55 |
+
Please make sure this test can pass normally.
|
| 56 |
+
|
| 57 |
+
## ✅ Hardware support
|
| 58 |
+
|
| 59 |
+
SWIFT will provide hardware support for developers, including free GPUs. If needed, please email us ([contact@modelscope.cn](mailto:contact@modelscope.cn)) or join our WeChat group:
|
| 60 |
+
|
| 61 |
+
<p align="left">
|
| 62 |
+
<img src="asset/wechat.png" width="250" style="display: inline-block;">
|
| 63 |
+
</p>
|
CONTRIBUTING_CN.md
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 贡献者指引
|
| 2 |
+
|
| 3 |
+
*欢迎帮SWIFT提供Feature PR、Bug反馈、文档补充或其他类型的贡献!*
|
| 4 |
+
|
| 5 |
+
## 目录
|
| 6 |
+
|
| 7 |
+
- [代码规约](#-代码规约)
|
| 8 |
+
- [贡献流程](#-贡献流程)
|
| 9 |
+
- [资源支持](#-资源支持)
|
| 10 |
+
|
| 11 |
+
## 📖 代码规约
|
| 12 |
+
|
| 13 |
+
请查看我们的[代码规约文档](./CODE_OF_CONDUCT.md).
|
| 14 |
+
|
| 15 |
+
## 🔁 贡献流程
|
| 16 |
+
|
| 17 |
+
### 我们需要什么
|
| 18 |
+
- 新技术和新模型:SWIFT需要支持更多的开源模型和数据集,或我们没有关注到的新技术,如果您对此有兴趣,可以提交PR给我们。
|
| 19 |
+
- 技术布道:如果您对技术布道有兴趣,欢迎在任何网站上帮我们撰写教程文档或视频等,并将链接发给我们。
|
| 20 |
+
- 社区供稿:您可以撰写和SWIFT有关的技术文章,并供稿给我们,我们审核通过后会在魔搭官方账号(知乎、公众号等)上进行发布,并属上您的名字。
|
| 21 |
+
|
| 22 |
+
### 激励
|
| 23 |
+
|
| 24 |
+
- 我们会以魔搭社区的身份给贡献者颁发电子证书,以鼓励您的无私贡献。
|
| 25 |
+
- 我们会赠送相关魔搭社区相关周边小礼品。
|
| 26 |
+
- 我们会赠送开发期间的免费A10算力,具体可以查看[资源支持](#-资源支持)章节。
|
| 27 |
+
|
| 28 |
+
### 提交PR(Pull Requests)
|
| 29 |
+
|
| 30 |
+
任何feature开发都在github上以先Fork后PR的形式进行。
|
| 31 |
+
|
| 32 |
+
1. Fork:进入[SWIFT](https://github.com/modelscope/swift)页面后,点击**Fork按钮**执行。完成后会在您的个人组织下克隆出一个SWIFT代码库
|
| 33 |
+
|
| 34 |
+
2. Clone:将第一步产生的代码库clone到本地并**拉新分支**进行开发,开发中请及时点击**Sync Fork按钮**同步`main`分支,防止代码过期并冲突
|
| 35 |
+
|
| 36 |
+
3. 提交PR:开发、测试完成后将代码推送到远程分支。在github上点击**Pull Requests页面**,新建一个PR,源分支选择您提交的代码分支,目标分支选择`modelscope/swift:main`分支
|
| 37 |
+
|
| 38 |
+
4. 撰写描述:在PR中填写良好的feature描述是必要的,让Reviewers知道您的修改内容
|
| 39 |
+
|
| 40 |
+
5. Review:我们希望合入的代码简洁高效,因此可能会提出一些问题并讨论。请注意,任何review中提出的问题是针对代码本身,而非您个人。在所有问题讨论通过后,您的代码会被通过
|
| 41 |
+
|
| 42 |
+
### 代码规范和开发方式
|
| 43 |
+
|
| 44 |
+
SWIFT有约定俗成的变量命名方式和开发方式。在开发中请尽量遵循这些方式。
|
| 45 |
+
|
| 46 |
+
1. 变量命名以下划线分割,类名以所有单词首字母大写方式命名
|
| 47 |
+
2. 所有的python缩进都是四个空格取代一个tab
|
| 48 |
+
3. 选用知名的开源库,避免使用闭源库或不稳定的开源库,避免重复造轮子
|
| 49 |
+
|
| 50 |
+
SWIFT在PR提交后会进行两类测试:
|
| 51 |
+
|
| 52 |
+
- Code Lint测试 对代码进行静态规范走查的测试,为保证改测试通过,请保证本地预先进行了Code lint。方法是:
|
| 53 |
+
|
| 54 |
+
```shell
|
| 55 |
+
pip install pre-commit
|
| 56 |
+
# 在swift文件夹内
|
| 57 |
+
pre-commit run --all-files
|
| 58 |
+
# 对pre-commit报的错误进行修改,直到所有的检查都是成功状态
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
- CI Tests 冒烟测试和单元测试,请查看下一章节
|
| 62 |
+
|
| 63 |
+
### Running CI Tests
|
| 64 |
+
|
| 65 |
+
在提交PR前,请保证您的开发代码已经受到了测试用例的保护。例如,对新功能的冒烟测试,或者各种边缘case的单元测试等。在代码review时Reviewers也会关注这一点。同时,也会有服务专门运行CI Tests,运行所有的测试用例,测试用例通过后代码才可以合并。
|
| 66 |
+
|
| 67 |
+
另外,由于运行时间过长,我们跳过了部分重要测试,为保证您的逻辑是正确的,可以在本地执行该测试:
|
| 68 |
+
|
| 69 |
+
```shell
|
| 70 |
+
python tests/llm/test_run.py
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
请保证该测试可以正常通过。
|
| 74 |
+
|
| 75 |
+
## ✅ 资源支持
|
| 76 |
+
|
| 77 |
+
SWIFT会为开发者提供资源支持,包括免费的GPU算力。如果需要请邮件联系我们([contact@modelscope.cn](mailto:contact@modelscope.cn))或加入我们的微信群:
|
| 78 |
+
|
| 79 |
+
<p align="left">
|
| 80 |
+
<img src="asset/wechat.png" width="250" style="display: inline-block;">
|
| 81 |
+
</p>
|
GRPO_TEST.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GRPOtrain.sh
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
WANDB_API_KEY="a7ab128385681b17ad156ad0d8c81ba3e2296164" \
|
| 2 |
+
CUDA_VISIBLE_DEVICES=0,1 \
|
| 3 |
+
NPROC_PER_NODE=2 \
|
| 4 |
+
swift rlhf \
|
| 5 |
+
--rlhf_type grpo \
|
| 6 |
+
--model /root/autodl-tmp/output_7B_FULL_cotSFT/v11-20250721-183605/checkpoint-330 \
|
| 7 |
+
--external_plugins GRPO/Reward.py \
|
| 8 |
+
--reward_funcs external_r1v_acc external_r1v_format_acc \
|
| 9 |
+
--use_vllm false \
|
| 10 |
+
--train_type full \
|
| 11 |
+
--torch_dtype bfloat16 \
|
| 12 |
+
--dataset 'all_dataset_train_resampled_16000.jsonl' \
|
| 13 |
+
--max_completion_length 512 \
|
| 14 |
+
--num_train_epochs 2 \
|
| 15 |
+
--per_device_train_batch_size 2 \
|
| 16 |
+
--per_device_eval_batch_size 2 \
|
| 17 |
+
--learning_rate 1e-6 \
|
| 18 |
+
--gradient_accumulation_steps 2 \
|
| 19 |
+
--save_strategy 'steps' \
|
| 20 |
+
--eval_strategy 'steps' \
|
| 21 |
+
--eval_steps 290 \
|
| 22 |
+
--save_steps 290 \
|
| 23 |
+
--save_total_limit 5 \
|
| 24 |
+
--logging_steps 5 \
|
| 25 |
+
--output_dir /root/autodl-tmp/output_7B_GRPO \
|
| 26 |
+
--warmup_ratio 0.01 \
|
| 27 |
+
--dataloader_num_workers 1 \
|
| 28 |
+
--num_generations 2 \
|
| 29 |
+
--temperature 1.0 \
|
| 30 |
+
--log_completions true \
|
| 31 |
+
--num_iterations 1 \
|
| 32 |
+
--async_generate false \
|
| 33 |
+
--beta 0.01 \
|
| 34 |
+
--deepspeed zero3_offload \
|
| 35 |
+
--report_to wandb \
|
| 36 |
+
# --vllm_mode server \
|
| 37 |
+
# --vllm_server_host 127.0.0.1 \
|
| 38 |
+
# --vllm_server_port 8000 \
|
HH_TEST.jsonl
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0011/Classify/V00_S0549_I00000377/V00_S0549_I00000377_P0377__V00_S0549_I00000377_P0682A_stereo.wav"], "solution": 2}
|
| 2 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0011/Classify/V00_S0544_I00000770/V00_S0544_I00000770_P0658A__V00_S0544_I00000770_P0691_stereo.wav"], "solution": 2}
|
| 3 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0011/Classify/V00_S0540_I00000545/V00_S0540_I00000545_P0658A__V00_S0540_I00000545_P0687_stereo.wav"], "solution": 2}
|
| 4 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0011/Classify/V00_S0540_I00000542/V00_S0540_I00000542_P0658A__V00_S0540_I00000542_P0687_stereo.wav"], "solution": 2}
|
| 5 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0688_I00000126/V00_S0688_I00000126_P0658A__V00_S0688_I00000126_P0737_stereo.wav"], "solution": 2}
|
| 6 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0697_I00000384/V00_S0697_I00000384_P0383A__V00_S0697_I00000384_P0851_stereo.wav"], "solution": 2}
|
| 7 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0684_I00000135/V00_S0684_I00000135_P0179__V00_S0684_I00000135_P0658A_stereo.wav"], "solution": 2}
|
| 8 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0668_I00000770/V00_S0668_I00000770_P0230A__V00_S0668_I00000770_P0827_stereo.wav"], "solution": 2}
|
| 9 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0670_I00000545/V00_S0670_I00000545_P0460__V00_S0670_I00000545_P0825A_stereo.wav"], "solution": 2}
|
| 10 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0668_I00000582/V00_S0668_I00000582_P0230A__V00_S0668_I00000582_P0827_stereo.wav"], "solution": 2}
|
| 11 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0668_I00000581/V00_S0668_I00000581_P0230A__V00_S0668_I00000581_P0827_stereo.wav"], "solution": 2}
|
| 12 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0683_I00000582/V00_S0683_I00000582_P0027A__V00_S0683_I00000582_P0193_stereo.wav"], "solution": 2}
|
| 13 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0673_I00000371/V00_S0673_I00000371_P0012A__V00_S0673_I00000371_P0830_stereo.wav"], "solution": 2}
|
| 14 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0680_I00000484/V00_S0680_I00000484_P0658A__V00_S0680_I00000484_P0833_stereo.wav"], "solution": 2}
|
| 15 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0688_I00000135/V00_S0688_I00000135_P0658A__V00_S0688_I00000135_P0737_stereo.wav"], "solution": 2}
|
| 16 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0670_I00000542/V00_S0670_I00000542_P0460__V00_S0670_I00000542_P0825A_stereo.wav"], "solution": 2}
|
| 17 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0680_I00000135/V00_S0680_I00000135_P0658A__V00_S0680_I00000135_P0833_stereo.wav"], "solution": 2}
|
| 18 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0672_I00000375/V00_S0672_I00000375_P0230A__V00_S0672_I00000375_P0829_stereo.wav"], "solution": 2}
|
| 19 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0672_I00000135/V00_S0672_I00000135_P0230A__V00_S0672_I00000135_P0829_stereo.wav"], "solution": 2}
|
| 20 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0669_I00000138/V00_S0669_I00000138_P0383A__V00_S0669_I00000138_P0826_stereo.wav"], "solution": 2}
|
| 21 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0013/Classify/V00_S0677_I00000371/V00_S0677_I00000371_P0122__V00_S0677_I00000371_P0383A_stereo.wav"], "solution": 2}
|
| 22 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0660_I00000581/V00_S0660_I00000581_P0185A__V00_S0660_I00000581_P0817_stereo.wav"], "solution": 2}
|
| 23 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0645_I00000495/V00_S0645_I00000495_P0262A__V00_S0645_I00000495_P0800_stereo.wav"], "solution": 2}
|
| 24 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0640_I00000535/V00_S0640_I00000535_P0229A__V00_S0640_I00000535_P0798_stereo.wav"], "solution": 2}
|
| 25 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0609_I00000138/V00_S0609_I00000138_P0229A__V00_S0609_I00000138_P0764_stereo.wav"], "solution": 2}
|
| 26 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0636_I00000544/V00_S0636_I00000544_P0012A__V00_S0636_I00000544_P0794_stereo.wav"], "solution": 2}
|
| 27 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0668_I00000132/V00_S0668_I00000132_P0230A__V00_S0668_I00000132_P0827_stereo.wav"], "solution": 2}
|
| 28 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0647_I00000481/V00_S0647_I00000481_P0323A__V00_S0647_I00000481_P0421_stereo.wav"], "solution": 2}
|
| 29 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0611_I00000960/V00_S0611_I00000960_P0262A__V00_S0611_I00000960_P0323A_stereo.wav"], "solution": 2}
|
| 30 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0636_I00000376/V00_S0636_I00000376_P0012A__V00_S0636_I00000376_P0794_stereo.wav"], "solution": 2}
|
| 31 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0636_I00000504/V00_S0636_I00000504_P0012A__V00_S0636_I00000504_P0794_stereo.wav"], "solution": 2}
|
| 32 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0660_I00000135/V00_S0660_I00000135_P0185A__V00_S0660_I00000135_P0817_stereo.wav"], "solution": 2}
|
| 33 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0649_I00000542/V00_S0649_I00000542_P0262A__V00_S0649_I00000542_P0802_stereo.wav"], "solution": 2}
|
| 34 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0666_I00000135/V00_S0666_I00000135_P0823__V00_S0666_I00000135_P0825A_stereo.wav"], "solution": 2}
|
| 35 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0636_I00000542/V00_S0636_I00000542_P0012A__V00_S0636_I00000542_P0794_stereo.wav"], "solution": 2}
|
| 36 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0640_I00000483/V00_S0640_I00000483_P0229A__V00_S0640_I00000483_P0798_stereo.wav"], "solution": 2}
|
| 37 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0666_I00000582/V00_S0666_I00000582_P0823__V00_S0666_I00000582_P0825A_stereo.wav"], "solution": 2}
|
| 38 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0645_I00000535/V00_S0645_I00000535_P0262A__V00_S0645_I00000535_P0800_stereo.wav"], "solution": 2}
|
| 39 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0666_I00000138/V00_S0666_I00000138_P0823__V00_S0666_I00000138_P0825A_stereo.wav"], "solution": 2}
|
| 40 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0640_I00000539/V00_S0640_I00000539_P0229A__V00_S0640_I00000539_P0798_stereo.wav"], "solution": 2}
|
| 41 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0645_I00000135/V00_S0645_I00000135_P0262A__V00_S0645_I00000135_P0800_stereo.wav"], "solution": 2}
|
| 42 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0607_I00001286/V00_S0607_I00001286_P0262A__V00_S0607_I00001286_P0323A_stereo.wav"], "solution": 2}
|
| 43 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0613_I00000384/V00_S0613_I00000384_P0229A__V00_S0613_I00000384_P0765_stereo.wav"], "solution": 2}
|
| 44 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0651_I00000129/V00_S0651_I00000129_P0323A__V00_S0651_I00000129_P0506_stereo.wav"], "solution": 2}
|
| 45 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0606_I00000800/V00_S0606_I00000800_P0005A__V00_S0606_I00000800_P0658A_stereo.wav"], "solution": 2}
|
| 46 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0660_I00000131/V00_S0660_I00000131_P0185A__V00_S0660_I00000131_P0817_stereo.wav"], "solution": 2}
|
| 47 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0664_I00000542/V00_S0664_I00000542_P0185A__V00_S0664_I00000542_P0822_stereo.wav"], "solution": 2}
|
| 48 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0640_I00000487/V00_S0640_I00000487_P0229A__V00_S0640_I00000487_P0798_stereo.wav"], "solution": 2}
|
| 49 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0666_I00000125/V00_S0666_I00000125_P0823__V00_S0666_I00000125_P0825A_stereo.wav"], "solution": 2}
|
| 50 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0640_I00000575/V00_S0640_I00000575_P0229A__V00_S0640_I00000575_P0798_stereo.wav"], "solution": 2}
|
| 51 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0613_I00000544/V00_S0613_I00000544_P0229A__V00_S0613_I00000544_P0765_stereo.wav"], "solution": 2}
|
| 52 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0666_I00000770/V00_S0666_I00000770_P0823__V00_S0666_I00000770_P0825A_stereo.wav"], "solution": 2}
|
| 53 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/seamless-interaction/improvised/train/0012/Classify/V00_S0660_I00000138/V00_S0660_I00000138_P0185A__V00_S0660_I00000138_P0817_stereo.wav"], "solution": 2}
|
HM_TEST.jsonl
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/001.wav"], "solution": 2}
|
| 2 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/002.wav"], "solution": 2}
|
| 3 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/003.wav"], "solution": 2}
|
| 4 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/004.wav"], "solution": 2}
|
| 5 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/005.wav"], "solution": 2}
|
| 6 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/006.wav"], "solution": 2}
|
| 7 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/007.wav"], "solution": 2}
|
| 8 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/008.wav"], "solution": 2}
|
| 9 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/009.wav"], "solution": 2}
|
| 10 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/010.wav"], "solution": 2}
|
| 11 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/011.wav"], "solution": 2}
|
| 12 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/012.wav"], "solution": 2}
|
| 13 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/013.wav"], "solution": 2}
|
| 14 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/014.wav"], "solution": 2}
|
| 15 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/015.wav"], "solution": 2}
|
| 16 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/016.wav"], "solution": 2}
|
| 17 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/017.wav"], "solution": 2}
|
| 18 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/018.wav"], "solution": 2}
|
| 19 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/019.wav"], "solution": 2}
|
| 20 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/020.wav"], "solution": 2}
|
| 21 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/021.wav"], "solution": 2}
|
| 22 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/第16开始txt不规范/022.wav"], "solution": 2}
|
| 23 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/xiaoyuaudios/xiaoyu1.wav"], "solution": 2}
|
| 24 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/xiaoyuaudios/xiaoyu2.wav"], "solution": 2}
|
| 25 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/xiaoyuaudios/xiaoyu3.wav"], "solution": 2}
|
| 26 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/xiaoyuaudios/xiaoyu4.wav"], "solution": 1}
|
| 27 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/xiaoyuaudios/xiaoyu5.wav"], "solution": 1}
|
| 28 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/001.wav"], "solution": 1}
|
| 29 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/002.wav"], "solution": 1}
|
| 30 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/003.wav"], "solution": 2}
|
| 31 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/004.wav"], "solution": 1}
|
| 32 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/005.wav"], "solution": 1}
|
| 33 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/006.wav"], "solution": 2}
|
| 34 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/007.wav"], "solution": 1}
|
| 35 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/008.wav"], "solution": 2}
|
| 36 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/009.wav"], "solution": 2}
|
| 37 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/010.wav"], "solution": 2}
|
| 38 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/011.wav"], "solution": 1}
|
| 39 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/012.wav"], "solution": 2}
|
| 40 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/013.wav"], "solution": 1}
|
| 41 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/014.wav"], "solution": 1}
|
| 42 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/015.wav"], "solution": 1}
|
| 43 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/016.wav"], "solution": 1}
|
| 44 |
+
{"messages": [{"role": "user", "content": "<audio># Interactional Dialogue Evaluation\n\n**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\nListen to a two-person interactional dialogue speech (Dual-channel audio, with each channel representing one speaker), labeled as speakers A and B. Evaluate the quality of the interaction, focusing on:\n**Response Relevance:** \n**logical consistency, topic coherence**\n**Interactional Fluency:**\n**Detect and evaluate extended vocal overlaps, e.g., cross-channel overlap.**\n**Detect and evaluate long pauses, e.g., pauses more than 3s between speaker turns.\n\n****Note**: Small pauses and brief overlaps in audio are acceptable, while prolonged pauses and overlapping audio are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n## Scoring Criteria\nAssign a single holistic score based on the combined evaluation:\n`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n## Evaluation Output Format:\nStrictly follow this template:\n<response think>\n[Analysing Response Relevance and giving reasons for scoring...]\n</response think>\n<fluency think>\n[Analysing Interactional Fluency and giving reasons for scoring.]\n</fluency think>\n<overall score>X</overall score>\n"}], "audios": ["/root/autodl-tmp/wavrewardDataset/conversations/data/testdata/predict_result_mission4/audios/duihua/duihua/017.wav"], "solution": 2}
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
VLLM.sh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \
|
| 2 |
+
--model /root/autodl-tmp/output_7B_FULL_cotSFT/v0-20250621-230827/Qwen2.5-Omni-7B \
|
| 3 |
+
--tokenizer /root/autodl-tmp/output_7B_FULL_cotSFT/v0-20250621-230827/Qwen2.5-Omni-7B \
|
| 4 |
+
--dtype bfloat16 \
|
| 5 |
+
--host 127.0.0.1 \
|
| 6 |
+
--port 8000 \
|
| 7 |
+
--gpu-memory-utilization 0.9
|
add_errorType.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
# 读取原始的allcorrect.json文件
|
| 4 |
+
with open('ms-swift/matched_scores_2_1.json', 'r', encoding='utf-8') as f:
|
| 5 |
+
allcorrect_data = json.load(f)
|
| 6 |
+
|
| 7 |
+
# 读取merged_shuffled_train.json文件
|
| 8 |
+
with open('/root/autodl-tmp/600_train/merged_shuffled_train.json', 'r', encoding='utf-8') as f:
|
| 9 |
+
merged_data = json.load(f)
|
| 10 |
+
|
| 11 |
+
# 遍历allcorrect.json中的每个条目
|
| 12 |
+
for entry in allcorrect_data:
|
| 13 |
+
# 获取key
|
| 14 |
+
key = entry.get('key')
|
| 15 |
+
if key:
|
| 16 |
+
# 在merged_data中查找对应的条目
|
| 17 |
+
if key in merged_data:
|
| 18 |
+
# 获取error_type并添加到entry中
|
| 19 |
+
error_type = merged_data[key].get('error_type')
|
| 20 |
+
entry['error_type'] = error_type
|
| 21 |
+
|
| 22 |
+
# 将更新后的数据写回文件
|
| 23 |
+
output_file = 'ms-swift/allcorrect_with_error_type.json'
|
| 24 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 25 |
+
json.dump(allcorrect_data, f, ensure_ascii=False, indent=2)
|
| 26 |
+
|
| 27 |
+
print(f"处理完成,结果已保存到 {output_file}")
|
| 28 |
+
|
| 29 |
+
# 统计error_type的分布
|
| 30 |
+
error_type_stats = {}
|
| 31 |
+
for entry in allcorrect_data:
|
| 32 |
+
error_type = entry.get('error_type')
|
| 33 |
+
if error_type:
|
| 34 |
+
error_type_stats[error_type] = error_type_stats.get(error_type, 0) + 1
|
| 35 |
+
else:
|
| 36 |
+
error_type_stats['no_error_type'] = error_type_stats.get('no_error_type', 0) + 1
|
| 37 |
+
|
| 38 |
+
print("\nError Type 统计:")
|
| 39 |
+
for error_type, count in error_type_stats.items():
|
| 40 |
+
print(f"{error_type}: {count}")
|
analyze_dialogue_lengths.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from collections import Counter
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import List, Dict
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
|
| 7 |
+
def analyze_dialogue_lengths(file_path: str) -> Dict:
|
| 8 |
+
# Read the JSONL file
|
| 9 |
+
lengths = []
|
| 10 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 11 |
+
for line in f:
|
| 12 |
+
try:
|
| 13 |
+
item = json.loads(line.strip())
|
| 14 |
+
for message in item['messages']:
|
| 15 |
+
if message['role'] == 'assistant':
|
| 16 |
+
content = message['content']
|
| 17 |
+
length = len(content)
|
| 18 |
+
lengths.append(length)
|
| 19 |
+
except json.JSONDecodeError as e:
|
| 20 |
+
print(f"Error parsing line: {e}")
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
if not lengths:
|
| 24 |
+
print(f"No valid assistant responses found in {file_path}")
|
| 25 |
+
return {}
|
| 26 |
+
|
| 27 |
+
# Calculate statistics
|
| 28 |
+
max_length = max(lengths)
|
| 29 |
+
avg_length = np.mean(lengths)
|
| 30 |
+
median_length = np.median(lengths)
|
| 31 |
+
|
| 32 |
+
# Calculate length distribution with more detailed ranges
|
| 33 |
+
length_ranges = {
|
| 34 |
+
'0-100': 0,
|
| 35 |
+
'101-500': 0,
|
| 36 |
+
'501-1000': 0,
|
| 37 |
+
'1001-2000': 0,
|
| 38 |
+
'2001-3000': 0,
|
| 39 |
+
'3001-4000': 0,
|
| 40 |
+
'4001-5000': 0,
|
| 41 |
+
'5001-6000': 0,
|
| 42 |
+
'6000+': 0
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
for length in lengths:
|
| 46 |
+
if length <= 100:
|
| 47 |
+
length_ranges['0-100'] += 1
|
| 48 |
+
elif length <= 500:
|
| 49 |
+
length_ranges['101-500'] += 1
|
| 50 |
+
elif length <= 1000:
|
| 51 |
+
length_ranges['501-1000'] += 1
|
| 52 |
+
elif length <= 2000:
|
| 53 |
+
length_ranges['1001-2000'] += 1
|
| 54 |
+
elif length <= 3000:
|
| 55 |
+
length_ranges['2001-3000'] += 1
|
| 56 |
+
elif length <= 4000:
|
| 57 |
+
length_ranges['3001-4000'] += 1
|
| 58 |
+
elif length <= 5000:
|
| 59 |
+
length_ranges['4001-5000'] += 1
|
| 60 |
+
elif length <= 6000:
|
| 61 |
+
length_ranges['5001-6000'] += 1
|
| 62 |
+
else:
|
| 63 |
+
length_ranges['6000+'] += 1
|
| 64 |
+
|
| 65 |
+
# Calculate percentages
|
| 66 |
+
total = len(lengths)
|
| 67 |
+
percentages = {k: (v/total)*100 for k, v in length_ranges.items()}
|
| 68 |
+
|
| 69 |
+
# Print results
|
| 70 |
+
print(f"\nAnalysis Results for {file_path}:")
|
| 71 |
+
print(f"Total number of assistant responses: {total}")
|
| 72 |
+
print(f"Maximum length: {max_length} characters")
|
| 73 |
+
print(f"Average length: {avg_length:.2f} characters")
|
| 74 |
+
print(f"Median length: {median_length:.2f} characters")
|
| 75 |
+
print("\nLength Distribution:")
|
| 76 |
+
for range_name, percentage in percentages.items():
|
| 77 |
+
print(f"{range_name}: {percentage:.2f}%")
|
| 78 |
+
|
| 79 |
+
# Create a histogram with more bins for better visualization
|
| 80 |
+
plt.figure(figsize=(12, 6))
|
| 81 |
+
plt.hist(lengths, bins=100, edgecolor='black')
|
| 82 |
+
plt.title('Distribution of Assistant Response Lengths')
|
| 83 |
+
plt.xlabel('Length (characters)')
|
| 84 |
+
plt.ylabel('Frequency')
|
| 85 |
+
plt.savefig('dialogue_length_distribution.png')
|
| 86 |
+
plt.close()
|
| 87 |
+
|
| 88 |
+
# Create a bar chart for the ranges
|
| 89 |
+
plt.figure(figsize=(12, 6))
|
| 90 |
+
ranges = list(length_ranges.keys())
|
| 91 |
+
counts = list(length_ranges.values())
|
| 92 |
+
plt.bar(ranges, counts)
|
| 93 |
+
plt.title('Distribution of Response Lengths by Range')
|
| 94 |
+
plt.xlabel('Length Range')
|
| 95 |
+
plt.ylabel('Count')
|
| 96 |
+
plt.xticks(rotation=45)
|
| 97 |
+
plt.tight_layout()
|
| 98 |
+
plt.savefig('dialogue_length_ranges.png')
|
| 99 |
+
plt.close()
|
| 100 |
+
|
| 101 |
+
return {
|
| 102 |
+
'total_responses': total,
|
| 103 |
+
'max_length': max_length,
|
| 104 |
+
'avg_length': avg_length,
|
| 105 |
+
'median_length': median_length,
|
| 106 |
+
'distribution': percentages
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
# Analyze both train and test datasets
|
| 111 |
+
train_results = analyze_dialogue_lengths('dataset_cotSFTtrain.json')
|
| 112 |
+
test_results = analyze_dialogue_lengths('dataset_cotSFTtest.json')
|
compare_scores.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
|
| 5 |
+
infer_result_path = '/root/autodl-tmp/output_7B_GRPO/v28-20250722-002940/checkpoint-870/infer_result/53_HH.jsonl'
|
| 6 |
+
test_path = '/root/autodl-tmp/ms-swift/all_audio_test_50.jsonl'
|
| 7 |
+
output_path = 'inference_comparison_result.json'
|
| 8 |
+
|
| 9 |
+
def extract_overall_score(response_text):
|
| 10 |
+
match = re.search(r'<overall score>(\d+)</overall score>', response_text)
|
| 11 |
+
if match:
|
| 12 |
+
return int(match.group(1))
|
| 13 |
+
return None
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
# 读取infer_result文件,建立audio到score的映射
|
| 17 |
+
infer_audio2score = {}
|
| 18 |
+
with open(infer_result_path, 'r', encoding='utf-8') as f:
|
| 19 |
+
for line in f:
|
| 20 |
+
data = json.loads(line)
|
| 21 |
+
score = extract_overall_score(data['response'])
|
| 22 |
+
audios = tuple(data.get('audios', []))
|
| 23 |
+
infer_audio2score[audios] = {
|
| 24 |
+
'score': score,
|
| 25 |
+
'raw_response': data['response']
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
# 读取test文件,建立audio到solution的映射
|
| 29 |
+
test_audio2solution = {}
|
| 30 |
+
with open(test_path, 'r', encoding='utf-8') as f:
|
| 31 |
+
for line in f:
|
| 32 |
+
data = json.loads(line)
|
| 33 |
+
solution = data['solution']
|
| 34 |
+
audios = tuple(data.get('audios', []))
|
| 35 |
+
test_audio2solution[audios] = solution
|
| 36 |
+
|
| 37 |
+
# 统计和收集错误样本 & 所有推理结果
|
| 38 |
+
stats_per_class = defaultdict(lambda: {'correct': 0, 'incorrect': 0})
|
| 39 |
+
incorrect_samples_solution1 = []
|
| 40 |
+
all_results = []
|
| 41 |
+
|
| 42 |
+
total = 0
|
| 43 |
+
correct = 0
|
| 44 |
+
|
| 45 |
+
for audios, solution in test_audio2solution.items():
|
| 46 |
+
infer_entry = infer_audio2score.get(audios, None)
|
| 47 |
+
infer_score = infer_entry['score'] if infer_entry else None
|
| 48 |
+
raw_response = infer_entry['raw_response'] if infer_entry else None
|
| 49 |
+
match = infer_score == solution
|
| 50 |
+
|
| 51 |
+
# 收集所有结果
|
| 52 |
+
all_results.append({
|
| 53 |
+
'audios': audios,
|
| 54 |
+
'gt_solution': solution,
|
| 55 |
+
'predicted_score': infer_score,
|
| 56 |
+
'match': match,
|
| 57 |
+
'response': raw_response
|
| 58 |
+
})
|
| 59 |
+
|
| 60 |
+
if match:
|
| 61 |
+
correct += 1
|
| 62 |
+
stats_per_class[solution]['correct'] += 1
|
| 63 |
+
else:
|
| 64 |
+
stats_per_class[solution]['incorrect'] += 1
|
| 65 |
+
if solution == 1:
|
| 66 |
+
incorrect_samples_solution1.append({
|
| 67 |
+
'audios': audios,
|
| 68 |
+
'gt_solution': solution,
|
| 69 |
+
'predicted_score': infer_score,
|
| 70 |
+
'response': raw_response
|
| 71 |
+
})
|
| 72 |
+
|
| 73 |
+
total += 1
|
| 74 |
+
|
| 75 |
+
# 总体准确率
|
| 76 |
+
print(f'\nOverall Accuracy: {correct}/{total} = {correct/total:.2%}\n')
|
| 77 |
+
|
| 78 |
+
# 每类准确率
|
| 79 |
+
print("Per-Class Accuracy:")
|
| 80 |
+
for solution, stats in sorted(stats_per_class.items()):
|
| 81 |
+
total_class = stats['correct'] + stats['incorrect']
|
| 82 |
+
accuracy = stats['correct'] / total_class if total_class > 0 else 0.0
|
| 83 |
+
print(f'Class {solution}: Correct={stats["correct"]}, Incorrect={stats["incorrect"]}, Accuracy={accuracy:.2%}')
|
| 84 |
+
|
| 85 |
+
# 列出 solution=1 且预测错误的样本
|
| 86 |
+
print("\nIncorrect Samples for solution = 1:")
|
| 87 |
+
for sample in incorrect_samples_solution1:
|
| 88 |
+
print(json.dumps(sample, indent=2, ensure_ascii=False))
|
| 89 |
+
|
| 90 |
+
# 写入所有结果到 JSON 文件
|
| 91 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 92 |
+
json.dump(all_results, f, indent=2, ensure_ascii=False)
|
| 93 |
+
print(f"\nAll inference comparison results saved to: {output_path}")
|
| 94 |
+
|
| 95 |
+
if __name__ == '__main__':
|
| 96 |
+
main()
|
docs/resources/kto_data.png
ADDED
|
Git LFS Details
|
docs/resources/web-ui-en.jpg
ADDED
|
Git LFS Details
|
docs/transformers/build/lib/transformers/models/clip/configuration_clip.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""CLIP model configuration"""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
from ...processing_utils import ProcessorMixin
|
| 23 |
+
from ...utils import TensorType
|
| 24 |
+
|
| 25 |
+
from ...configuration_utils import PretrainedConfig
|
| 26 |
+
from ...onnx import OnnxConfig
|
| 27 |
+
from ...utils import logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class CLIPTextConfig(PretrainedConfig):
|
| 34 |
+
r"""
|
| 35 |
+
This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP
|
| 36 |
+
text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 37 |
+
with the defaults will yield a similar configuration to that of the text encoder of the CLIP
|
| 38 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
| 39 |
+
|
| 40 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 41 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
vocab_size (`int`, *optional*, defaults to 49408):
|
| 45 |
+
Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by
|
| 46 |
+
the `inputs_ids` passed when calling [`CLIPModel`].
|
| 47 |
+
hidden_size (`int`, *optional*, defaults to 512):
|
| 48 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 49 |
+
intermediate_size (`int`, *optional*, defaults to 2048):
|
| 50 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 51 |
+
projection_dim (`int`, *optional*, defaults to 512):
|
| 52 |
+
Dimensionality of text and vision projection layers.
|
| 53 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 54 |
+
Number of hidden layers in the Transformer encoder.
|
| 55 |
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
| 56 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 57 |
+
max_position_embeddings (`int`, *optional*, defaults to 77):
|
| 58 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 59 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 60 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
| 61 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 62 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
| 63 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 64 |
+
The epsilon used by the layer normalization layers.
|
| 65 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 66 |
+
The dropout ratio for the attention probabilities.
|
| 67 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 68 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 69 |
+
initializer_factor (`float`, *optional*, defaults to 1.0):
|
| 70 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
| 71 |
+
testing).
|
| 72 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 73 |
+
Padding token id.
|
| 74 |
+
bos_token_id (`int`, *optional*, defaults to 49406):
|
| 75 |
+
Beginning of stream token id.
|
| 76 |
+
eos_token_id (`int`, *optional*, defaults to 49407):
|
| 77 |
+
End of stream token id.
|
| 78 |
+
|
| 79 |
+
Example:
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
>>> from transformers import CLIPTextConfig, CLIPTextModel
|
| 83 |
+
|
| 84 |
+
>>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration
|
| 85 |
+
>>> configuration = CLIPTextConfig()
|
| 86 |
+
|
| 87 |
+
>>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
| 88 |
+
>>> model = CLIPTextModel(configuration)
|
| 89 |
+
|
| 90 |
+
>>> # Accessing the model configuration
|
| 91 |
+
>>> configuration = model.config
|
| 92 |
+
```"""
|
| 93 |
+
|
| 94 |
+
model_type = "clip_text_model"
|
| 95 |
+
base_config_key = "text_config"
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
vocab_size=49408,
|
| 100 |
+
hidden_size=512,
|
| 101 |
+
intermediate_size=2048,
|
| 102 |
+
projection_dim=512,
|
| 103 |
+
num_hidden_layers=12,
|
| 104 |
+
num_attention_heads=8,
|
| 105 |
+
max_position_embeddings=77,
|
| 106 |
+
hidden_act="quick_gelu",
|
| 107 |
+
layer_norm_eps=1e-5,
|
| 108 |
+
attention_dropout=0.0,
|
| 109 |
+
initializer_range=0.02,
|
| 110 |
+
initializer_factor=1.0,
|
| 111 |
+
# This differs from `CLIPTokenizer`'s default and from openai/clip
|
| 112 |
+
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
|
| 113 |
+
pad_token_id=1,
|
| 114 |
+
bos_token_id=49406,
|
| 115 |
+
eos_token_id=49407,
|
| 116 |
+
**kwargs,
|
| 117 |
+
):
|
| 118 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 119 |
+
|
| 120 |
+
self.vocab_size = vocab_size
|
| 121 |
+
self.hidden_size = hidden_size
|
| 122 |
+
self.intermediate_size = intermediate_size
|
| 123 |
+
self.projection_dim = projection_dim
|
| 124 |
+
self.num_hidden_layers = num_hidden_layers
|
| 125 |
+
self.num_attention_heads = num_attention_heads
|
| 126 |
+
self.max_position_embeddings = max_position_embeddings
|
| 127 |
+
self.layer_norm_eps = layer_norm_eps
|
| 128 |
+
self.hidden_act = hidden_act
|
| 129 |
+
self.initializer_range = initializer_range
|
| 130 |
+
self.initializer_factor = initializer_factor
|
| 131 |
+
self.attention_dropout = attention_dropout
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class CLIPVisionConfig(PretrainedConfig):
|
| 135 |
+
r"""
|
| 136 |
+
This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
|
| 137 |
+
CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
| 138 |
+
configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
|
| 139 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
| 140 |
+
|
| 141 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 142 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 146 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 147 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 148 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 149 |
+
projection_dim (`int`, *optional*, defaults to 512):
|
| 150 |
+
Dimensionality of text and vision projection layers.
|
| 151 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 152 |
+
Number of hidden layers in the Transformer encoder.
|
| 153 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 154 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 155 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 156 |
+
The number of input channels.
|
| 157 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 158 |
+
The size (resolution) of each image.
|
| 159 |
+
patch_size (`int`, *optional*, defaults to 32):
|
| 160 |
+
The size (resolution) of each patch.
|
| 161 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
| 162 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 163 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
| 164 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 165 |
+
The epsilon used by the layer normalization layers.
|
| 166 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 167 |
+
The dropout ratio for the attention probabilities.
|
| 168 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 169 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 170 |
+
initializer_factor (`float`, *optional*, defaults to 1.0):
|
| 171 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
| 172 |
+
testing).
|
| 173 |
+
|
| 174 |
+
Example:
|
| 175 |
+
|
| 176 |
+
```python
|
| 177 |
+
>>> from transformers import CLIPVisionConfig, CLIPVisionModel
|
| 178 |
+
|
| 179 |
+
>>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
|
| 180 |
+
>>> configuration = CLIPVisionConfig()
|
| 181 |
+
|
| 182 |
+
>>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
| 183 |
+
>>> model = CLIPVisionModel(configuration)
|
| 184 |
+
|
| 185 |
+
>>> # Accessing the model configuration
|
| 186 |
+
>>> configuration = model.config
|
| 187 |
+
```"""
|
| 188 |
+
|
| 189 |
+
model_type = "clip_vision_model"
|
| 190 |
+
base_config_key = "vision_config"
|
| 191 |
+
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
hidden_size=768,
|
| 195 |
+
intermediate_size=3072,
|
| 196 |
+
projection_dim=512,
|
| 197 |
+
num_hidden_layers=12,
|
| 198 |
+
num_attention_heads=12,
|
| 199 |
+
num_channels=3,
|
| 200 |
+
image_size=224,
|
| 201 |
+
patch_size=32,
|
| 202 |
+
hidden_act="quick_gelu",
|
| 203 |
+
layer_norm_eps=1e-5,
|
| 204 |
+
attention_dropout=0.0,
|
| 205 |
+
initializer_range=0.02,
|
| 206 |
+
initializer_factor=1.0,
|
| 207 |
+
**kwargs,
|
| 208 |
+
):
|
| 209 |
+
super().__init__(**kwargs)
|
| 210 |
+
|
| 211 |
+
self.hidden_size = hidden_size
|
| 212 |
+
self.intermediate_size = intermediate_size
|
| 213 |
+
self.projection_dim = projection_dim
|
| 214 |
+
self.num_hidden_layers = num_hidden_layers
|
| 215 |
+
self.num_attention_heads = num_attention_heads
|
| 216 |
+
self.num_channels = num_channels
|
| 217 |
+
self.patch_size = patch_size
|
| 218 |
+
self.image_size = image_size
|
| 219 |
+
self.initializer_range = initializer_range
|
| 220 |
+
self.initializer_factor = initializer_factor
|
| 221 |
+
self.attention_dropout = attention_dropout
|
| 222 |
+
self.layer_norm_eps = layer_norm_eps
|
| 223 |
+
self.hidden_act = hidden_act
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class CLIPConfig(PretrainedConfig):
|
| 227 |
+
r"""
|
| 228 |
+
[`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate
|
| 229 |
+
a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating
|
| 230 |
+
a configuration with the defaults will yield a similar configuration to that of the CLIP
|
| 231 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
| 232 |
+
|
| 233 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 234 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
text_config (`dict`, *optional*):
|
| 238 |
+
Dictionary of configuration options used to initialize [`CLIPTextConfig`].
|
| 239 |
+
vision_config (`dict`, *optional*):
|
| 240 |
+
Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
|
| 241 |
+
projection_dim (`int`, *optional*, defaults to 512):
|
| 242 |
+
Dimensionality of text and vision projection layers.
|
| 243 |
+
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
| 244 |
+
The initial value of the *logit_scale* parameter. Default is used as per the original CLIP implementation.
|
| 245 |
+
kwargs (*optional*):
|
| 246 |
+
Dictionary of keyword arguments.
|
| 247 |
+
|
| 248 |
+
Example:
|
| 249 |
+
|
| 250 |
+
```python
|
| 251 |
+
>>> from transformers import CLIPConfig, CLIPModel
|
| 252 |
+
|
| 253 |
+
>>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration
|
| 254 |
+
>>> configuration = CLIPConfig()
|
| 255 |
+
|
| 256 |
+
>>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
| 257 |
+
>>> model = CLIPModel(configuration)
|
| 258 |
+
|
| 259 |
+
>>> # Accessing the model configuration
|
| 260 |
+
>>> configuration = model.config
|
| 261 |
+
|
| 262 |
+
>>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig
|
| 263 |
+
>>> from transformers import CLIPTextConfig, CLIPVisionConfig
|
| 264 |
+
|
| 265 |
+
>>> # Initializing a CLIPText and CLIPVision configuration
|
| 266 |
+
>>> config_text = CLIPTextConfig()
|
| 267 |
+
>>> config_vision = CLIPVisionConfig()
|
| 268 |
+
|
| 269 |
+
>>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
|
| 270 |
+
```"""
|
| 271 |
+
|
| 272 |
+
model_type = "clip"
|
| 273 |
+
sub_configs = {"text_config": CLIPTextConfig, "vision_config": CLIPVisionConfig}
|
| 274 |
+
|
| 275 |
+
def __init__(
|
| 276 |
+
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
|
| 277 |
+
):
|
| 278 |
+
# If `_config_dict` exist, we use them for the backward compatibility.
|
| 279 |
+
# We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
|
| 280 |
+
# of confusion!).
|
| 281 |
+
text_config_dict = kwargs.pop("text_config_dict", None)
|
| 282 |
+
vision_config_dict = kwargs.pop("vision_config_dict", None)
|
| 283 |
+
|
| 284 |
+
super().__init__(**kwargs)
|
| 285 |
+
|
| 286 |
+
# Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
|
| 287 |
+
# `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
|
| 288 |
+
# cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
|
| 289 |
+
if text_config_dict is not None:
|
| 290 |
+
if text_config is None:
|
| 291 |
+
text_config = {}
|
| 292 |
+
|
| 293 |
+
# This is the complete result when using `text_config_dict`.
|
| 294 |
+
_text_config_dict = CLIPTextConfig(**text_config_dict).to_dict()
|
| 295 |
+
|
| 296 |
+
# Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
|
| 297 |
+
for key, value in _text_config_dict.items():
|
| 298 |
+
if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
|
| 299 |
+
# If specified in `text_config_dict`
|
| 300 |
+
if key in text_config_dict:
|
| 301 |
+
message = (
|
| 302 |
+
f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
|
| 303 |
+
f'The value `text_config_dict["{key}"]` will be used instead.'
|
| 304 |
+
)
|
| 305 |
+
# If inferred from default argument values (just to be super careful)
|
| 306 |
+
else:
|
| 307 |
+
message = (
|
| 308 |
+
f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The "
|
| 309 |
+
f'value `text_config["{key}"]` will be overridden.'
|
| 310 |
+
)
|
| 311 |
+
logger.info(message)
|
| 312 |
+
|
| 313 |
+
# Update all values in `text_config` with the ones in `_text_config_dict`.
|
| 314 |
+
text_config.update(_text_config_dict)
|
| 315 |
+
|
| 316 |
+
if vision_config_dict is not None:
|
| 317 |
+
if vision_config is None:
|
| 318 |
+
vision_config = {}
|
| 319 |
+
|
| 320 |
+
# This is the complete result when using `vision_config_dict`.
|
| 321 |
+
_vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict()
|
| 322 |
+
# convert keys to string instead of integer
|
| 323 |
+
if "id2label" in _vision_config_dict:
|
| 324 |
+
_vision_config_dict["id2label"] = {
|
| 325 |
+
str(key): value for key, value in _vision_config_dict["id2label"].items()
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
# Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
|
| 329 |
+
for key, value in _vision_config_dict.items():
|
| 330 |
+
if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
|
| 331 |
+
# If specified in `vision_config_dict`
|
| 332 |
+
if key in vision_config_dict:
|
| 333 |
+
message = (
|
| 334 |
+
f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
|
| 335 |
+
f'values. The value `vision_config_dict["{key}"]` will be used instead.'
|
| 336 |
+
)
|
| 337 |
+
# If inferred from default argument values (just to be super careful)
|
| 338 |
+
else:
|
| 339 |
+
message = (
|
| 340 |
+
f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. "
|
| 341 |
+
f'The value `vision_config["{key}"]` will be overridden.'
|
| 342 |
+
)
|
| 343 |
+
logger.info(message)
|
| 344 |
+
|
| 345 |
+
# Update all values in `vision_config` with the ones in `_vision_config_dict`.
|
| 346 |
+
vision_config.update(_vision_config_dict)
|
| 347 |
+
|
| 348 |
+
if text_config is None:
|
| 349 |
+
text_config = {}
|
| 350 |
+
logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")
|
| 351 |
+
|
| 352 |
+
if vision_config is None:
|
| 353 |
+
vision_config = {}
|
| 354 |
+
logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
|
| 355 |
+
|
| 356 |
+
self.text_config = CLIPTextConfig(**text_config)
|
| 357 |
+
self.vision_config = CLIPVisionConfig(**vision_config)
|
| 358 |
+
|
| 359 |
+
self.projection_dim = projection_dim
|
| 360 |
+
self.logit_scale_init_value = logit_scale_init_value
|
| 361 |
+
self.initializer_factor = 1.0
|
| 362 |
+
|
| 363 |
+
@classmethod
|
| 364 |
+
def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs):
|
| 365 |
+
r"""
|
| 366 |
+
Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
|
| 367 |
+
configuration.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
[`CLIPConfig`]: An instance of a configuration object
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class CLIPOnnxConfig(OnnxConfig):
|
| 377 |
+
@property
|
| 378 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 379 |
+
return OrderedDict(
|
| 380 |
+
[
|
| 381 |
+
("input_ids", {0: "batch", 1: "sequence"}),
|
| 382 |
+
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
| 383 |
+
("attention_mask", {0: "batch", 1: "sequence"}),
|
| 384 |
+
]
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
@property
|
| 388 |
+
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 389 |
+
return OrderedDict(
|
| 390 |
+
[
|
| 391 |
+
("logits_per_image", {0: "batch"}),
|
| 392 |
+
("logits_per_text", {0: "batch"}),
|
| 393 |
+
("text_embeds", {0: "batch"}),
|
| 394 |
+
("image_embeds", {0: "batch"}),
|
| 395 |
+
]
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
@property
|
| 399 |
+
def atol_for_validation(self) -> float:
|
| 400 |
+
return 1e-4
|
| 401 |
+
|
| 402 |
+
def generate_dummy_inputs(
|
| 403 |
+
self,
|
| 404 |
+
processor: "ProcessorMixin",
|
| 405 |
+
batch_size: int = -1,
|
| 406 |
+
seq_length: int = -1,
|
| 407 |
+
framework: Optional["TensorType"] = None,
|
| 408 |
+
) -> Mapping[str, Any]:
|
| 409 |
+
text_input_dict = super().generate_dummy_inputs(
|
| 410 |
+
processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
|
| 411 |
+
)
|
| 412 |
+
image_input_dict = super().generate_dummy_inputs(
|
| 413 |
+
processor.image_processor, batch_size=batch_size, framework=framework
|
| 414 |
+
)
|
| 415 |
+
return {**text_input_dict, **image_input_dict}
|
| 416 |
+
|
| 417 |
+
@property
|
| 418 |
+
def default_onnx_opset(self) -> int:
|
| 419 |
+
return 14
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
__all__ = ["CLIPConfig", "CLIPOnnxConfig", "CLIPTextConfig", "CLIPVisionConfig"]
|
docs/transformers/build/lib/transformers/models/clip/feature_extraction_clip.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Feature extractor class for CLIP."""
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
from ...utils import logging
|
| 20 |
+
from ...utils.import_utils import requires
|
| 21 |
+
from .image_processing_clip import CLIPImageProcessor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@requires(backends=("vision",))
|
| 28 |
+
class CLIPFeatureExtractor(CLIPImageProcessor):
|
| 29 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 30 |
+
warnings.warn(
|
| 31 |
+
"The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
|
| 32 |
+
" use CLIPImageProcessor instead.",
|
| 33 |
+
FutureWarning,
|
| 34 |
+
)
|
| 35 |
+
super().__init__(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
__all__ = ["CLIPFeatureExtractor"]
|
docs/transformers/build/lib/transformers/models/clip/image_processing_clip.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for CLIP."""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 22 |
+
from ...image_transforms import (
|
| 23 |
+
convert_to_rgb,
|
| 24 |
+
get_resize_output_image_size,
|
| 25 |
+
resize,
|
| 26 |
+
to_channel_dimension_format,
|
| 27 |
+
)
|
| 28 |
+
from ...image_utils import (
|
| 29 |
+
OPENAI_CLIP_MEAN,
|
| 30 |
+
OPENAI_CLIP_STD,
|
| 31 |
+
ChannelDimension,
|
| 32 |
+
ImageInput,
|
| 33 |
+
PILImageResampling,
|
| 34 |
+
infer_channel_dimension_format,
|
| 35 |
+
is_scaled_image,
|
| 36 |
+
make_flat_list_of_images,
|
| 37 |
+
to_numpy_array,
|
| 38 |
+
valid_images,
|
| 39 |
+
validate_kwargs,
|
| 40 |
+
validate_preprocess_arguments,
|
| 41 |
+
)
|
| 42 |
+
from ...utils import TensorType, is_vision_available, logging
|
| 43 |
+
from ...utils.import_utils import requires
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if is_vision_available():
|
| 50 |
+
import PIL
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@requires(backends=("vision",))
|
| 54 |
+
class CLIPImageProcessor(BaseImageProcessor):
|
| 55 |
+
r"""
|
| 56 |
+
Constructs a CLIP image processor.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 60 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
| 61 |
+
`do_resize` in the `preprocess` method.
|
| 62 |
+
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
|
| 63 |
+
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
|
| 64 |
+
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
|
| 65 |
+
method.
|
| 66 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
| 67 |
+
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
| 68 |
+
do_center_crop (`bool`, *optional*, defaults to `True`):
|
| 69 |
+
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
|
| 70 |
+
`preprocess` method.
|
| 71 |
+
crop_size (`Dict[str, int]` *optional*, defaults to 224):
|
| 72 |
+
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
|
| 73 |
+
method.
|
| 74 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 75 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
| 76 |
+
the `preprocess` method.
|
| 77 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 78 |
+
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
| 79 |
+
method.
|
| 80 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 81 |
+
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
| 82 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
| 83 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 84 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 85 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
| 86 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 87 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 88 |
+
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 89 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 90 |
+
Whether to convert the image to RGB.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
model_input_names = ["pixel_values"]
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
do_resize: bool = True,
|
| 98 |
+
size: Dict[str, int] = None,
|
| 99 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 100 |
+
do_center_crop: bool = True,
|
| 101 |
+
crop_size: Dict[str, int] = None,
|
| 102 |
+
do_rescale: bool = True,
|
| 103 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 104 |
+
do_normalize: bool = True,
|
| 105 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 106 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 107 |
+
do_convert_rgb: bool = True,
|
| 108 |
+
**kwargs,
|
| 109 |
+
) -> None:
|
| 110 |
+
super().__init__(**kwargs)
|
| 111 |
+
size = size if size is not None else {"shortest_edge": 224}
|
| 112 |
+
size = get_size_dict(size, default_to_square=False)
|
| 113 |
+
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
| 114 |
+
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
|
| 115 |
+
|
| 116 |
+
self.do_resize = do_resize
|
| 117 |
+
self.size = size
|
| 118 |
+
self.resample = resample
|
| 119 |
+
self.do_center_crop = do_center_crop
|
| 120 |
+
self.crop_size = crop_size
|
| 121 |
+
self.do_rescale = do_rescale
|
| 122 |
+
self.rescale_factor = rescale_factor
|
| 123 |
+
self.do_normalize = do_normalize
|
| 124 |
+
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
| 125 |
+
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
| 126 |
+
self.do_convert_rgb = do_convert_rgb
|
| 127 |
+
self._valid_processor_keys = [
|
| 128 |
+
"images",
|
| 129 |
+
"do_resize",
|
| 130 |
+
"size",
|
| 131 |
+
"resample",
|
| 132 |
+
"do_center_crop",
|
| 133 |
+
"crop_size",
|
| 134 |
+
"do_rescale",
|
| 135 |
+
"rescale_factor",
|
| 136 |
+
"do_normalize",
|
| 137 |
+
"image_mean",
|
| 138 |
+
"image_std",
|
| 139 |
+
"do_convert_rgb",
|
| 140 |
+
"return_tensors",
|
| 141 |
+
"data_format",
|
| 142 |
+
"input_data_format",
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
# for backwards compatibility of KOSMOS-2
|
| 146 |
+
if "use_square_size" in kwargs and kwargs["use_square_size"]:
|
| 147 |
+
self.size = {"height": size["shortest_edge"], "width": size["shortest_edge"]}
|
| 148 |
+
# Let's remove `use_square_size` (as it is removed from #27690), so the future Kosmos-2 image processors
|
| 149 |
+
# won't have this attr. being saved. (otherwise, it will enter this if branch while there is no more
|
| 150 |
+
# `shortest_edge` key.
|
| 151 |
+
delattr(self, "use_square_size")
|
| 152 |
+
|
| 153 |
+
def resize(
|
| 154 |
+
self,
|
| 155 |
+
image: np.ndarray,
|
| 156 |
+
size: Dict[str, int],
|
| 157 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 158 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 159 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 160 |
+
**kwargs,
|
| 161 |
+
) -> np.ndarray:
|
| 162 |
+
"""
|
| 163 |
+
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
|
| 164 |
+
resized to keep the input aspect ratio.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
image (`np.ndarray`):
|
| 168 |
+
Image to resize.
|
| 169 |
+
size (`Dict[str, int]`):
|
| 170 |
+
Size of the output image.
|
| 171 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
| 172 |
+
Resampling filter to use when resiizing the image.
|
| 173 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 174 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
| 175 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 176 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 177 |
+
"""
|
| 178 |
+
default_to_square = True
|
| 179 |
+
if "shortest_edge" in size:
|
| 180 |
+
size = size["shortest_edge"]
|
| 181 |
+
default_to_square = False
|
| 182 |
+
elif "height" in size and "width" in size:
|
| 183 |
+
size = (size["height"], size["width"])
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
|
| 186 |
+
|
| 187 |
+
output_size = get_resize_output_image_size(
|
| 188 |
+
image,
|
| 189 |
+
size=size,
|
| 190 |
+
default_to_square=default_to_square,
|
| 191 |
+
input_data_format=input_data_format,
|
| 192 |
+
)
|
| 193 |
+
return resize(
|
| 194 |
+
image,
|
| 195 |
+
size=output_size,
|
| 196 |
+
resample=resample,
|
| 197 |
+
data_format=data_format,
|
| 198 |
+
input_data_format=input_data_format,
|
| 199 |
+
**kwargs,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
def preprocess(
|
| 203 |
+
self,
|
| 204 |
+
images: ImageInput,
|
| 205 |
+
do_resize: Optional[bool] = None,
|
| 206 |
+
size: Dict[str, int] = None,
|
| 207 |
+
resample: PILImageResampling = None,
|
| 208 |
+
do_center_crop: Optional[bool] = None,
|
| 209 |
+
crop_size: Optional[int] = None,
|
| 210 |
+
do_rescale: Optional[bool] = None,
|
| 211 |
+
rescale_factor: Optional[float] = None,
|
| 212 |
+
do_normalize: Optional[bool] = None,
|
| 213 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 214 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 215 |
+
do_convert_rgb: Optional[bool] = None,
|
| 216 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 217 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 218 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 219 |
+
**kwargs,
|
| 220 |
+
) -> PIL.Image.Image:
|
| 221 |
+
"""
|
| 222 |
+
Preprocess an image or batch of images.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
images (`ImageInput`):
|
| 226 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 227 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 228 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 229 |
+
Whether to resize the image.
|
| 230 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 231 |
+
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
| 232 |
+
the longest edge resized to keep the input aspect ratio.
|
| 233 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 234 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 235 |
+
has an effect if `do_resize` is set to `True`.
|
| 236 |
+
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
| 237 |
+
Whether to center crop the image.
|
| 238 |
+
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
| 239 |
+
Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
|
| 240 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 241 |
+
Whether to rescale the image.
|
| 242 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 243 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 244 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 245 |
+
Whether to normalize the image.
|
| 246 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 247 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 248 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 249 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
| 250 |
+
`True`.
|
| 251 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 252 |
+
Whether to convert the image to RGB.
|
| 253 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 254 |
+
The type of tensors to return. Can be one of:
|
| 255 |
+
- Unset: Return a list of `np.ndarray`.
|
| 256 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 257 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 258 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 259 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 260 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 261 |
+
The channel dimension format for the output image. Can be one of:
|
| 262 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 263 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 264 |
+
- Unset: Use the channel dimension format of the input image.
|
| 265 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 266 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 267 |
+
from the input image. Can be one of:
|
| 268 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 269 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 270 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 271 |
+
"""
|
| 272 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 273 |
+
size = size if size is not None else self.size
|
| 274 |
+
size = get_size_dict(size, param_name="size", default_to_square=False)
|
| 275 |
+
resample = resample if resample is not None else self.resample
|
| 276 |
+
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
| 277 |
+
crop_size = crop_size if crop_size is not None else self.crop_size
|
| 278 |
+
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
|
| 279 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 280 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 281 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 282 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 283 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 284 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 285 |
+
|
| 286 |
+
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
| 287 |
+
|
| 288 |
+
images = make_flat_list_of_images(images)
|
| 289 |
+
|
| 290 |
+
if not valid_images(images):
|
| 291 |
+
raise ValueError(
|
| 292 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 293 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 294 |
+
)
|
| 295 |
+
validate_preprocess_arguments(
|
| 296 |
+
do_rescale=do_rescale,
|
| 297 |
+
rescale_factor=rescale_factor,
|
| 298 |
+
do_normalize=do_normalize,
|
| 299 |
+
image_mean=image_mean,
|
| 300 |
+
image_std=image_std,
|
| 301 |
+
do_center_crop=do_center_crop,
|
| 302 |
+
crop_size=crop_size,
|
| 303 |
+
do_resize=do_resize,
|
| 304 |
+
size=size,
|
| 305 |
+
resample=resample,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if do_convert_rgb:
|
| 309 |
+
images = [convert_to_rgb(image) for image in images]
|
| 310 |
+
|
| 311 |
+
# All transformations expect numpy arrays.
|
| 312 |
+
images = [to_numpy_array(image) for image in images]
|
| 313 |
+
|
| 314 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 315 |
+
logger.warning_once(
|
| 316 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 317 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
if input_data_format is None:
|
| 321 |
+
# We assume that all images have the same channel dimension format.
|
| 322 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 323 |
+
|
| 324 |
+
all_images = []
|
| 325 |
+
for image in images:
|
| 326 |
+
if do_resize:
|
| 327 |
+
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
| 328 |
+
|
| 329 |
+
if do_center_crop:
|
| 330 |
+
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
|
| 331 |
+
|
| 332 |
+
if do_rescale:
|
| 333 |
+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 334 |
+
|
| 335 |
+
if do_normalize:
|
| 336 |
+
image = self.normalize(
|
| 337 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
all_images.append(image)
|
| 341 |
+
images = [
|
| 342 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 343 |
+
for image in all_images
|
| 344 |
+
]
|
| 345 |
+
|
| 346 |
+
data = {"pixel_values": images}
|
| 347 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
__all__ = ["CLIPImageProcessor"]
|
docs/transformers/build/lib/transformers/models/clip/image_processing_clip_fast.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fast Image processor class for CLIP."""
|
| 16 |
+
|
| 17 |
+
from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast
|
| 18 |
+
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
|
| 19 |
+
from ...utils import add_start_docstrings
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@add_start_docstrings(
|
| 23 |
+
"Constructs a fast CLIP image processor.",
|
| 24 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
| 25 |
+
)
|
| 26 |
+
class CLIPImageProcessorFast(BaseImageProcessorFast):
|
| 27 |
+
# To be checked against the slow image processor
|
| 28 |
+
# None values left after checking can be removed
|
| 29 |
+
resample = PILImageResampling.BICUBIC
|
| 30 |
+
image_mean = OPENAI_CLIP_MEAN
|
| 31 |
+
image_std = OPENAI_CLIP_STD
|
| 32 |
+
size = {"shortest_edge": 224}
|
| 33 |
+
default_to_square = False
|
| 34 |
+
crop_size = {"height": 224, "width": 224}
|
| 35 |
+
do_resize = True
|
| 36 |
+
do_center_crop = True
|
| 37 |
+
do_rescale = True
|
| 38 |
+
do_normalize = True
|
| 39 |
+
do_convert_rgb = True
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
__all__ = ["CLIPImageProcessorFast"]
|
docs/transformers/build/lib/transformers/models/clip/modeling_clip.py
ADDED
|
@@ -0,0 +1,1473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch CLIP model."""
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch import nn
|
| 22 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 23 |
+
|
| 24 |
+
from ...activations import ACT2FN
|
| 25 |
+
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
| 26 |
+
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
| 27 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 28 |
+
from ...utils import (
|
| 29 |
+
ModelOutput,
|
| 30 |
+
add_code_sample_docstrings,
|
| 31 |
+
add_start_docstrings,
|
| 32 |
+
add_start_docstrings_to_model_forward,
|
| 33 |
+
can_return_tuple,
|
| 34 |
+
logging,
|
| 35 |
+
replace_return_docstrings,
|
| 36 |
+
torch_int,
|
| 37 |
+
)
|
| 38 |
+
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
# General docstring
|
| 44 |
+
_CONFIG_FOR_DOC = "CLIPConfig"
|
| 45 |
+
_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
|
| 46 |
+
|
| 47 |
+
# Image classification docstring
|
| 48 |
+
_IMAGE_CLASS_CHECKPOINT = "openai/clip-vit-base-patch32"
|
| 49 |
+
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# contrastive loss function, adapted from
|
| 53 |
+
# https://sachinruk.github.io/blog/2021-03-07-clip.html
|
| 54 |
+
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
caption_loss = contrastive_loss(similarity)
|
| 60 |
+
image_loss = contrastive_loss(similarity.t())
|
| 61 |
+
return (caption_loss + image_loss) / 2.0
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
|
| 65 |
+
"""
|
| 66 |
+
This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
|
| 67 |
+
model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
|
| 68 |
+
"""
|
| 69 |
+
square_tensor = torch.pow(tensor, 2)
|
| 70 |
+
sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
|
| 71 |
+
normed_tensor = torch.pow(sum_tensor, 0.5)
|
| 72 |
+
return normed_tensor
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class CLIPVisionModelOutput(ModelOutput):
|
| 77 |
+
"""
|
| 78 |
+
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
| 82 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
| 83 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 84 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 85 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 86 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 87 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 88 |
+
|
| 89 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 90 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 91 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 92 |
+
sequence_length)`.
|
| 93 |
+
|
| 94 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 95 |
+
heads.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
| 99 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 100 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 101 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dataclass
|
| 105 |
+
class CLIPTextModelOutput(ModelOutput):
|
| 106 |
+
"""
|
| 107 |
+
Base class for text model's outputs that also contains a pooling of the last hidden states.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
| 111 |
+
The text embeddings obtained by applying the projection layer to the pooler_output.
|
| 112 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 113 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 114 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 115 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 116 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 117 |
+
|
| 118 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 119 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 120 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 121 |
+
sequence_length)`.
|
| 122 |
+
|
| 123 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 124 |
+
heads.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
text_embeds: Optional[torch.FloatTensor] = None
|
| 128 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 129 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 130 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@dataclass
|
| 134 |
+
class CLIPOutput(ModelOutput):
|
| 135 |
+
"""
|
| 136 |
+
Args:
|
| 137 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
| 138 |
+
Contrastive loss for image-text similarity.
|
| 139 |
+
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
| 140 |
+
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
| 141 |
+
similarity scores.
|
| 142 |
+
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
| 143 |
+
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
| 144 |
+
similarity scores.
|
| 145 |
+
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
| 146 |
+
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
|
| 147 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
| 148 |
+
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
| 149 |
+
text_model_output (`BaseModelOutputWithPooling`):
|
| 150 |
+
The output of the [`CLIPTextModel`].
|
| 151 |
+
vision_model_output (`BaseModelOutputWithPooling`):
|
| 152 |
+
The output of the [`CLIPVisionModel`].
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
loss: Optional[torch.FloatTensor] = None
|
| 156 |
+
logits_per_image: Optional[torch.FloatTensor] = None
|
| 157 |
+
logits_per_text: Optional[torch.FloatTensor] = None
|
| 158 |
+
text_embeds: Optional[torch.FloatTensor] = None
|
| 159 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
| 160 |
+
text_model_output: BaseModelOutputWithPooling = None
|
| 161 |
+
vision_model_output: BaseModelOutputWithPooling = None
|
| 162 |
+
|
| 163 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 164 |
+
return tuple(
|
| 165 |
+
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
| 166 |
+
for k in self.keys()
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class CLIPVisionEmbeddings(nn.Module):
|
| 171 |
+
def __init__(self, config: CLIPVisionConfig):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.config = config
|
| 174 |
+
self.embed_dim = config.hidden_size
|
| 175 |
+
self.image_size = config.image_size
|
| 176 |
+
self.patch_size = config.patch_size
|
| 177 |
+
|
| 178 |
+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
| 179 |
+
|
| 180 |
+
self.patch_embedding = nn.Conv2d(
|
| 181 |
+
in_channels=config.num_channels,
|
| 182 |
+
out_channels=self.embed_dim,
|
| 183 |
+
kernel_size=self.patch_size,
|
| 184 |
+
stride=self.patch_size,
|
| 185 |
+
bias=False,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 189 |
+
self.num_positions = self.num_patches + 1
|
| 190 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
| 191 |
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
| 192 |
+
|
| 193 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
| 194 |
+
"""
|
| 195 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
| 196 |
+
images. This method is also adapted to support torch.jit tracing.
|
| 197 |
+
|
| 198 |
+
Adapted from:
|
| 199 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
| 200 |
+
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
num_patches = embeddings.shape[1] - 1
|
| 204 |
+
position_embedding = self.position_embedding.weight.unsqueeze(0)
|
| 205 |
+
num_positions = position_embedding.shape[1] - 1
|
| 206 |
+
|
| 207 |
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
| 208 |
+
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
| 209 |
+
return self.position_embedding(self.position_ids)
|
| 210 |
+
|
| 211 |
+
class_pos_embed = position_embedding[:, :1]
|
| 212 |
+
patch_pos_embed = position_embedding[:, 1:]
|
| 213 |
+
|
| 214 |
+
dim = embeddings.shape[-1]
|
| 215 |
+
|
| 216 |
+
new_height = height // self.patch_size
|
| 217 |
+
new_width = width // self.patch_size
|
| 218 |
+
|
| 219 |
+
sqrt_num_positions = torch_int(num_positions**0.5)
|
| 220 |
+
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
| 221 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 222 |
+
|
| 223 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 224 |
+
patch_pos_embed,
|
| 225 |
+
size=(new_height, new_width),
|
| 226 |
+
mode="bicubic",
|
| 227 |
+
align_corners=False,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 231 |
+
|
| 232 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 233 |
+
|
| 234 |
+
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
| 235 |
+
batch_size, _, height, width = pixel_values.shape
|
| 236 |
+
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
|
| 239 |
+
)
|
| 240 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 241 |
+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
| 242 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 243 |
+
|
| 244 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
| 245 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 246 |
+
if interpolate_pos_encoding:
|
| 247 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
| 248 |
+
else:
|
| 249 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
| 250 |
+
return embeddings
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class CLIPTextEmbeddings(nn.Module):
|
| 254 |
+
def __init__(self, config: CLIPTextConfig):
|
| 255 |
+
super().__init__()
|
| 256 |
+
embed_dim = config.hidden_size
|
| 257 |
+
|
| 258 |
+
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
| 259 |
+
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
| 260 |
+
|
| 261 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 262 |
+
self.register_buffer(
|
| 263 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def forward(
|
| 267 |
+
self,
|
| 268 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 269 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 270 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 271 |
+
) -> torch.Tensor:
|
| 272 |
+
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
| 273 |
+
max_position_embedding = self.position_embedding.weight.shape[0]
|
| 274 |
+
|
| 275 |
+
if seq_length > max_position_embedding:
|
| 276 |
+
raise ValueError(
|
| 277 |
+
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
|
| 278 |
+
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
if position_ids is None:
|
| 282 |
+
position_ids = self.position_ids[:, :seq_length]
|
| 283 |
+
|
| 284 |
+
if inputs_embeds is None:
|
| 285 |
+
inputs_embeds = self.token_embedding(input_ids)
|
| 286 |
+
|
| 287 |
+
position_embeddings = self.position_embedding(position_ids)
|
| 288 |
+
embeddings = inputs_embeds + position_embeddings
|
| 289 |
+
|
| 290 |
+
return embeddings
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def eager_attention_forward(
|
| 294 |
+
module: nn.Module,
|
| 295 |
+
query: torch.Tensor,
|
| 296 |
+
key: torch.Tensor,
|
| 297 |
+
value: torch.Tensor,
|
| 298 |
+
attention_mask: Optional[torch.Tensor],
|
| 299 |
+
scaling: float,
|
| 300 |
+
dropout: float = 0.0,
|
| 301 |
+
output_attentions: bool = True,
|
| 302 |
+
**kwargs,
|
| 303 |
+
):
|
| 304 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
| 305 |
+
if attention_mask is not None:
|
| 306 |
+
attn_weights = attn_weights + attention_mask
|
| 307 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 308 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 309 |
+
|
| 310 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 311 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 312 |
+
if not output_attentions:
|
| 313 |
+
attn_weights = None
|
| 314 |
+
return attn_output, attn_weights
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class CLIPAttention(nn.Module):
|
| 318 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 319 |
+
|
| 320 |
+
def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.config = config
|
| 323 |
+
self.embed_dim = config.hidden_size
|
| 324 |
+
self.num_heads = config.num_attention_heads
|
| 325 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 326 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 327 |
+
raise ValueError(
|
| 328 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 329 |
+
f" {self.num_heads})."
|
| 330 |
+
)
|
| 331 |
+
self.scale = self.head_dim**-0.5
|
| 332 |
+
self.dropout = config.attention_dropout
|
| 333 |
+
self.is_causal = False
|
| 334 |
+
|
| 335 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 336 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 337 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 338 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 339 |
+
|
| 340 |
+
def forward(
|
| 341 |
+
self,
|
| 342 |
+
hidden_states: torch.Tensor,
|
| 343 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 344 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
| 345 |
+
output_attentions: Optional[bool] = False,
|
| 346 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 347 |
+
"""Input shape: Batch x Time x Channel"""
|
| 348 |
+
|
| 349 |
+
batch_size, seq_length, embed_dim = hidden_states.shape
|
| 350 |
+
|
| 351 |
+
queries = self.q_proj(hidden_states)
|
| 352 |
+
keys = self.k_proj(hidden_states)
|
| 353 |
+
values = self.v_proj(hidden_states)
|
| 354 |
+
|
| 355 |
+
queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
|
| 356 |
+
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
|
| 357 |
+
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
|
| 358 |
+
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
| 359 |
+
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
|
| 360 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 361 |
+
self.is_causal = causal_attention_mask is not None
|
| 362 |
+
else:
|
| 363 |
+
if attention_mask is not None and causal_attention_mask is not None:
|
| 364 |
+
attention_mask = attention_mask + causal_attention_mask
|
| 365 |
+
elif causal_attention_mask is not None:
|
| 366 |
+
attention_mask = causal_attention_mask
|
| 367 |
+
|
| 368 |
+
attention_interface: Callable = eager_attention_forward
|
| 369 |
+
if self.config._attn_implementation != "eager":
|
| 370 |
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
| 371 |
+
logger.warning_once(
|
| 372 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 373 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 374 |
+
)
|
| 375 |
+
else:
|
| 376 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 377 |
+
|
| 378 |
+
attn_output, attn_weights = attention_interface(
|
| 379 |
+
self,
|
| 380 |
+
queries,
|
| 381 |
+
keys,
|
| 382 |
+
values,
|
| 383 |
+
attention_mask,
|
| 384 |
+
is_causal=self.is_causal,
|
| 385 |
+
scaling=self.scale,
|
| 386 |
+
dropout=0.0 if not self.training else self.dropout,
|
| 387 |
+
output_attentions=output_attentions,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
| 391 |
+
attn_output = self.out_proj(attn_output)
|
| 392 |
+
|
| 393 |
+
if not output_attentions:
|
| 394 |
+
attn_weights = None
|
| 395 |
+
return attn_output, attn_weights
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class CLIPMLP(nn.Module):
|
| 399 |
+
def __init__(self, config):
|
| 400 |
+
super().__init__()
|
| 401 |
+
self.config = config
|
| 402 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 403 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 404 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 405 |
+
|
| 406 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 407 |
+
hidden_states = self.fc1(hidden_states)
|
| 408 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 409 |
+
hidden_states = self.fc2(hidden_states)
|
| 410 |
+
return hidden_states
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class CLIPEncoderLayer(nn.Module):
|
| 414 |
+
def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]):
|
| 415 |
+
super().__init__()
|
| 416 |
+
self.embed_dim = config.hidden_size
|
| 417 |
+
self.self_attn = CLIPAttention(config)
|
| 418 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 419 |
+
self.mlp = CLIPMLP(config)
|
| 420 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 421 |
+
|
| 422 |
+
def forward(
|
| 423 |
+
self,
|
| 424 |
+
hidden_states: torch.Tensor,
|
| 425 |
+
attention_mask: torch.Tensor,
|
| 426 |
+
causal_attention_mask: torch.Tensor,
|
| 427 |
+
output_attentions: Optional[bool] = False,
|
| 428 |
+
) -> Tuple[torch.FloatTensor]:
|
| 429 |
+
"""
|
| 430 |
+
Args:
|
| 431 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 432 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 433 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 434 |
+
`(config.encoder_attention_heads,)`.
|
| 435 |
+
output_attentions (`bool`, *optional*):
|
| 436 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 437 |
+
returned tensors for more detail.
|
| 438 |
+
"""
|
| 439 |
+
residual = hidden_states
|
| 440 |
+
|
| 441 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 442 |
+
hidden_states, attn_weights = self.self_attn(
|
| 443 |
+
hidden_states=hidden_states,
|
| 444 |
+
attention_mask=attention_mask,
|
| 445 |
+
causal_attention_mask=causal_attention_mask,
|
| 446 |
+
output_attentions=output_attentions,
|
| 447 |
+
)
|
| 448 |
+
hidden_states = residual + hidden_states
|
| 449 |
+
|
| 450 |
+
residual = hidden_states
|
| 451 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 452 |
+
hidden_states = self.mlp(hidden_states)
|
| 453 |
+
hidden_states = residual + hidden_states
|
| 454 |
+
|
| 455 |
+
outputs = (hidden_states,)
|
| 456 |
+
|
| 457 |
+
if output_attentions:
|
| 458 |
+
outputs += (attn_weights,)
|
| 459 |
+
|
| 460 |
+
return outputs
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class CLIPPreTrainedModel(PreTrainedModel):
|
| 464 |
+
"""
|
| 465 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 466 |
+
models.
|
| 467 |
+
"""
|
| 468 |
+
|
| 469 |
+
config_class = CLIPConfig
|
| 470 |
+
base_model_prefix = "clip"
|
| 471 |
+
supports_gradient_checkpointing = True
|
| 472 |
+
_supports_sdpa = True
|
| 473 |
+
_supports_flash_attn_2 = True
|
| 474 |
+
|
| 475 |
+
def _init_weights(self, module):
|
| 476 |
+
"""Initialize the weights"""
|
| 477 |
+
factor = self.config.initializer_factor
|
| 478 |
+
if isinstance(module, CLIPTextEmbeddings):
|
| 479 |
+
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
| 480 |
+
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
| 481 |
+
elif isinstance(module, CLIPVisionEmbeddings):
|
| 482 |
+
factor = self.config.initializer_factor
|
| 483 |
+
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
| 484 |
+
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
| 485 |
+
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
| 486 |
+
elif isinstance(module, CLIPAttention):
|
| 487 |
+
factor = self.config.initializer_factor
|
| 488 |
+
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
| 489 |
+
out_proj_std = (module.embed_dim**-0.5) * factor
|
| 490 |
+
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
|
| 491 |
+
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
|
| 492 |
+
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
|
| 493 |
+
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
| 494 |
+
elif isinstance(module, CLIPMLP):
|
| 495 |
+
factor = self.config.initializer_factor
|
| 496 |
+
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
| 497 |
+
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
| 498 |
+
nn.init.normal_(module.fc1.weight, std=fc_std)
|
| 499 |
+
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
| 500 |
+
elif isinstance(module, CLIPModel):
|
| 501 |
+
nn.init.normal_(
|
| 502 |
+
module.text_projection.weight,
|
| 503 |
+
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
| 504 |
+
)
|
| 505 |
+
nn.init.normal_(
|
| 506 |
+
module.visual_projection.weight,
|
| 507 |
+
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
| 508 |
+
)
|
| 509 |
+
elif isinstance(module, CLIPVisionModelWithProjection):
|
| 510 |
+
nn.init.normal_(
|
| 511 |
+
module.visual_projection.weight,
|
| 512 |
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
| 513 |
+
)
|
| 514 |
+
elif isinstance(module, CLIPTextModelWithProjection):
|
| 515 |
+
nn.init.normal_(
|
| 516 |
+
module.text_projection.weight,
|
| 517 |
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
| 518 |
+
)
|
| 519 |
+
elif isinstance(module, CLIPForImageClassification):
|
| 520 |
+
nn.init.normal_(
|
| 521 |
+
module.classifier.weight,
|
| 522 |
+
std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
if isinstance(module, nn.LayerNorm):
|
| 526 |
+
module.bias.data.zero_()
|
| 527 |
+
module.weight.data.fill_(1.0)
|
| 528 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 529 |
+
module.bias.data.zero_()
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
CLIP_START_DOCSTRING = r"""
|
| 533 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 534 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 535 |
+
etc.)
|
| 536 |
+
|
| 537 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 538 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 539 |
+
and behavior.
|
| 540 |
+
|
| 541 |
+
Parameters:
|
| 542 |
+
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
| 543 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 544 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 545 |
+
"""
|
| 546 |
+
|
| 547 |
+
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
| 548 |
+
Args:
|
| 549 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 550 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 551 |
+
it.
|
| 552 |
+
|
| 553 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 554 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 555 |
+
|
| 556 |
+
[What are input IDs?](../glossary#input-ids)
|
| 557 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 558 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 559 |
+
|
| 560 |
+
- 1 for tokens that are **not masked**,
|
| 561 |
+
- 0 for tokens that are **masked**.
|
| 562 |
+
|
| 563 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 564 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 565 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 566 |
+
config.max_position_embeddings - 1]`.
|
| 567 |
+
|
| 568 |
+
[What are position IDs?](../glossary#position-ids)
|
| 569 |
+
output_attentions (`bool`, *optional*):
|
| 570 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 571 |
+
tensors for more detail.
|
| 572 |
+
output_hidden_states (`bool`, *optional*):
|
| 573 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 574 |
+
more detail.
|
| 575 |
+
return_dict (`bool`, *optional*):
|
| 576 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 577 |
+
"""
|
| 578 |
+
|
| 579 |
+
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
| 580 |
+
Args:
|
| 581 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 582 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
| 583 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
| 584 |
+
output_attentions (`bool`, *optional*):
|
| 585 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 586 |
+
tensors for more detail.
|
| 587 |
+
output_hidden_states (`bool`, *optional*):
|
| 588 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 589 |
+
more detail.
|
| 590 |
+
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
|
| 591 |
+
Whether to interpolate the pre-trained position encodings.
|
| 592 |
+
return_dict (`bool`, *optional*):
|
| 593 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 594 |
+
"""
|
| 595 |
+
|
| 596 |
+
CLIP_INPUTS_DOCSTRING = r"""
|
| 597 |
+
Args:
|
| 598 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 599 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 600 |
+
it.
|
| 601 |
+
|
| 602 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 603 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 604 |
+
|
| 605 |
+
[What are input IDs?](../glossary#input-ids)
|
| 606 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 607 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 608 |
+
|
| 609 |
+
- 1 for tokens that are **not masked**,
|
| 610 |
+
- 0 for tokens that are **masked**.
|
| 611 |
+
|
| 612 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 613 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 614 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 615 |
+
config.max_position_embeddings - 1]`.
|
| 616 |
+
|
| 617 |
+
[What are position IDs?](../glossary#position-ids)
|
| 618 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 619 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
| 620 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
| 621 |
+
return_loss (`bool`, *optional*):
|
| 622 |
+
Whether or not to return the contrastive loss.
|
| 623 |
+
output_attentions (`bool`, *optional*):
|
| 624 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 625 |
+
tensors for more detail.
|
| 626 |
+
output_hidden_states (`bool`, *optional*):
|
| 627 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 628 |
+
more detail.
|
| 629 |
+
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
|
| 630 |
+
Whether to interpolate the pre-trained position encodings.
|
| 631 |
+
return_dict (`bool`, *optional*):
|
| 632 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 633 |
+
"""
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
class CLIPEncoder(nn.Module):
|
| 637 |
+
"""
|
| 638 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 639 |
+
[`CLIPEncoderLayer`].
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
config: CLIPConfig
|
| 643 |
+
"""
|
| 644 |
+
|
| 645 |
+
def __init__(self, config: CLIPConfig):
|
| 646 |
+
super().__init__()
|
| 647 |
+
self.config = config
|
| 648 |
+
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 649 |
+
self.gradient_checkpointing = False
|
| 650 |
+
|
| 651 |
+
@can_return_tuple
|
| 652 |
+
def forward(
|
| 653 |
+
self,
|
| 654 |
+
inputs_embeds,
|
| 655 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 656 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
| 657 |
+
output_attentions: Optional[bool] = None,
|
| 658 |
+
output_hidden_states: Optional[bool] = None,
|
| 659 |
+
) -> BaseModelOutput:
|
| 660 |
+
r"""
|
| 661 |
+
Args:
|
| 662 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 663 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 664 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
| 665 |
+
than the model's internal embedding lookup matrix.
|
| 666 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 667 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 668 |
+
|
| 669 |
+
- 1 for tokens that are **not masked**,
|
| 670 |
+
- 0 for tokens that are **masked**.
|
| 671 |
+
|
| 672 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 673 |
+
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 674 |
+
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
| 675 |
+
|
| 676 |
+
- 1 for tokens that are **not masked**,
|
| 677 |
+
- 0 for tokens that are **masked**.
|
| 678 |
+
|
| 679 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 680 |
+
output_attentions (`bool`, *optional*):
|
| 681 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 682 |
+
returned tensors for more detail.
|
| 683 |
+
output_hidden_states (`bool`, *optional*):
|
| 684 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 685 |
+
for more detail.
|
| 686 |
+
return_dict (`bool`, *optional*):
|
| 687 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 688 |
+
"""
|
| 689 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 690 |
+
output_hidden_states = (
|
| 691 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
encoder_states = () if output_hidden_states else None
|
| 695 |
+
all_attentions = () if output_attentions else None
|
| 696 |
+
|
| 697 |
+
hidden_states = inputs_embeds
|
| 698 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 699 |
+
if output_hidden_states:
|
| 700 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 701 |
+
if self.gradient_checkpointing and self.training:
|
| 702 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 703 |
+
encoder_layer.__call__,
|
| 704 |
+
hidden_states,
|
| 705 |
+
attention_mask,
|
| 706 |
+
causal_attention_mask,
|
| 707 |
+
output_attentions,
|
| 708 |
+
)
|
| 709 |
+
else:
|
| 710 |
+
layer_outputs = encoder_layer(
|
| 711 |
+
hidden_states,
|
| 712 |
+
attention_mask,
|
| 713 |
+
causal_attention_mask,
|
| 714 |
+
output_attentions=output_attentions,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
hidden_states = layer_outputs[0]
|
| 718 |
+
|
| 719 |
+
if output_attentions:
|
| 720 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 721 |
+
|
| 722 |
+
if output_hidden_states:
|
| 723 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 724 |
+
|
| 725 |
+
return BaseModelOutput(
|
| 726 |
+
last_hidden_state=hidden_states,
|
| 727 |
+
hidden_states=encoder_states,
|
| 728 |
+
attentions=all_attentions,
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
class CLIPTextTransformer(nn.Module):
|
| 733 |
+
def __init__(self, config: CLIPTextConfig):
|
| 734 |
+
super().__init__()
|
| 735 |
+
self.config = config
|
| 736 |
+
embed_dim = config.hidden_size
|
| 737 |
+
self.embeddings = CLIPTextEmbeddings(config)
|
| 738 |
+
self.encoder = CLIPEncoder(config)
|
| 739 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 740 |
+
|
| 741 |
+
# For `pooled_output` computation
|
| 742 |
+
self.eos_token_id = config.eos_token_id
|
| 743 |
+
|
| 744 |
+
# For attention mask, it differs between `flash_attention_2` and other attention implementations
|
| 745 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 746 |
+
|
| 747 |
+
@can_return_tuple
|
| 748 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
| 749 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
| 750 |
+
def forward(
|
| 751 |
+
self,
|
| 752 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 753 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 754 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 755 |
+
output_attentions: Optional[bool] = None,
|
| 756 |
+
output_hidden_states: Optional[bool] = None,
|
| 757 |
+
) -> BaseModelOutputWithPooling:
|
| 758 |
+
r"""
|
| 759 |
+
Returns:
|
| 760 |
+
|
| 761 |
+
"""
|
| 762 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 763 |
+
output_hidden_states = (
|
| 764 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
if input_ids is None:
|
| 768 |
+
raise ValueError("You have to specify input_ids")
|
| 769 |
+
|
| 770 |
+
input_shape = input_ids.size()
|
| 771 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 772 |
+
|
| 773 |
+
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
| 774 |
+
|
| 775 |
+
# CLIP's text model uses causal mask, prepare it here.
|
| 776 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
| 777 |
+
causal_attention_mask = _create_4d_causal_attention_mask(
|
| 778 |
+
input_shape, hidden_states.dtype, device=hidden_states.device
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
# expand attention_mask
|
| 782 |
+
if attention_mask is not None and not self._use_flash_attention_2:
|
| 783 |
+
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
|
| 784 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
| 785 |
+
|
| 786 |
+
encoder_outputs: BaseModelOutput = self.encoder(
|
| 787 |
+
inputs_embeds=hidden_states,
|
| 788 |
+
attention_mask=attention_mask,
|
| 789 |
+
causal_attention_mask=causal_attention_mask,
|
| 790 |
+
output_attentions=output_attentions,
|
| 791 |
+
output_hidden_states=output_hidden_states,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
last_hidden_state = encoder_outputs.last_hidden_state
|
| 795 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
| 796 |
+
|
| 797 |
+
if self.eos_token_id == 2:
|
| 798 |
+
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
| 799 |
+
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
| 800 |
+
# ------------------------------------------------------------
|
| 801 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
| 802 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 803 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
| 804 |
+
pooled_output = last_hidden_state[
|
| 805 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
| 806 |
+
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
| 807 |
+
]
|
| 808 |
+
else:
|
| 809 |
+
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
| 810 |
+
pooled_output = last_hidden_state[
|
| 811 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
| 812 |
+
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
| 813 |
+
# Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
|
| 814 |
+
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
|
| 815 |
+
.int()
|
| 816 |
+
.argmax(dim=-1),
|
| 817 |
+
]
|
| 818 |
+
|
| 819 |
+
return BaseModelOutputWithPooling(
|
| 820 |
+
last_hidden_state=last_hidden_state,
|
| 821 |
+
pooler_output=pooled_output,
|
| 822 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 823 |
+
attentions=encoder_outputs.attentions,
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
@add_start_docstrings(
|
| 828 |
+
"""The text model from CLIP without any head or projection on top.""",
|
| 829 |
+
CLIP_START_DOCSTRING,
|
| 830 |
+
)
|
| 831 |
+
class CLIPTextModel(CLIPPreTrainedModel):
|
| 832 |
+
config_class = CLIPTextConfig
|
| 833 |
+
|
| 834 |
+
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
| 835 |
+
|
| 836 |
+
def __init__(self, config: CLIPTextConfig):
|
| 837 |
+
super().__init__(config)
|
| 838 |
+
self.text_model = CLIPTextTransformer(config)
|
| 839 |
+
# Initialize weights and apply final processing
|
| 840 |
+
self.post_init()
|
| 841 |
+
|
| 842 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 843 |
+
return self.text_model.embeddings.token_embedding
|
| 844 |
+
|
| 845 |
+
def set_input_embeddings(self, value):
|
| 846 |
+
self.text_model.embeddings.token_embedding = value
|
| 847 |
+
|
| 848 |
+
@can_return_tuple
|
| 849 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
| 850 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
| 851 |
+
def forward(
|
| 852 |
+
self,
|
| 853 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 854 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 855 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 856 |
+
output_attentions: Optional[bool] = None,
|
| 857 |
+
output_hidden_states: Optional[bool] = None,
|
| 858 |
+
) -> BaseModelOutputWithPooling:
|
| 859 |
+
r"""
|
| 860 |
+
Returns:
|
| 861 |
+
|
| 862 |
+
Examples:
|
| 863 |
+
|
| 864 |
+
```python
|
| 865 |
+
>>> from transformers import AutoTokenizer, CLIPTextModel
|
| 866 |
+
|
| 867 |
+
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 868 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
| 869 |
+
|
| 870 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
| 871 |
+
|
| 872 |
+
>>> outputs = model(**inputs)
|
| 873 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
| 874 |
+
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
| 875 |
+
```"""
|
| 876 |
+
|
| 877 |
+
return self.text_model(
|
| 878 |
+
input_ids=input_ids,
|
| 879 |
+
attention_mask=attention_mask,
|
| 880 |
+
position_ids=position_ids,
|
| 881 |
+
output_attentions=output_attentions,
|
| 882 |
+
output_hidden_states=output_hidden_states,
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
class CLIPVisionTransformer(nn.Module):
|
| 887 |
+
def __init__(self, config: CLIPVisionConfig):
|
| 888 |
+
super().__init__()
|
| 889 |
+
self.config = config
|
| 890 |
+
embed_dim = config.hidden_size
|
| 891 |
+
|
| 892 |
+
self.embeddings = CLIPVisionEmbeddings(config)
|
| 893 |
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 894 |
+
self.encoder = CLIPEncoder(config)
|
| 895 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 896 |
+
|
| 897 |
+
@can_return_tuple
|
| 898 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
| 899 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
| 900 |
+
def forward(
|
| 901 |
+
self,
|
| 902 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 903 |
+
output_attentions: Optional[bool] = None,
|
| 904 |
+
output_hidden_states: Optional[bool] = None,
|
| 905 |
+
interpolate_pos_encoding: Optional[bool] = False,
|
| 906 |
+
) -> BaseModelOutputWithPooling:
|
| 907 |
+
r"""
|
| 908 |
+
Returns:
|
| 909 |
+
|
| 910 |
+
"""
|
| 911 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 912 |
+
output_hidden_states = (
|
| 913 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
if pixel_values is None:
|
| 917 |
+
raise ValueError("You have to specify pixel_values")
|
| 918 |
+
|
| 919 |
+
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 920 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
| 921 |
+
|
| 922 |
+
encoder_outputs: BaseModelOutput = self.encoder(
|
| 923 |
+
inputs_embeds=hidden_states,
|
| 924 |
+
output_attentions=output_attentions,
|
| 925 |
+
output_hidden_states=output_hidden_states,
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
last_hidden_state = encoder_outputs.last_hidden_state
|
| 929 |
+
pooled_output = last_hidden_state[:, 0, :]
|
| 930 |
+
pooled_output = self.post_layernorm(pooled_output)
|
| 931 |
+
|
| 932 |
+
return BaseModelOutputWithPooling(
|
| 933 |
+
last_hidden_state=last_hidden_state,
|
| 934 |
+
pooler_output=pooled_output,
|
| 935 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 936 |
+
attentions=encoder_outputs.attentions,
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
@add_start_docstrings(
|
| 941 |
+
"""The vision model from CLIP without any head or projection on top.""",
|
| 942 |
+
CLIP_START_DOCSTRING,
|
| 943 |
+
)
|
| 944 |
+
class CLIPVisionModel(CLIPPreTrainedModel):
|
| 945 |
+
config_class = CLIPVisionConfig
|
| 946 |
+
main_input_name = "pixel_values"
|
| 947 |
+
_no_split_modules = ["CLIPEncoderLayer"]
|
| 948 |
+
|
| 949 |
+
def __init__(self, config: CLIPVisionConfig):
|
| 950 |
+
super().__init__(config)
|
| 951 |
+
self.vision_model = CLIPVisionTransformer(config)
|
| 952 |
+
# Initialize weights and apply final processing
|
| 953 |
+
self.post_init()
|
| 954 |
+
|
| 955 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 956 |
+
return self.vision_model.embeddings.patch_embedding
|
| 957 |
+
|
| 958 |
+
@can_return_tuple
|
| 959 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
| 960 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
| 961 |
+
def forward(
|
| 962 |
+
self,
|
| 963 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 964 |
+
output_attentions: Optional[bool] = None,
|
| 965 |
+
output_hidden_states: Optional[bool] = None,
|
| 966 |
+
interpolate_pos_encoding: bool = False,
|
| 967 |
+
) -> BaseModelOutputWithPooling:
|
| 968 |
+
r"""
|
| 969 |
+
Returns:
|
| 970 |
+
|
| 971 |
+
Examples:
|
| 972 |
+
|
| 973 |
+
```python
|
| 974 |
+
>>> from PIL import Image
|
| 975 |
+
>>> import requests
|
| 976 |
+
>>> from transformers import AutoProcessor, CLIPVisionModel
|
| 977 |
+
|
| 978 |
+
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 979 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 980 |
+
|
| 981 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 982 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 983 |
+
|
| 984 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
| 985 |
+
|
| 986 |
+
>>> outputs = model(**inputs)
|
| 987 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
| 988 |
+
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
| 989 |
+
```"""
|
| 990 |
+
|
| 991 |
+
return self.vision_model(
|
| 992 |
+
pixel_values=pixel_values,
|
| 993 |
+
output_attentions=output_attentions,
|
| 994 |
+
output_hidden_states=output_hidden_states,
|
| 995 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
@add_start_docstrings(CLIP_START_DOCSTRING)
|
| 1000 |
+
class CLIPModel(CLIPPreTrainedModel):
|
| 1001 |
+
config_class = CLIPConfig
|
| 1002 |
+
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
|
| 1003 |
+
|
| 1004 |
+
def __init__(self, config: CLIPConfig):
|
| 1005 |
+
super().__init__(config)
|
| 1006 |
+
|
| 1007 |
+
if not isinstance(config.text_config, CLIPTextConfig):
|
| 1008 |
+
raise TypeError(
|
| 1009 |
+
"config.text_config is expected to be of type CLIPTextConfig but is of type"
|
| 1010 |
+
f" {type(config.text_config)}."
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
if not isinstance(config.vision_config, CLIPVisionConfig):
|
| 1014 |
+
raise TypeError(
|
| 1015 |
+
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
|
| 1016 |
+
f" {type(config.vision_config)}."
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
text_config = config.text_config
|
| 1020 |
+
vision_config = config.vision_config
|
| 1021 |
+
|
| 1022 |
+
self.projection_dim = config.projection_dim
|
| 1023 |
+
self.text_embed_dim = text_config.hidden_size
|
| 1024 |
+
self.vision_embed_dim = vision_config.hidden_size
|
| 1025 |
+
|
| 1026 |
+
text_model = CLIPTextModel._from_config(text_config)
|
| 1027 |
+
self.text_model = text_model.text_model
|
| 1028 |
+
|
| 1029 |
+
vision_model = CLIPVisionModel._from_config(vision_config)
|
| 1030 |
+
self.vision_model = vision_model.vision_model
|
| 1031 |
+
|
| 1032 |
+
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
| 1033 |
+
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
| 1034 |
+
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
| 1035 |
+
|
| 1036 |
+
# Initialize weights and apply final processing
|
| 1037 |
+
self.post_init()
|
| 1038 |
+
|
| 1039 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
| 1040 |
+
def get_text_features(
|
| 1041 |
+
self,
|
| 1042 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1043 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1044 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1045 |
+
output_attentions: Optional[bool] = None,
|
| 1046 |
+
output_hidden_states: Optional[bool] = None,
|
| 1047 |
+
) -> torch.FloatTensor:
|
| 1048 |
+
r"""
|
| 1049 |
+
Returns:
|
| 1050 |
+
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
| 1051 |
+
applying the projection layer to the pooled output of [`CLIPTextModel`].
|
| 1052 |
+
|
| 1053 |
+
Examples:
|
| 1054 |
+
|
| 1055 |
+
```python
|
| 1056 |
+
>>> from transformers import AutoTokenizer, CLIPModel
|
| 1057 |
+
|
| 1058 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 1059 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
| 1060 |
+
|
| 1061 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
| 1062 |
+
>>> text_features = model.get_text_features(**inputs)
|
| 1063 |
+
```"""
|
| 1064 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
| 1065 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1066 |
+
output_hidden_states = (
|
| 1067 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
text_outputs: BaseModelOutputWithPooling = self.text_model(
|
| 1071 |
+
input_ids=input_ids,
|
| 1072 |
+
attention_mask=attention_mask,
|
| 1073 |
+
position_ids=position_ids,
|
| 1074 |
+
output_attentions=output_attentions,
|
| 1075 |
+
output_hidden_states=output_hidden_states,
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
pooled_output = text_outputs.pooler_output
|
| 1079 |
+
text_features = self.text_projection(pooled_output)
|
| 1080 |
+
|
| 1081 |
+
return text_features
|
| 1082 |
+
|
| 1083 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
| 1084 |
+
def get_image_features(
|
| 1085 |
+
self,
|
| 1086 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1087 |
+
output_attentions: Optional[bool] = None,
|
| 1088 |
+
output_hidden_states: Optional[bool] = None,
|
| 1089 |
+
interpolate_pos_encoding: bool = False,
|
| 1090 |
+
) -> torch.FloatTensor:
|
| 1091 |
+
r"""
|
| 1092 |
+
Returns:
|
| 1093 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
| 1094 |
+
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
| 1095 |
+
|
| 1096 |
+
Examples:
|
| 1097 |
+
|
| 1098 |
+
```python
|
| 1099 |
+
>>> from PIL import Image
|
| 1100 |
+
>>> import requests
|
| 1101 |
+
>>> from transformers import AutoProcessor, CLIPModel
|
| 1102 |
+
|
| 1103 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 1104 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 1105 |
+
|
| 1106 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1107 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1108 |
+
|
| 1109 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
| 1110 |
+
|
| 1111 |
+
>>> image_features = model.get_image_features(**inputs)
|
| 1112 |
+
```"""
|
| 1113 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
| 1114 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1115 |
+
output_hidden_states = (
|
| 1116 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1117 |
+
)
|
| 1118 |
+
|
| 1119 |
+
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
|
| 1120 |
+
pixel_values=pixel_values,
|
| 1121 |
+
output_attentions=output_attentions,
|
| 1122 |
+
output_hidden_states=output_hidden_states,
|
| 1123 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
pooled_output = vision_outputs.pooler_output
|
| 1127 |
+
image_features = self.visual_projection(pooled_output)
|
| 1128 |
+
|
| 1129 |
+
return image_features
|
| 1130 |
+
|
| 1131 |
+
@can_return_tuple
|
| 1132 |
+
@add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
|
| 1133 |
+
@replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig)
|
| 1134 |
+
def forward(
|
| 1135 |
+
self,
|
| 1136 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1137 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1138 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1139 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1140 |
+
return_loss: Optional[bool] = None,
|
| 1141 |
+
output_attentions: Optional[bool] = None,
|
| 1142 |
+
output_hidden_states: Optional[bool] = None,
|
| 1143 |
+
interpolate_pos_encoding: bool = False,
|
| 1144 |
+
) -> CLIPOutput:
|
| 1145 |
+
r"""
|
| 1146 |
+
Returns:
|
| 1147 |
+
|
| 1148 |
+
Examples:
|
| 1149 |
+
|
| 1150 |
+
```python
|
| 1151 |
+
>>> from PIL import Image
|
| 1152 |
+
>>> import requests
|
| 1153 |
+
>>> from transformers import AutoProcessor, CLIPModel
|
| 1154 |
+
|
| 1155 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 1156 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 1157 |
+
|
| 1158 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1159 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1160 |
+
|
| 1161 |
+
>>> inputs = processor(
|
| 1162 |
+
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
| 1163 |
+
... )
|
| 1164 |
+
|
| 1165 |
+
>>> outputs = model(**inputs)
|
| 1166 |
+
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
| 1167 |
+
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
| 1168 |
+
```"""
|
| 1169 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
| 1170 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1171 |
+
output_hidden_states = (
|
| 1172 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
|
| 1176 |
+
pixel_values=pixel_values,
|
| 1177 |
+
output_attentions=output_attentions,
|
| 1178 |
+
output_hidden_states=output_hidden_states,
|
| 1179 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1180 |
+
)
|
| 1181 |
+
|
| 1182 |
+
text_outputs: BaseModelOutputWithPooling = self.text_model(
|
| 1183 |
+
input_ids=input_ids,
|
| 1184 |
+
attention_mask=attention_mask,
|
| 1185 |
+
position_ids=position_ids,
|
| 1186 |
+
output_attentions=output_attentions,
|
| 1187 |
+
output_hidden_states=output_hidden_states,
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
image_embeds = vision_outputs.pooler_output
|
| 1191 |
+
image_embeds = self.visual_projection(image_embeds)
|
| 1192 |
+
|
| 1193 |
+
text_embeds = text_outputs.pooler_output
|
| 1194 |
+
text_embeds = self.text_projection(text_embeds)
|
| 1195 |
+
|
| 1196 |
+
# normalized features
|
| 1197 |
+
image_embeds = image_embeds / _get_vector_norm(image_embeds)
|
| 1198 |
+
text_embeds = text_embeds / _get_vector_norm(text_embeds)
|
| 1199 |
+
|
| 1200 |
+
# cosine similarity as logits
|
| 1201 |
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
|
| 1202 |
+
logits_per_text = logits_per_text * self.logit_scale.exp().to(text_embeds.device)
|
| 1203 |
+
|
| 1204 |
+
logits_per_image = logits_per_text.t()
|
| 1205 |
+
|
| 1206 |
+
loss = None
|
| 1207 |
+
if return_loss:
|
| 1208 |
+
loss = clip_loss(logits_per_text)
|
| 1209 |
+
|
| 1210 |
+
return CLIPOutput(
|
| 1211 |
+
loss=loss,
|
| 1212 |
+
logits_per_image=logits_per_image,
|
| 1213 |
+
logits_per_text=logits_per_text,
|
| 1214 |
+
text_embeds=text_embeds,
|
| 1215 |
+
image_embeds=image_embeds,
|
| 1216 |
+
text_model_output=text_outputs,
|
| 1217 |
+
vision_model_output=vision_outputs,
|
| 1218 |
+
)
|
| 1219 |
+
|
| 1220 |
+
|
| 1221 |
+
@add_start_docstrings(
|
| 1222 |
+
"""
|
| 1223 |
+
CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output).
|
| 1224 |
+
""",
|
| 1225 |
+
CLIP_START_DOCSTRING,
|
| 1226 |
+
)
|
| 1227 |
+
class CLIPTextModelWithProjection(CLIPPreTrainedModel):
|
| 1228 |
+
config_class = CLIPTextConfig
|
| 1229 |
+
|
| 1230 |
+
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
| 1231 |
+
|
| 1232 |
+
def __init__(self, config: CLIPTextConfig):
|
| 1233 |
+
super().__init__(config)
|
| 1234 |
+
|
| 1235 |
+
text_model = CLIPTextModel._from_config(config)
|
| 1236 |
+
self.text_model = text_model.text_model
|
| 1237 |
+
|
| 1238 |
+
self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
|
| 1239 |
+
|
| 1240 |
+
# Initialize weights and apply final processing
|
| 1241 |
+
self.post_init()
|
| 1242 |
+
|
| 1243 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 1244 |
+
return self.text_model.embeddings.token_embedding
|
| 1245 |
+
|
| 1246 |
+
def set_input_embeddings(self, value):
|
| 1247 |
+
self.text_model.embeddings.token_embedding = value
|
| 1248 |
+
|
| 1249 |
+
@can_return_tuple
|
| 1250 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
| 1251 |
+
@replace_return_docstrings(output_type=CLIPTextModelOutput, config_class=CLIPTextConfig)
|
| 1252 |
+
def forward(
|
| 1253 |
+
self,
|
| 1254 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1255 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1256 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1257 |
+
output_attentions: Optional[bool] = None,
|
| 1258 |
+
output_hidden_states: Optional[bool] = None,
|
| 1259 |
+
) -> CLIPTextModelOutput:
|
| 1260 |
+
r"""
|
| 1261 |
+
Returns:
|
| 1262 |
+
|
| 1263 |
+
Examples:
|
| 1264 |
+
|
| 1265 |
+
```python
|
| 1266 |
+
>>> from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
| 1267 |
+
|
| 1268 |
+
>>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
|
| 1269 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
| 1270 |
+
|
| 1271 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
| 1272 |
+
|
| 1273 |
+
>>> outputs = model(**inputs)
|
| 1274 |
+
>>> text_embeds = outputs.text_embeds
|
| 1275 |
+
```"""
|
| 1276 |
+
|
| 1277 |
+
text_outputs: BaseModelOutputWithPooling = self.text_model(
|
| 1278 |
+
input_ids=input_ids,
|
| 1279 |
+
attention_mask=attention_mask,
|
| 1280 |
+
position_ids=position_ids,
|
| 1281 |
+
output_attentions=output_attentions,
|
| 1282 |
+
output_hidden_states=output_hidden_states,
|
| 1283 |
+
)
|
| 1284 |
+
pooled_output = text_outputs.pooler_output
|
| 1285 |
+
text_embeds = self.text_projection(pooled_output)
|
| 1286 |
+
|
| 1287 |
+
return CLIPTextModelOutput(
|
| 1288 |
+
text_embeds=text_embeds,
|
| 1289 |
+
last_hidden_state=text_outputs.last_hidden_state,
|
| 1290 |
+
hidden_states=text_outputs.hidden_states,
|
| 1291 |
+
attentions=text_outputs.attentions,
|
| 1292 |
+
)
|
| 1293 |
+
|
| 1294 |
+
|
| 1295 |
+
@add_start_docstrings(
|
| 1296 |
+
"""
|
| 1297 |
+
CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output).
|
| 1298 |
+
""",
|
| 1299 |
+
CLIP_START_DOCSTRING,
|
| 1300 |
+
)
|
| 1301 |
+
class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
|
| 1302 |
+
config_class = CLIPVisionConfig
|
| 1303 |
+
main_input_name = "pixel_values"
|
| 1304 |
+
|
| 1305 |
+
def __init__(self, config: CLIPVisionConfig):
|
| 1306 |
+
super().__init__(config)
|
| 1307 |
+
|
| 1308 |
+
vision_model = CLIPVisionModel._from_config(config)
|
| 1309 |
+
self.vision_model = vision_model.vision_model
|
| 1310 |
+
|
| 1311 |
+
self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
|
| 1312 |
+
|
| 1313 |
+
# Initialize weights and apply final processing
|
| 1314 |
+
self.post_init()
|
| 1315 |
+
|
| 1316 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 1317 |
+
return self.vision_model.embeddings.patch_embedding
|
| 1318 |
+
|
| 1319 |
+
@can_return_tuple
|
| 1320 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
| 1321 |
+
@replace_return_docstrings(output_type=CLIPVisionModelOutput, config_class=CLIPVisionConfig)
|
| 1322 |
+
def forward(
|
| 1323 |
+
self,
|
| 1324 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1325 |
+
output_attentions: Optional[bool] = None,
|
| 1326 |
+
output_hidden_states: Optional[bool] = None,
|
| 1327 |
+
interpolate_pos_encoding: bool = False,
|
| 1328 |
+
) -> CLIPVisionModelOutput:
|
| 1329 |
+
r"""
|
| 1330 |
+
Returns:
|
| 1331 |
+
|
| 1332 |
+
Examples:
|
| 1333 |
+
|
| 1334 |
+
```python
|
| 1335 |
+
>>> from PIL import Image
|
| 1336 |
+
>>> import requests
|
| 1337 |
+
>>> from transformers import AutoProcessor, CLIPVisionModelWithProjection
|
| 1338 |
+
|
| 1339 |
+
>>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
|
| 1340 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 1341 |
+
|
| 1342 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1343 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1344 |
+
|
| 1345 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
| 1346 |
+
|
| 1347 |
+
>>> outputs = model(**inputs)
|
| 1348 |
+
>>> image_embeds = outputs.image_embeds
|
| 1349 |
+
```"""
|
| 1350 |
+
|
| 1351 |
+
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
|
| 1352 |
+
pixel_values=pixel_values,
|
| 1353 |
+
output_attentions=output_attentions,
|
| 1354 |
+
output_hidden_states=output_hidden_states,
|
| 1355 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1356 |
+
)
|
| 1357 |
+
pooled_output = vision_outputs.pooler_output
|
| 1358 |
+
image_embeds = self.visual_projection(pooled_output)
|
| 1359 |
+
|
| 1360 |
+
return CLIPVisionModelOutput(
|
| 1361 |
+
image_embeds=image_embeds,
|
| 1362 |
+
last_hidden_state=vision_outputs.last_hidden_state,
|
| 1363 |
+
hidden_states=vision_outputs.hidden_states,
|
| 1364 |
+
attentions=vision_outputs.attentions,
|
| 1365 |
+
)
|
| 1366 |
+
|
| 1367 |
+
|
| 1368 |
+
@add_start_docstrings(
|
| 1369 |
+
"""
|
| 1370 |
+
CLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
|
| 1371 |
+
the patch tokens) e.g. for ImageNet.
|
| 1372 |
+
""",
|
| 1373 |
+
CLIP_START_DOCSTRING,
|
| 1374 |
+
)
|
| 1375 |
+
class CLIPForImageClassification(CLIPPreTrainedModel):
|
| 1376 |
+
main_input_name = "pixel_values"
|
| 1377 |
+
|
| 1378 |
+
def __init__(self, config: CLIPConfig) -> None:
|
| 1379 |
+
super().__init__(config)
|
| 1380 |
+
|
| 1381 |
+
self.num_labels = config.num_labels
|
| 1382 |
+
vision_model = CLIPVisionModel._from_config(config.vision_config)
|
| 1383 |
+
self.vision_model = vision_model.vision_model
|
| 1384 |
+
|
| 1385 |
+
# Classifier head
|
| 1386 |
+
self.classifier = (
|
| 1387 |
+
nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 1388 |
+
)
|
| 1389 |
+
|
| 1390 |
+
# Initialize weights and apply final processing
|
| 1391 |
+
self.post_init()
|
| 1392 |
+
|
| 1393 |
+
@can_return_tuple
|
| 1394 |
+
@add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
|
| 1395 |
+
@add_code_sample_docstrings(
|
| 1396 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 1397 |
+
output_type=ImageClassifierOutput,
|
| 1398 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1399 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 1400 |
+
)
|
| 1401 |
+
def forward(
|
| 1402 |
+
self,
|
| 1403 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 1404 |
+
labels: Optional[torch.Tensor] = None,
|
| 1405 |
+
output_attentions: Optional[bool] = None,
|
| 1406 |
+
output_hidden_states: Optional[bool] = None,
|
| 1407 |
+
) -> ImageClassifierOutput:
|
| 1408 |
+
r"""
|
| 1409 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1410 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 1411 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1412 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1413 |
+
"""
|
| 1414 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1415 |
+
output_hidden_states = (
|
| 1416 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1417 |
+
)
|
| 1418 |
+
|
| 1419 |
+
outputs: BaseModelOutputWithPooling = self.vision_model(
|
| 1420 |
+
pixel_values,
|
| 1421 |
+
output_attentions=output_attentions,
|
| 1422 |
+
output_hidden_states=output_hidden_states,
|
| 1423 |
+
)
|
| 1424 |
+
|
| 1425 |
+
sequence_output = outputs.last_hidden_state
|
| 1426 |
+
|
| 1427 |
+
# average pool the patch tokens
|
| 1428 |
+
sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
|
| 1429 |
+
# apply classifier
|
| 1430 |
+
logits = self.classifier(sequence_output)
|
| 1431 |
+
|
| 1432 |
+
loss = None
|
| 1433 |
+
if labels is not None:
|
| 1434 |
+
# move labels to correct device to enable model parallelism
|
| 1435 |
+
labels = labels.to(logits.device)
|
| 1436 |
+
if self.config.problem_type is None:
|
| 1437 |
+
if self.num_labels == 1:
|
| 1438 |
+
self.config.problem_type = "regression"
|
| 1439 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1440 |
+
self.config.problem_type = "single_label_classification"
|
| 1441 |
+
else:
|
| 1442 |
+
self.config.problem_type = "multi_label_classification"
|
| 1443 |
+
|
| 1444 |
+
if self.config.problem_type == "regression":
|
| 1445 |
+
loss_fct = MSELoss()
|
| 1446 |
+
if self.num_labels == 1:
|
| 1447 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1448 |
+
else:
|
| 1449 |
+
loss = loss_fct(logits, labels)
|
| 1450 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1451 |
+
loss_fct = CrossEntropyLoss()
|
| 1452 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1453 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1454 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1455 |
+
loss = loss_fct(logits, labels)
|
| 1456 |
+
|
| 1457 |
+
return ImageClassifierOutput(
|
| 1458 |
+
loss=loss,
|
| 1459 |
+
logits=logits,
|
| 1460 |
+
hidden_states=outputs.hidden_states,
|
| 1461 |
+
attentions=outputs.attentions,
|
| 1462 |
+
)
|
| 1463 |
+
|
| 1464 |
+
|
| 1465 |
+
__all__ = [
|
| 1466 |
+
"CLIPModel",
|
| 1467 |
+
"CLIPPreTrainedModel",
|
| 1468 |
+
"CLIPTextModel",
|
| 1469 |
+
"CLIPTextModelWithProjection",
|
| 1470 |
+
"CLIPVisionModel",
|
| 1471 |
+
"CLIPVisionModelWithProjection",
|
| 1472 |
+
"CLIPForImageClassification",
|
| 1473 |
+
]
|
docs/transformers/build/lib/transformers/models/clip/modeling_flax_clip.py
ADDED
|
@@ -0,0 +1,1306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The OpenAI Team Authors, The Google Flax Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Any, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import flax
|
| 19 |
+
import flax.linen as nn
|
| 20 |
+
import jax
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
| 23 |
+
from flax.linen import combine_masks, make_causal_mask
|
| 24 |
+
from flax.linen.attention import dot_product_attention_weights
|
| 25 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 26 |
+
from jax import lax
|
| 27 |
+
|
| 28 |
+
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
|
| 29 |
+
from ...modeling_flax_utils import (
|
| 30 |
+
ACT2FN,
|
| 31 |
+
FlaxPreTrainedModel,
|
| 32 |
+
append_replace_return_docstrings,
|
| 33 |
+
overwrite_call_docstring,
|
| 34 |
+
)
|
| 35 |
+
from ...utils import ModelOutput, add_start_docstrings, logging
|
| 36 |
+
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
CLIP_START_DOCSTRING = r"""
|
| 42 |
+
|
| 43 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 44 |
+
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
|
| 45 |
+
|
| 46 |
+
This model is also a
|
| 47 |
+
[flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
|
| 48 |
+
a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
|
| 49 |
+
behavior.
|
| 50 |
+
|
| 51 |
+
Finally, this model supports inherent JAX features such as:
|
| 52 |
+
|
| 53 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
| 54 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
| 55 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
| 56 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
| 57 |
+
|
| 58 |
+
Parameters:
|
| 59 |
+
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
| 60 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 61 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 62 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
| 63 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
| 64 |
+
`jax.numpy.bfloat16` (on TPUs).
|
| 65 |
+
|
| 66 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
| 67 |
+
specified all the computation will be performed with the given `dtype`.
|
| 68 |
+
|
| 69 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
| 70 |
+
parameters.**
|
| 71 |
+
|
| 72 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
| 73 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
| 77 |
+
Args:
|
| 78 |
+
input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
|
| 79 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 80 |
+
it.
|
| 81 |
+
|
| 82 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 83 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 84 |
+
|
| 85 |
+
[What are input IDs?](../glossary#input-ids)
|
| 86 |
+
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
| 87 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 88 |
+
|
| 89 |
+
- 1 for tokens that are **not masked**,
|
| 90 |
+
- 0 for tokens that are **masked**.
|
| 91 |
+
|
| 92 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 93 |
+
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
| 94 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 95 |
+
config.max_position_embeddings - 1]`.
|
| 96 |
+
|
| 97 |
+
[What are position IDs?](../glossary#position-ids)
|
| 98 |
+
output_attentions (`bool`, *optional*):
|
| 99 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 100 |
+
tensors for more detail.
|
| 101 |
+
output_hidden_states (`bool`, *optional*):
|
| 102 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 103 |
+
more detail.
|
| 104 |
+
return_dict (`bool`, *optional*):
|
| 105 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
| 109 |
+
Args:
|
| 110 |
+
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
|
| 111 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
| 112 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
| 113 |
+
output_attentions (`bool`, *optional*):
|
| 114 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 115 |
+
tensors for more detail.
|
| 116 |
+
output_hidden_states (`bool`, *optional*):
|
| 117 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 118 |
+
more detail.
|
| 119 |
+
return_dict (`bool`, *optional*):
|
| 120 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
CLIP_INPUTS_DOCSTRING = r"""
|
| 124 |
+
Args:
|
| 125 |
+
input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
|
| 126 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 127 |
+
it.
|
| 128 |
+
|
| 129 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 130 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 131 |
+
|
| 132 |
+
[What are input IDs?](../glossary#input-ids)
|
| 133 |
+
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
| 134 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 135 |
+
|
| 136 |
+
- 1 for tokens that are **not masked**,
|
| 137 |
+
- 0 for tokens that are **masked**.
|
| 138 |
+
|
| 139 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 140 |
+
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
| 141 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 142 |
+
config.max_position_embeddings - 1]`.
|
| 143 |
+
|
| 144 |
+
[What are position IDs?](../glossary#position-ids)
|
| 145 |
+
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
|
| 146 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
| 147 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
| 148 |
+
output_attentions (`bool`, *optional*):
|
| 149 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 150 |
+
tensors for more detail.
|
| 151 |
+
output_hidden_states (`bool`, *optional*):
|
| 152 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 153 |
+
more detail.
|
| 154 |
+
return_dict (`bool`, *optional*):
|
| 155 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@flax.struct.dataclass
|
| 160 |
+
class FlaxCLIPTextModelOutput(ModelOutput):
|
| 161 |
+
"""
|
| 162 |
+
Base class for text model's outputs that also contains a pooling of the last hidden states.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
text_embeds (`jnp.ndarray` of shape `(batch_size, output_dim`):
|
| 166 |
+
The text embeddings obtained by applying the projection layer to the pooled output of
|
| 167 |
+
[`FlaxCLIPTextModel`].
|
| 168 |
+
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 169 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 170 |
+
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 171 |
+
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 172 |
+
`(batch_size, sequence_length, hidden_size)`.
|
| 173 |
+
|
| 174 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 175 |
+
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 176 |
+
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 177 |
+
sequence_length)`.
|
| 178 |
+
|
| 179 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 180 |
+
heads.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
text_embeds: jnp.ndarray = None
|
| 184 |
+
last_hidden_state: jnp.ndarray = None
|
| 185 |
+
hidden_states: Optional[Tuple[jnp.ndarray, ...]] = None
|
| 186 |
+
attentions: Optional[Tuple[jnp.ndarray, ...]] = None
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@flax.struct.dataclass
|
| 190 |
+
class FlaxCLIPOutput(ModelOutput):
|
| 191 |
+
"""
|
| 192 |
+
Args:
|
| 193 |
+
logits_per_image:(`jnp.ndarray` of shape `(image_batch_size, text_batch_size)`):
|
| 194 |
+
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
| 195 |
+
similarity scores.
|
| 196 |
+
logits_per_text:(`jnp.ndarray` of shape `(text_batch_size, image_batch_size)`):
|
| 197 |
+
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
| 198 |
+
similarity scores.
|
| 199 |
+
text_embeds(`jnp.ndarray` of shape `(batch_size, output_dim`):
|
| 200 |
+
The text embeddings obtained by applying the projection layer to the pooled output of
|
| 201 |
+
[`FlaxCLIPTextModel`].
|
| 202 |
+
image_embeds(`jnp.ndarray` of shape `(batch_size, output_dim`):
|
| 203 |
+
The image embeddings obtained by applying the projection layer to the pooled output of
|
| 204 |
+
[`FlaxCLIPVisionModel`].
|
| 205 |
+
text_model_output(`FlaxBaseModelOutputWithPooling`):
|
| 206 |
+
The output of the [`FlaxCLIPTextModel`].
|
| 207 |
+
vision_model_output(`FlaxBaseModelOutputWithPooling`):
|
| 208 |
+
The output of the [`FlaxCLIPVisionModel`].
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
logits_per_image: jnp.ndarray = None
|
| 212 |
+
logits_per_text: jnp.ndarray = None
|
| 213 |
+
text_embeds: jnp.ndarray = None
|
| 214 |
+
image_embeds: jnp.ndarray = None
|
| 215 |
+
text_model_output: FlaxBaseModelOutputWithPooling = None
|
| 216 |
+
vision_model_output: FlaxBaseModelOutputWithPooling = None
|
| 217 |
+
|
| 218 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 219 |
+
return tuple(
|
| 220 |
+
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
| 221 |
+
for k in self.keys()
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class FlaxCLIPVisionEmbeddings(nn.Module):
|
| 226 |
+
config: CLIPVisionConfig
|
| 227 |
+
dtype: jnp.dtype = jnp.float32
|
| 228 |
+
|
| 229 |
+
def setup(self):
|
| 230 |
+
embed_dim = self.config.hidden_size
|
| 231 |
+
image_size = self.config.image_size
|
| 232 |
+
patch_size = self.config.patch_size
|
| 233 |
+
|
| 234 |
+
self.class_embedding = self.param("class_embedding", jax.nn.initializers.normal(stddev=0.02), (embed_dim,))
|
| 235 |
+
|
| 236 |
+
self.patch_embedding = nn.Conv(
|
| 237 |
+
embed_dim,
|
| 238 |
+
kernel_size=(patch_size, patch_size),
|
| 239 |
+
strides=(patch_size, patch_size),
|
| 240 |
+
padding="VALID",
|
| 241 |
+
use_bias=False,
|
| 242 |
+
dtype=self.dtype,
|
| 243 |
+
kernel_init=jax.nn.initializers.normal(),
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
self.num_patches = (image_size // patch_size) ** 2
|
| 247 |
+
num_positions = self.num_patches + 1
|
| 248 |
+
self.position_embedding = nn.Embed(num_positions, embed_dim, embedding_init=jax.nn.initializers.normal())
|
| 249 |
+
self.position_ids = jnp.expand_dims(jnp.arange(0, num_positions, dtype="i4"), axis=0)
|
| 250 |
+
|
| 251 |
+
def __call__(self, pixel_values):
|
| 252 |
+
patch_embeds = self.patch_embedding(pixel_values)
|
| 253 |
+
batch_size, height, width, channels = patch_embeds.shape
|
| 254 |
+
patch_embeds = jnp.reshape(patch_embeds, (batch_size, height * width, channels))
|
| 255 |
+
|
| 256 |
+
class_embeds = jnp.expand_dims(self.class_embedding, axis=(0, 1))
|
| 257 |
+
class_embeds = jnp.tile(class_embeds, (batch_size, 1, 1))
|
| 258 |
+
embeddings = jnp.concatenate([class_embeds, patch_embeds], axis=1)
|
| 259 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
| 260 |
+
return embeddings
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class FlaxCLIPTextEmbeddings(nn.Module):
|
| 264 |
+
config: CLIPTextConfig
|
| 265 |
+
dtype: jnp.dtype = jnp.float32
|
| 266 |
+
|
| 267 |
+
def setup(self):
|
| 268 |
+
embed_dim = self.config.hidden_size
|
| 269 |
+
|
| 270 |
+
self.token_embedding = nn.Embed(self.config.vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal())
|
| 271 |
+
self.position_embedding = nn.Embed(
|
| 272 |
+
self.config.max_position_embeddings, embed_dim, embedding_init=jax.nn.initializers.normal()
|
| 273 |
+
)
|
| 274 |
+
self.position_ids = jnp.expand_dims(
|
| 275 |
+
jnp.arange(0, self.config.max_position_embeddings, dtype="i4"), axis=(0, 1)
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
def __call__(self, input_ids, position_ids):
|
| 279 |
+
input_embeds = self.token_embedding(input_ids.astype("i4"))
|
| 280 |
+
position_embeds = self.position_embedding(position_ids.astype("i4"))
|
| 281 |
+
|
| 282 |
+
embeddings = input_embeds + position_embeds
|
| 283 |
+
return embeddings
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class FlaxCLIPAttention(nn.Module):
|
| 287 |
+
config: Union[CLIPTextConfig, CLIPVisionConfig]
|
| 288 |
+
dtype: jnp.dtype = jnp.float32
|
| 289 |
+
|
| 290 |
+
def setup(self):
|
| 291 |
+
self.embed_dim = self.config.hidden_size
|
| 292 |
+
self.num_heads = self.config.num_attention_heads
|
| 293 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 294 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 295 |
+
raise ValueError(
|
| 296 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 297 |
+
f" {self.num_heads})."
|
| 298 |
+
)
|
| 299 |
+
self.scale = self.head_dim**-0.5
|
| 300 |
+
self.dropout = self.config.attention_dropout
|
| 301 |
+
|
| 302 |
+
self.k_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
|
| 303 |
+
self.v_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
|
| 304 |
+
self.q_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
|
| 305 |
+
self.out_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
|
| 306 |
+
|
| 307 |
+
self.causal = isinstance(self.config, CLIPTextConfig)
|
| 308 |
+
if self.causal:
|
| 309 |
+
self.causal_mask = make_causal_mask(jnp.ones((1, self.config.max_position_embeddings), dtype="i4"))
|
| 310 |
+
|
| 311 |
+
def _split_heads(self, hidden_states):
|
| 312 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
| 313 |
+
|
| 314 |
+
def _merge_heads(self, hidden_states):
|
| 315 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
| 316 |
+
|
| 317 |
+
def __call__(
|
| 318 |
+
self,
|
| 319 |
+
hidden_states,
|
| 320 |
+
attention_mask=None,
|
| 321 |
+
deterministic: bool = True,
|
| 322 |
+
output_attentions: bool = False,
|
| 323 |
+
):
|
| 324 |
+
query = self.q_proj(hidden_states)
|
| 325 |
+
key = self.k_proj(hidden_states)
|
| 326 |
+
value = self.v_proj(hidden_states)
|
| 327 |
+
|
| 328 |
+
query = self._split_heads(query)
|
| 329 |
+
key = self._split_heads(key)
|
| 330 |
+
value = self._split_heads(value)
|
| 331 |
+
|
| 332 |
+
causal_attention_mask = None
|
| 333 |
+
if self.causal:
|
| 334 |
+
query_length, key_length = query.shape[1], key.shape[1]
|
| 335 |
+
causal_attention_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length]
|
| 336 |
+
|
| 337 |
+
if attention_mask is not None and causal_attention_mask is not None:
|
| 338 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
| 339 |
+
attention_mask = combine_masks(attention_mask, causal_attention_mask, dtype="i4")
|
| 340 |
+
elif causal_attention_mask is not None:
|
| 341 |
+
attention_mask = causal_attention_mask
|
| 342 |
+
elif attention_mask is not None:
|
| 343 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
| 344 |
+
|
| 345 |
+
if attention_mask is not None:
|
| 346 |
+
attention_bias = lax.select(
|
| 347 |
+
attention_mask > 0,
|
| 348 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
| 349 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
| 350 |
+
)
|
| 351 |
+
else:
|
| 352 |
+
attention_bias = None
|
| 353 |
+
|
| 354 |
+
dropout_rng = None
|
| 355 |
+
if not deterministic and self.dropout > 0.0:
|
| 356 |
+
dropout_rng = self.make_rng("dropout")
|
| 357 |
+
|
| 358 |
+
attn_weights = dot_product_attention_weights(
|
| 359 |
+
query,
|
| 360 |
+
key,
|
| 361 |
+
bias=attention_bias,
|
| 362 |
+
dropout_rng=dropout_rng,
|
| 363 |
+
dropout_rate=self.dropout,
|
| 364 |
+
deterministic=deterministic,
|
| 365 |
+
dtype=self.dtype,
|
| 366 |
+
precision=None,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
|
| 370 |
+
attn_output = self._merge_heads(attn_output)
|
| 371 |
+
attn_output = self.out_proj(attn_output)
|
| 372 |
+
|
| 373 |
+
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
| 374 |
+
return outputs
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class FlaxCLIPMLP(nn.Module):
|
| 378 |
+
config: Union[CLIPTextConfig, CLIPVisionConfig]
|
| 379 |
+
dtype: jnp.dtype = jnp.float32
|
| 380 |
+
|
| 381 |
+
def setup(self):
|
| 382 |
+
self.activation_fn = ACT2FN[self.config.hidden_act]
|
| 383 |
+
self.fc1 = nn.Dense(
|
| 384 |
+
self.config.intermediate_size,
|
| 385 |
+
dtype=self.dtype,
|
| 386 |
+
kernel_init=jax.nn.initializers.normal(0.01),
|
| 387 |
+
)
|
| 388 |
+
self.fc2 = nn.Dense(self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
|
| 389 |
+
|
| 390 |
+
def __call__(self, hidden_states):
|
| 391 |
+
hidden_states = self.fc1(hidden_states)
|
| 392 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 393 |
+
hidden_states = self.fc2(hidden_states)
|
| 394 |
+
return hidden_states
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class FlaxCLIPEncoderLayer(nn.Module):
|
| 398 |
+
config: Union[CLIPTextConfig, CLIPVisionConfig]
|
| 399 |
+
dtype: jnp.dtype = jnp.float32
|
| 400 |
+
|
| 401 |
+
def setup(self):
|
| 402 |
+
self.self_attn = FlaxCLIPAttention(self.config, dtype=self.dtype)
|
| 403 |
+
self.layer_norm1 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 404 |
+
self.mlp = FlaxCLIPMLP(self.config, dtype=self.dtype)
|
| 405 |
+
self.layer_norm2 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 406 |
+
|
| 407 |
+
def __call__(
|
| 408 |
+
self,
|
| 409 |
+
hidden_states,
|
| 410 |
+
attention_mask,
|
| 411 |
+
deterministic: bool = True,
|
| 412 |
+
output_attentions: bool = False,
|
| 413 |
+
):
|
| 414 |
+
residual = hidden_states
|
| 415 |
+
|
| 416 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 417 |
+
attn_outputs = self.self_attn(
|
| 418 |
+
hidden_states=hidden_states,
|
| 419 |
+
attention_mask=attention_mask,
|
| 420 |
+
deterministic=deterministic,
|
| 421 |
+
output_attentions=output_attentions,
|
| 422 |
+
)
|
| 423 |
+
hidden_states = attn_outputs[0]
|
| 424 |
+
hidden_states = residual + hidden_states
|
| 425 |
+
|
| 426 |
+
residual = hidden_states
|
| 427 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 428 |
+
hidden_states = self.mlp(hidden_states)
|
| 429 |
+
hidden_states = residual + hidden_states
|
| 430 |
+
|
| 431 |
+
outputs = (hidden_states,)
|
| 432 |
+
|
| 433 |
+
if output_attentions:
|
| 434 |
+
outputs += attn_outputs[1:]
|
| 435 |
+
|
| 436 |
+
return outputs
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
class FlaxCLIPLayerCollection(nn.Module):
|
| 440 |
+
config: Union[CLIPTextConfig, CLIPVisionConfig]
|
| 441 |
+
dtype: jnp.dtype = jnp.float32
|
| 442 |
+
|
| 443 |
+
def setup(self):
|
| 444 |
+
self.layers = [
|
| 445 |
+
FlaxCLIPEncoderLayer(self.config, name=str(i), dtype=self.dtype)
|
| 446 |
+
for i in range(self.config.num_hidden_layers)
|
| 447 |
+
]
|
| 448 |
+
|
| 449 |
+
def __call__(
|
| 450 |
+
self,
|
| 451 |
+
hidden_states,
|
| 452 |
+
attention_mask=None,
|
| 453 |
+
deterministic: bool = True,
|
| 454 |
+
output_attentions: bool = False,
|
| 455 |
+
output_hidden_states: bool = False,
|
| 456 |
+
return_dict: bool = True,
|
| 457 |
+
):
|
| 458 |
+
all_attentions = () if output_attentions else None
|
| 459 |
+
all_hidden_states = () if output_hidden_states else None
|
| 460 |
+
|
| 461 |
+
for layer in self.layers:
|
| 462 |
+
if output_hidden_states:
|
| 463 |
+
all_hidden_states += (hidden_states,)
|
| 464 |
+
|
| 465 |
+
layer_outputs = layer(
|
| 466 |
+
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
| 467 |
+
)
|
| 468 |
+
hidden_states = layer_outputs[0]
|
| 469 |
+
|
| 470 |
+
if output_attentions:
|
| 471 |
+
all_attentions += (layer_outputs[1],)
|
| 472 |
+
|
| 473 |
+
if output_hidden_states:
|
| 474 |
+
all_hidden_states += (hidden_states,)
|
| 475 |
+
|
| 476 |
+
outputs = (hidden_states,)
|
| 477 |
+
|
| 478 |
+
if not return_dict:
|
| 479 |
+
return tuple(v for v in outputs if v is not None)
|
| 480 |
+
|
| 481 |
+
return FlaxBaseModelOutput(
|
| 482 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class FlaxCLIPEncoder(nn.Module):
|
| 487 |
+
config: Union[CLIPTextConfig, CLIPVisionConfig]
|
| 488 |
+
dtype: jnp.dtype = jnp.float32
|
| 489 |
+
|
| 490 |
+
def setup(self):
|
| 491 |
+
self.layers = FlaxCLIPLayerCollection(self.config, dtype=self.dtype)
|
| 492 |
+
|
| 493 |
+
def __call__(
|
| 494 |
+
self,
|
| 495 |
+
inputs_embeds,
|
| 496 |
+
attention_mask=None,
|
| 497 |
+
deterministic: bool = True,
|
| 498 |
+
output_attentions: bool = False,
|
| 499 |
+
output_hidden_states: bool = False,
|
| 500 |
+
return_dict: bool = True,
|
| 501 |
+
):
|
| 502 |
+
return self.layers(
|
| 503 |
+
hidden_states=inputs_embeds,
|
| 504 |
+
attention_mask=attention_mask,
|
| 505 |
+
deterministic=deterministic,
|
| 506 |
+
output_attentions=output_attentions,
|
| 507 |
+
output_hidden_states=output_hidden_states,
|
| 508 |
+
return_dict=return_dict,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class FlaxCLIPTextTransformer(nn.Module):
|
| 513 |
+
config: CLIPTextConfig
|
| 514 |
+
dtype: jnp.dtype = jnp.float32
|
| 515 |
+
|
| 516 |
+
def setup(self):
|
| 517 |
+
self.embeddings = FlaxCLIPTextEmbeddings(self.config, dtype=self.dtype)
|
| 518 |
+
self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype)
|
| 519 |
+
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 520 |
+
|
| 521 |
+
# For `pooled_output` computation
|
| 522 |
+
self.eos_token_id = self.config.eos_token_id
|
| 523 |
+
|
| 524 |
+
def __call__(
|
| 525 |
+
self,
|
| 526 |
+
input_ids,
|
| 527 |
+
attention_mask,
|
| 528 |
+
position_ids,
|
| 529 |
+
deterministic: bool = True,
|
| 530 |
+
output_attentions: bool = False,
|
| 531 |
+
output_hidden_states: bool = False,
|
| 532 |
+
return_dict: bool = True,
|
| 533 |
+
):
|
| 534 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 535 |
+
output_hidden_states = (
|
| 536 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 537 |
+
)
|
| 538 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 539 |
+
|
| 540 |
+
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
| 541 |
+
|
| 542 |
+
encoder_outputs = self.encoder(
|
| 543 |
+
inputs_embeds=hidden_states,
|
| 544 |
+
attention_mask=attention_mask,
|
| 545 |
+
deterministic=deterministic,
|
| 546 |
+
output_attentions=output_attentions,
|
| 547 |
+
output_hidden_states=output_hidden_states,
|
| 548 |
+
return_dict=return_dict,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
last_hidden_state = encoder_outputs[0]
|
| 552 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
| 553 |
+
|
| 554 |
+
if self.eos_token_id == 2:
|
| 555 |
+
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
| 556 |
+
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
| 557 |
+
# ------------------------------------------------------------
|
| 558 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
| 559 |
+
# take features from the EOS embedding (eos_token_id is the highest number in each sequence)
|
| 560 |
+
pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)]
|
| 561 |
+
else:
|
| 562 |
+
# (no need to cast from bool to int after comparing to `eos_token_id`)
|
| 563 |
+
pooled_output = last_hidden_state[
|
| 564 |
+
jnp.arange(last_hidden_state.shape[0]), (input_ids == self.eos_token_id).argmax(axis=-1)
|
| 565 |
+
]
|
| 566 |
+
|
| 567 |
+
if not return_dict:
|
| 568 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 569 |
+
|
| 570 |
+
return FlaxBaseModelOutputWithPooling(
|
| 571 |
+
last_hidden_state=last_hidden_state,
|
| 572 |
+
pooler_output=pooled_output,
|
| 573 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 574 |
+
attentions=encoder_outputs.attentions,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
class FlaxCLIPVisionTransformer(nn.Module):
|
| 579 |
+
config: CLIPVisionConfig
|
| 580 |
+
dtype: jnp.dtype = jnp.float32
|
| 581 |
+
|
| 582 |
+
def setup(self):
|
| 583 |
+
self.embeddings = FlaxCLIPVisionEmbeddings(self.config, dtype=self.dtype)
|
| 584 |
+
self.pre_layrnorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 585 |
+
self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype)
|
| 586 |
+
self.post_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 587 |
+
|
| 588 |
+
def __call__(
|
| 589 |
+
self,
|
| 590 |
+
pixel_values=None,
|
| 591 |
+
deterministic: bool = True,
|
| 592 |
+
output_attentions=None,
|
| 593 |
+
output_hidden_states=None,
|
| 594 |
+
return_dict: bool = True,
|
| 595 |
+
):
|
| 596 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 597 |
+
output_hidden_states = (
|
| 598 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 599 |
+
)
|
| 600 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 601 |
+
|
| 602 |
+
hidden_states = self.embeddings(pixel_values)
|
| 603 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
| 604 |
+
|
| 605 |
+
encoder_outputs = self.encoder(
|
| 606 |
+
inputs_embeds=hidden_states,
|
| 607 |
+
deterministic=deterministic,
|
| 608 |
+
output_attentions=output_attentions,
|
| 609 |
+
output_hidden_states=output_hidden_states,
|
| 610 |
+
return_dict=return_dict,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
last_hidden_state = encoder_outputs[0]
|
| 614 |
+
pooled_output = last_hidden_state[:, 0, :]
|
| 615 |
+
pooled_output = self.post_layernorm(pooled_output)
|
| 616 |
+
|
| 617 |
+
if not return_dict:
|
| 618 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 619 |
+
|
| 620 |
+
return FlaxBaseModelOutputWithPooling(
|
| 621 |
+
last_hidden_state=last_hidden_state,
|
| 622 |
+
pooler_output=pooled_output,
|
| 623 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 624 |
+
attentions=encoder_outputs.attentions,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
|
| 629 |
+
config_class = CLIPTextConfig
|
| 630 |
+
module_class: nn.Module = None
|
| 631 |
+
|
| 632 |
+
def __init__(
|
| 633 |
+
self,
|
| 634 |
+
config: CLIPTextConfig,
|
| 635 |
+
input_shape=(1, 1),
|
| 636 |
+
seed: int = 0,
|
| 637 |
+
dtype: jnp.dtype = jnp.float32,
|
| 638 |
+
_do_init: bool = True,
|
| 639 |
+
**kwargs,
|
| 640 |
+
):
|
| 641 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
| 642 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
| 643 |
+
|
| 644 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
| 645 |
+
# init input tensor
|
| 646 |
+
input_ids = jnp.zeros(input_shape, dtype="i4")
|
| 647 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
| 648 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 649 |
+
|
| 650 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
| 651 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
| 652 |
+
|
| 653 |
+
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids)["params"]
|
| 654 |
+
|
| 655 |
+
if params is not None:
|
| 656 |
+
random_params = flatten_dict(unfreeze(random_params))
|
| 657 |
+
params = flatten_dict(unfreeze(params))
|
| 658 |
+
for missing_key in self._missing_keys:
|
| 659 |
+
params[missing_key] = random_params[missing_key]
|
| 660 |
+
self._missing_keys = set()
|
| 661 |
+
return freeze(unflatten_dict(params))
|
| 662 |
+
else:
|
| 663 |
+
return random_params
|
| 664 |
+
|
| 665 |
+
def __call__(
|
| 666 |
+
self,
|
| 667 |
+
input_ids,
|
| 668 |
+
attention_mask=None,
|
| 669 |
+
position_ids=None,
|
| 670 |
+
params: dict = None,
|
| 671 |
+
dropout_rng: jax.random.PRNGKey = None,
|
| 672 |
+
train: bool = False,
|
| 673 |
+
output_attentions: Optional[bool] = None,
|
| 674 |
+
output_hidden_states: Optional[bool] = None,
|
| 675 |
+
return_dict: Optional[bool] = None,
|
| 676 |
+
):
|
| 677 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 678 |
+
output_hidden_states = (
|
| 679 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 680 |
+
)
|
| 681 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 682 |
+
|
| 683 |
+
if position_ids is None:
|
| 684 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
| 685 |
+
|
| 686 |
+
if attention_mask is None:
|
| 687 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 688 |
+
|
| 689 |
+
# Handle any PRNG if needed
|
| 690 |
+
rngs = {}
|
| 691 |
+
if dropout_rng is not None:
|
| 692 |
+
rngs["dropout"] = dropout_rng
|
| 693 |
+
|
| 694 |
+
return self.module.apply(
|
| 695 |
+
{"params": params or self.params},
|
| 696 |
+
jnp.array(input_ids, dtype="i4"),
|
| 697 |
+
jnp.array(attention_mask, dtype="i4"),
|
| 698 |
+
jnp.array(position_ids, dtype="i4"),
|
| 699 |
+
not train,
|
| 700 |
+
output_attentions,
|
| 701 |
+
output_hidden_states,
|
| 702 |
+
return_dict,
|
| 703 |
+
rngs=rngs,
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
|
| 708 |
+
config_class = CLIPVisionConfig
|
| 709 |
+
main_input_name = "pixel_values"
|
| 710 |
+
module_class: nn.Module = None
|
| 711 |
+
|
| 712 |
+
def __init__(
|
| 713 |
+
self,
|
| 714 |
+
config: CLIPVisionConfig,
|
| 715 |
+
input_shape: Optional[Tuple] = None,
|
| 716 |
+
seed: int = 0,
|
| 717 |
+
dtype: jnp.dtype = jnp.float32,
|
| 718 |
+
_do_init: bool = True,
|
| 719 |
+
**kwargs,
|
| 720 |
+
):
|
| 721 |
+
if input_shape is None:
|
| 722 |
+
input_shape = (1, config.image_size, config.image_size, 3)
|
| 723 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
| 724 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
| 725 |
+
|
| 726 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
| 727 |
+
# init input tensor
|
| 728 |
+
pixel_values = jax.random.normal(rng, input_shape)
|
| 729 |
+
|
| 730 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
| 731 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
| 732 |
+
|
| 733 |
+
random_params = self.module.init(rngs, pixel_values)["params"]
|
| 734 |
+
|
| 735 |
+
if params is not None:
|
| 736 |
+
random_params = flatten_dict(unfreeze(random_params))
|
| 737 |
+
params = flatten_dict(unfreeze(params))
|
| 738 |
+
for missing_key in self._missing_keys:
|
| 739 |
+
params[missing_key] = random_params[missing_key]
|
| 740 |
+
self._missing_keys = set()
|
| 741 |
+
return freeze(unflatten_dict(params))
|
| 742 |
+
else:
|
| 743 |
+
return random_params
|
| 744 |
+
|
| 745 |
+
def __call__(
|
| 746 |
+
self,
|
| 747 |
+
pixel_values,
|
| 748 |
+
params: dict = None,
|
| 749 |
+
dropout_rng: jax.random.PRNGKey = None,
|
| 750 |
+
train: bool = False,
|
| 751 |
+
output_attentions: Optional[bool] = None,
|
| 752 |
+
output_hidden_states: Optional[bool] = None,
|
| 753 |
+
return_dict: Optional[bool] = None,
|
| 754 |
+
):
|
| 755 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 756 |
+
output_hidden_states = (
|
| 757 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 758 |
+
)
|
| 759 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 760 |
+
|
| 761 |
+
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
| 762 |
+
|
| 763 |
+
# Handle any PRNG if needed
|
| 764 |
+
rngs = {}
|
| 765 |
+
if dropout_rng is not None:
|
| 766 |
+
rngs["dropout"] = dropout_rng
|
| 767 |
+
|
| 768 |
+
return self.module.apply(
|
| 769 |
+
{"params": params or self.params},
|
| 770 |
+
jnp.array(pixel_values, dtype=jnp.float32),
|
| 771 |
+
not train,
|
| 772 |
+
output_attentions,
|
| 773 |
+
output_hidden_states,
|
| 774 |
+
return_dict,
|
| 775 |
+
rngs=rngs,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
| 780 |
+
config_class = CLIPConfig
|
| 781 |
+
module_class: nn.Module = None
|
| 782 |
+
|
| 783 |
+
def __init__(
|
| 784 |
+
self,
|
| 785 |
+
config: CLIPConfig,
|
| 786 |
+
input_shape: Optional[Tuple] = None,
|
| 787 |
+
seed: int = 0,
|
| 788 |
+
dtype: jnp.dtype = jnp.float32,
|
| 789 |
+
_do_init: bool = True,
|
| 790 |
+
**kwargs,
|
| 791 |
+
):
|
| 792 |
+
if input_shape is None:
|
| 793 |
+
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
| 794 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
| 795 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
| 796 |
+
|
| 797 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
| 798 |
+
# init input tensor
|
| 799 |
+
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
| 800 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
| 801 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 802 |
+
|
| 803 |
+
pixel_values = jax.random.normal(rng, input_shape[1])
|
| 804 |
+
|
| 805 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
| 806 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
| 807 |
+
|
| 808 |
+
random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
|
| 809 |
+
|
| 810 |
+
if params is not None:
|
| 811 |
+
random_params = flatten_dict(unfreeze(random_params))
|
| 812 |
+
params = flatten_dict(unfreeze(params))
|
| 813 |
+
for missing_key in self._missing_keys:
|
| 814 |
+
params[missing_key] = random_params[missing_key]
|
| 815 |
+
self._missing_keys = set()
|
| 816 |
+
return freeze(unflatten_dict(params))
|
| 817 |
+
else:
|
| 818 |
+
return random_params
|
| 819 |
+
|
| 820 |
+
def __call__(
|
| 821 |
+
self,
|
| 822 |
+
input_ids,
|
| 823 |
+
pixel_values,
|
| 824 |
+
attention_mask=None,
|
| 825 |
+
position_ids=None,
|
| 826 |
+
params: dict = None,
|
| 827 |
+
dropout_rng: jax.random.PRNGKey = None,
|
| 828 |
+
train: bool = False,
|
| 829 |
+
output_attentions: Optional[bool] = None,
|
| 830 |
+
output_hidden_states: Optional[bool] = None,
|
| 831 |
+
return_dict: Optional[bool] = None,
|
| 832 |
+
):
|
| 833 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 834 |
+
output_hidden_states = (
|
| 835 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 836 |
+
)
|
| 837 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 838 |
+
|
| 839 |
+
if position_ids is None:
|
| 840 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
| 841 |
+
|
| 842 |
+
if attention_mask is None:
|
| 843 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 844 |
+
|
| 845 |
+
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
| 846 |
+
|
| 847 |
+
# Handle any PRNG if needed
|
| 848 |
+
rngs = {}
|
| 849 |
+
if dropout_rng is not None:
|
| 850 |
+
rngs["dropout"] = dropout_rng
|
| 851 |
+
|
| 852 |
+
return self.module.apply(
|
| 853 |
+
{"params": params or self.params},
|
| 854 |
+
jnp.array(input_ids, dtype="i4"),
|
| 855 |
+
jnp.array(pixel_values, dtype=jnp.float32),
|
| 856 |
+
jnp.array(attention_mask, dtype="i4"),
|
| 857 |
+
jnp.array(position_ids, dtype="i4"),
|
| 858 |
+
not train,
|
| 859 |
+
output_attentions,
|
| 860 |
+
output_hidden_states,
|
| 861 |
+
return_dict,
|
| 862 |
+
rngs=rngs,
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
def get_text_features(
|
| 866 |
+
self,
|
| 867 |
+
input_ids,
|
| 868 |
+
attention_mask=None,
|
| 869 |
+
position_ids=None,
|
| 870 |
+
params: dict = None,
|
| 871 |
+
dropout_rng: jax.random.PRNGKey = None,
|
| 872 |
+
train=False,
|
| 873 |
+
):
|
| 874 |
+
r"""
|
| 875 |
+
Args:
|
| 876 |
+
input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
|
| 877 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
| 878 |
+
provide it.
|
| 879 |
+
|
| 880 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 881 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 882 |
+
|
| 883 |
+
[What are input IDs?](../glossary#input-ids)
|
| 884 |
+
|
| 885 |
+
Returns:
|
| 886 |
+
text_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The text embeddings obtained by applying
|
| 887 |
+
the projection layer to the pooled output of [`FlaxCLIPTextModel`].
|
| 888 |
+
|
| 889 |
+
Examples:
|
| 890 |
+
|
| 891 |
+
```python
|
| 892 |
+
>>> from transformers import AutoTokenizer, FlaxCLIPModel
|
| 893 |
+
|
| 894 |
+
>>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 895 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
| 896 |
+
|
| 897 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")
|
| 898 |
+
>>> text_features = model.get_text_features(**inputs)
|
| 899 |
+
```"""
|
| 900 |
+
if position_ids is None:
|
| 901 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
| 902 |
+
|
| 903 |
+
if attention_mask is None:
|
| 904 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 905 |
+
|
| 906 |
+
# Handle any PRNG if needed
|
| 907 |
+
rngs = {}
|
| 908 |
+
if dropout_rng is not None:
|
| 909 |
+
rngs["dropout"] = dropout_rng
|
| 910 |
+
|
| 911 |
+
def _get_features(module, input_ids, attention_mask, position_ids, deterministic):
|
| 912 |
+
text_outputs = module.text_model(
|
| 913 |
+
input_ids=input_ids,
|
| 914 |
+
attention_mask=attention_mask,
|
| 915 |
+
position_ids=position_ids,
|
| 916 |
+
deterministic=deterministic,
|
| 917 |
+
)
|
| 918 |
+
pooled_output = text_outputs[1]
|
| 919 |
+
text_features = module.text_projection(pooled_output)
|
| 920 |
+
return text_features
|
| 921 |
+
|
| 922 |
+
return self.module.apply(
|
| 923 |
+
{"params": params or self.params},
|
| 924 |
+
jnp.array(input_ids, dtype="i4"),
|
| 925 |
+
jnp.array(attention_mask, dtype="i4"),
|
| 926 |
+
jnp.array(position_ids, dtype="i4"),
|
| 927 |
+
not train,
|
| 928 |
+
method=_get_features,
|
| 929 |
+
rngs=rngs,
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
def get_image_features(
|
| 933 |
+
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
|
| 934 |
+
):
|
| 935 |
+
r"""
|
| 936 |
+
Args:
|
| 937 |
+
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
|
| 938 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
|
| 939 |
+
using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
| 940 |
+
|
| 941 |
+
Returns:
|
| 942 |
+
image_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
| 943 |
+
applying the projection layer to the pooled output of [`FlaxCLIPVisionModel`]
|
| 944 |
+
|
| 945 |
+
Examples:
|
| 946 |
+
|
| 947 |
+
```python
|
| 948 |
+
>>> from PIL import Image
|
| 949 |
+
>>> import requests
|
| 950 |
+
>>> from transformers import AutoProcessor, FlaxCLIPModel
|
| 951 |
+
|
| 952 |
+
>>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 953 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 954 |
+
|
| 955 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 956 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 957 |
+
|
| 958 |
+
>>> inputs = processor(images=image, return_tensors="np")
|
| 959 |
+
|
| 960 |
+
>>> image_features = model.get_image_features(**inputs)
|
| 961 |
+
```"""
|
| 962 |
+
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
| 963 |
+
|
| 964 |
+
# Handle any PRNG if needed
|
| 965 |
+
rngs = {}
|
| 966 |
+
if dropout_rng is not None:
|
| 967 |
+
rngs["dropout"] = dropout_rng
|
| 968 |
+
|
| 969 |
+
def _get_features(module, pixel_values, deterministic):
|
| 970 |
+
vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
|
| 971 |
+
pooled_output = vision_outputs[1] # pooled_output
|
| 972 |
+
image_features = module.visual_projection(pooled_output)
|
| 973 |
+
return image_features
|
| 974 |
+
|
| 975 |
+
return self.module.apply(
|
| 976 |
+
{"params": params or self.params},
|
| 977 |
+
jnp.array(pixel_values, dtype=jnp.float32),
|
| 978 |
+
not train,
|
| 979 |
+
method=_get_features,
|
| 980 |
+
rngs=rngs,
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
|
| 984 |
+
class FlaxCLIPTextModule(nn.Module):
|
| 985 |
+
config: CLIPTextConfig
|
| 986 |
+
dtype: jnp.dtype = jnp.float32
|
| 987 |
+
|
| 988 |
+
def setup(self):
|
| 989 |
+
self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype)
|
| 990 |
+
|
| 991 |
+
def __call__(
|
| 992 |
+
self,
|
| 993 |
+
input_ids,
|
| 994 |
+
attention_mask,
|
| 995 |
+
position_ids,
|
| 996 |
+
deterministic: bool = True,
|
| 997 |
+
output_attentions: bool = False,
|
| 998 |
+
output_hidden_states: bool = False,
|
| 999 |
+
return_dict: bool = True,
|
| 1000 |
+
):
|
| 1001 |
+
return self.text_model(
|
| 1002 |
+
input_ids=input_ids,
|
| 1003 |
+
attention_mask=attention_mask,
|
| 1004 |
+
position_ids=position_ids,
|
| 1005 |
+
deterministic=deterministic,
|
| 1006 |
+
output_attentions=output_attentions,
|
| 1007 |
+
output_hidden_states=output_hidden_states,
|
| 1008 |
+
return_dict=return_dict,
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
class FlaxCLIPTextModel(FlaxCLIPTextPreTrainedModel):
|
| 1013 |
+
module_class = FlaxCLIPTextModule
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
FLAX_CLIP_TEXT_MODEL_DOCSTRING = """
|
| 1017 |
+
Returns:
|
| 1018 |
+
|
| 1019 |
+
Example:
|
| 1020 |
+
|
| 1021 |
+
```python
|
| 1022 |
+
>>> from transformers import AutoTokenizer, FlaxCLIPTextModel
|
| 1023 |
+
|
| 1024 |
+
>>> model = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 1025 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
| 1026 |
+
|
| 1027 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")
|
| 1028 |
+
|
| 1029 |
+
>>> outputs = model(**inputs)
|
| 1030 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
| 1031 |
+
>>> pooler_output = outputs.pooler_output # pooled (EOS token) states
|
| 1032 |
+
```
|
| 1033 |
+
"""
|
| 1034 |
+
|
| 1035 |
+
overwrite_call_docstring(FlaxCLIPTextModel, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_DOCSTRING)
|
| 1036 |
+
append_replace_return_docstrings(
|
| 1037 |
+
FlaxCLIPTextModel, output_type=FlaxBaseModelOutputWithPooling, config_class=CLIPTextConfig
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
class FlaxCLIPTextModelWithProjectionModule(nn.Module):
|
| 1042 |
+
config: CLIPTextConfig
|
| 1043 |
+
dtype: jnp.dtype = jnp.float32
|
| 1044 |
+
|
| 1045 |
+
def setup(self):
|
| 1046 |
+
self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype)
|
| 1047 |
+
self.text_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
|
| 1048 |
+
|
| 1049 |
+
def __call__(
|
| 1050 |
+
self,
|
| 1051 |
+
input_ids,
|
| 1052 |
+
attention_mask,
|
| 1053 |
+
position_ids,
|
| 1054 |
+
deterministic: bool = True,
|
| 1055 |
+
output_attentions: bool = False,
|
| 1056 |
+
output_hidden_states: bool = False,
|
| 1057 |
+
return_dict: bool = True,
|
| 1058 |
+
):
|
| 1059 |
+
text_outputs = self.text_model(
|
| 1060 |
+
input_ids=input_ids,
|
| 1061 |
+
attention_mask=attention_mask,
|
| 1062 |
+
position_ids=position_ids,
|
| 1063 |
+
deterministic=deterministic,
|
| 1064 |
+
output_attentions=output_attentions,
|
| 1065 |
+
output_hidden_states=output_hidden_states,
|
| 1066 |
+
return_dict=return_dict,
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
+
pooled_output = text_outputs[1]
|
| 1070 |
+
text_embeds = self.text_projection(pooled_output)
|
| 1071 |
+
|
| 1072 |
+
if not return_dict:
|
| 1073 |
+
return (text_embeds, text_outputs[0]) + text_outputs[2:]
|
| 1074 |
+
|
| 1075 |
+
return FlaxCLIPTextModelOutput(
|
| 1076 |
+
text_embeds=text_embeds,
|
| 1077 |
+
last_hidden_state=text_outputs.last_hidden_state,
|
| 1078 |
+
hidden_states=text_outputs.hidden_states,
|
| 1079 |
+
attentions=text_outputs.attentions,
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
class FlaxCLIPTextModelWithProjection(FlaxCLIPTextPreTrainedModel):
|
| 1084 |
+
module_class = FlaxCLIPTextModelWithProjectionModule
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING = """
|
| 1088 |
+
Returns:
|
| 1089 |
+
|
| 1090 |
+
Example:
|
| 1091 |
+
|
| 1092 |
+
```python
|
| 1093 |
+
>>> from transformers import AutoTokenizer, FlaxCLIPTextModelWithProjection
|
| 1094 |
+
|
| 1095 |
+
>>> model = FlaxCLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
|
| 1096 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
| 1097 |
+
|
| 1098 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")
|
| 1099 |
+
|
| 1100 |
+
>>> outputs = model(**inputs)
|
| 1101 |
+
>>> text_embeds = outputs.text_embeds
|
| 1102 |
+
```
|
| 1103 |
+
"""
|
| 1104 |
+
|
| 1105 |
+
overwrite_call_docstring(
|
| 1106 |
+
FlaxCLIPTextModelWithProjection, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING
|
| 1107 |
+
)
|
| 1108 |
+
append_replace_return_docstrings(
|
| 1109 |
+
FlaxCLIPTextModelWithProjection, output_type=FlaxCLIPTextModelOutput, config_class=CLIPTextConfig
|
| 1110 |
+
)
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
class FlaxCLIPVisionModule(nn.Module):
|
| 1114 |
+
config: CLIPVisionConfig
|
| 1115 |
+
dtype: jnp.dtype = jnp.float32
|
| 1116 |
+
|
| 1117 |
+
def setup(self):
|
| 1118 |
+
self.vision_model = FlaxCLIPVisionTransformer(self.config, dtype=self.dtype)
|
| 1119 |
+
|
| 1120 |
+
def __call__(
|
| 1121 |
+
self,
|
| 1122 |
+
pixel_values,
|
| 1123 |
+
deterministic: bool = True,
|
| 1124 |
+
output_attentions: bool = False,
|
| 1125 |
+
output_hidden_states: bool = False,
|
| 1126 |
+
return_dict: bool = True,
|
| 1127 |
+
):
|
| 1128 |
+
return self.vision_model(
|
| 1129 |
+
pixel_values=pixel_values,
|
| 1130 |
+
deterministic=deterministic,
|
| 1131 |
+
output_attentions=output_attentions,
|
| 1132 |
+
output_hidden_states=output_hidden_states,
|
| 1133 |
+
return_dict=return_dict,
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
|
| 1137 |
+
class FlaxCLIPVisionModel(FlaxCLIPVisionPreTrainedModel):
|
| 1138 |
+
module_class = FlaxCLIPVisionModule
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
FLAX_CLIP_VISION_MODEL_DOCSTRING = """
|
| 1142 |
+
Returns:
|
| 1143 |
+
|
| 1144 |
+
Example:
|
| 1145 |
+
|
| 1146 |
+
```python
|
| 1147 |
+
>>> from PIL import Image
|
| 1148 |
+
>>> import requests
|
| 1149 |
+
>>> from transformers import AutoProcessor, FlaxCLIPVisionModel
|
| 1150 |
+
|
| 1151 |
+
>>> model = FlaxCLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 1152 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 1153 |
+
|
| 1154 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1155 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1156 |
+
|
| 1157 |
+
>>> inputs = processor(images=image, return_tensors="np")
|
| 1158 |
+
|
| 1159 |
+
>>> outputs = model(**inputs)
|
| 1160 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
| 1161 |
+
>>> pooler_output = outputs.pooler_output # pooled CLS states
|
| 1162 |
+
```
|
| 1163 |
+
"""
|
| 1164 |
+
|
| 1165 |
+
overwrite_call_docstring(FlaxCLIPVisionModel, CLIP_VISION_INPUTS_DOCSTRING + FLAX_CLIP_VISION_MODEL_DOCSTRING)
|
| 1166 |
+
append_replace_return_docstrings(
|
| 1167 |
+
FlaxCLIPVisionModel, output_type=FlaxBaseModelOutputWithPooling, config_class=CLIPVisionConfig
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
|
| 1171 |
+
class FlaxCLIPModule(nn.Module):
|
| 1172 |
+
config: CLIPConfig
|
| 1173 |
+
dtype: jnp.dtype = jnp.float32
|
| 1174 |
+
|
| 1175 |
+
def setup(self):
|
| 1176 |
+
text_config = self.config.text_config
|
| 1177 |
+
vision_config = self.config.vision_config
|
| 1178 |
+
|
| 1179 |
+
self.projection_dim = self.config.projection_dim
|
| 1180 |
+
self.text_embed_dim = text_config.hidden_size
|
| 1181 |
+
self.vision_embed_dim = vision_config.hidden_size
|
| 1182 |
+
|
| 1183 |
+
self.text_model = FlaxCLIPTextTransformer(text_config, dtype=self.dtype)
|
| 1184 |
+
self.vision_model = FlaxCLIPVisionTransformer(vision_config, dtype=self.dtype)
|
| 1185 |
+
|
| 1186 |
+
self.visual_projection = nn.Dense(
|
| 1187 |
+
self.projection_dim,
|
| 1188 |
+
dtype=self.dtype,
|
| 1189 |
+
kernel_init=jax.nn.initializers.normal(0.02),
|
| 1190 |
+
use_bias=False,
|
| 1191 |
+
)
|
| 1192 |
+
self.text_projection = nn.Dense(
|
| 1193 |
+
self.projection_dim,
|
| 1194 |
+
dtype=self.dtype,
|
| 1195 |
+
kernel_init=jax.nn.initializers.normal(0.02),
|
| 1196 |
+
use_bias=False,
|
| 1197 |
+
)
|
| 1198 |
+
|
| 1199 |
+
self.logit_scale = self.param(
|
| 1200 |
+
"logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, []
|
| 1201 |
+
)
|
| 1202 |
+
|
| 1203 |
+
def __call__(
|
| 1204 |
+
self,
|
| 1205 |
+
input_ids=None,
|
| 1206 |
+
pixel_values=None,
|
| 1207 |
+
attention_mask=None,
|
| 1208 |
+
position_ids=None,
|
| 1209 |
+
deterministic: bool = True,
|
| 1210 |
+
output_attentions=None,
|
| 1211 |
+
output_hidden_states=None,
|
| 1212 |
+
return_dict=None,
|
| 1213 |
+
):
|
| 1214 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1215 |
+
|
| 1216 |
+
vision_outputs = self.vision_model(
|
| 1217 |
+
pixel_values=pixel_values,
|
| 1218 |
+
deterministic=deterministic,
|
| 1219 |
+
output_attentions=output_attentions,
|
| 1220 |
+
output_hidden_states=output_hidden_states,
|
| 1221 |
+
return_dict=return_dict,
|
| 1222 |
+
)
|
| 1223 |
+
|
| 1224 |
+
text_outputs = self.text_model(
|
| 1225 |
+
input_ids=input_ids,
|
| 1226 |
+
attention_mask=attention_mask,
|
| 1227 |
+
position_ids=position_ids,
|
| 1228 |
+
deterministic=deterministic,
|
| 1229 |
+
output_attentions=output_attentions,
|
| 1230 |
+
output_hidden_states=output_hidden_states,
|
| 1231 |
+
return_dict=return_dict,
|
| 1232 |
+
)
|
| 1233 |
+
|
| 1234 |
+
image_embeds = vision_outputs[1]
|
| 1235 |
+
image_embeds = self.visual_projection(image_embeds)
|
| 1236 |
+
|
| 1237 |
+
text_embeds = text_outputs[1]
|
| 1238 |
+
text_embeds = self.text_projection(text_embeds)
|
| 1239 |
+
|
| 1240 |
+
# normalized features
|
| 1241 |
+
image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
|
| 1242 |
+
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
|
| 1243 |
+
|
| 1244 |
+
# cosine similarity as logits
|
| 1245 |
+
logit_scale = jnp.exp(self.logit_scale)
|
| 1246 |
+
logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
|
| 1247 |
+
logits_per_image = logits_per_text.T
|
| 1248 |
+
|
| 1249 |
+
if not return_dict:
|
| 1250 |
+
return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
| 1251 |
+
|
| 1252 |
+
return FlaxCLIPOutput(
|
| 1253 |
+
logits_per_image=logits_per_image,
|
| 1254 |
+
logits_per_text=logits_per_text,
|
| 1255 |
+
text_embeds=text_embeds,
|
| 1256 |
+
image_embeds=image_embeds,
|
| 1257 |
+
text_model_output=text_outputs,
|
| 1258 |
+
vision_model_output=vision_outputs,
|
| 1259 |
+
)
|
| 1260 |
+
|
| 1261 |
+
|
| 1262 |
+
@add_start_docstrings(CLIP_START_DOCSTRING)
|
| 1263 |
+
class FlaxCLIPModel(FlaxCLIPPreTrainedModel):
|
| 1264 |
+
module_class = FlaxCLIPModule
|
| 1265 |
+
|
| 1266 |
+
|
| 1267 |
+
FLAX_CLIP_MODEL_DOCSTRING = """
|
| 1268 |
+
Returns:
|
| 1269 |
+
|
| 1270 |
+
Example:
|
| 1271 |
+
|
| 1272 |
+
```python
|
| 1273 |
+
>>> import jax
|
| 1274 |
+
>>> from PIL import Image
|
| 1275 |
+
>>> import requests
|
| 1276 |
+
>>> from transformers import AutoProcessor, FlaxCLIPModel
|
| 1277 |
+
|
| 1278 |
+
>>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 1279 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 1280 |
+
|
| 1281 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1282 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1283 |
+
|
| 1284 |
+
>>> inputs = processor(
|
| 1285 |
+
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="np", padding=True
|
| 1286 |
+
... )
|
| 1287 |
+
|
| 1288 |
+
>>> outputs = model(**inputs)
|
| 1289 |
+
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
| 1290 |
+
>>> probs = jax.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities
|
| 1291 |
+
```
|
| 1292 |
+
"""
|
| 1293 |
+
|
| 1294 |
+
overwrite_call_docstring(FlaxCLIPModel, CLIP_INPUTS_DOCSTRING + FLAX_CLIP_MODEL_DOCSTRING)
|
| 1295 |
+
append_replace_return_docstrings(FlaxCLIPModel, output_type=FlaxCLIPOutput, config_class=CLIPConfig)
|
| 1296 |
+
|
| 1297 |
+
|
| 1298 |
+
__all__ = [
|
| 1299 |
+
"FlaxCLIPModel",
|
| 1300 |
+
"FlaxCLIPPreTrainedModel",
|
| 1301 |
+
"FlaxCLIPTextModel",
|
| 1302 |
+
"FlaxCLIPTextPreTrainedModel",
|
| 1303 |
+
"FlaxCLIPTextModelWithProjection",
|
| 1304 |
+
"FlaxCLIPVisionModel",
|
| 1305 |
+
"FlaxCLIPVisionPreTrainedModel",
|
| 1306 |
+
]
|
docs/transformers/build/lib/transformers/models/clip/modeling_tf_clip.py
ADDED
|
@@ -0,0 +1,1460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""TF 2.0 CLIP model."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Any, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import tensorflow as tf
|
| 25 |
+
|
| 26 |
+
from ...activations_tf import get_tf_activation
|
| 27 |
+
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling
|
| 28 |
+
|
| 29 |
+
# Public API
|
| 30 |
+
from ...modeling_tf_utils import (
|
| 31 |
+
TFModelInputType,
|
| 32 |
+
TFPreTrainedModel,
|
| 33 |
+
get_initializer,
|
| 34 |
+
keras,
|
| 35 |
+
keras_serializable,
|
| 36 |
+
unpack_inputs,
|
| 37 |
+
)
|
| 38 |
+
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
|
| 39 |
+
from ...utils import (
|
| 40 |
+
ModelOutput,
|
| 41 |
+
add_start_docstrings,
|
| 42 |
+
add_start_docstrings_to_model_forward,
|
| 43 |
+
logging,
|
| 44 |
+
replace_return_docstrings,
|
| 45 |
+
)
|
| 46 |
+
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
LARGE_NEGATIVE = -1e8
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
|
| 58 |
+
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
| 59 |
+
"""
|
| 60 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 61 |
+
"""
|
| 62 |
+
src_len = shape_list(mask)[1]
|
| 63 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 64 |
+
one_cst = tf.constant(1.0)
|
| 65 |
+
mask = tf.cast(mask, dtype=one_cst.dtype)
|
| 66 |
+
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
|
| 67 |
+
|
| 68 |
+
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# contrastive loss function, adapted from
|
| 72 |
+
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
|
| 73 |
+
def contrastive_loss(logits: tf.Tensor) -> tf.Tensor:
|
| 74 |
+
return tf.math.reduce_mean(
|
| 75 |
+
keras.metrics.sparse_categorical_crossentropy(
|
| 76 |
+
y_true=tf.range(shape_list(logits)[0]), y_pred=logits, from_logits=True
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def clip_loss(similarity: tf.Tensor) -> tf.Tensor:
|
| 82 |
+
caption_loss = contrastive_loss(similarity)
|
| 83 |
+
image_loss = contrastive_loss(tf.transpose(similarity))
|
| 84 |
+
return (caption_loss + image_loss) / 2.0
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@dataclass
|
| 88 |
+
class TFCLIPOutput(ModelOutput):
|
| 89 |
+
"""
|
| 90 |
+
Args:
|
| 91 |
+
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
| 92 |
+
Contrastive loss for image-text similarity.
|
| 93 |
+
logits_per_image:(`tf.Tensor` of shape `(image_batch_size, text_batch_size)`):
|
| 94 |
+
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
| 95 |
+
similarity scores.
|
| 96 |
+
logits_per_text:(`tf.Tensor` of shape `(text_batch_size, image_batch_size)`):
|
| 97 |
+
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
| 98 |
+
similarity scores.
|
| 99 |
+
text_embeds(`tf.Tensor` of shape `(batch_size, output_dim`):
|
| 100 |
+
The text embeddings obtained by applying the projection layer to the pooled output of [`TFCLIPTextModel`].
|
| 101 |
+
image_embeds(`tf.Tensor` of shape `(batch_size, output_dim`):
|
| 102 |
+
The image embeddings obtained by applying the projection layer to the pooled output of
|
| 103 |
+
[`TFCLIPVisionModel`].
|
| 104 |
+
text_model_output([`~modeling_tf_utils.TFBaseModelOutputWithPooling`]):
|
| 105 |
+
The output of the [`TFCLIPTextModel`].
|
| 106 |
+
vision_model_output([`~modeling_tf_utils.TFBaseModelOutputWithPooling`]):
|
| 107 |
+
The output of the [`TFCLIPVisionModel`].
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
loss: tf.Tensor | None = None
|
| 111 |
+
logits_per_image: Optional[tf.Tensor] = None
|
| 112 |
+
logits_per_text: Optional[tf.Tensor] = None
|
| 113 |
+
text_embeds: Optional[tf.Tensor] = None
|
| 114 |
+
image_embeds: Optional[tf.Tensor] = None
|
| 115 |
+
text_model_output: TFBaseModelOutputWithPooling = None
|
| 116 |
+
vision_model_output: TFBaseModelOutputWithPooling = None
|
| 117 |
+
|
| 118 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 119 |
+
return tuple(
|
| 120 |
+
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
| 121 |
+
for k in self.keys()
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class TFCLIPVisionEmbeddings(keras.layers.Layer):
|
| 126 |
+
def __init__(self, config: CLIPVisionConfig, **kwargs):
|
| 127 |
+
super().__init__(**kwargs)
|
| 128 |
+
|
| 129 |
+
self.embed_dim = config.hidden_size
|
| 130 |
+
self.image_size = config.image_size
|
| 131 |
+
self.patch_size = config.patch_size
|
| 132 |
+
|
| 133 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 134 |
+
self.num_positions = self.num_patches + 1
|
| 135 |
+
|
| 136 |
+
self.config = config
|
| 137 |
+
|
| 138 |
+
self.patch_embedding = keras.layers.Conv2D(
|
| 139 |
+
filters=self.embed_dim,
|
| 140 |
+
kernel_size=self.patch_size,
|
| 141 |
+
strides=self.patch_size,
|
| 142 |
+
padding="valid",
|
| 143 |
+
data_format="channels_last",
|
| 144 |
+
use_bias=False,
|
| 145 |
+
kernel_initializer=get_initializer(self.config.initializer_range * self.config.initializer_factor),
|
| 146 |
+
name="patch_embedding",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def build(self, input_shape: tf.TensorShape = None):
|
| 150 |
+
factor = self.config.initializer_factor
|
| 151 |
+
|
| 152 |
+
self.class_embedding = self.add_weight(
|
| 153 |
+
shape=(self.embed_dim,),
|
| 154 |
+
initializer=get_initializer(self.embed_dim**-0.5 * factor),
|
| 155 |
+
trainable=True,
|
| 156 |
+
name="class_embedding",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
with tf.name_scope("position_embedding"):
|
| 160 |
+
self.position_embedding = self.add_weight(
|
| 161 |
+
shape=(self.num_positions, self.embed_dim),
|
| 162 |
+
initializer=get_initializer(self.config.initializer_range * factor),
|
| 163 |
+
trainable=True,
|
| 164 |
+
name="embeddings",
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if self.built:
|
| 168 |
+
return
|
| 169 |
+
self.built = True
|
| 170 |
+
if getattr(self, "patch_embedding", None) is not None:
|
| 171 |
+
with tf.name_scope(self.patch_embedding.name):
|
| 172 |
+
self.patch_embedding.build([None, None, None, self.config.num_channels])
|
| 173 |
+
|
| 174 |
+
def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
|
| 175 |
+
"""`pixel_values` is expected to be of NCHW format."""
|
| 176 |
+
|
| 177 |
+
batch_size, num_channels, height, width = shape_list(pixel_values)
|
| 178 |
+
|
| 179 |
+
# When running on CPU, `tf.nn.conv2d` doesn't support `NCHW` format.
|
| 180 |
+
# So change the input format from `NCHW` to `NHWC`.
|
| 181 |
+
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
|
| 182 |
+
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
| 183 |
+
|
| 184 |
+
patch_embeds = self.patch_embedding(pixel_values)
|
| 185 |
+
|
| 186 |
+
# Change the 2D spatial dimensions to a single temporal dimension.
|
| 187 |
+
# shape = (batch_size, num_patches, out_channels=embed_dim)
|
| 188 |
+
patch_embeds = tf.reshape(tensor=patch_embeds, shape=(batch_size, self.num_patches, -1))
|
| 189 |
+
|
| 190 |
+
# add the [CLS] token to the embedded patch tokens
|
| 191 |
+
class_embeds = tf.broadcast_to(self.class_embedding, shape=(batch_size, 1, self.embed_dim))
|
| 192 |
+
embeddings = tf.concat((class_embeds, patch_embeds), axis=1)
|
| 193 |
+
|
| 194 |
+
embeddings = embeddings + self.position_embedding
|
| 195 |
+
|
| 196 |
+
return embeddings
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class TFCLIPTextEmbeddings(keras.layers.Layer):
|
| 200 |
+
def __init__(self, config: CLIPTextConfig, **kwargs):
|
| 201 |
+
super().__init__(**kwargs)
|
| 202 |
+
|
| 203 |
+
self.embed_dim = config.hidden_size
|
| 204 |
+
|
| 205 |
+
self.config = config
|
| 206 |
+
|
| 207 |
+
def build(self, input_shape: tf.TensorShape = None):
|
| 208 |
+
with tf.name_scope("token_embedding"):
|
| 209 |
+
self.weight = self.add_weight(
|
| 210 |
+
shape=(self.config.vocab_size, self.embed_dim),
|
| 211 |
+
initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range),
|
| 212 |
+
trainable=True,
|
| 213 |
+
name="weight",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
with tf.name_scope("position_embedding"):
|
| 217 |
+
self.position_embedding = self.add_weight(
|
| 218 |
+
shape=(self.config.max_position_embeddings, self.embed_dim),
|
| 219 |
+
initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range),
|
| 220 |
+
trainable=True,
|
| 221 |
+
name="embeddings",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
super().build(input_shape)
|
| 225 |
+
|
| 226 |
+
def call(
|
| 227 |
+
self,
|
| 228 |
+
input_ids: Optional[tf.Tensor] = None,
|
| 229 |
+
position_ids: Optional[tf.Tensor] = None,
|
| 230 |
+
inputs_embeds: Optional[tf.Tensor] = None,
|
| 231 |
+
) -> tf.Tensor:
|
| 232 |
+
"""
|
| 233 |
+
Applies embedding based on inputs tensor.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
final_embeddings (`tf.Tensor`): output embedding tensor.
|
| 237 |
+
"""
|
| 238 |
+
if input_ids is None and inputs_embeds is None:
|
| 239 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 240 |
+
|
| 241 |
+
if inputs_embeds is None:
|
| 242 |
+
check_embeddings_within_bounds(input_ids, self.config.vocab_size)
|
| 243 |
+
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
|
| 244 |
+
|
| 245 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 246 |
+
|
| 247 |
+
if position_ids is None:
|
| 248 |
+
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
| 249 |
+
|
| 250 |
+
position_embeds = tf.gather(params=self.position_embedding, indices=position_ids)
|
| 251 |
+
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
| 252 |
+
final_embeddings = inputs_embeds + position_embeds
|
| 253 |
+
|
| 254 |
+
return final_embeddings
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class TFCLIPAttention(keras.layers.Layer):
|
| 258 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 259 |
+
|
| 260 |
+
def __init__(self, config: CLIPConfig, **kwargs):
|
| 261 |
+
super().__init__(**kwargs)
|
| 262 |
+
|
| 263 |
+
self.embed_dim = config.hidden_size
|
| 264 |
+
self.num_attention_heads = config.num_attention_heads
|
| 265 |
+
self.attention_head_size = self.embed_dim // self.num_attention_heads
|
| 266 |
+
if self.attention_head_size * self.num_attention_heads != self.embed_dim:
|
| 267 |
+
raise ValueError(
|
| 268 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 269 |
+
f" {self.num_attention_heads})."
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
factor = config.initializer_factor
|
| 273 |
+
in_proj_std = (self.embed_dim**-0.5) * ((2 * config.num_hidden_layers) ** -0.5) * factor
|
| 274 |
+
out_proj_std = (self.embed_dim**-0.5) * factor
|
| 275 |
+
|
| 276 |
+
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
|
| 277 |
+
|
| 278 |
+
self.q_proj = keras.layers.Dense(
|
| 279 |
+
units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="q_proj"
|
| 280 |
+
)
|
| 281 |
+
self.k_proj = keras.layers.Dense(
|
| 282 |
+
units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="k_proj"
|
| 283 |
+
)
|
| 284 |
+
self.v_proj = keras.layers.Dense(
|
| 285 |
+
units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="v_proj"
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
self.dropout = keras.layers.Dropout(rate=config.attention_dropout)
|
| 289 |
+
|
| 290 |
+
self.out_proj = keras.layers.Dense(
|
| 291 |
+
units=self.embed_dim, kernel_initializer=get_initializer(out_proj_std), name="out_proj"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention.transpose_for_scores
|
| 295 |
+
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
| 296 |
+
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
| 297 |
+
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
| 298 |
+
|
| 299 |
+
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
| 300 |
+
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
| 301 |
+
|
| 302 |
+
def call(
|
| 303 |
+
self,
|
| 304 |
+
hidden_states: tf.Tensor,
|
| 305 |
+
attention_mask: tf.Tensor,
|
| 306 |
+
causal_attention_mask: tf.Tensor,
|
| 307 |
+
output_attentions: bool,
|
| 308 |
+
training: bool = False,
|
| 309 |
+
) -> Tuple[tf.Tensor]:
|
| 310 |
+
"""Input shape: Batch x Time x Channel"""
|
| 311 |
+
|
| 312 |
+
batch_size = shape_list(hidden_states)[0]
|
| 313 |
+
mixed_query_layer = self.q_proj(inputs=hidden_states)
|
| 314 |
+
mixed_key_layer = self.k_proj(inputs=hidden_states)
|
| 315 |
+
mixed_value_layer = self.v_proj(inputs=hidden_states)
|
| 316 |
+
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
| 317 |
+
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
| 318 |
+
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
| 319 |
+
|
| 320 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 321 |
+
# (batch size, num_heads, seq_len_q, seq_len_k)
|
| 322 |
+
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
| 323 |
+
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
|
| 324 |
+
attention_scores = tf.divide(attention_scores, dk)
|
| 325 |
+
|
| 326 |
+
# apply the causal_attention_mask first
|
| 327 |
+
if causal_attention_mask is not None:
|
| 328 |
+
# Apply the causal attention mask (precomputed for all layers in TFCLIPModel call() function)
|
| 329 |
+
attention_scores = tf.add(attention_scores, causal_attention_mask)
|
| 330 |
+
|
| 331 |
+
if attention_mask is not None:
|
| 332 |
+
# Apply the attention mask (precomputed for all layers in TFCLIPModel call() function)
|
| 333 |
+
attention_scores = tf.add(attention_scores, attention_mask)
|
| 334 |
+
|
| 335 |
+
# Normalize the attention scores to probabilities.
|
| 336 |
+
_attention_probs = stable_softmax(logits=attention_scores, axis=-1)
|
| 337 |
+
|
| 338 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 339 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 340 |
+
attention_probs = self.dropout(inputs=_attention_probs, training=training)
|
| 341 |
+
|
| 342 |
+
attention_output = tf.matmul(attention_probs, value_layer)
|
| 343 |
+
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
|
| 344 |
+
|
| 345 |
+
# (batch_size, seq_len_q, embed_dim)
|
| 346 |
+
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.embed_dim))
|
| 347 |
+
|
| 348 |
+
attention_output = self.out_proj(attention_output, training=training)
|
| 349 |
+
# In TFBert, attention weights are returned after dropout.
|
| 350 |
+
# However, in CLIP, they are returned before dropout.
|
| 351 |
+
outputs = (attention_output, _attention_probs) if output_attentions else (attention_output,)
|
| 352 |
+
|
| 353 |
+
return outputs
|
| 354 |
+
|
| 355 |
+
def build(self, input_shape=None):
|
| 356 |
+
if self.built:
|
| 357 |
+
return
|
| 358 |
+
self.built = True
|
| 359 |
+
if getattr(self, "q_proj", None) is not None:
|
| 360 |
+
with tf.name_scope(self.q_proj.name):
|
| 361 |
+
self.q_proj.build([None, None, self.embed_dim])
|
| 362 |
+
if getattr(self, "k_proj", None) is not None:
|
| 363 |
+
with tf.name_scope(self.k_proj.name):
|
| 364 |
+
self.k_proj.build([None, None, self.embed_dim])
|
| 365 |
+
if getattr(self, "v_proj", None) is not None:
|
| 366 |
+
with tf.name_scope(self.v_proj.name):
|
| 367 |
+
self.v_proj.build([None, None, self.embed_dim])
|
| 368 |
+
if getattr(self, "out_proj", None) is not None:
|
| 369 |
+
with tf.name_scope(self.out_proj.name):
|
| 370 |
+
self.out_proj.build([None, None, self.embed_dim])
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class TFCLIPMLP(keras.layers.Layer):
|
| 374 |
+
def __init__(self, config: CLIPConfig, **kwargs):
|
| 375 |
+
super().__init__(**kwargs)
|
| 376 |
+
|
| 377 |
+
self.activation_fn = get_tf_activation(config.hidden_act)
|
| 378 |
+
|
| 379 |
+
factor = config.initializer_factor
|
| 380 |
+
in_proj_std = (config.hidden_size**-0.5) * ((2 * config.num_hidden_layers) ** -0.5) * factor
|
| 381 |
+
fc_std = (2 * config.hidden_size) ** -0.5 * factor
|
| 382 |
+
|
| 383 |
+
self.fc1 = keras.layers.Dense(
|
| 384 |
+
units=config.intermediate_size, kernel_initializer=get_initializer(fc_std), name="fc1"
|
| 385 |
+
)
|
| 386 |
+
self.fc2 = keras.layers.Dense(
|
| 387 |
+
units=config.hidden_size, kernel_initializer=get_initializer(in_proj_std), name="fc2"
|
| 388 |
+
)
|
| 389 |
+
self.config = config
|
| 390 |
+
|
| 391 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 392 |
+
hidden_states = self.fc1(inputs=hidden_states)
|
| 393 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 394 |
+
hidden_states = self.fc2(inputs=hidden_states)
|
| 395 |
+
return hidden_states
|
| 396 |
+
|
| 397 |
+
def build(self, input_shape=None):
|
| 398 |
+
if self.built:
|
| 399 |
+
return
|
| 400 |
+
self.built = True
|
| 401 |
+
if getattr(self, "fc1", None) is not None:
|
| 402 |
+
with tf.name_scope(self.fc1.name):
|
| 403 |
+
self.fc1.build([None, None, self.config.hidden_size])
|
| 404 |
+
if getattr(self, "fc2", None) is not None:
|
| 405 |
+
with tf.name_scope(self.fc2.name):
|
| 406 |
+
self.fc2.build([None, None, self.config.intermediate_size])
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class TFCLIPEncoderLayer(keras.layers.Layer):
|
| 410 |
+
def __init__(self, config: CLIPConfig, **kwargs):
|
| 411 |
+
super().__init__(**kwargs)
|
| 412 |
+
|
| 413 |
+
self.embed_dim = config.hidden_size
|
| 414 |
+
self.self_attn = TFCLIPAttention(config, name="self_attn")
|
| 415 |
+
self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
|
| 416 |
+
self.mlp = TFCLIPMLP(config, name="mlp")
|
| 417 |
+
self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
|
| 418 |
+
|
| 419 |
+
def call(
|
| 420 |
+
self,
|
| 421 |
+
hidden_states: tf.Tensor,
|
| 422 |
+
attention_mask: tf.Tensor,
|
| 423 |
+
causal_attention_mask: tf.Tensor,
|
| 424 |
+
output_attentions: bool,
|
| 425 |
+
training: bool = False,
|
| 426 |
+
) -> Tuple[tf.Tensor]:
|
| 427 |
+
"""
|
| 428 |
+
Args:
|
| 429 |
+
hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 430 |
+
attention_mask (`tf.Tensor`): attention mask of size
|
| 431 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 432 |
+
causal_attention_mask (`tf.Tensor`): causal attention mask of size
|
| 433 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 434 |
+
output_attentions (`bool`):
|
| 435 |
+
Whether or not to return the attentions tensors of all attention layers. See `outputs` under returned
|
| 436 |
+
tensors for more detail.
|
| 437 |
+
"""
|
| 438 |
+
residual = hidden_states
|
| 439 |
+
|
| 440 |
+
hidden_states = self.layer_norm1(inputs=hidden_states)
|
| 441 |
+
attention_outputs = self.self_attn(
|
| 442 |
+
hidden_states=hidden_states,
|
| 443 |
+
attention_mask=attention_mask,
|
| 444 |
+
causal_attention_mask=causal_attention_mask,
|
| 445 |
+
output_attentions=output_attentions,
|
| 446 |
+
training=training,
|
| 447 |
+
)
|
| 448 |
+
hidden_states = attention_outputs[0]
|
| 449 |
+
hidden_states = residual + hidden_states
|
| 450 |
+
|
| 451 |
+
residual = hidden_states
|
| 452 |
+
hidden_states = self.layer_norm2(inputs=hidden_states)
|
| 453 |
+
hidden_states = self.mlp(hidden_states=hidden_states)
|
| 454 |
+
hidden_states = residual + hidden_states
|
| 455 |
+
|
| 456 |
+
outputs = (hidden_states,) + attention_outputs[1:] # add attentions if we output them
|
| 457 |
+
|
| 458 |
+
return outputs
|
| 459 |
+
|
| 460 |
+
def build(self, input_shape=None):
|
| 461 |
+
if self.built:
|
| 462 |
+
return
|
| 463 |
+
self.built = True
|
| 464 |
+
if getattr(self, "self_attn", None) is not None:
|
| 465 |
+
with tf.name_scope(self.self_attn.name):
|
| 466 |
+
self.self_attn.build(None)
|
| 467 |
+
if getattr(self, "layer_norm1", None) is not None:
|
| 468 |
+
with tf.name_scope(self.layer_norm1.name):
|
| 469 |
+
self.layer_norm1.build([None, None, self.embed_dim])
|
| 470 |
+
if getattr(self, "mlp", None) is not None:
|
| 471 |
+
with tf.name_scope(self.mlp.name):
|
| 472 |
+
self.mlp.build(None)
|
| 473 |
+
if getattr(self, "layer_norm2", None) is not None:
|
| 474 |
+
with tf.name_scope(self.layer_norm2.name):
|
| 475 |
+
self.layer_norm2.build([None, None, self.embed_dim])
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class TFCLIPEncoder(keras.layers.Layer):
|
| 479 |
+
"""
|
| 480 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 481 |
+
[`TFCLIPEncoderLayer`].
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
config: CLIPConfig
|
| 485 |
+
"""
|
| 486 |
+
|
| 487 |
+
def __init__(self, config: CLIPConfig, **kwargs):
|
| 488 |
+
super().__init__(**kwargs)
|
| 489 |
+
|
| 490 |
+
self.layers = [TFCLIPEncoderLayer(config, name=f"layers_._{i}") for i in range(config.num_hidden_layers)]
|
| 491 |
+
|
| 492 |
+
def call(
|
| 493 |
+
self,
|
| 494 |
+
hidden_states: tf.Tensor,
|
| 495 |
+
attention_mask: tf.Tensor,
|
| 496 |
+
causal_attention_mask: tf.Tensor,
|
| 497 |
+
output_attentions: bool,
|
| 498 |
+
output_hidden_states: bool,
|
| 499 |
+
return_dict: bool,
|
| 500 |
+
training: bool = False,
|
| 501 |
+
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
| 502 |
+
all_hidden_states = () if output_hidden_states else None
|
| 503 |
+
all_attentions = () if output_attentions else None
|
| 504 |
+
|
| 505 |
+
for i, layer_module in enumerate(self.layers):
|
| 506 |
+
if output_hidden_states:
|
| 507 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 508 |
+
|
| 509 |
+
layer_outputs = layer_module(
|
| 510 |
+
hidden_states=hidden_states,
|
| 511 |
+
attention_mask=attention_mask,
|
| 512 |
+
causal_attention_mask=causal_attention_mask,
|
| 513 |
+
output_attentions=output_attentions,
|
| 514 |
+
training=training,
|
| 515 |
+
)
|
| 516 |
+
hidden_states = layer_outputs[0]
|
| 517 |
+
|
| 518 |
+
if output_attentions:
|
| 519 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 520 |
+
|
| 521 |
+
# Add last layer
|
| 522 |
+
if output_hidden_states:
|
| 523 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 524 |
+
|
| 525 |
+
if not return_dict:
|
| 526 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 527 |
+
|
| 528 |
+
return TFBaseModelOutput(
|
| 529 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
def build(self, input_shape=None):
|
| 533 |
+
if self.built:
|
| 534 |
+
return
|
| 535 |
+
self.built = True
|
| 536 |
+
if getattr(self, "layers", None) is not None:
|
| 537 |
+
for layer in self.layers:
|
| 538 |
+
with tf.name_scope(layer.name):
|
| 539 |
+
layer.build(None)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class TFCLIPTextTransformer(keras.layers.Layer):
|
| 543 |
+
def __init__(self, config: CLIPTextConfig, **kwargs):
|
| 544 |
+
super().__init__(**kwargs)
|
| 545 |
+
|
| 546 |
+
self.embeddings = TFCLIPTextEmbeddings(config, name="embeddings")
|
| 547 |
+
self.encoder = TFCLIPEncoder(config, name="encoder")
|
| 548 |
+
self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm")
|
| 549 |
+
|
| 550 |
+
# For `pooled_output` computation
|
| 551 |
+
self.eos_token_id = config.eos_token_id
|
| 552 |
+
self.embed_dim = config.hidden_size
|
| 553 |
+
|
| 554 |
+
def call(
|
| 555 |
+
self,
|
| 556 |
+
input_ids: TFModelInputType,
|
| 557 |
+
attention_mask: tf.Tensor,
|
| 558 |
+
position_ids: tf.Tensor,
|
| 559 |
+
output_attentions: bool,
|
| 560 |
+
output_hidden_states: bool,
|
| 561 |
+
return_dict: bool,
|
| 562 |
+
training: bool = False,
|
| 563 |
+
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
| 564 |
+
input_shape = shape_list(input_ids)
|
| 565 |
+
|
| 566 |
+
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
| 567 |
+
|
| 568 |
+
batch_size, seq_length = input_shape
|
| 569 |
+
# CLIP's text model uses causal mask, prepare it here.
|
| 570 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
| 571 |
+
causal_attention_mask = self._build_causal_attention_mask(batch_size, seq_length, dtype=embedding_output.dtype)
|
| 572 |
+
|
| 573 |
+
# check attention mask and invert
|
| 574 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 575 |
+
attention_mask = _expand_mask(attention_mask)
|
| 576 |
+
|
| 577 |
+
encoder_outputs = self.encoder(
|
| 578 |
+
hidden_states=embedding_output,
|
| 579 |
+
attention_mask=attention_mask,
|
| 580 |
+
causal_attention_mask=causal_attention_mask,
|
| 581 |
+
output_attentions=output_attentions,
|
| 582 |
+
output_hidden_states=output_hidden_states,
|
| 583 |
+
return_dict=return_dict,
|
| 584 |
+
training=training,
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
sequence_output = encoder_outputs[0]
|
| 588 |
+
sequence_output = self.final_layer_norm(inputs=sequence_output)
|
| 589 |
+
|
| 590 |
+
if self.eos_token_id == 2:
|
| 591 |
+
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
| 592 |
+
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
| 593 |
+
# ------------------------------------------------------------
|
| 594 |
+
# text_embeds.shape = [batch_size, n_ctx, transformer.width]
|
| 595 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 596 |
+
pooled_output = tf.gather_nd(
|
| 597 |
+
params=sequence_output,
|
| 598 |
+
indices=tf.stack(
|
| 599 |
+
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
|
| 600 |
+
),
|
| 601 |
+
)
|
| 602 |
+
else:
|
| 603 |
+
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
| 604 |
+
pooled_output = tf.gather_nd(
|
| 605 |
+
params=sequence_output,
|
| 606 |
+
indices=tf.stack(
|
| 607 |
+
values=(
|
| 608 |
+
tf.range(input_shape[0], dtype=tf.int64),
|
| 609 |
+
tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1),
|
| 610 |
+
),
|
| 611 |
+
axis=1,
|
| 612 |
+
),
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
if not return_dict:
|
| 616 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 617 |
+
|
| 618 |
+
return TFBaseModelOutputWithPooling(
|
| 619 |
+
last_hidden_state=sequence_output,
|
| 620 |
+
pooler_output=pooled_output,
|
| 621 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 622 |
+
attentions=encoder_outputs.attentions,
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
def _build_causal_attention_mask(self, batch_size, seq_length, dtype=tf.float32):
|
| 626 |
+
# It is possible with an unspecified sequence length for seq_length to be
|
| 627 |
+
# a runtime value, which is unsupported by tf.constant. Per the TensorFlow
|
| 628 |
+
# docs, tf.fill can handle runtime dynamic shapes:
|
| 629 |
+
# https://www.tensorflow.org/api_docs/python/tf/fill
|
| 630 |
+
diag = tf.cast(tf.fill((seq_length,), 0.0), dtype)
|
| 631 |
+
|
| 632 |
+
# set an additive 2D attention mask with all places being masked
|
| 633 |
+
to_mask = tf.cast(tf.fill((seq_length, seq_length), -10000.0), dtype)
|
| 634 |
+
|
| 635 |
+
# set diagonal & lower triangular parts to 0 (i.e. the places not to be masked)
|
| 636 |
+
# TIP: think the 2D matrix as the space of (query_seq, key_seq)
|
| 637 |
+
to_mask = tf.linalg.band_part(to_mask, 0, -1)
|
| 638 |
+
# to_mask = tf.linalg.band_part(to_mask, -1, 0)
|
| 639 |
+
to_mask = tf.linalg.set_diag(to_mask, diagonal=diag)
|
| 640 |
+
|
| 641 |
+
return tf.broadcast_to(input=to_mask, shape=(batch_size, 1, seq_length, seq_length))
|
| 642 |
+
|
| 643 |
+
def build(self, input_shape=None):
|
| 644 |
+
if self.built:
|
| 645 |
+
return
|
| 646 |
+
self.built = True
|
| 647 |
+
if getattr(self, "embeddings", None) is not None:
|
| 648 |
+
with tf.name_scope(self.embeddings.name):
|
| 649 |
+
self.embeddings.build(None)
|
| 650 |
+
if getattr(self, "encoder", None) is not None:
|
| 651 |
+
with tf.name_scope(self.encoder.name):
|
| 652 |
+
self.encoder.build(None)
|
| 653 |
+
if getattr(self, "final_layer_norm", None) is not None:
|
| 654 |
+
with tf.name_scope(self.final_layer_norm.name):
|
| 655 |
+
self.final_layer_norm.build([None, None, self.embed_dim])
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
@keras_serializable
|
| 659 |
+
class TFCLIPTextMainLayer(keras.layers.Layer):
|
| 660 |
+
config_class = CLIPTextConfig
|
| 661 |
+
|
| 662 |
+
def __init__(self, config: CLIPTextConfig, **kwargs):
|
| 663 |
+
super().__init__(**kwargs)
|
| 664 |
+
self.config = config
|
| 665 |
+
self.text_model = TFCLIPTextTransformer(config, name="text_model")
|
| 666 |
+
|
| 667 |
+
def get_input_embeddings(self) -> keras.layers.Layer:
|
| 668 |
+
return self.text_model.embeddings
|
| 669 |
+
|
| 670 |
+
def set_input_embeddings(self, value: tf.Variable):
|
| 671 |
+
self.text_model.embeddings.weight = value
|
| 672 |
+
self.text_model.embeddings.vocab_size = shape_list(value)[0]
|
| 673 |
+
|
| 674 |
+
@unpack_inputs
|
| 675 |
+
def call(
|
| 676 |
+
self,
|
| 677 |
+
input_ids: TFModelInputType | None = None,
|
| 678 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 679 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 680 |
+
output_attentions: Optional[bool] = None,
|
| 681 |
+
output_hidden_states: Optional[bool] = None,
|
| 682 |
+
return_dict: Optional[bool] = None,
|
| 683 |
+
training: bool = False,
|
| 684 |
+
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
| 685 |
+
if input_ids is None:
|
| 686 |
+
raise ValueError("You have to specify input_ids")
|
| 687 |
+
|
| 688 |
+
input_shape = shape_list(input_ids)
|
| 689 |
+
|
| 690 |
+
if attention_mask is None:
|
| 691 |
+
attention_mask = tf.fill(dims=input_shape, value=1)
|
| 692 |
+
|
| 693 |
+
text_model_outputs = self.text_model(
|
| 694 |
+
input_ids=input_ids,
|
| 695 |
+
attention_mask=attention_mask,
|
| 696 |
+
position_ids=position_ids,
|
| 697 |
+
output_attentions=output_attentions,
|
| 698 |
+
output_hidden_states=output_hidden_states,
|
| 699 |
+
return_dict=return_dict,
|
| 700 |
+
training=training,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
return text_model_outputs
|
| 704 |
+
|
| 705 |
+
def build(self, input_shape=None):
|
| 706 |
+
if self.built:
|
| 707 |
+
return
|
| 708 |
+
self.built = True
|
| 709 |
+
if getattr(self, "text_model", None) is not None:
|
| 710 |
+
with tf.name_scope(self.text_model.name):
|
| 711 |
+
self.text_model.build(None)
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
class TFCLIPVisionTransformer(keras.layers.Layer):
|
| 715 |
+
def __init__(self, config: CLIPVisionConfig, **kwargs):
|
| 716 |
+
super().__init__(**kwargs)
|
| 717 |
+
|
| 718 |
+
self.embeddings = TFCLIPVisionEmbeddings(config, name="embeddings")
|
| 719 |
+
self.pre_layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="pre_layrnorm")
|
| 720 |
+
self.encoder = TFCLIPEncoder(config, name="encoder")
|
| 721 |
+
self.post_layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm")
|
| 722 |
+
self.embed_dim = config.hidden_size
|
| 723 |
+
|
| 724 |
+
def call(
|
| 725 |
+
self,
|
| 726 |
+
pixel_values: TFModelInputType,
|
| 727 |
+
output_attentions: bool,
|
| 728 |
+
output_hidden_states: bool,
|
| 729 |
+
return_dict: bool,
|
| 730 |
+
training: bool = False,
|
| 731 |
+
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
| 732 |
+
embedding_output = self.embeddings(pixel_values=pixel_values)
|
| 733 |
+
embedding_output = self.pre_layernorm(inputs=embedding_output)
|
| 734 |
+
|
| 735 |
+
encoder_outputs = self.encoder(
|
| 736 |
+
hidden_states=embedding_output,
|
| 737 |
+
attention_mask=None,
|
| 738 |
+
causal_attention_mask=None,
|
| 739 |
+
output_attentions=output_attentions,
|
| 740 |
+
output_hidden_states=output_hidden_states,
|
| 741 |
+
return_dict=return_dict,
|
| 742 |
+
training=training,
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
sequence_output = encoder_outputs[0]
|
| 746 |
+
pooled_output = sequence_output[:, 0, :]
|
| 747 |
+
pooled_output = self.post_layernorm(inputs=pooled_output)
|
| 748 |
+
|
| 749 |
+
if not return_dict:
|
| 750 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 751 |
+
|
| 752 |
+
return TFBaseModelOutputWithPooling(
|
| 753 |
+
last_hidden_state=sequence_output,
|
| 754 |
+
pooler_output=pooled_output,
|
| 755 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 756 |
+
attentions=encoder_outputs.attentions,
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
def build(self, input_shape=None):
|
| 760 |
+
if self.built:
|
| 761 |
+
return
|
| 762 |
+
self.built = True
|
| 763 |
+
if getattr(self, "embeddings", None) is not None:
|
| 764 |
+
with tf.name_scope(self.embeddings.name):
|
| 765 |
+
self.embeddings.build(None)
|
| 766 |
+
if getattr(self, "pre_layernorm", None) is not None:
|
| 767 |
+
with tf.name_scope(self.pre_layernorm.name):
|
| 768 |
+
self.pre_layernorm.build([None, None, self.embed_dim])
|
| 769 |
+
if getattr(self, "encoder", None) is not None:
|
| 770 |
+
with tf.name_scope(self.encoder.name):
|
| 771 |
+
self.encoder.build(None)
|
| 772 |
+
if getattr(self, "post_layernorm", None) is not None:
|
| 773 |
+
with tf.name_scope(self.post_layernorm.name):
|
| 774 |
+
self.post_layernorm.build([None, self.embed_dim])
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
@keras_serializable
|
| 778 |
+
class TFCLIPVisionMainLayer(keras.layers.Layer):
|
| 779 |
+
config_class = CLIPVisionConfig
|
| 780 |
+
|
| 781 |
+
def __init__(self, config: CLIPVisionConfig, **kwargs):
|
| 782 |
+
super().__init__(**kwargs)
|
| 783 |
+
self.config = config
|
| 784 |
+
self.vision_model = TFCLIPVisionTransformer(config, name="vision_model")
|
| 785 |
+
|
| 786 |
+
def get_input_embeddings(self) -> keras.layers.Layer:
|
| 787 |
+
return self.vision_model.embeddings
|
| 788 |
+
|
| 789 |
+
@unpack_inputs
|
| 790 |
+
def call(
|
| 791 |
+
self,
|
| 792 |
+
pixel_values: TFModelInputType | None = None,
|
| 793 |
+
output_attentions: Optional[bool] = None,
|
| 794 |
+
output_hidden_states: Optional[bool] = None,
|
| 795 |
+
return_dict: Optional[bool] = None,
|
| 796 |
+
training: bool = False,
|
| 797 |
+
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
| 798 |
+
if pixel_values is None:
|
| 799 |
+
raise ValueError("You have to specify pixel_values")
|
| 800 |
+
|
| 801 |
+
vision_model_outputs = self.vision_model(
|
| 802 |
+
pixel_values=pixel_values,
|
| 803 |
+
output_attentions=output_attentions,
|
| 804 |
+
output_hidden_states=output_hidden_states,
|
| 805 |
+
return_dict=return_dict,
|
| 806 |
+
training=training,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
return vision_model_outputs
|
| 810 |
+
|
| 811 |
+
def build(self, input_shape=None):
|
| 812 |
+
if self.built:
|
| 813 |
+
return
|
| 814 |
+
self.built = True
|
| 815 |
+
if getattr(self, "vision_model", None) is not None:
|
| 816 |
+
with tf.name_scope(self.vision_model.name):
|
| 817 |
+
self.vision_model.build(None)
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
@keras_serializable
|
| 821 |
+
class TFCLIPMainLayer(keras.layers.Layer):
|
| 822 |
+
config_class = CLIPConfig
|
| 823 |
+
|
| 824 |
+
def __init__(self, config: CLIPConfig, **kwargs):
|
| 825 |
+
super().__init__(**kwargs)
|
| 826 |
+
|
| 827 |
+
if not isinstance(config.text_config, CLIPTextConfig):
|
| 828 |
+
raise TypeError(
|
| 829 |
+
"config.text_config is expected to be of type CLIPTextConfig but is of type"
|
| 830 |
+
f" {type(config.text_config)}."
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
if not isinstance(config.vision_config, CLIPVisionConfig):
|
| 834 |
+
raise TypeError(
|
| 835 |
+
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
|
| 836 |
+
f" {type(config.vision_config)}."
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
self.config = config
|
| 840 |
+
|
| 841 |
+
text_config = config.text_config
|
| 842 |
+
vision_config = config.vision_config
|
| 843 |
+
|
| 844 |
+
self.projection_dim = config.projection_dim
|
| 845 |
+
|
| 846 |
+
self.text_model = TFCLIPTextTransformer(text_config, name="text_model")
|
| 847 |
+
self.vision_model = TFCLIPVisionTransformer(vision_config, name="vision_model")
|
| 848 |
+
|
| 849 |
+
self.visual_projection = keras.layers.Dense(
|
| 850 |
+
units=self.projection_dim,
|
| 851 |
+
kernel_initializer=get_initializer(vision_config.hidden_size**-0.5 * self.config.initializer_factor),
|
| 852 |
+
use_bias=False,
|
| 853 |
+
name="visual_projection",
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
self.text_projection = keras.layers.Dense(
|
| 857 |
+
units=self.projection_dim,
|
| 858 |
+
kernel_initializer=get_initializer(text_config.hidden_size**-0.5 * self.config.initializer_factor),
|
| 859 |
+
use_bias=False,
|
| 860 |
+
name="text_projection",
|
| 861 |
+
)
|
| 862 |
+
self.text_embed_dim = text_config.hidden_size
|
| 863 |
+
self.vision_embed_dim = vision_config.hidden_size
|
| 864 |
+
|
| 865 |
+
def build(self, input_shape: tf.TensorShape = None):
|
| 866 |
+
self.logit_scale = self.add_weight(
|
| 867 |
+
shape=(1,),
|
| 868 |
+
initializer=keras.initializers.Constant(self.config.logit_scale_init_value),
|
| 869 |
+
trainable=True,
|
| 870 |
+
name="logit_scale",
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
if self.built:
|
| 874 |
+
return
|
| 875 |
+
self.built = True
|
| 876 |
+
if getattr(self, "text_model", None) is not None:
|
| 877 |
+
with tf.name_scope(self.text_model.name):
|
| 878 |
+
self.text_model.build(None)
|
| 879 |
+
if getattr(self, "vision_model", None) is not None:
|
| 880 |
+
with tf.name_scope(self.vision_model.name):
|
| 881 |
+
self.vision_model.build(None)
|
| 882 |
+
if getattr(self, "visual_projection", None) is not None:
|
| 883 |
+
with tf.name_scope(self.visual_projection.name):
|
| 884 |
+
self.visual_projection.build([None, None, self.vision_embed_dim])
|
| 885 |
+
if getattr(self, "text_projection", None) is not None:
|
| 886 |
+
with tf.name_scope(self.text_projection.name):
|
| 887 |
+
self.text_projection.build([None, None, self.text_embed_dim])
|
| 888 |
+
|
| 889 |
+
@unpack_inputs
|
| 890 |
+
def get_text_features(
|
| 891 |
+
self,
|
| 892 |
+
input_ids: TFModelInputType | None = None,
|
| 893 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 894 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 895 |
+
output_attentions: Optional[bool] = None,
|
| 896 |
+
output_hidden_states: Optional[bool] = None,
|
| 897 |
+
return_dict: Optional[bool] = None,
|
| 898 |
+
training: bool = False,
|
| 899 |
+
) -> tf.Tensor:
|
| 900 |
+
if input_ids is None:
|
| 901 |
+
raise ValueError("You have to specify either input_ids")
|
| 902 |
+
|
| 903 |
+
input_shape = shape_list(input_ids)
|
| 904 |
+
|
| 905 |
+
if attention_mask is None:
|
| 906 |
+
attention_mask = tf.fill(dims=input_shape, value=1)
|
| 907 |
+
|
| 908 |
+
text_outputs = self.text_model(
|
| 909 |
+
input_ids=input_ids,
|
| 910 |
+
attention_mask=attention_mask,
|
| 911 |
+
position_ids=position_ids,
|
| 912 |
+
output_attentions=output_attentions,
|
| 913 |
+
output_hidden_states=output_hidden_states,
|
| 914 |
+
return_dict=return_dict,
|
| 915 |
+
training=training,
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
pooled_output = text_outputs[1]
|
| 919 |
+
text_features = self.text_projection(inputs=pooled_output)
|
| 920 |
+
|
| 921 |
+
return text_features
|
| 922 |
+
|
| 923 |
+
@unpack_inputs
|
| 924 |
+
def get_image_features(
|
| 925 |
+
self,
|
| 926 |
+
pixel_values: TFModelInputType | None = None,
|
| 927 |
+
output_attentions: Optional[bool] = None,
|
| 928 |
+
output_hidden_states: Optional[bool] = None,
|
| 929 |
+
return_dict: Optional[bool] = None,
|
| 930 |
+
training: bool = False,
|
| 931 |
+
) -> tf.Tensor:
|
| 932 |
+
if pixel_values is None:
|
| 933 |
+
raise ValueError("You have to specify pixel_values")
|
| 934 |
+
|
| 935 |
+
vision_outputs = self.vision_model(
|
| 936 |
+
pixel_values=pixel_values,
|
| 937 |
+
output_attentions=output_attentions,
|
| 938 |
+
output_hidden_states=output_hidden_states,
|
| 939 |
+
return_dict=return_dict,
|
| 940 |
+
training=training,
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
pooled_output = vision_outputs[1] # pooled_output
|
| 944 |
+
image_features = self.visual_projection(inputs=pooled_output)
|
| 945 |
+
|
| 946 |
+
return image_features
|
| 947 |
+
|
| 948 |
+
@unpack_inputs
|
| 949 |
+
def call(
|
| 950 |
+
self,
|
| 951 |
+
input_ids: TFModelInputType | None = None,
|
| 952 |
+
pixel_values: TFModelInputType | None = None,
|
| 953 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 954 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 955 |
+
return_loss: Optional[bool] = None,
|
| 956 |
+
output_attentions: Optional[bool] = None,
|
| 957 |
+
output_hidden_states: Optional[bool] = None,
|
| 958 |
+
return_dict: Optional[bool] = None,
|
| 959 |
+
training: bool = False,
|
| 960 |
+
) -> Union[TFCLIPOutput, Tuple[tf.Tensor]]:
|
| 961 |
+
if input_ids is None:
|
| 962 |
+
raise ValueError("You have to specify either input_ids")
|
| 963 |
+
if pixel_values is None:
|
| 964 |
+
raise ValueError("You have to specify pixel_values")
|
| 965 |
+
|
| 966 |
+
input_shape = shape_list(input_ids)
|
| 967 |
+
|
| 968 |
+
if attention_mask is None:
|
| 969 |
+
attention_mask = tf.fill(dims=input_shape, value=1)
|
| 970 |
+
|
| 971 |
+
vision_outputs = self.vision_model(
|
| 972 |
+
pixel_values=pixel_values,
|
| 973 |
+
output_attentions=output_attentions,
|
| 974 |
+
output_hidden_states=output_hidden_states,
|
| 975 |
+
return_dict=return_dict,
|
| 976 |
+
training=training,
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
text_outputs = self.text_model(
|
| 980 |
+
input_ids=input_ids,
|
| 981 |
+
attention_mask=attention_mask,
|
| 982 |
+
position_ids=position_ids,
|
| 983 |
+
output_attentions=output_attentions,
|
| 984 |
+
output_hidden_states=output_hidden_states,
|
| 985 |
+
return_dict=return_dict,
|
| 986 |
+
training=training,
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
image_embeds = vision_outputs[1]
|
| 990 |
+
image_embeds = self.visual_projection(inputs=image_embeds)
|
| 991 |
+
|
| 992 |
+
text_embeds = text_outputs[1]
|
| 993 |
+
text_embeds = self.text_projection(inputs=text_embeds)
|
| 994 |
+
|
| 995 |
+
# normalized features
|
| 996 |
+
image_embeds = image_embeds / tf.norm(tensor=image_embeds, ord="euclidean", axis=-1, keepdims=True)
|
| 997 |
+
text_embeds = text_embeds / tf.norm(tensor=text_embeds, ord="euclidean", axis=-1, keepdims=True)
|
| 998 |
+
|
| 999 |
+
# cosine similarity as logits
|
| 1000 |
+
logit_scale = tf.math.exp(self.logit_scale)
|
| 1001 |
+
logits_per_text = tf.matmul(text_embeds, image_embeds, transpose_b=True) * logit_scale
|
| 1002 |
+
logits_per_image = tf.transpose(logits_per_text)
|
| 1003 |
+
|
| 1004 |
+
loss = None
|
| 1005 |
+
if return_loss:
|
| 1006 |
+
loss = clip_loss(logits_per_text)
|
| 1007 |
+
loss = tf.reshape(loss, (1,))
|
| 1008 |
+
|
| 1009 |
+
if not return_dict:
|
| 1010 |
+
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
| 1011 |
+
return (loss,) + output if loss is not None else output
|
| 1012 |
+
|
| 1013 |
+
return TFCLIPOutput(
|
| 1014 |
+
loss=loss,
|
| 1015 |
+
logits_per_image=logits_per_image,
|
| 1016 |
+
logits_per_text=logits_per_text,
|
| 1017 |
+
text_embeds=text_embeds,
|
| 1018 |
+
image_embeds=image_embeds,
|
| 1019 |
+
text_model_output=text_outputs,
|
| 1020 |
+
vision_model_output=vision_outputs,
|
| 1021 |
+
)
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
class TFCLIPPreTrainedModel(TFPreTrainedModel):
|
| 1025 |
+
"""
|
| 1026 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 1027 |
+
models.
|
| 1028 |
+
"""
|
| 1029 |
+
|
| 1030 |
+
config_class = CLIPConfig
|
| 1031 |
+
base_model_prefix = "clip"
|
| 1032 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 1033 |
+
_keys_to_ignore_on_load_unexpected = [r"position_ids"]
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
CLIP_START_DOCSTRING = r"""
|
| 1037 |
+
|
| 1038 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1039 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1040 |
+
etc.)
|
| 1041 |
+
|
| 1042 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 1043 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 1044 |
+
behavior.
|
| 1045 |
+
|
| 1046 |
+
<Tip>
|
| 1047 |
+
|
| 1048 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 1049 |
+
|
| 1050 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 1051 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 1052 |
+
|
| 1053 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 1054 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 1055 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 1056 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 1057 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 1058 |
+
positional argument:
|
| 1059 |
+
|
| 1060 |
+
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
|
| 1061 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 1062 |
+
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
| 1063 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 1064 |
+
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
|
| 1065 |
+
|
| 1066 |
+
Note that when creating models and layers with
|
| 1067 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 1068 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 1069 |
+
|
| 1070 |
+
</Tip>
|
| 1071 |
+
|
| 1072 |
+
Args:
|
| 1073 |
+
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
| 1074 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 1075 |
+
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1076 |
+
"""
|
| 1077 |
+
|
| 1078 |
+
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
| 1079 |
+
Args:
|
| 1080 |
+
input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
|
| 1081 |
+
Indices of input sequence tokens in the vocabulary.
|
| 1082 |
+
|
| 1083 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
| 1084 |
+
[`PreTrainedTokenizer.encode`] for details.
|
| 1085 |
+
|
| 1086 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1087 |
+
attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 1088 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1089 |
+
|
| 1090 |
+
- 1 for tokens that are **not masked**,
|
| 1091 |
+
- 0 for tokens that are **masked**.
|
| 1092 |
+
|
| 1093 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1094 |
+
position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 1095 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1096 |
+
config.max_position_embeddings - 1]`.
|
| 1097 |
+
|
| 1098 |
+
[What are position IDs?](../glossary#position-ids)
|
| 1099 |
+
output_attentions (`bool`, *optional*):
|
| 1100 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1101 |
+
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
| 1102 |
+
config will be used instead.
|
| 1103 |
+
output_hidden_states (`bool`, *optional*):
|
| 1104 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1105 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 1106 |
+
used instead.
|
| 1107 |
+
return_dict (`bool`, *optional*):
|
| 1108 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 1109 |
+
eager mode, in graph mode the value will always be set to True.
|
| 1110 |
+
training (`bool`, *optional*, defaults to `False``):
|
| 1111 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 1112 |
+
behaviors between training and evaluation).
|
| 1113 |
+
"""
|
| 1114 |
+
|
| 1115 |
+
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
| 1116 |
+
Args:
|
| 1117 |
+
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
|
| 1118 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 1119 |
+
[`CLIPImageProcessor.__call__`] for details. output_attentions (`bool`, *optional*): Whether or not to
|
| 1120 |
+
return the attentions tensors of all attention layers. See `attentions` under returned tensors for more
|
| 1121 |
+
detail. This argument can be used only in eager mode, in graph mode the value in the config will be used
|
| 1122 |
+
instead.
|
| 1123 |
+
output_hidden_states (`bool`, *optional*):
|
| 1124 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1125 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 1126 |
+
used instead.
|
| 1127 |
+
return_dict (`bool`, *optional*):
|
| 1128 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 1129 |
+
eager mode, in graph mode the value will always be set to True.
|
| 1130 |
+
training (`bool`, *optional*, defaults to `False``):
|
| 1131 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 1132 |
+
behaviors between training and evaluation).
|
| 1133 |
+
"""
|
| 1134 |
+
|
| 1135 |
+
CLIP_INPUTS_DOCSTRING = r"""
|
| 1136 |
+
Args:
|
| 1137 |
+
input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
|
| 1138 |
+
Indices of input sequence tokens in the vocabulary.
|
| 1139 |
+
|
| 1140 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
| 1141 |
+
[`PreTrainedTokenizer.encode`] for details.
|
| 1142 |
+
|
| 1143 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1144 |
+
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
|
| 1145 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 1146 |
+
[`CLIPImageProcessor.__call__`] for details.
|
| 1147 |
+
attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 1148 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1149 |
+
|
| 1150 |
+
- 1 for tokens that are **not masked**,
|
| 1151 |
+
- 0 for tokens that are **masked**.
|
| 1152 |
+
|
| 1153 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1154 |
+
position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 1155 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1156 |
+
config.max_position_embeddings - 1]`.
|
| 1157 |
+
|
| 1158 |
+
[What are position IDs?](../glossary#position-ids)
|
| 1159 |
+
return_loss (`bool`, *optional*):
|
| 1160 |
+
Whether or not to return the contrastive loss.
|
| 1161 |
+
output_attentions (`bool`, *optional*):
|
| 1162 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1163 |
+
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
| 1164 |
+
config will be used instead.
|
| 1165 |
+
output_hidden_states (`bool`, *optional*):
|
| 1166 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1167 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 1168 |
+
used instead.
|
| 1169 |
+
return_dict (`bool`, *optional*):
|
| 1170 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 1171 |
+
eager mode, in graph mode the value will always be set to True.
|
| 1172 |
+
training (`bool`, *optional*, defaults to `False``):
|
| 1173 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 1174 |
+
behaviors between training and evaluation).
|
| 1175 |
+
"""
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
class TFCLIPTextModel(TFCLIPPreTrainedModel):
|
| 1179 |
+
config_class = CLIPTextConfig
|
| 1180 |
+
|
| 1181 |
+
def __init__(self, config: CLIPTextConfig, *inputs, **kwargs):
|
| 1182 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1183 |
+
|
| 1184 |
+
self.clip = TFCLIPTextMainLayer(config, name="clip")
|
| 1185 |
+
|
| 1186 |
+
@unpack_inputs
|
| 1187 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1188 |
+
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
| 1189 |
+
def call(
|
| 1190 |
+
self,
|
| 1191 |
+
input_ids: TFModelInputType | None = None,
|
| 1192 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1193 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1194 |
+
output_attentions: Optional[bool] = None,
|
| 1195 |
+
output_hidden_states: Optional[bool] = None,
|
| 1196 |
+
return_dict: Optional[bool] = None,
|
| 1197 |
+
training: Optional[bool] = False,
|
| 1198 |
+
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
| 1199 |
+
r"""
|
| 1200 |
+
Returns:
|
| 1201 |
+
|
| 1202 |
+
Examples:
|
| 1203 |
+
|
| 1204 |
+
```python
|
| 1205 |
+
>>> from transformers import AutoTokenizer, TFCLIPTextModel
|
| 1206 |
+
|
| 1207 |
+
>>> model = TFCLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 1208 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
| 1209 |
+
|
| 1210 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf")
|
| 1211 |
+
|
| 1212 |
+
>>> outputs = model(**inputs)
|
| 1213 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
| 1214 |
+
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
| 1215 |
+
```"""
|
| 1216 |
+
|
| 1217 |
+
outputs = self.clip(
|
| 1218 |
+
input_ids=input_ids,
|
| 1219 |
+
attention_mask=attention_mask,
|
| 1220 |
+
position_ids=position_ids,
|
| 1221 |
+
output_attentions=output_attentions,
|
| 1222 |
+
output_hidden_states=output_hidden_states,
|
| 1223 |
+
return_dict=return_dict,
|
| 1224 |
+
training=training,
|
| 1225 |
+
)
|
| 1226 |
+
|
| 1227 |
+
return outputs
|
| 1228 |
+
|
| 1229 |
+
def build(self, input_shape=None):
|
| 1230 |
+
if self.built:
|
| 1231 |
+
return
|
| 1232 |
+
self.built = True
|
| 1233 |
+
if getattr(self, "clip", None) is not None:
|
| 1234 |
+
with tf.name_scope(self.clip.name):
|
| 1235 |
+
self.clip.build(None)
|
| 1236 |
+
|
| 1237 |
+
|
| 1238 |
+
class TFCLIPVisionModel(TFCLIPPreTrainedModel):
|
| 1239 |
+
config_class = CLIPVisionConfig
|
| 1240 |
+
main_input_name = "pixel_values"
|
| 1241 |
+
|
| 1242 |
+
def __init__(self, config: CLIPVisionConfig, *inputs, **kwargs):
|
| 1243 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1244 |
+
|
| 1245 |
+
self.clip = TFCLIPVisionMainLayer(config, name="clip")
|
| 1246 |
+
|
| 1247 |
+
@unpack_inputs
|
| 1248 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
| 1249 |
+
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
| 1250 |
+
def call(
|
| 1251 |
+
self,
|
| 1252 |
+
pixel_values: TFModelInputType | None = None,
|
| 1253 |
+
output_attentions: Optional[bool] = None,
|
| 1254 |
+
output_hidden_states: Optional[bool] = None,
|
| 1255 |
+
return_dict: Optional[bool] = None,
|
| 1256 |
+
training: Optional[bool] = False,
|
| 1257 |
+
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
| 1258 |
+
r"""
|
| 1259 |
+
Returns:
|
| 1260 |
+
|
| 1261 |
+
Examples:
|
| 1262 |
+
|
| 1263 |
+
```python
|
| 1264 |
+
>>> from PIL import Image
|
| 1265 |
+
>>> import requests
|
| 1266 |
+
>>> from transformers import AutoProcessor, TFCLIPVisionModel
|
| 1267 |
+
|
| 1268 |
+
>>> model = TFCLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 1269 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 1270 |
+
|
| 1271 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1272 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1273 |
+
|
| 1274 |
+
>>> inputs = processor(images=image, return_tensors="tf")
|
| 1275 |
+
|
| 1276 |
+
>>> outputs = model(**inputs)
|
| 1277 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
| 1278 |
+
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
| 1279 |
+
```"""
|
| 1280 |
+
|
| 1281 |
+
outputs = self.clip(
|
| 1282 |
+
pixel_values=pixel_values,
|
| 1283 |
+
output_attentions=output_attentions,
|
| 1284 |
+
output_hidden_states=output_hidden_states,
|
| 1285 |
+
return_dict=return_dict,
|
| 1286 |
+
training=training,
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
return outputs
|
| 1290 |
+
|
| 1291 |
+
def build(self, input_shape=None):
|
| 1292 |
+
if self.built:
|
| 1293 |
+
return
|
| 1294 |
+
self.built = True
|
| 1295 |
+
if getattr(self, "clip", None) is not None:
|
| 1296 |
+
with tf.name_scope(self.clip.name):
|
| 1297 |
+
self.clip.build(None)
|
| 1298 |
+
|
| 1299 |
+
|
| 1300 |
+
@add_start_docstrings(CLIP_START_DOCSTRING)
|
| 1301 |
+
class TFCLIPModel(TFCLIPPreTrainedModel):
|
| 1302 |
+
config_class = CLIPConfig
|
| 1303 |
+
|
| 1304 |
+
def __init__(self, config: CLIPConfig, *inputs, **kwargs):
|
| 1305 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1306 |
+
|
| 1307 |
+
self.clip = TFCLIPMainLayer(config, name="clip")
|
| 1308 |
+
|
| 1309 |
+
@unpack_inputs
|
| 1310 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1311 |
+
def get_text_features(
|
| 1312 |
+
self,
|
| 1313 |
+
input_ids: TFModelInputType | None = None,
|
| 1314 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1315 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1316 |
+
output_attentions: Optional[bool] = None,
|
| 1317 |
+
output_hidden_states: Optional[bool] = None,
|
| 1318 |
+
return_dict: Optional[bool] = None,
|
| 1319 |
+
training: bool = False,
|
| 1320 |
+
) -> tf.Tensor:
|
| 1321 |
+
r"""
|
| 1322 |
+
Returns:
|
| 1323 |
+
text_features (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying
|
| 1324 |
+
the projection layer to the pooled output of [`TFCLIPTextModel`].
|
| 1325 |
+
|
| 1326 |
+
Examples:
|
| 1327 |
+
|
| 1328 |
+
```python
|
| 1329 |
+
>>> from transformers import AutoTokenizer, TFCLIPModel
|
| 1330 |
+
|
| 1331 |
+
>>> model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 1332 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
| 1333 |
+
|
| 1334 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf")
|
| 1335 |
+
>>> text_features = model.get_text_features(**inputs)
|
| 1336 |
+
```"""
|
| 1337 |
+
|
| 1338 |
+
text_features = self.clip.get_text_features(
|
| 1339 |
+
input_ids=input_ids,
|
| 1340 |
+
attention_mask=attention_mask,
|
| 1341 |
+
position_ids=position_ids,
|
| 1342 |
+
output_attentions=output_attentions,
|
| 1343 |
+
output_hidden_states=output_hidden_states,
|
| 1344 |
+
return_dict=return_dict,
|
| 1345 |
+
)
|
| 1346 |
+
|
| 1347 |
+
return text_features
|
| 1348 |
+
|
| 1349 |
+
@unpack_inputs
|
| 1350 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
| 1351 |
+
def get_image_features(
|
| 1352 |
+
self,
|
| 1353 |
+
pixel_values: TFModelInputType | None = None,
|
| 1354 |
+
output_attentions: Optional[bool] = None,
|
| 1355 |
+
output_hidden_states: Optional[bool] = None,
|
| 1356 |
+
return_dict: Optional[bool] = None,
|
| 1357 |
+
training: bool = False,
|
| 1358 |
+
) -> tf.Tensor:
|
| 1359 |
+
r"""
|
| 1360 |
+
Returns:
|
| 1361 |
+
image_features (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying
|
| 1362 |
+
the projection layer to the pooled output of [`TFCLIPVisionModel`].
|
| 1363 |
+
|
| 1364 |
+
Examples:
|
| 1365 |
+
|
| 1366 |
+
```python
|
| 1367 |
+
>>> from PIL import Image
|
| 1368 |
+
>>> import requests
|
| 1369 |
+
>>> from transformers import AutoProcessor, TFCLIPModel
|
| 1370 |
+
|
| 1371 |
+
>>> model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 1372 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 1373 |
+
|
| 1374 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1375 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1376 |
+
|
| 1377 |
+
>>> inputs = processor(images=image, return_tensors="tf")
|
| 1378 |
+
|
| 1379 |
+
>>> image_features = model.get_image_features(**inputs)
|
| 1380 |
+
```"""
|
| 1381 |
+
|
| 1382 |
+
image_features = self.clip.get_image_features(
|
| 1383 |
+
pixel_values=pixel_values,
|
| 1384 |
+
output_attentions=output_attentions,
|
| 1385 |
+
output_hidden_states=output_hidden_states,
|
| 1386 |
+
return_dict=return_dict,
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
return image_features
|
| 1390 |
+
|
| 1391 |
+
@unpack_inputs
|
| 1392 |
+
@add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1393 |
+
@replace_return_docstrings(output_type=TFCLIPOutput, config_class=CLIPConfig)
|
| 1394 |
+
def call(
|
| 1395 |
+
self,
|
| 1396 |
+
input_ids: TFModelInputType | None = None,
|
| 1397 |
+
pixel_values: TFModelInputType | None = None,
|
| 1398 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1399 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1400 |
+
return_loss: Optional[bool] = None,
|
| 1401 |
+
output_attentions: Optional[bool] = None,
|
| 1402 |
+
output_hidden_states: Optional[bool] = None,
|
| 1403 |
+
return_dict: Optional[bool] = None,
|
| 1404 |
+
training: bool = False,
|
| 1405 |
+
) -> Union[TFCLIPOutput, Tuple[tf.Tensor]]:
|
| 1406 |
+
r"""
|
| 1407 |
+
Returns:
|
| 1408 |
+
|
| 1409 |
+
Examples:
|
| 1410 |
+
|
| 1411 |
+
```python
|
| 1412 |
+
>>> import tensorflow as tf
|
| 1413 |
+
>>> from PIL import Image
|
| 1414 |
+
>>> import requests
|
| 1415 |
+
>>> from transformers import AutoProcessor, TFCLIPModel
|
| 1416 |
+
|
| 1417 |
+
>>> model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 1418 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 1419 |
+
|
| 1420 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1421 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1422 |
+
|
| 1423 |
+
>>> inputs = processor(
|
| 1424 |
+
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="tf", padding=True
|
| 1425 |
+
... )
|
| 1426 |
+
|
| 1427 |
+
>>> outputs = model(**inputs)
|
| 1428 |
+
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
| 1429 |
+
>>> probs = tf.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities
|
| 1430 |
+
```"""
|
| 1431 |
+
|
| 1432 |
+
outputs = self.clip(
|
| 1433 |
+
input_ids=input_ids,
|
| 1434 |
+
pixel_values=pixel_values,
|
| 1435 |
+
attention_mask=attention_mask,
|
| 1436 |
+
position_ids=position_ids,
|
| 1437 |
+
return_loss=return_loss,
|
| 1438 |
+
output_attentions=output_attentions,
|
| 1439 |
+
output_hidden_states=output_hidden_states,
|
| 1440 |
+
return_dict=return_dict,
|
| 1441 |
+
)
|
| 1442 |
+
|
| 1443 |
+
return outputs
|
| 1444 |
+
|
| 1445 |
+
def serving_output(self, output: TFCLIPOutput) -> TFCLIPOutput:
|
| 1446 |
+
# TODO: As is this currently fails with saved_model=True, because
|
| 1447 |
+
# TensorFlow cannot trace through nested dataclasses. Reference:
|
| 1448 |
+
# https://github.com/huggingface/transformers/pull/16886
|
| 1449 |
+
return output
|
| 1450 |
+
|
| 1451 |
+
def build(self, input_shape=None):
|
| 1452 |
+
if self.built:
|
| 1453 |
+
return
|
| 1454 |
+
self.built = True
|
| 1455 |
+
if getattr(self, "clip", None) is not None:
|
| 1456 |
+
with tf.name_scope(self.clip.name):
|
| 1457 |
+
self.clip.build(None)
|
| 1458 |
+
|
| 1459 |
+
|
| 1460 |
+
__all__ = ["TFCLIPModel", "TFCLIPPreTrainedModel", "TFCLIPTextModel", "TFCLIPVisionModel"]
|
docs/transformers/build/lib/transformers/models/clip/processing_clip.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Image/Text processor class for CLIP
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import warnings
|
| 20 |
+
|
| 21 |
+
from ...processing_utils import ProcessorMixin
|
| 22 |
+
from ...tokenization_utils_base import BatchEncoding
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class CLIPProcessor(ProcessorMixin):
|
| 26 |
+
r"""
|
| 27 |
+
Constructs a CLIP processor which wraps a CLIP image processor and a CLIP tokenizer into a single processor.
|
| 28 |
+
|
| 29 |
+
[`CLIPProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`CLIPTokenizerFast`]. See the
|
| 30 |
+
[`~CLIPProcessor.__call__`] and [`~CLIPProcessor.decode`] for more information.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
image_processor ([`CLIPImageProcessor`], *optional*):
|
| 34 |
+
The image processor is a required input.
|
| 35 |
+
tokenizer ([`CLIPTokenizerFast`], *optional*):
|
| 36 |
+
The tokenizer is a required input.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
attributes = ["image_processor", "tokenizer"]
|
| 40 |
+
image_processor_class = ("CLIPImageProcessor", "CLIPImageProcessorFast")
|
| 41 |
+
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
|
| 42 |
+
|
| 43 |
+
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
| 44 |
+
feature_extractor = None
|
| 45 |
+
if "feature_extractor" in kwargs:
|
| 46 |
+
warnings.warn(
|
| 47 |
+
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
|
| 48 |
+
" instead.",
|
| 49 |
+
FutureWarning,
|
| 50 |
+
)
|
| 51 |
+
feature_extractor = kwargs.pop("feature_extractor")
|
| 52 |
+
|
| 53 |
+
image_processor = image_processor if image_processor is not None else feature_extractor
|
| 54 |
+
if image_processor is None:
|
| 55 |
+
raise ValueError("You need to specify an `image_processor`.")
|
| 56 |
+
if tokenizer is None:
|
| 57 |
+
raise ValueError("You need to specify a `tokenizer`.")
|
| 58 |
+
|
| 59 |
+
super().__init__(image_processor, tokenizer)
|
| 60 |
+
|
| 61 |
+
def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
|
| 62 |
+
"""
|
| 63 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
| 64 |
+
and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode
|
| 65 |
+
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
| 66 |
+
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
| 67 |
+
of the above two methods for more information.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
| 71 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 72 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 73 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 74 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
| 75 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 76 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 77 |
+
|
| 78 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 79 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 80 |
+
|
| 81 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 82 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 83 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 84 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
|
| 88 |
+
|
| 89 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 90 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 91 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
| 92 |
+
`None`).
|
| 93 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 94 |
+
"""
|
| 95 |
+
tokenizer_kwargs, image_processor_kwargs = {}, {}
|
| 96 |
+
if kwargs:
|
| 97 |
+
tokenizer_kwargs = {k: v for k, v in kwargs.items() if k not in self.image_processor._valid_processor_keys}
|
| 98 |
+
image_processor_kwargs = {
|
| 99 |
+
k: v for k, v in kwargs.items() if k in self.image_processor._valid_processor_keys
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
if text is None and images is None:
|
| 103 |
+
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
| 104 |
+
|
| 105 |
+
if text is not None:
|
| 106 |
+
encoding = self.tokenizer(text, return_tensors=return_tensors, **tokenizer_kwargs)
|
| 107 |
+
|
| 108 |
+
if images is not None:
|
| 109 |
+
image_features = self.image_processor(images, return_tensors=return_tensors, **image_processor_kwargs)
|
| 110 |
+
|
| 111 |
+
if text is not None and images is not None:
|
| 112 |
+
encoding["pixel_values"] = image_features.pixel_values
|
| 113 |
+
return encoding
|
| 114 |
+
elif text is not None:
|
| 115 |
+
return encoding
|
| 116 |
+
else:
|
| 117 |
+
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
|
| 118 |
+
|
| 119 |
+
def batch_decode(self, *args, **kwargs):
|
| 120 |
+
"""
|
| 121 |
+
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 122 |
+
refer to the docstring of this method for more information.
|
| 123 |
+
"""
|
| 124 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 125 |
+
|
| 126 |
+
def decode(self, *args, **kwargs):
|
| 127 |
+
"""
|
| 128 |
+
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 129 |
+
the docstring of this method for more information.
|
| 130 |
+
"""
|
| 131 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def model_input_names(self):
|
| 135 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 136 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 137 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def feature_extractor_class(self):
|
| 141 |
+
warnings.warn(
|
| 142 |
+
"`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
|
| 143 |
+
FutureWarning,
|
| 144 |
+
)
|
| 145 |
+
return self.image_processor_class
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def feature_extractor(self):
|
| 149 |
+
warnings.warn(
|
| 150 |
+
"`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
|
| 151 |
+
FutureWarning,
|
| 152 |
+
)
|
| 153 |
+
return self.image_processor
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
__all__ = ["CLIPProcessor"]
|
docs/transformers/build/lib/transformers/models/clip/tokenization_clip.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The Open AI Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for CLIP."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import unicodedata
|
| 20 |
+
from functools import lru_cache
|
| 21 |
+
from typing import List, Optional, Tuple
|
| 22 |
+
|
| 23 |
+
import regex as re
|
| 24 |
+
|
| 25 |
+
from ...tokenization_utils import AddedToken, PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
| 26 |
+
from ...utils import logging
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
VOCAB_FILES_NAMES = {
|
| 32 |
+
"vocab_file": "vocab.json",
|
| 33 |
+
"merges_file": "merges.txt",
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@lru_cache()
|
| 38 |
+
def bytes_to_unicode():
|
| 39 |
+
"""
|
| 40 |
+
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
| 41 |
+
characters the bpe code barfs on.
|
| 42 |
+
|
| 43 |
+
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
| 44 |
+
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
| 45 |
+
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
| 46 |
+
tables between utf-8 bytes and unicode strings.
|
| 47 |
+
"""
|
| 48 |
+
bs = (
|
| 49 |
+
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
| 50 |
+
)
|
| 51 |
+
cs = bs[:]
|
| 52 |
+
n = 0
|
| 53 |
+
for b in range(2**8):
|
| 54 |
+
if b not in bs:
|
| 55 |
+
bs.append(b)
|
| 56 |
+
cs.append(2**8 + n)
|
| 57 |
+
n += 1
|
| 58 |
+
cs = [chr(n) for n in cs]
|
| 59 |
+
return dict(zip(bs, cs))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_pairs(word):
|
| 63 |
+
"""
|
| 64 |
+
Return set of symbol pairs in a word.
|
| 65 |
+
|
| 66 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 67 |
+
"""
|
| 68 |
+
pairs = set()
|
| 69 |
+
prev_char = word[0]
|
| 70 |
+
for char in word[1:]:
|
| 71 |
+
pairs.add((prev_char, char))
|
| 72 |
+
prev_char = char
|
| 73 |
+
return pairs
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def whitespace_clean(text):
|
| 77 |
+
text = re.sub(r"\s+", " ", text)
|
| 78 |
+
text = text.strip()
|
| 79 |
+
return text
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
|
| 83 |
+
def whitespace_tokenize(text):
|
| 84 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
| 85 |
+
text = text.strip()
|
| 86 |
+
if not text:
|
| 87 |
+
return []
|
| 88 |
+
tokens = text.split()
|
| 89 |
+
return tokens
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
|
| 93 |
+
class BasicTokenizer:
|
| 94 |
+
"""
|
| 95 |
+
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 99 |
+
Whether or not to lowercase the input when tokenizing.
|
| 100 |
+
never_split (`Iterable`, *optional*):
|
| 101 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
| 102 |
+
`do_basic_tokenize=True`
|
| 103 |
+
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
|
| 104 |
+
Whether or not to tokenize Chinese characters.
|
| 105 |
+
|
| 106 |
+
This should likely be deactivated for Japanese (see this
|
| 107 |
+
[issue](https://github.com/huggingface/transformers/issues/328)).
|
| 108 |
+
strip_accents (`bool`, *optional*):
|
| 109 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
| 110 |
+
value for `lowercase` (as in the original BERT).
|
| 111 |
+
do_split_on_punc (`bool`, *optional*, defaults to `True`):
|
| 112 |
+
In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
|
| 113 |
+
the full context of the words, such as contractions.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
do_lower_case=True,
|
| 119 |
+
never_split=None,
|
| 120 |
+
tokenize_chinese_chars=True,
|
| 121 |
+
strip_accents=None,
|
| 122 |
+
do_split_on_punc=True,
|
| 123 |
+
):
|
| 124 |
+
if never_split is None:
|
| 125 |
+
never_split = []
|
| 126 |
+
self.do_lower_case = do_lower_case
|
| 127 |
+
self.never_split = set(never_split)
|
| 128 |
+
self.tokenize_chinese_chars = tokenize_chinese_chars
|
| 129 |
+
self.strip_accents = strip_accents
|
| 130 |
+
self.do_split_on_punc = do_split_on_punc
|
| 131 |
+
|
| 132 |
+
def tokenize(self, text, never_split=None):
|
| 133 |
+
"""
|
| 134 |
+
Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
never_split (`List[str]`, *optional*)
|
| 138 |
+
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
| 139 |
+
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
|
| 140 |
+
"""
|
| 141 |
+
# union() returns a new set by concatenating the two sets.
|
| 142 |
+
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
|
| 143 |
+
text = self._clean_text(text)
|
| 144 |
+
|
| 145 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
| 146 |
+
# models. This is also applied to the English models now, but it doesn't
|
| 147 |
+
# matter since the English models were not trained on any Chinese data
|
| 148 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
| 149 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
| 150 |
+
# words in the English Wikipedia.).
|
| 151 |
+
if self.tokenize_chinese_chars:
|
| 152 |
+
text = self._tokenize_chinese_chars(text)
|
| 153 |
+
# prevents treating the same character with different unicode codepoints as different characters
|
| 154 |
+
unicode_normalized_text = unicodedata.normalize("NFC", text)
|
| 155 |
+
orig_tokens = whitespace_tokenize(unicode_normalized_text)
|
| 156 |
+
split_tokens = []
|
| 157 |
+
for token in orig_tokens:
|
| 158 |
+
if token not in never_split:
|
| 159 |
+
if self.do_lower_case:
|
| 160 |
+
token = token.lower()
|
| 161 |
+
if self.strip_accents is not False:
|
| 162 |
+
token = self._run_strip_accents(token)
|
| 163 |
+
elif self.strip_accents:
|
| 164 |
+
token = self._run_strip_accents(token)
|
| 165 |
+
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
| 166 |
+
|
| 167 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
| 168 |
+
return output_tokens
|
| 169 |
+
|
| 170 |
+
def _run_strip_accents(self, text):
|
| 171 |
+
"""Strips accents from a piece of text."""
|
| 172 |
+
text = unicodedata.normalize("NFD", text)
|
| 173 |
+
output = []
|
| 174 |
+
for char in text:
|
| 175 |
+
cat = unicodedata.category(char)
|
| 176 |
+
if cat == "Mn":
|
| 177 |
+
continue
|
| 178 |
+
output.append(char)
|
| 179 |
+
return "".join(output)
|
| 180 |
+
|
| 181 |
+
def _run_split_on_punc(self, text, never_split=None):
|
| 182 |
+
"""Splits punctuation on a piece of text."""
|
| 183 |
+
if not self.do_split_on_punc or (never_split is not None and text in never_split):
|
| 184 |
+
return [text]
|
| 185 |
+
chars = list(text)
|
| 186 |
+
i = 0
|
| 187 |
+
start_new_word = True
|
| 188 |
+
output = []
|
| 189 |
+
while i < len(chars):
|
| 190 |
+
char = chars[i]
|
| 191 |
+
if _is_punctuation(char):
|
| 192 |
+
output.append([char])
|
| 193 |
+
start_new_word = True
|
| 194 |
+
else:
|
| 195 |
+
if start_new_word:
|
| 196 |
+
output.append([])
|
| 197 |
+
start_new_word = False
|
| 198 |
+
output[-1].append(char)
|
| 199 |
+
i += 1
|
| 200 |
+
|
| 201 |
+
return ["".join(x) for x in output]
|
| 202 |
+
|
| 203 |
+
def _tokenize_chinese_chars(self, text):
|
| 204 |
+
"""Adds whitespace around any CJK character."""
|
| 205 |
+
output = []
|
| 206 |
+
for char in text:
|
| 207 |
+
cp = ord(char)
|
| 208 |
+
if self._is_chinese_char(cp):
|
| 209 |
+
output.append(" ")
|
| 210 |
+
output.append(char)
|
| 211 |
+
output.append(" ")
|
| 212 |
+
else:
|
| 213 |
+
output.append(char)
|
| 214 |
+
return "".join(output)
|
| 215 |
+
|
| 216 |
+
def _is_chinese_char(self, cp):
|
| 217 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
| 218 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 219 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 220 |
+
#
|
| 221 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
| 222 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
| 223 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
| 224 |
+
# space-separated words, so they are not treated specially and handled
|
| 225 |
+
# like the all of the other languages.
|
| 226 |
+
if (
|
| 227 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
| 228 |
+
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
| 229 |
+
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
| 230 |
+
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
| 231 |
+
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
| 232 |
+
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
| 233 |
+
or (cp >= 0xF900 and cp <= 0xFAFF)
|
| 234 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
| 235 |
+
): #
|
| 236 |
+
return True
|
| 237 |
+
|
| 238 |
+
return False
|
| 239 |
+
|
| 240 |
+
def _clean_text(self, text):
|
| 241 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
| 242 |
+
output = []
|
| 243 |
+
for char in text:
|
| 244 |
+
cp = ord(char)
|
| 245 |
+
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
| 246 |
+
continue
|
| 247 |
+
if _is_whitespace(char):
|
| 248 |
+
output.append(" ")
|
| 249 |
+
else:
|
| 250 |
+
output.append(char)
|
| 251 |
+
return "".join(output)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class CLIPTokenizer(PreTrainedTokenizer):
|
| 255 |
+
"""
|
| 256 |
+
Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 257 |
+
|
| 258 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 259 |
+
this superclass for more information regarding those methods.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
vocab_file (`str`):
|
| 263 |
+
Path to the vocabulary file.
|
| 264 |
+
merges_file (`str`):
|
| 265 |
+
Path to the merges file.
|
| 266 |
+
errors (`str`, *optional*, defaults to `"replace"`):
|
| 267 |
+
Paradigm to follow when decoding bytes to UTF-8. See
|
| 268 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
| 269 |
+
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 270 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 271 |
+
token instead.
|
| 272 |
+
bos_token (`str`, *optional*, defaults to `"<|startoftext|>"`):
|
| 273 |
+
The beginning of sequence token.
|
| 274 |
+
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 275 |
+
The end of sequence token.
|
| 276 |
+
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 277 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 281 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 282 |
+
|
| 283 |
+
def __init__(
|
| 284 |
+
self,
|
| 285 |
+
vocab_file,
|
| 286 |
+
merges_file,
|
| 287 |
+
errors="replace",
|
| 288 |
+
unk_token="<|endoftext|>",
|
| 289 |
+
bos_token="<|startoftext|>",
|
| 290 |
+
eos_token="<|endoftext|>",
|
| 291 |
+
pad_token="<|endoftext|>", # hack to enable padding
|
| 292 |
+
**kwargs,
|
| 293 |
+
):
|
| 294 |
+
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
| 295 |
+
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
| 296 |
+
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
| 297 |
+
try:
|
| 298 |
+
import ftfy
|
| 299 |
+
|
| 300 |
+
self.fix_text = ftfy.fix_text
|
| 301 |
+
except ImportError:
|
| 302 |
+
logger.info("ftfy or spacy is not installed using custom BasicTokenizer instead of ftfy.")
|
| 303 |
+
self.nlp = BasicTokenizer(strip_accents=False, do_split_on_punc=False)
|
| 304 |
+
self.fix_text = None
|
| 305 |
+
|
| 306 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
| 307 |
+
self.encoder = json.load(vocab_handle)
|
| 308 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 309 |
+
self.errors = errors # how to handle errors in decoding
|
| 310 |
+
self.byte_encoder = bytes_to_unicode()
|
| 311 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 312 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
| 313 |
+
bpe_merges = merges_handle.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
| 314 |
+
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
|
| 315 |
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
| 316 |
+
self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
|
| 317 |
+
|
| 318 |
+
self.pat = re.compile(
|
| 319 |
+
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
| 320 |
+
re.IGNORECASE,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
super().__init__(
|
| 324 |
+
errors=errors,
|
| 325 |
+
unk_token=unk_token,
|
| 326 |
+
bos_token=bos_token,
|
| 327 |
+
eos_token=eos_token,
|
| 328 |
+
pad_token=pad_token,
|
| 329 |
+
**kwargs,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
@property
|
| 333 |
+
def vocab_size(self):
|
| 334 |
+
return len(self.encoder)
|
| 335 |
+
|
| 336 |
+
def get_vocab(self):
|
| 337 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 338 |
+
|
| 339 |
+
def build_inputs_with_special_tokens(
|
| 340 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 341 |
+
) -> List[int]:
|
| 342 |
+
"""
|
| 343 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 344 |
+
adding special tokens. A CLIP sequence has the following format:
|
| 345 |
+
|
| 346 |
+
- single sequence: `<|startoftext|> X <|endoftext|>`
|
| 347 |
+
|
| 348 |
+
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
token_ids_0 (`List[int]`):
|
| 352 |
+
List of IDs to which the special tokens will be added.
|
| 353 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 354 |
+
Optional second list of IDs for sequence pairs.
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 358 |
+
"""
|
| 359 |
+
bos_token = [self.bos_token_id]
|
| 360 |
+
eos_token = [self.eos_token_id]
|
| 361 |
+
|
| 362 |
+
if token_ids_1 is None:
|
| 363 |
+
return bos_token + token_ids_0 + eos_token
|
| 364 |
+
return bos_token + token_ids_0 + eos_token + eos_token + token_ids_1 + eos_token
|
| 365 |
+
|
| 366 |
+
def get_special_tokens_mask(
|
| 367 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 368 |
+
) -> List[int]:
|
| 369 |
+
"""
|
| 370 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 371 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
token_ids_0 (`List[int]`):
|
| 375 |
+
List of IDs.
|
| 376 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 377 |
+
Optional second list of IDs for sequence pairs.
|
| 378 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 379 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 383 |
+
"""
|
| 384 |
+
|
| 385 |
+
if already_has_special_tokens:
|
| 386 |
+
return super().get_special_tokens_mask(
|
| 387 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
if token_ids_1 is None:
|
| 391 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 392 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + [1] + ([0] * len(token_ids_1)) + [1]
|
| 393 |
+
|
| 394 |
+
def create_token_type_ids_from_sequences(
|
| 395 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 396 |
+
) -> List[int]:
|
| 397 |
+
"""
|
| 398 |
+
Create a mask from the two sequences passed. CLIP does not make use of token type ids, therefore a list of
|
| 399 |
+
zeros is returned.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
token_ids_0 (`List[int]`):
|
| 403 |
+
List of IDs.
|
| 404 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 405 |
+
Optional second list of IDs for sequence pairs.
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
`List[int]`: List of zeros.
|
| 409 |
+
"""
|
| 410 |
+
bos_token = [self.bos_token_id]
|
| 411 |
+
eos_token = [self.eos_token_id]
|
| 412 |
+
|
| 413 |
+
if token_ids_1 is None:
|
| 414 |
+
return len(bos_token + token_ids_0 + eos_token) * [0]
|
| 415 |
+
return len(bos_token + token_ids_0 + eos_token + eos_token + token_ids_1 + eos_token) * [0]
|
| 416 |
+
|
| 417 |
+
def bpe(self, token):
|
| 418 |
+
if token in self.cache:
|
| 419 |
+
return self.cache[token]
|
| 420 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
| 421 |
+
pairs = get_pairs(word)
|
| 422 |
+
|
| 423 |
+
if not pairs:
|
| 424 |
+
return token + "</w>"
|
| 425 |
+
|
| 426 |
+
while True:
|
| 427 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 428 |
+
if bigram not in self.bpe_ranks:
|
| 429 |
+
break
|
| 430 |
+
first, second = bigram
|
| 431 |
+
new_word = []
|
| 432 |
+
i = 0
|
| 433 |
+
while i < len(word):
|
| 434 |
+
try:
|
| 435 |
+
j = word.index(first, i)
|
| 436 |
+
except ValueError:
|
| 437 |
+
new_word.extend(word[i:])
|
| 438 |
+
break
|
| 439 |
+
else:
|
| 440 |
+
new_word.extend(word[i:j])
|
| 441 |
+
i = j
|
| 442 |
+
|
| 443 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 444 |
+
new_word.append(first + second)
|
| 445 |
+
i += 2
|
| 446 |
+
else:
|
| 447 |
+
new_word.append(word[i])
|
| 448 |
+
i += 1
|
| 449 |
+
new_word = tuple(new_word)
|
| 450 |
+
word = new_word
|
| 451 |
+
if len(word) == 1:
|
| 452 |
+
break
|
| 453 |
+
else:
|
| 454 |
+
pairs = get_pairs(word)
|
| 455 |
+
word = " ".join(word)
|
| 456 |
+
self.cache[token] = word
|
| 457 |
+
return word
|
| 458 |
+
|
| 459 |
+
def _tokenize(self, text):
|
| 460 |
+
"""Tokenize a string."""
|
| 461 |
+
bpe_tokens = []
|
| 462 |
+
if self.fix_text is None:
|
| 463 |
+
text = " ".join(self.nlp.tokenize(text))
|
| 464 |
+
else:
|
| 465 |
+
text = whitespace_clean(self.fix_text(text)).lower()
|
| 466 |
+
|
| 467 |
+
for token in re.findall(self.pat, text):
|
| 468 |
+
token = "".join(
|
| 469 |
+
self.byte_encoder[b] for b in token.encode("utf-8")
|
| 470 |
+
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
|
| 471 |
+
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
| 472 |
+
return bpe_tokens
|
| 473 |
+
|
| 474 |
+
def _convert_token_to_id(self, token):
|
| 475 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 476 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 477 |
+
|
| 478 |
+
def _convert_id_to_token(self, index):
|
| 479 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 480 |
+
return self.decoder.get(index)
|
| 481 |
+
|
| 482 |
+
def convert_tokens_to_string(self, tokens):
|
| 483 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 484 |
+
text = "".join(tokens)
|
| 485 |
+
byte_array = bytearray([self.byte_decoder[c] for c in text])
|
| 486 |
+
text = byte_array.decode("utf-8", errors=self.errors).replace("</w>", " ").strip()
|
| 487 |
+
return text
|
| 488 |
+
|
| 489 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 490 |
+
if not os.path.isdir(save_directory):
|
| 491 |
+
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
| 492 |
+
return
|
| 493 |
+
vocab_file = os.path.join(
|
| 494 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 495 |
+
)
|
| 496 |
+
merge_file = os.path.join(
|
| 497 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 501 |
+
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
| 502 |
+
|
| 503 |
+
index = 0
|
| 504 |
+
with open(merge_file, "w", encoding="utf-8") as writer:
|
| 505 |
+
writer.write("#version: 0.2\n")
|
| 506 |
+
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
| 507 |
+
if index != token_index:
|
| 508 |
+
logger.warning(
|
| 509 |
+
"Saving vocabulary to {}: BPE merge indices are not consecutive."
|
| 510 |
+
" Please check that the tokenizer is not corrupted!".format(merge_file)
|
| 511 |
+
)
|
| 512 |
+
index = token_index
|
| 513 |
+
writer.write(" ".join(bpe_tokens) + "\n")
|
| 514 |
+
index += 1
|
| 515 |
+
|
| 516 |
+
return vocab_file, merge_file
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
__all__ = ["CLIPTokenizer"]
|
docs/transformers/build/lib/transformers/models/clip/tokenization_clip_fast.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The Open AI Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for OpenAI GPT."""
|
| 16 |
+
|
| 17 |
+
from typing import List, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
from tokenizers import pre_tokenizers
|
| 20 |
+
|
| 21 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 22 |
+
from ...utils import logging
|
| 23 |
+
from .tokenization_clip import CLIPTokenizer
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CLIPTokenizerFast(PreTrainedTokenizerFast):
|
| 32 |
+
"""
|
| 33 |
+
Construct a "fast" CLIP tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
|
| 34 |
+
Byte-Pair-Encoding.
|
| 35 |
+
|
| 36 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
| 37 |
+
refer to this superclass for more information regarding those methods.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
vocab_file (`str`, *optional*):
|
| 41 |
+
Path to the vocabulary file.
|
| 42 |
+
merges_file (`str`, *optional*):
|
| 43 |
+
Path to the merges file.
|
| 44 |
+
tokenizer_file (`str`, *optional*):
|
| 45 |
+
The path to a tokenizer file to use instead of the vocab file.
|
| 46 |
+
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 47 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 48 |
+
token instead.
|
| 49 |
+
bos_token (`str`, *optional*, defaults to `"<|startoftext|>"`):
|
| 50 |
+
The beginning of sequence token.
|
| 51 |
+
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 52 |
+
The end of sequence token.
|
| 53 |
+
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 54 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 58 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 59 |
+
slow_tokenizer_class = CLIPTokenizer
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
vocab_file=None,
|
| 64 |
+
merges_file=None,
|
| 65 |
+
tokenizer_file=None,
|
| 66 |
+
unk_token="<|endoftext|>",
|
| 67 |
+
bos_token="<|startoftext|>",
|
| 68 |
+
eos_token="<|endoftext|>",
|
| 69 |
+
pad_token="<|endoftext|>", # hack to enable padding
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
super().__init__(
|
| 73 |
+
vocab_file,
|
| 74 |
+
merges_file,
|
| 75 |
+
tokenizer_file=tokenizer_file,
|
| 76 |
+
unk_token=unk_token,
|
| 77 |
+
bos_token=bos_token,
|
| 78 |
+
eos_token=eos_token,
|
| 79 |
+
pad_token=pad_token,
|
| 80 |
+
**kwargs,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if not isinstance(self.backend_tokenizer.pre_tokenizer, pre_tokenizers.Sequence):
|
| 84 |
+
raise ValueError(
|
| 85 |
+
"The `backend_tokenizer` provided does not match the expected format. The CLIP tokenizer has been"
|
| 86 |
+
" heavily modified from transformers version 4.17.0. You need to convert the tokenizer you are using"
|
| 87 |
+
" to be compatible with this version.The easiest way to do so is"
|
| 88 |
+
' `CLIPTokenizerFast.from_pretrained("path_to_local_folder_or_hub_repo, from_slow=True)`. If you want'
|
| 89 |
+
" to use your existing tokenizer, you will have to revert to a version prior to 4.17.0 of"
|
| 90 |
+
" transformers."
|
| 91 |
+
)
|
| 92 |
+
self._wrap_decode_method_backend_tokenizer()
|
| 93 |
+
|
| 94 |
+
# Very ugly hack to enable padding to have a correct decoding see https://github.com/huggingface/tokenizers/issues/872
|
| 95 |
+
def _wrap_decode_method_backend_tokenizer(self):
|
| 96 |
+
orig_decode_method = self.backend_tokenizer.decode
|
| 97 |
+
|
| 98 |
+
## define this as a local variable to avoid circular reference
|
| 99 |
+
## See: https://github.com/huggingface/transformers/issues/30930
|
| 100 |
+
end_of_word_suffix = self.backend_tokenizer.model.end_of_word_suffix
|
| 101 |
+
|
| 102 |
+
def new_decode_method(*args, **kwargs):
|
| 103 |
+
text = orig_decode_method(*args, **kwargs)
|
| 104 |
+
text = text.replace(end_of_word_suffix, " ").strip()
|
| 105 |
+
return text
|
| 106 |
+
|
| 107 |
+
self.backend_tokenizer.decode = new_decode_method
|
| 108 |
+
|
| 109 |
+
def build_inputs_with_special_tokens(
|
| 110 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 111 |
+
) -> List[int]:
|
| 112 |
+
"""
|
| 113 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 114 |
+
adding special tokens. A CLIP sequence has the following format:
|
| 115 |
+
|
| 116 |
+
- single sequence: `<|startoftext|> X <|endoftext|>`
|
| 117 |
+
|
| 118 |
+
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
token_ids_0 (`List[int]`):
|
| 122 |
+
List of IDs to which the special tokens will be added.
|
| 123 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 124 |
+
Optional second list of IDs for sequence pairs.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 128 |
+
"""
|
| 129 |
+
bos_token = [self.bos_token_id]
|
| 130 |
+
eos_token = [self.eos_token_id]
|
| 131 |
+
|
| 132 |
+
if token_ids_1 is None:
|
| 133 |
+
return bos_token + token_ids_0 + eos_token
|
| 134 |
+
return bos_token + token_ids_0 + eos_token + eos_token + token_ids_1 + eos_token
|
| 135 |
+
|
| 136 |
+
def create_token_type_ids_from_sequences(
|
| 137 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 138 |
+
) -> List[int]:
|
| 139 |
+
"""
|
| 140 |
+
Create a mask from the two sequences passed. CLIP does not make use of token type ids, therefore a list of
|
| 141 |
+
zeros is returned.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
token_ids_0 (`List[int]`):
|
| 145 |
+
List of IDs.
|
| 146 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 147 |
+
Optional second list of IDs for sequence pairs.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
`List[int]`: List of zeros.
|
| 151 |
+
"""
|
| 152 |
+
bos_token = [self.bos_token_id]
|
| 153 |
+
eos_token = [self.eos_token_id]
|
| 154 |
+
|
| 155 |
+
if token_ids_1 is None:
|
| 156 |
+
return len(bos_token + token_ids_0 + eos_token) * [0]
|
| 157 |
+
return len(bos_token + token_ids_0 + eos_token + eos_token + token_ids_1 + eos_token) * [0]
|
| 158 |
+
|
| 159 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 160 |
+
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
| 161 |
+
return tuple(files)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
__all__ = ["CLIPTokenizerFast"]
|
docs/transformers/build/lib/transformers/models/clipseg/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_clipseg import *
|
| 22 |
+
from .modeling_clipseg import *
|
| 23 |
+
from .processing_clipseg import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/clipseg/configuration_clipseg.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""CLIPSeg model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CLIPSegTextConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`CLIPSegModel`]. It is used to instantiate an
|
| 27 |
+
CLIPSeg model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 28 |
+
with the defaults will yield a similar configuration to that of the CLIPSeg
|
| 29 |
+
[CIDAS/clipseg-rd64](https://huggingface.co/CIDAS/clipseg-rd64) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
vocab_size (`int`, *optional*, defaults to 49408):
|
| 36 |
+
Vocabulary size of the CLIPSeg text model. Defines the number of different tokens that can be represented
|
| 37 |
+
by the `inputs_ids` passed when calling [`CLIPSegModel`].
|
| 38 |
+
hidden_size (`int`, *optional*, defaults to 512):
|
| 39 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 40 |
+
intermediate_size (`int`, *optional*, defaults to 2048):
|
| 41 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 42 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 43 |
+
Number of hidden layers in the Transformer encoder.
|
| 44 |
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
| 45 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 46 |
+
max_position_embeddings (`int`, *optional*, defaults to 77):
|
| 47 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 48 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 49 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
| 50 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 51 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
| 52 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 53 |
+
The epsilon used by the layer normalization layers.
|
| 54 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 55 |
+
The dropout ratio for the attention probabilities.
|
| 56 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 57 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 58 |
+
initializer_factor (`float`, *optional*, defaults to 1.0):
|
| 59 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
| 60 |
+
testing).
|
| 61 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 62 |
+
Padding token id.
|
| 63 |
+
bos_token_id (`int`, *optional*, defaults to 49406):
|
| 64 |
+
Beginning of stream token id.
|
| 65 |
+
eos_token_id (`int`, *optional*, defaults to 49407):
|
| 66 |
+
End of stream token id.
|
| 67 |
+
|
| 68 |
+
Example:
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
>>> from transformers import CLIPSegTextConfig, CLIPSegTextModel
|
| 72 |
+
|
| 73 |
+
>>> # Initializing a CLIPSegTextConfig with CIDAS/clipseg-rd64 style configuration
|
| 74 |
+
>>> configuration = CLIPSegTextConfig()
|
| 75 |
+
|
| 76 |
+
>>> # Initializing a CLIPSegTextModel (with random weights) from the CIDAS/clipseg-rd64 style configuration
|
| 77 |
+
>>> model = CLIPSegTextModel(configuration)
|
| 78 |
+
|
| 79 |
+
>>> # Accessing the model configuration
|
| 80 |
+
>>> configuration = model.config
|
| 81 |
+
```"""
|
| 82 |
+
|
| 83 |
+
model_type = "clipseg_text_model"
|
| 84 |
+
base_config_key = "text_config"
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
vocab_size=49408,
|
| 89 |
+
hidden_size=512,
|
| 90 |
+
intermediate_size=2048,
|
| 91 |
+
num_hidden_layers=12,
|
| 92 |
+
num_attention_heads=8,
|
| 93 |
+
max_position_embeddings=77,
|
| 94 |
+
hidden_act="quick_gelu",
|
| 95 |
+
layer_norm_eps=1e-5,
|
| 96 |
+
attention_dropout=0.0,
|
| 97 |
+
initializer_range=0.02,
|
| 98 |
+
initializer_factor=1.0,
|
| 99 |
+
pad_token_id=1,
|
| 100 |
+
bos_token_id=49406,
|
| 101 |
+
eos_token_id=49407,
|
| 102 |
+
**kwargs,
|
| 103 |
+
):
|
| 104 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 105 |
+
|
| 106 |
+
self.vocab_size = vocab_size
|
| 107 |
+
self.hidden_size = hidden_size
|
| 108 |
+
self.intermediate_size = intermediate_size
|
| 109 |
+
self.num_hidden_layers = num_hidden_layers
|
| 110 |
+
self.num_attention_heads = num_attention_heads
|
| 111 |
+
self.max_position_embeddings = max_position_embeddings
|
| 112 |
+
self.layer_norm_eps = layer_norm_eps
|
| 113 |
+
self.hidden_act = hidden_act
|
| 114 |
+
self.initializer_range = initializer_range
|
| 115 |
+
self.initializer_factor = initializer_factor
|
| 116 |
+
self.attention_dropout = attention_dropout
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class CLIPSegVisionConfig(PretrainedConfig):
|
| 120 |
+
r"""
|
| 121 |
+
This is the configuration class to store the configuration of a [`CLIPSegModel`]. It is used to instantiate an
|
| 122 |
+
CLIPSeg model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 123 |
+
with the defaults will yield a similar configuration to that of the CLIPSeg
|
| 124 |
+
[CIDAS/clipseg-rd64](https://huggingface.co/CIDAS/clipseg-rd64) architecture.
|
| 125 |
+
|
| 126 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 127 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 131 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 132 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 133 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 134 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 135 |
+
Number of hidden layers in the Transformer encoder.
|
| 136 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 137 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 138 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 139 |
+
The number of input channels.
|
| 140 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 141 |
+
The size (resolution) of each image.
|
| 142 |
+
patch_size (`int`, *optional*, defaults to 32):
|
| 143 |
+
The size (resolution) of each patch.
|
| 144 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
| 145 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 146 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
| 147 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 148 |
+
The epsilon used by the layer normalization layers.
|
| 149 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 150 |
+
The dropout ratio for the attention probabilities.
|
| 151 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 152 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 153 |
+
initializer_factor (`float`, *optional*, defaults to 1.0):
|
| 154 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
| 155 |
+
testing).
|
| 156 |
+
|
| 157 |
+
Example:
|
| 158 |
+
|
| 159 |
+
```python
|
| 160 |
+
>>> from transformers import CLIPSegVisionConfig, CLIPSegVisionModel
|
| 161 |
+
|
| 162 |
+
>>> # Initializing a CLIPSegVisionConfig with CIDAS/clipseg-rd64 style configuration
|
| 163 |
+
>>> configuration = CLIPSegVisionConfig()
|
| 164 |
+
|
| 165 |
+
>>> # Initializing a CLIPSegVisionModel (with random weights) from the CIDAS/clipseg-rd64 style configuration
|
| 166 |
+
>>> model = CLIPSegVisionModel(configuration)
|
| 167 |
+
|
| 168 |
+
>>> # Accessing the model configuration
|
| 169 |
+
>>> configuration = model.config
|
| 170 |
+
```"""
|
| 171 |
+
|
| 172 |
+
model_type = "clipseg_vision_model"
|
| 173 |
+
base_config_key = "vision_config"
|
| 174 |
+
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
hidden_size=768,
|
| 178 |
+
intermediate_size=3072,
|
| 179 |
+
num_hidden_layers=12,
|
| 180 |
+
num_attention_heads=12,
|
| 181 |
+
num_channels=3,
|
| 182 |
+
image_size=224,
|
| 183 |
+
patch_size=32,
|
| 184 |
+
hidden_act="quick_gelu",
|
| 185 |
+
layer_norm_eps=1e-5,
|
| 186 |
+
attention_dropout=0.0,
|
| 187 |
+
initializer_range=0.02,
|
| 188 |
+
initializer_factor=1.0,
|
| 189 |
+
**kwargs,
|
| 190 |
+
):
|
| 191 |
+
super().__init__(**kwargs)
|
| 192 |
+
|
| 193 |
+
self.hidden_size = hidden_size
|
| 194 |
+
self.intermediate_size = intermediate_size
|
| 195 |
+
self.num_hidden_layers = num_hidden_layers
|
| 196 |
+
self.num_attention_heads = num_attention_heads
|
| 197 |
+
self.num_channels = num_channels
|
| 198 |
+
self.patch_size = patch_size
|
| 199 |
+
self.image_size = image_size
|
| 200 |
+
self.initializer_range = initializer_range
|
| 201 |
+
self.initializer_factor = initializer_factor
|
| 202 |
+
self.attention_dropout = attention_dropout
|
| 203 |
+
self.layer_norm_eps = layer_norm_eps
|
| 204 |
+
self.hidden_act = hidden_act
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class CLIPSegConfig(PretrainedConfig):
|
| 208 |
+
r"""
|
| 209 |
+
[`CLIPSegConfig`] is the configuration class to store the configuration of a [`CLIPSegModel`]. It is used to
|
| 210 |
+
instantiate a CLIPSeg model according to the specified arguments, defining the text model and vision model configs.
|
| 211 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the CLIPSeg
|
| 212 |
+
[CIDAS/clipseg-rd64](https://huggingface.co/CIDAS/clipseg-rd64) architecture.
|
| 213 |
+
|
| 214 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 215 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
text_config (`dict`, *optional*):
|
| 219 |
+
Dictionary of configuration options used to initialize [`CLIPSegTextConfig`].
|
| 220 |
+
vision_config (`dict`, *optional*):
|
| 221 |
+
Dictionary of configuration options used to initialize [`CLIPSegVisionConfig`].
|
| 222 |
+
projection_dim (`int`, *optional*, defaults to 512):
|
| 223 |
+
Dimensionality of text and vision projection layers.
|
| 224 |
+
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
| 225 |
+
The initial value of the *logit_scale* parameter. Default is used as per the original CLIPSeg implementation.
|
| 226 |
+
extract_layers (`List[int]`, *optional*, defaults to `[3, 6, 9]`):
|
| 227 |
+
Layers to extract when forwarding the query image through the frozen visual backbone of CLIP.
|
| 228 |
+
reduce_dim (`int`, *optional*, defaults to 64):
|
| 229 |
+
Dimensionality to reduce the CLIP vision embedding.
|
| 230 |
+
decoder_num_attention_heads (`int`, *optional*, defaults to 4):
|
| 231 |
+
Number of attention heads in the decoder of CLIPSeg.
|
| 232 |
+
decoder_attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 233 |
+
The dropout ratio for the attention probabilities.
|
| 234 |
+
decoder_hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
| 235 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 236 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
| 237 |
+
decoder_intermediate_size (`int`, *optional*, defaults to 2048):
|
| 238 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layers in the Transformer decoder.
|
| 239 |
+
conditional_layer (`int`, *optional*, defaults to 0):
|
| 240 |
+
The layer to use of the Transformer encoder whose activations will be combined with the condition
|
| 241 |
+
embeddings using FiLM (Feature-wise Linear Modulation). If 0, the last layer is used.
|
| 242 |
+
use_complex_transposed_convolution (`bool`, *optional*, defaults to `False`):
|
| 243 |
+
Whether to use a more complex transposed convolution in the decoder, enabling more fine-grained
|
| 244 |
+
segmentation.
|
| 245 |
+
kwargs (*optional*):
|
| 246 |
+
Dictionary of keyword arguments.
|
| 247 |
+
|
| 248 |
+
Example:
|
| 249 |
+
|
| 250 |
+
```python
|
| 251 |
+
>>> from transformers import CLIPSegConfig, CLIPSegModel
|
| 252 |
+
|
| 253 |
+
>>> # Initializing a CLIPSegConfig with CIDAS/clipseg-rd64 style configuration
|
| 254 |
+
>>> configuration = CLIPSegConfig()
|
| 255 |
+
|
| 256 |
+
>>> # Initializing a CLIPSegModel (with random weights) from the CIDAS/clipseg-rd64 style configuration
|
| 257 |
+
>>> model = CLIPSegModel(configuration)
|
| 258 |
+
|
| 259 |
+
>>> # Accessing the model configuration
|
| 260 |
+
>>> configuration = model.config
|
| 261 |
+
|
| 262 |
+
>>> # We can also initialize a CLIPSegConfig from a CLIPSegTextConfig and a CLIPSegVisionConfig
|
| 263 |
+
|
| 264 |
+
>>> # Initializing a CLIPSegText and CLIPSegVision configuration
|
| 265 |
+
>>> config_text = CLIPSegTextConfig()
|
| 266 |
+
>>> config_vision = CLIPSegVisionConfig()
|
| 267 |
+
|
| 268 |
+
>>> config = CLIPSegConfig.from_text_vision_configs(config_text, config_vision)
|
| 269 |
+
```"""
|
| 270 |
+
|
| 271 |
+
model_type = "clipseg"
|
| 272 |
+
sub_configs = {"text_config": CLIPSegTextConfig, "vision_config": CLIPSegVisionConfig}
|
| 273 |
+
|
| 274 |
+
def __init__(
|
| 275 |
+
self,
|
| 276 |
+
text_config=None,
|
| 277 |
+
vision_config=None,
|
| 278 |
+
projection_dim=512,
|
| 279 |
+
logit_scale_init_value=2.6592,
|
| 280 |
+
extract_layers=[3, 6, 9],
|
| 281 |
+
reduce_dim=64,
|
| 282 |
+
decoder_num_attention_heads=4,
|
| 283 |
+
decoder_attention_dropout=0.0,
|
| 284 |
+
decoder_hidden_act="quick_gelu",
|
| 285 |
+
decoder_intermediate_size=2048,
|
| 286 |
+
conditional_layer=0,
|
| 287 |
+
use_complex_transposed_convolution=False,
|
| 288 |
+
**kwargs,
|
| 289 |
+
):
|
| 290 |
+
# If `_config_dict` exist, we use them for the backward compatibility.
|
| 291 |
+
# We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
|
| 292 |
+
# of confusion!).
|
| 293 |
+
text_config_dict = kwargs.pop("text_config_dict", None)
|
| 294 |
+
vision_config_dict = kwargs.pop("vision_config_dict", None)
|
| 295 |
+
|
| 296 |
+
super().__init__(**kwargs)
|
| 297 |
+
|
| 298 |
+
# Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
|
| 299 |
+
# `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
|
| 300 |
+
# cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
|
| 301 |
+
if text_config_dict is not None:
|
| 302 |
+
if text_config is None:
|
| 303 |
+
text_config = {}
|
| 304 |
+
|
| 305 |
+
# This is the complete result when using `text_config_dict`.
|
| 306 |
+
_text_config_dict = CLIPSegTextConfig(**text_config_dict).to_dict()
|
| 307 |
+
|
| 308 |
+
# Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
|
| 309 |
+
for key, value in _text_config_dict.items():
|
| 310 |
+
if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
|
| 311 |
+
# If specified in `text_config_dict`
|
| 312 |
+
if key in text_config_dict:
|
| 313 |
+
message = (
|
| 314 |
+
f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
|
| 315 |
+
f'The value `text_config_dict["{key}"]` will be used instead.'
|
| 316 |
+
)
|
| 317 |
+
# If inferred from default argument values (just to be super careful)
|
| 318 |
+
else:
|
| 319 |
+
message = (
|
| 320 |
+
f"`text_config_dict` is provided which will be used to initialize `CLIPSegTextConfig`. The "
|
| 321 |
+
f'value `text_config["{key}"]` will be overridden.'
|
| 322 |
+
)
|
| 323 |
+
logger.info(message)
|
| 324 |
+
|
| 325 |
+
# Update all values in `text_config` with the ones in `_text_config_dict`.
|
| 326 |
+
text_config.update(_text_config_dict)
|
| 327 |
+
|
| 328 |
+
if vision_config_dict is not None:
|
| 329 |
+
if vision_config is None:
|
| 330 |
+
vision_config = {}
|
| 331 |
+
|
| 332 |
+
# This is the complete result when using `vision_config_dict`.
|
| 333 |
+
_vision_config_dict = CLIPSegVisionConfig(**vision_config_dict).to_dict()
|
| 334 |
+
# convert keys to string instead of integer
|
| 335 |
+
if "id2label" in _vision_config_dict:
|
| 336 |
+
_vision_config_dict["id2label"] = {
|
| 337 |
+
str(key): value for key, value in _vision_config_dict["id2label"].items()
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
# Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
|
| 341 |
+
for key, value in _vision_config_dict.items():
|
| 342 |
+
if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
|
| 343 |
+
# If specified in `vision_config_dict`
|
| 344 |
+
if key in vision_config_dict:
|
| 345 |
+
message = (
|
| 346 |
+
f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
|
| 347 |
+
f'values. The value `vision_config_dict["{key}"]` will be used instead.'
|
| 348 |
+
)
|
| 349 |
+
# If inferred from default argument values (just to be super careful)
|
| 350 |
+
else:
|
| 351 |
+
message = (
|
| 352 |
+
f"`vision_config_dict` is provided which will be used to initialize `CLIPSegVisionConfig`. "
|
| 353 |
+
f'The value `vision_config["{key}"]` will be overridden.'
|
| 354 |
+
)
|
| 355 |
+
logger.info(message)
|
| 356 |
+
|
| 357 |
+
# Update all values in `vision_config` with the ones in `_vision_config_dict`.
|
| 358 |
+
vision_config.update(_vision_config_dict)
|
| 359 |
+
|
| 360 |
+
if text_config is None:
|
| 361 |
+
text_config = {}
|
| 362 |
+
logger.info("`text_config` is `None`. Initializing the `CLIPSegTextConfig` with default values.")
|
| 363 |
+
|
| 364 |
+
if vision_config is None:
|
| 365 |
+
vision_config = {}
|
| 366 |
+
logger.info("`vision_config` is `None`. initializing the `CLIPSegVisionConfig` with default values.")
|
| 367 |
+
|
| 368 |
+
self.text_config = CLIPSegTextConfig(**text_config)
|
| 369 |
+
self.vision_config = CLIPSegVisionConfig(**vision_config)
|
| 370 |
+
|
| 371 |
+
self.projection_dim = projection_dim
|
| 372 |
+
self.logit_scale_init_value = logit_scale_init_value
|
| 373 |
+
self.extract_layers = extract_layers
|
| 374 |
+
self.reduce_dim = reduce_dim
|
| 375 |
+
self.decoder_num_attention_heads = decoder_num_attention_heads
|
| 376 |
+
self.decoder_attention_dropout = decoder_attention_dropout
|
| 377 |
+
self.decoder_hidden_act = decoder_hidden_act
|
| 378 |
+
self.decoder_intermediate_size = decoder_intermediate_size
|
| 379 |
+
self.conditional_layer = conditional_layer
|
| 380 |
+
self.initializer_factor = 1.0
|
| 381 |
+
self.use_complex_transposed_convolution = use_complex_transposed_convolution
|
| 382 |
+
|
| 383 |
+
@classmethod
|
| 384 |
+
def from_text_vision_configs(cls, text_config: CLIPSegTextConfig, vision_config: CLIPSegVisionConfig, **kwargs):
|
| 385 |
+
r"""
|
| 386 |
+
Instantiate a [`CLIPSegConfig`] (or a derived class) from clipseg text model configuration and clipseg vision
|
| 387 |
+
model configuration.
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
[`CLIPSegConfig`]: An instance of a configuration object
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
__all__ = ["CLIPSegConfig", "CLIPSegTextConfig", "CLIPSegVisionConfig"]
|
docs/transformers/build/lib/transformers/models/clipseg/convert_clipseg_original_pytorch_to_hf.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Convert CLIPSeg checkpoints from the original repository. URL: https://github.com/timojl/clipseg."""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
|
| 20 |
+
import requests
|
| 21 |
+
import torch
|
| 22 |
+
from PIL import Image
|
| 23 |
+
|
| 24 |
+
from transformers import (
|
| 25 |
+
CLIPSegConfig,
|
| 26 |
+
CLIPSegForImageSegmentation,
|
| 27 |
+
CLIPSegProcessor,
|
| 28 |
+
CLIPSegTextConfig,
|
| 29 |
+
CLIPSegVisionConfig,
|
| 30 |
+
CLIPTokenizer,
|
| 31 |
+
ViTImageProcessor,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_clipseg_config(model_name):
|
| 36 |
+
text_config = CLIPSegTextConfig()
|
| 37 |
+
vision_config = CLIPSegVisionConfig(patch_size=16)
|
| 38 |
+
|
| 39 |
+
use_complex_transposed_convolution = True if "refined" in model_name else False
|
| 40 |
+
reduce_dim = 16 if "rd16" in model_name else 64
|
| 41 |
+
|
| 42 |
+
config = CLIPSegConfig.from_text_vision_configs(
|
| 43 |
+
text_config,
|
| 44 |
+
vision_config,
|
| 45 |
+
use_complex_transposed_convolution=use_complex_transposed_convolution,
|
| 46 |
+
reduce_dim=reduce_dim,
|
| 47 |
+
)
|
| 48 |
+
return config
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def rename_key(name):
|
| 52 |
+
# update prefixes
|
| 53 |
+
if "clip_model" in name:
|
| 54 |
+
name = name.replace("clip_model", "clip")
|
| 55 |
+
if "transformer" in name:
|
| 56 |
+
if "visual" in name:
|
| 57 |
+
name = name.replace("visual.transformer", "vision_model")
|
| 58 |
+
else:
|
| 59 |
+
name = name.replace("transformer", "text_model")
|
| 60 |
+
if "resblocks" in name:
|
| 61 |
+
name = name.replace("resblocks", "encoder.layers")
|
| 62 |
+
if "ln_1" in name:
|
| 63 |
+
name = name.replace("ln_1", "layer_norm1")
|
| 64 |
+
if "ln_2" in name:
|
| 65 |
+
name = name.replace("ln_2", "layer_norm2")
|
| 66 |
+
if "c_fc" in name:
|
| 67 |
+
name = name.replace("c_fc", "fc1")
|
| 68 |
+
if "c_proj" in name:
|
| 69 |
+
name = name.replace("c_proj", "fc2")
|
| 70 |
+
if "attn" in name and "self" not in name:
|
| 71 |
+
name = name.replace("attn", "self_attn")
|
| 72 |
+
# text encoder
|
| 73 |
+
if "token_embedding" in name:
|
| 74 |
+
name = name.replace("token_embedding", "text_model.embeddings.token_embedding")
|
| 75 |
+
if "positional_embedding" in name and "visual" not in name:
|
| 76 |
+
name = name.replace("positional_embedding", "text_model.embeddings.position_embedding.weight")
|
| 77 |
+
if "ln_final" in name:
|
| 78 |
+
name = name.replace("ln_final", "text_model.final_layer_norm")
|
| 79 |
+
# vision encoder
|
| 80 |
+
if "visual.class_embedding" in name:
|
| 81 |
+
name = name.replace("visual.class_embedding", "vision_model.embeddings.class_embedding")
|
| 82 |
+
if "visual.conv1" in name:
|
| 83 |
+
name = name.replace("visual.conv1", "vision_model.embeddings.patch_embedding")
|
| 84 |
+
if "visual.positional_embedding" in name:
|
| 85 |
+
name = name.replace("visual.positional_embedding", "vision_model.embeddings.position_embedding.weight")
|
| 86 |
+
if "visual.ln_pre" in name:
|
| 87 |
+
name = name.replace("visual.ln_pre", "vision_model.pre_layrnorm")
|
| 88 |
+
if "visual.ln_post" in name:
|
| 89 |
+
name = name.replace("visual.ln_post", "vision_model.post_layernorm")
|
| 90 |
+
# projection layers
|
| 91 |
+
if "visual.proj" in name:
|
| 92 |
+
name = name.replace("visual.proj", "visual_projection.weight")
|
| 93 |
+
if "text_projection" in name:
|
| 94 |
+
name = name.replace("text_projection", "text_projection.weight")
|
| 95 |
+
# decoder
|
| 96 |
+
if "trans_conv" in name:
|
| 97 |
+
name = name.replace("trans_conv", "transposed_convolution")
|
| 98 |
+
if "film_mul" in name or "film_add" in name or "reduce" in name or "transposed_convolution" in name:
|
| 99 |
+
name = "decoder." + name
|
| 100 |
+
if "blocks" in name:
|
| 101 |
+
name = name.replace("blocks", "decoder.layers")
|
| 102 |
+
if "linear1" in name:
|
| 103 |
+
name = name.replace("linear1", "mlp.fc1")
|
| 104 |
+
if "linear2" in name:
|
| 105 |
+
name = name.replace("linear2", "mlp.fc2")
|
| 106 |
+
if "norm1" in name and "layer_" not in name:
|
| 107 |
+
name = name.replace("norm1", "layer_norm1")
|
| 108 |
+
if "norm2" in name and "layer_" not in name:
|
| 109 |
+
name = name.replace("norm2", "layer_norm2")
|
| 110 |
+
|
| 111 |
+
return name
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def convert_state_dict(orig_state_dict, config):
|
| 115 |
+
for key in orig_state_dict.copy().keys():
|
| 116 |
+
val = orig_state_dict.pop(key)
|
| 117 |
+
|
| 118 |
+
if key.startswith("clip_model") and "attn.in_proj" in key:
|
| 119 |
+
key_split = key.split(".")
|
| 120 |
+
if "visual" in key:
|
| 121 |
+
layer_num = int(key_split[4])
|
| 122 |
+
dim = config.vision_config.hidden_size
|
| 123 |
+
prefix = "vision_model"
|
| 124 |
+
else:
|
| 125 |
+
layer_num = int(key_split[3])
|
| 126 |
+
dim = config.text_config.hidden_size
|
| 127 |
+
prefix = "text_model"
|
| 128 |
+
|
| 129 |
+
if "weight" in key:
|
| 130 |
+
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
|
| 131 |
+
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[
|
| 132 |
+
dim : dim * 2, :
|
| 133 |
+
]
|
| 134 |
+
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
|
| 135 |
+
else:
|
| 136 |
+
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
|
| 137 |
+
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
|
| 138 |
+
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
|
| 139 |
+
elif "self_attn" in key and "out_proj" not in key:
|
| 140 |
+
key_split = key.split(".")
|
| 141 |
+
layer_num = int(key_split[1])
|
| 142 |
+
dim = config.reduce_dim
|
| 143 |
+
if "weight" in key:
|
| 144 |
+
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
|
| 145 |
+
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[dim : dim * 2, :]
|
| 146 |
+
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
|
| 147 |
+
else:
|
| 148 |
+
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
|
| 149 |
+
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
|
| 150 |
+
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
|
| 151 |
+
else:
|
| 152 |
+
new_name = rename_key(key)
|
| 153 |
+
if "visual_projection" in new_name or "text_projection" in new_name:
|
| 154 |
+
val = val.T
|
| 155 |
+
orig_state_dict[new_name] = val
|
| 156 |
+
|
| 157 |
+
return orig_state_dict
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# We will verify our results on an image of cute cats
|
| 161 |
+
def prepare_img():
|
| 162 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 163 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 164 |
+
return image
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def convert_clipseg_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub):
|
| 168 |
+
config = get_clipseg_config(model_name)
|
| 169 |
+
model = CLIPSegForImageSegmentation(config)
|
| 170 |
+
model.eval()
|
| 171 |
+
|
| 172 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
| 173 |
+
|
| 174 |
+
# remove some keys
|
| 175 |
+
for key in state_dict.copy().keys():
|
| 176 |
+
if key.startswith("model"):
|
| 177 |
+
state_dict.pop(key, None)
|
| 178 |
+
|
| 179 |
+
# rename some keys
|
| 180 |
+
state_dict = convert_state_dict(state_dict, config)
|
| 181 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
| 182 |
+
|
| 183 |
+
if missing_keys != ["clip.text_model.embeddings.position_ids", "clip.vision_model.embeddings.position_ids"]:
|
| 184 |
+
raise ValueError("Missing keys that are not expected: {}".format(missing_keys))
|
| 185 |
+
if unexpected_keys != ["decoder.reduce.weight", "decoder.reduce.bias"]:
|
| 186 |
+
raise ValueError(f"Unexpected keys: {unexpected_keys}")
|
| 187 |
+
|
| 188 |
+
image_processor = ViTImageProcessor(size=352)
|
| 189 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
| 190 |
+
processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
| 191 |
+
|
| 192 |
+
image = prepare_img()
|
| 193 |
+
text = ["a glass", "something to fill", "wood", "a jar"]
|
| 194 |
+
|
| 195 |
+
inputs = processor(text=text, images=[image] * len(text), padding="max_length", return_tensors="pt")
|
| 196 |
+
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
outputs = model(**inputs)
|
| 199 |
+
|
| 200 |
+
# verify values
|
| 201 |
+
expected_conditional = torch.tensor([0.1110, -0.1882, 0.1645])
|
| 202 |
+
expected_pooled_output = torch.tensor([0.2692, -0.7197, -0.1328])
|
| 203 |
+
if model_name == "clipseg-rd64-refined":
|
| 204 |
+
expected_masks_slice = torch.tensor(
|
| 205 |
+
[[-10.0407, -9.9431, -10.2646], [-9.9751, -9.7064, -9.9586], [-9.6891, -9.5645, -9.9618]]
|
| 206 |
+
)
|
| 207 |
+
elif model_name == "clipseg-rd64":
|
| 208 |
+
expected_masks_slice = torch.tensor(
|
| 209 |
+
[[-7.2877, -7.2711, -7.2463], [-7.2652, -7.2780, -7.2520], [-7.2239, -7.2204, -7.2001]]
|
| 210 |
+
)
|
| 211 |
+
elif model_name == "clipseg-rd16":
|
| 212 |
+
expected_masks_slice = torch.tensor(
|
| 213 |
+
[[-6.3955, -6.4055, -6.4151], [-6.3911, -6.4033, -6.4100], [-6.3474, -6.3702, -6.3762]]
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
raise ValueError(f"Model name {model_name} not supported.")
|
| 217 |
+
|
| 218 |
+
assert torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3)
|
| 219 |
+
assert torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3)
|
| 220 |
+
assert torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3)
|
| 221 |
+
print("Looks ok!")
|
| 222 |
+
|
| 223 |
+
if pytorch_dump_folder_path is not None:
|
| 224 |
+
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
| 225 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 226 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
| 227 |
+
|
| 228 |
+
if push_to_hub:
|
| 229 |
+
print(f"Pushing model and processor for {model_name} to the hub")
|
| 230 |
+
model.push_to_hub(f"CIDAS/{model_name}")
|
| 231 |
+
processor.push_to_hub(f"CIDAS/{model_name}")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
if __name__ == "__main__":
|
| 235 |
+
parser = argparse.ArgumentParser()
|
| 236 |
+
# Required parameters
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--model_name",
|
| 239 |
+
default="clipseg-rd64",
|
| 240 |
+
type=str,
|
| 241 |
+
choices=["clipseg-rd16", "clipseg-rd64", "clipseg-rd64-refined"],
|
| 242 |
+
help=(
|
| 243 |
+
"Name of the model. Supported models are: clipseg-rd64, clipseg-rd16 and clipseg-rd64-refined (rd meaning"
|
| 244 |
+
" reduce dimension)"
|
| 245 |
+
),
|
| 246 |
+
)
|
| 247 |
+
parser.add_argument(
|
| 248 |
+
"--checkpoint_path",
|
| 249 |
+
default="/Users/nielsrogge/Documents/CLIPSeg/clip_plus_rd64-uni.pth",
|
| 250 |
+
type=str,
|
| 251 |
+
help=(
|
| 252 |
+
"Path to the original checkpoint. Note that the script assumes that the checkpoint includes both CLIP and"
|
| 253 |
+
" the decoder weights."
|
| 254 |
+
),
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
| 258 |
+
)
|
| 259 |
+
parser.add_argument(
|
| 260 |
+
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
args = parser.parse_args()
|
| 264 |
+
convert_clipseg_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
|
docs/transformers/build/lib/transformers/models/clipseg/modeling_clipseg.py
ADDED
|
@@ -0,0 +1,1520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch CLIPSeg model."""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import math
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
+
|
| 26 |
+
from ...activations import ACT2FN
|
| 27 |
+
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
| 28 |
+
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
| 29 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 30 |
+
from ...utils import (
|
| 31 |
+
ModelOutput,
|
| 32 |
+
add_start_docstrings,
|
| 33 |
+
add_start_docstrings_to_model_forward,
|
| 34 |
+
logging,
|
| 35 |
+
replace_return_docstrings,
|
| 36 |
+
torch_int,
|
| 37 |
+
)
|
| 38 |
+
from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
_CHECKPOINT_FOR_DOC = "CIDAS/clipseg-rd64-refined"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# contrastive loss function, adapted from
|
| 48 |
+
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
|
| 49 |
+
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->clipseg
|
| 54 |
+
def clipseg_loss(similarity: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
caption_loss = contrastive_loss(similarity)
|
| 56 |
+
image_loss = contrastive_loss(similarity.t())
|
| 57 |
+
return (caption_loss + image_loss) / 2.0
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->CLIPSeg
|
| 62 |
+
class CLIPSegOutput(ModelOutput):
|
| 63 |
+
"""
|
| 64 |
+
Args:
|
| 65 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
| 66 |
+
Contrastive loss for image-text similarity.
|
| 67 |
+
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
| 68 |
+
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
| 69 |
+
similarity scores.
|
| 70 |
+
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
| 71 |
+
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
| 72 |
+
similarity scores.
|
| 73 |
+
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
| 74 |
+
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegTextModel`].
|
| 75 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
| 76 |
+
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegVisionModel`].
|
| 77 |
+
text_model_output (`BaseModelOutputWithPooling`):
|
| 78 |
+
The output of the [`CLIPSegTextModel`].
|
| 79 |
+
vision_model_output (`BaseModelOutputWithPooling`):
|
| 80 |
+
The output of the [`CLIPSegVisionModel`].
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
loss: Optional[torch.FloatTensor] = None
|
| 84 |
+
logits_per_image: Optional[torch.FloatTensor] = None
|
| 85 |
+
logits_per_text: Optional[torch.FloatTensor] = None
|
| 86 |
+
text_embeds: Optional[torch.FloatTensor] = None
|
| 87 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
| 88 |
+
text_model_output: BaseModelOutputWithPooling = None
|
| 89 |
+
vision_model_output: BaseModelOutputWithPooling = None
|
| 90 |
+
|
| 91 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 92 |
+
return tuple(
|
| 93 |
+
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
| 94 |
+
for k in self.keys()
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class CLIPSegDecoderOutput(ModelOutput):
|
| 100 |
+
"""
|
| 101 |
+
Args:
|
| 102 |
+
logits (`torch.FloatTensor` of shape `(batch_size, height, width)`):
|
| 103 |
+
Classification scores for each pixel.
|
| 104 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 105 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 106 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 107 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 108 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 109 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 110 |
+
the self-attention heads.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
logits: Optional[torch.FloatTensor] = None
|
| 114 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 115 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@dataclass
|
| 119 |
+
class CLIPSegImageSegmentationOutput(ModelOutput):
|
| 120 |
+
"""
|
| 121 |
+
Args:
|
| 122 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
| 123 |
+
Contrastive loss for image-text similarity.
|
| 124 |
+
...
|
| 125 |
+
vision_model_output (`BaseModelOutputWithPooling`):
|
| 126 |
+
The output of the [`CLIPSegVisionModel`].
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
loss: Optional[torch.FloatTensor] = None
|
| 130 |
+
logits: Optional[torch.FloatTensor] = None
|
| 131 |
+
conditional_embeddings: Optional[torch.FloatTensor] = None
|
| 132 |
+
pooled_output: Optional[torch.FloatTensor] = None
|
| 133 |
+
vision_model_output: BaseModelOutputWithPooling = None
|
| 134 |
+
decoder_output: CLIPSegDecoderOutput = None
|
| 135 |
+
|
| 136 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 137 |
+
return tuple(
|
| 138 |
+
self[k] if k not in ["vision_model_output", "decoder_output"] else getattr(self, k).to_tuple()
|
| 139 |
+
for k in self.keys()
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class CLIPSegVisionEmbeddings(nn.Module):
|
| 144 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.__init__ with CLIP->CLIPSeg
|
| 145 |
+
def __init__(self, config: CLIPSegVisionConfig):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.config = config
|
| 148 |
+
self.embed_dim = config.hidden_size
|
| 149 |
+
self.image_size = config.image_size
|
| 150 |
+
self.patch_size = config.patch_size
|
| 151 |
+
|
| 152 |
+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
| 153 |
+
|
| 154 |
+
self.patch_embedding = nn.Conv2d(
|
| 155 |
+
in_channels=config.num_channels,
|
| 156 |
+
out_channels=self.embed_dim,
|
| 157 |
+
kernel_size=self.patch_size,
|
| 158 |
+
stride=self.patch_size,
|
| 159 |
+
bias=False,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 163 |
+
self.num_positions = self.num_patches + 1
|
| 164 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
| 165 |
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
| 166 |
+
|
| 167 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
| 168 |
+
"""
|
| 169 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
| 170 |
+
images. This method is also adapted to support torch.jit tracing.
|
| 171 |
+
|
| 172 |
+
Adapted from:
|
| 173 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
| 174 |
+
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
num_patches = embeddings.shape[1] - 1
|
| 178 |
+
position_embedding = self.position_embedding.weight.unsqueeze(0)
|
| 179 |
+
num_positions = position_embedding.shape[1] - 1
|
| 180 |
+
|
| 181 |
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
| 182 |
+
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
| 183 |
+
return self.position_embedding(self.position_ids)
|
| 184 |
+
|
| 185 |
+
class_pos_embed = position_embedding[:, :1]
|
| 186 |
+
patch_pos_embed = position_embedding[:, 1:]
|
| 187 |
+
|
| 188 |
+
dim = embeddings.shape[-1]
|
| 189 |
+
|
| 190 |
+
new_height = height // self.patch_size
|
| 191 |
+
new_width = width // self.patch_size
|
| 192 |
+
|
| 193 |
+
sqrt_num_positions = torch_int(num_positions**0.5)
|
| 194 |
+
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
| 195 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 196 |
+
|
| 197 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 198 |
+
patch_pos_embed,
|
| 199 |
+
size=(new_height, new_width),
|
| 200 |
+
mode="bicubic",
|
| 201 |
+
align_corners=False,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 205 |
+
|
| 206 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 207 |
+
|
| 208 |
+
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=True) -> torch.Tensor:
|
| 209 |
+
batch_size, _, height, width = pixel_values.shape
|
| 210 |
+
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
|
| 213 |
+
)
|
| 214 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
| 215 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 216 |
+
|
| 217 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
| 218 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 219 |
+
if interpolate_pos_encoding:
|
| 220 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
| 221 |
+
else:
|
| 222 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
| 223 |
+
return embeddings
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->CLIPSeg
|
| 227 |
+
class CLIPSegTextEmbeddings(nn.Module):
|
| 228 |
+
def __init__(self, config: CLIPSegTextConfig):
|
| 229 |
+
super().__init__()
|
| 230 |
+
embed_dim = config.hidden_size
|
| 231 |
+
|
| 232 |
+
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
| 233 |
+
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
| 234 |
+
|
| 235 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 236 |
+
self.register_buffer(
|
| 237 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def forward(
|
| 241 |
+
self,
|
| 242 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 243 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 244 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 245 |
+
) -> torch.Tensor:
|
| 246 |
+
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
| 247 |
+
max_position_embedding = self.position_embedding.weight.shape[0]
|
| 248 |
+
|
| 249 |
+
if seq_length > max_position_embedding:
|
| 250 |
+
raise ValueError(
|
| 251 |
+
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
|
| 252 |
+
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if position_ids is None:
|
| 256 |
+
position_ids = self.position_ids[:, :seq_length]
|
| 257 |
+
|
| 258 |
+
if inputs_embeds is None:
|
| 259 |
+
inputs_embeds = self.token_embedding(input_ids)
|
| 260 |
+
|
| 261 |
+
position_embeddings = self.position_embedding(position_ids)
|
| 262 |
+
embeddings = inputs_embeds + position_embeddings
|
| 263 |
+
|
| 264 |
+
return embeddings
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
|
| 268 |
+
def eager_attention_forward(
|
| 269 |
+
module: nn.Module,
|
| 270 |
+
query: torch.Tensor,
|
| 271 |
+
key: torch.Tensor,
|
| 272 |
+
value: torch.Tensor,
|
| 273 |
+
attention_mask: Optional[torch.Tensor],
|
| 274 |
+
scaling: float,
|
| 275 |
+
dropout: float = 0.0,
|
| 276 |
+
**kwargs,
|
| 277 |
+
):
|
| 278 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
| 279 |
+
if attention_mask is not None:
|
| 280 |
+
attn_weights = attn_weights + attention_mask
|
| 281 |
+
|
| 282 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 283 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 284 |
+
|
| 285 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 286 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 287 |
+
return attn_output, attn_weights
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class CLIPSegAttention(nn.Module):
|
| 291 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 292 |
+
|
| 293 |
+
def __init__(self, config: Union[CLIPSegVisionConfig, CLIPSegTextConfig]):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.config = config
|
| 296 |
+
self.embed_dim = config.hidden_size
|
| 297 |
+
self.num_heads = config.num_attention_heads
|
| 298 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 299 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 300 |
+
raise ValueError(
|
| 301 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 302 |
+
f" {self.num_heads})."
|
| 303 |
+
)
|
| 304 |
+
self.scale = self.head_dim**-0.5
|
| 305 |
+
self.dropout = config.attention_dropout
|
| 306 |
+
self.is_causal = False
|
| 307 |
+
|
| 308 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 309 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 310 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 311 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 312 |
+
|
| 313 |
+
def forward(
|
| 314 |
+
self,
|
| 315 |
+
hidden_states: torch.Tensor,
|
| 316 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 317 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
| 318 |
+
output_attentions: Optional[bool] = False,
|
| 319 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 320 |
+
"""Input shape: Batch x Time x Channel"""
|
| 321 |
+
|
| 322 |
+
batch_size, seq_length, embed_dim = hidden_states.shape
|
| 323 |
+
|
| 324 |
+
queries = self.q_proj(hidden_states)
|
| 325 |
+
keys = self.k_proj(hidden_states)
|
| 326 |
+
values = self.v_proj(hidden_states)
|
| 327 |
+
|
| 328 |
+
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
| 329 |
+
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
| 330 |
+
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
| 331 |
+
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
| 332 |
+
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
|
| 333 |
+
if self.config._attn_implementation != "flash_attention_2":
|
| 334 |
+
if attention_mask is not None and causal_attention_mask is not None:
|
| 335 |
+
attention_mask = attention_mask + causal_attention_mask
|
| 336 |
+
elif causal_attention_mask is not None:
|
| 337 |
+
attention_mask = causal_attention_mask
|
| 338 |
+
else:
|
| 339 |
+
self.is_causal = causal_attention_mask is not None
|
| 340 |
+
|
| 341 |
+
attention_interface: Callable = eager_attention_forward
|
| 342 |
+
if self.config._attn_implementation != "eager":
|
| 343 |
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
| 344 |
+
logger.warning_once(
|
| 345 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 346 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 347 |
+
)
|
| 348 |
+
else:
|
| 349 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 350 |
+
|
| 351 |
+
attn_output, attn_weights = attention_interface(
|
| 352 |
+
self,
|
| 353 |
+
queries,
|
| 354 |
+
keys,
|
| 355 |
+
values,
|
| 356 |
+
attention_mask,
|
| 357 |
+
is_causal=self.is_causal,
|
| 358 |
+
scaling=self.scale,
|
| 359 |
+
dropout=0.0 if not self.training else self.dropout,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
| 363 |
+
attn_output = self.out_proj(attn_output)
|
| 364 |
+
if not output_attentions:
|
| 365 |
+
attn_weights = None
|
| 366 |
+
|
| 367 |
+
return attn_output, attn_weights
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->CLIPSeg
|
| 371 |
+
class CLIPSegMLP(nn.Module):
|
| 372 |
+
def __init__(self, config):
|
| 373 |
+
super().__init__()
|
| 374 |
+
self.config = config
|
| 375 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 376 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 377 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 378 |
+
|
| 379 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 380 |
+
hidden_states = self.fc1(hidden_states)
|
| 381 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 382 |
+
hidden_states = self.fc2(hidden_states)
|
| 383 |
+
return hidden_states
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->CLIPSeg
|
| 387 |
+
class CLIPSegEncoderLayer(nn.Module):
|
| 388 |
+
def __init__(self, config: CLIPSegConfig):
|
| 389 |
+
super().__init__()
|
| 390 |
+
self.embed_dim = config.hidden_size
|
| 391 |
+
self.self_attn = CLIPSegAttention(config)
|
| 392 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 393 |
+
self.mlp = CLIPSegMLP(config)
|
| 394 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 395 |
+
|
| 396 |
+
def forward(
|
| 397 |
+
self,
|
| 398 |
+
hidden_states: torch.Tensor,
|
| 399 |
+
attention_mask: torch.Tensor,
|
| 400 |
+
causal_attention_mask: torch.Tensor,
|
| 401 |
+
output_attentions: Optional[bool] = False,
|
| 402 |
+
) -> Tuple[torch.FloatTensor]:
|
| 403 |
+
"""
|
| 404 |
+
Args:
|
| 405 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 406 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 407 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 408 |
+
`(config.encoder_attention_heads,)`.
|
| 409 |
+
output_attentions (`bool`, *optional*):
|
| 410 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 411 |
+
returned tensors for more detail.
|
| 412 |
+
"""
|
| 413 |
+
residual = hidden_states
|
| 414 |
+
|
| 415 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 416 |
+
hidden_states, attn_weights = self.self_attn(
|
| 417 |
+
hidden_states=hidden_states,
|
| 418 |
+
attention_mask=attention_mask,
|
| 419 |
+
causal_attention_mask=causal_attention_mask,
|
| 420 |
+
output_attentions=output_attentions,
|
| 421 |
+
)
|
| 422 |
+
hidden_states = residual + hidden_states
|
| 423 |
+
|
| 424 |
+
residual = hidden_states
|
| 425 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 426 |
+
hidden_states = self.mlp(hidden_states)
|
| 427 |
+
hidden_states = residual + hidden_states
|
| 428 |
+
|
| 429 |
+
outputs = (hidden_states,)
|
| 430 |
+
|
| 431 |
+
if output_attentions:
|
| 432 |
+
outputs += (attn_weights,)
|
| 433 |
+
|
| 434 |
+
return outputs
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class CLIPSegPreTrainedModel(PreTrainedModel):
|
| 438 |
+
"""
|
| 439 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 440 |
+
models.
|
| 441 |
+
"""
|
| 442 |
+
|
| 443 |
+
config_class = CLIPSegConfig
|
| 444 |
+
base_model_prefix = "clip"
|
| 445 |
+
supports_gradient_checkpointing = True
|
| 446 |
+
|
| 447 |
+
def _init_weights(self, module):
|
| 448 |
+
"""Initialize the weights"""
|
| 449 |
+
factor = self.config.initializer_factor
|
| 450 |
+
if isinstance(module, CLIPSegTextEmbeddings):
|
| 451 |
+
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
| 452 |
+
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
| 453 |
+
elif isinstance(module, CLIPSegVisionEmbeddings):
|
| 454 |
+
factor = self.config.initializer_factor
|
| 455 |
+
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
| 456 |
+
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
| 457 |
+
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
| 458 |
+
elif isinstance(module, CLIPSegAttention):
|
| 459 |
+
factor = self.config.initializer_factor
|
| 460 |
+
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
| 461 |
+
out_proj_std = (module.embed_dim**-0.5) * factor
|
| 462 |
+
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
|
| 463 |
+
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
|
| 464 |
+
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
|
| 465 |
+
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
| 466 |
+
elif isinstance(module, CLIPSegMLP):
|
| 467 |
+
factor = self.config.initializer_factor
|
| 468 |
+
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
| 469 |
+
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
| 470 |
+
nn.init.normal_(module.fc1.weight, std=fc_std)
|
| 471 |
+
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
| 472 |
+
elif isinstance(module, CLIPSegModel):
|
| 473 |
+
nn.init.normal_(
|
| 474 |
+
module.text_projection.weight,
|
| 475 |
+
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
| 476 |
+
)
|
| 477 |
+
nn.init.normal_(
|
| 478 |
+
module.visual_projection.weight,
|
| 479 |
+
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
if isinstance(module, nn.LayerNorm):
|
| 483 |
+
module.bias.data.zero_()
|
| 484 |
+
module.weight.data.fill_(1.0)
|
| 485 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 486 |
+
module.bias.data.zero_()
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
CLIPSEG_START_DOCSTRING = r"""
|
| 490 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 491 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 492 |
+
behavior.
|
| 493 |
+
|
| 494 |
+
Parameters:
|
| 495 |
+
config ([`CLIPSegConfig`]): Model configuration class with all the parameters of the model.
|
| 496 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 497 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 498 |
+
"""
|
| 499 |
+
|
| 500 |
+
CLIPSEG_TEXT_INPUTS_DOCSTRING = r"""
|
| 501 |
+
Args:
|
| 502 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 503 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 504 |
+
it.
|
| 505 |
+
|
| 506 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 507 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 508 |
+
|
| 509 |
+
[What are input IDs?](../glossary#input-ids)
|
| 510 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 511 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 512 |
+
|
| 513 |
+
- 1 for tokens that are **not masked**,
|
| 514 |
+
- 0 for tokens that are **masked**.
|
| 515 |
+
|
| 516 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 517 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 518 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 519 |
+
config.max_position_embeddings - 1]`.
|
| 520 |
+
|
| 521 |
+
[What are position IDs?](../glossary#position-ids)
|
| 522 |
+
output_attentions (`bool`, *optional*):
|
| 523 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 524 |
+
tensors for more detail.
|
| 525 |
+
output_hidden_states (`bool`, *optional*):
|
| 526 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 527 |
+
more detail.
|
| 528 |
+
return_dict (`bool`, *optional*):
|
| 529 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 530 |
+
"""
|
| 531 |
+
|
| 532 |
+
CLIPSEG_VISION_INPUTS_DOCSTRING = r"""
|
| 533 |
+
Args:
|
| 534 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 535 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
| 536 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
| 537 |
+
output_attentions (`bool`, *optional*):
|
| 538 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 539 |
+
tensors for more detail.
|
| 540 |
+
output_hidden_states (`bool`, *optional*):
|
| 541 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 542 |
+
more detail.
|
| 543 |
+
interpolate_pos_encoding (`bool`, *optional*, defaults to `True`):
|
| 544 |
+
Whether to interpolate the pre-trained position encodings.
|
| 545 |
+
return_dict (`bool`, *optional*):
|
| 546 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 547 |
+
"""
|
| 548 |
+
|
| 549 |
+
CLIPSEG_INPUTS_DOCSTRING = r"""
|
| 550 |
+
Args:
|
| 551 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 552 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 553 |
+
it.
|
| 554 |
+
|
| 555 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 556 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 557 |
+
|
| 558 |
+
[What are input IDs?](../glossary#input-ids)
|
| 559 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 560 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 561 |
+
|
| 562 |
+
- 1 for tokens that are **not masked**,
|
| 563 |
+
- 0 for tokens that are **masked**.
|
| 564 |
+
|
| 565 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 566 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 567 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 568 |
+
config.max_position_embeddings - 1]`.
|
| 569 |
+
|
| 570 |
+
[What are position IDs?](../glossary#position-ids)
|
| 571 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 572 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
| 573 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
| 574 |
+
return_loss (`bool`, *optional*):
|
| 575 |
+
Whether or not to return the contrastive loss.
|
| 576 |
+
output_attentions (`bool`, *optional*):
|
| 577 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 578 |
+
tensors for more detail.
|
| 579 |
+
output_hidden_states (`bool`, *optional*):
|
| 580 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 581 |
+
more detail.
|
| 582 |
+
interpolate_pos_encoding (`bool`, *optional*, defaults to `True`):
|
| 583 |
+
Whether to interpolate the pre-trained position encodings.
|
| 584 |
+
return_dict (`bool`, *optional*):
|
| 585 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 586 |
+
"""
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->CLIPSeg
|
| 590 |
+
class CLIPSegEncoder(nn.Module):
|
| 591 |
+
"""
|
| 592 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 593 |
+
[`CLIPSegEncoderLayer`].
|
| 594 |
+
|
| 595 |
+
Args:
|
| 596 |
+
config: CLIPSegConfig
|
| 597 |
+
"""
|
| 598 |
+
|
| 599 |
+
def __init__(self, config: CLIPSegConfig):
|
| 600 |
+
super().__init__()
|
| 601 |
+
self.config = config
|
| 602 |
+
self.layers = nn.ModuleList([CLIPSegEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 603 |
+
self.gradient_checkpointing = False
|
| 604 |
+
|
| 605 |
+
def forward(
|
| 606 |
+
self,
|
| 607 |
+
inputs_embeds,
|
| 608 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 609 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
| 610 |
+
output_attentions: Optional[bool] = None,
|
| 611 |
+
output_hidden_states: Optional[bool] = None,
|
| 612 |
+
return_dict: Optional[bool] = None,
|
| 613 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 614 |
+
r"""
|
| 615 |
+
Args:
|
| 616 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 617 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 618 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
| 619 |
+
than the model's internal embedding lookup matrix.
|
| 620 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 621 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 622 |
+
|
| 623 |
+
- 1 for tokens that are **not masked**,
|
| 624 |
+
- 0 for tokens that are **masked**.
|
| 625 |
+
|
| 626 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 627 |
+
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 628 |
+
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
| 629 |
+
|
| 630 |
+
- 1 for tokens that are **not masked**,
|
| 631 |
+
- 0 for tokens that are **masked**.
|
| 632 |
+
|
| 633 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 634 |
+
output_attentions (`bool`, *optional*):
|
| 635 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 636 |
+
returned tensors for more detail.
|
| 637 |
+
output_hidden_states (`bool`, *optional*):
|
| 638 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 639 |
+
for more detail.
|
| 640 |
+
return_dict (`bool`, *optional*):
|
| 641 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 642 |
+
"""
|
| 643 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 644 |
+
output_hidden_states = (
|
| 645 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 646 |
+
)
|
| 647 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 648 |
+
|
| 649 |
+
encoder_states = () if output_hidden_states else None
|
| 650 |
+
all_attentions = () if output_attentions else None
|
| 651 |
+
|
| 652 |
+
hidden_states = inputs_embeds
|
| 653 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 654 |
+
if output_hidden_states:
|
| 655 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 656 |
+
if self.gradient_checkpointing and self.training:
|
| 657 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 658 |
+
encoder_layer.__call__,
|
| 659 |
+
hidden_states,
|
| 660 |
+
attention_mask,
|
| 661 |
+
causal_attention_mask,
|
| 662 |
+
output_attentions,
|
| 663 |
+
)
|
| 664 |
+
else:
|
| 665 |
+
layer_outputs = encoder_layer(
|
| 666 |
+
hidden_states,
|
| 667 |
+
attention_mask,
|
| 668 |
+
causal_attention_mask,
|
| 669 |
+
output_attentions=output_attentions,
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
hidden_states = layer_outputs[0]
|
| 673 |
+
|
| 674 |
+
if output_attentions:
|
| 675 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 676 |
+
|
| 677 |
+
if output_hidden_states:
|
| 678 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 679 |
+
|
| 680 |
+
if not return_dict:
|
| 681 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
| 682 |
+
return BaseModelOutput(
|
| 683 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
class CLIPSegTextTransformer(nn.Module):
|
| 688 |
+
def __init__(self, config: CLIPSegTextConfig):
|
| 689 |
+
super().__init__()
|
| 690 |
+
self.config = config
|
| 691 |
+
embed_dim = config.hidden_size
|
| 692 |
+
self.embeddings = CLIPSegTextEmbeddings(config)
|
| 693 |
+
self.encoder = CLIPSegEncoder(config)
|
| 694 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 695 |
+
|
| 696 |
+
# For `pooled_output` computation
|
| 697 |
+
self.eos_token_id = config.eos_token_id
|
| 698 |
+
|
| 699 |
+
@add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING)
|
| 700 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegTextConfig)
|
| 701 |
+
# Adapted from transformers.models.clip.modeling_clip.CLIPTextTransformer.forward with clip->clipseg, CLIP->CLIPSeg
|
| 702 |
+
def forward(
|
| 703 |
+
self,
|
| 704 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 705 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 706 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 707 |
+
output_attentions: Optional[bool] = None,
|
| 708 |
+
output_hidden_states: Optional[bool] = None,
|
| 709 |
+
return_dict: Optional[bool] = None,
|
| 710 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 711 |
+
r"""
|
| 712 |
+
Returns:
|
| 713 |
+
|
| 714 |
+
"""
|
| 715 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 716 |
+
output_hidden_states = (
|
| 717 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 718 |
+
)
|
| 719 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 720 |
+
|
| 721 |
+
if input_ids is None:
|
| 722 |
+
raise ValueError("You have to specify input_ids")
|
| 723 |
+
|
| 724 |
+
input_shape = input_ids.size()
|
| 725 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 726 |
+
|
| 727 |
+
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
| 728 |
+
|
| 729 |
+
# CLIPSeg's text model uses causal mask, prepare it here.
|
| 730 |
+
# https://github.com/openai/CLIPSeg/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clipseg/model.py#L324
|
| 731 |
+
causal_attention_mask = _create_4d_causal_attention_mask(
|
| 732 |
+
input_shape, hidden_states.dtype, device=hidden_states.device
|
| 733 |
+
)
|
| 734 |
+
# expand attention_mask
|
| 735 |
+
if attention_mask is not None:
|
| 736 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 737 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
| 738 |
+
|
| 739 |
+
encoder_outputs = self.encoder(
|
| 740 |
+
inputs_embeds=hidden_states,
|
| 741 |
+
attention_mask=attention_mask,
|
| 742 |
+
causal_attention_mask=causal_attention_mask,
|
| 743 |
+
output_attentions=output_attentions,
|
| 744 |
+
output_hidden_states=output_hidden_states,
|
| 745 |
+
return_dict=return_dict,
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
last_hidden_state = encoder_outputs[0]
|
| 749 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
| 750 |
+
|
| 751 |
+
if self.eos_token_id == 2:
|
| 752 |
+
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
| 753 |
+
# A CLIPSeg model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
| 754 |
+
# ------------------------------------------------------------
|
| 755 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
| 756 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 757 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
| 758 |
+
pooled_output = last_hidden_state[
|
| 759 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
| 760 |
+
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
| 761 |
+
]
|
| 762 |
+
else:
|
| 763 |
+
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
| 764 |
+
pooled_output = last_hidden_state[
|
| 765 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
| 766 |
+
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
| 767 |
+
# Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
|
| 768 |
+
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
|
| 769 |
+
.int()
|
| 770 |
+
.argmax(dim=-1),
|
| 771 |
+
]
|
| 772 |
+
|
| 773 |
+
if not return_dict:
|
| 774 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 775 |
+
|
| 776 |
+
return BaseModelOutputWithPooling(
|
| 777 |
+
last_hidden_state=last_hidden_state,
|
| 778 |
+
pooler_output=pooled_output,
|
| 779 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 780 |
+
attentions=encoder_outputs.attentions,
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
class CLIPSegTextModel(CLIPSegPreTrainedModel):
|
| 785 |
+
config_class = CLIPSegTextConfig
|
| 786 |
+
|
| 787 |
+
_no_split_modules = ["CLIPSegTextEmbeddings", "CLIPSegEncoderLayer"]
|
| 788 |
+
|
| 789 |
+
def __init__(self, config: CLIPSegTextConfig):
|
| 790 |
+
super().__init__(config)
|
| 791 |
+
self.text_model = CLIPSegTextTransformer(config)
|
| 792 |
+
# Initialize weights and apply final processing
|
| 793 |
+
self.post_init()
|
| 794 |
+
|
| 795 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 796 |
+
return self.text_model.embeddings.token_embedding
|
| 797 |
+
|
| 798 |
+
def set_input_embeddings(self, value):
|
| 799 |
+
self.text_model.embeddings.token_embedding = value
|
| 800 |
+
|
| 801 |
+
@add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING)
|
| 802 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegTextConfig)
|
| 803 |
+
def forward(
|
| 804 |
+
self,
|
| 805 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 806 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 807 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 808 |
+
output_attentions: Optional[bool] = None,
|
| 809 |
+
output_hidden_states: Optional[bool] = None,
|
| 810 |
+
return_dict: Optional[bool] = None,
|
| 811 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 812 |
+
r"""
|
| 813 |
+
Returns:
|
| 814 |
+
|
| 815 |
+
Examples:
|
| 816 |
+
|
| 817 |
+
```python
|
| 818 |
+
>>> from transformers import AutoTokenizer, CLIPSegTextModel
|
| 819 |
+
|
| 820 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 821 |
+
>>> model = CLIPSegTextModel.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 822 |
+
|
| 823 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
| 824 |
+
|
| 825 |
+
>>> outputs = model(**inputs)
|
| 826 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
| 827 |
+
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
| 828 |
+
```"""
|
| 829 |
+
return self.text_model(
|
| 830 |
+
input_ids=input_ids,
|
| 831 |
+
attention_mask=attention_mask,
|
| 832 |
+
position_ids=position_ids,
|
| 833 |
+
output_attentions=output_attentions,
|
| 834 |
+
output_hidden_states=output_hidden_states,
|
| 835 |
+
return_dict=return_dict,
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
class CLIPSegVisionTransformer(nn.Module):
|
| 840 |
+
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIP->CLIPSeg
|
| 841 |
+
def __init__(self, config: CLIPSegVisionConfig):
|
| 842 |
+
super().__init__()
|
| 843 |
+
self.config = config
|
| 844 |
+
embed_dim = config.hidden_size
|
| 845 |
+
|
| 846 |
+
self.embeddings = CLIPSegVisionEmbeddings(config)
|
| 847 |
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 848 |
+
self.encoder = CLIPSegEncoder(config)
|
| 849 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 850 |
+
|
| 851 |
+
@add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING)
|
| 852 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegVisionConfig)
|
| 853 |
+
def forward(
|
| 854 |
+
self,
|
| 855 |
+
pixel_values: Optional[torch.FloatTensor],
|
| 856 |
+
output_attentions: Optional[bool] = None,
|
| 857 |
+
output_hidden_states: Optional[bool] = None,
|
| 858 |
+
return_dict: Optional[bool] = None,
|
| 859 |
+
interpolate_pos_encoding: Optional[bool] = True,
|
| 860 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 861 |
+
r"""
|
| 862 |
+
Returns:
|
| 863 |
+
|
| 864 |
+
"""
|
| 865 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 866 |
+
output_hidden_states = (
|
| 867 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 868 |
+
)
|
| 869 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 870 |
+
|
| 871 |
+
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 872 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
| 873 |
+
|
| 874 |
+
encoder_outputs = self.encoder(
|
| 875 |
+
inputs_embeds=hidden_states,
|
| 876 |
+
output_attentions=output_attentions,
|
| 877 |
+
output_hidden_states=output_hidden_states,
|
| 878 |
+
return_dict=return_dict,
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
last_hidden_state = encoder_outputs[0]
|
| 882 |
+
pooled_output = last_hidden_state[:, 0, :]
|
| 883 |
+
pooled_output = self.post_layernorm(pooled_output)
|
| 884 |
+
|
| 885 |
+
if not return_dict:
|
| 886 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 887 |
+
|
| 888 |
+
return BaseModelOutputWithPooling(
|
| 889 |
+
last_hidden_state=last_hidden_state,
|
| 890 |
+
pooler_output=pooled_output,
|
| 891 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 892 |
+
attentions=encoder_outputs.attentions,
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
class CLIPSegVisionModel(CLIPSegPreTrainedModel):
|
| 897 |
+
config_class = CLIPSegVisionConfig
|
| 898 |
+
main_input_name = "pixel_values"
|
| 899 |
+
|
| 900 |
+
def __init__(self, config: CLIPSegVisionConfig):
|
| 901 |
+
super().__init__(config)
|
| 902 |
+
self.vision_model = CLIPSegVisionTransformer(config)
|
| 903 |
+
# Initialize weights and apply final processing
|
| 904 |
+
self.post_init()
|
| 905 |
+
|
| 906 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 907 |
+
return self.vision_model.embeddings.patch_embedding
|
| 908 |
+
|
| 909 |
+
@add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING)
|
| 910 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegVisionConfig)
|
| 911 |
+
def forward(
|
| 912 |
+
self,
|
| 913 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 914 |
+
output_attentions: Optional[bool] = None,
|
| 915 |
+
output_hidden_states: Optional[bool] = None,
|
| 916 |
+
interpolate_pos_encoding: Optional[bool] = True,
|
| 917 |
+
return_dict: Optional[bool] = None,
|
| 918 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 919 |
+
r"""
|
| 920 |
+
Returns:
|
| 921 |
+
|
| 922 |
+
Examples:
|
| 923 |
+
|
| 924 |
+
```python
|
| 925 |
+
>>> from PIL import Image
|
| 926 |
+
>>> import requests
|
| 927 |
+
>>> from transformers import AutoProcessor, CLIPSegVisionModel
|
| 928 |
+
|
| 929 |
+
>>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 930 |
+
>>> model = CLIPSegVisionModel.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 931 |
+
|
| 932 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 933 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 934 |
+
|
| 935 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
| 936 |
+
|
| 937 |
+
>>> outputs = model(**inputs)
|
| 938 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
| 939 |
+
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
| 940 |
+
```"""
|
| 941 |
+
return self.vision_model(
|
| 942 |
+
pixel_values=pixel_values,
|
| 943 |
+
output_attentions=output_attentions,
|
| 944 |
+
output_hidden_states=output_hidden_states,
|
| 945 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 946 |
+
return_dict=return_dict,
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
@add_start_docstrings(CLIPSEG_START_DOCSTRING)
|
| 951 |
+
class CLIPSegModel(CLIPSegPreTrainedModel):
|
| 952 |
+
config_class = CLIPSegConfig
|
| 953 |
+
|
| 954 |
+
def __init__(self, config: CLIPSegConfig):
|
| 955 |
+
super().__init__(config)
|
| 956 |
+
|
| 957 |
+
if not isinstance(config.text_config, CLIPSegTextConfig):
|
| 958 |
+
raise TypeError(
|
| 959 |
+
"config.text_config is expected to be of type CLIPSegTextConfig but is of type"
|
| 960 |
+
f" {type(config.text_config)}."
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
if not isinstance(config.vision_config, CLIPSegVisionConfig):
|
| 964 |
+
raise TypeError(
|
| 965 |
+
"config.vision_config is expected to be of type CLIPSegVisionConfig but is of type"
|
| 966 |
+
f" {type(config.vision_config)}."
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
text_config = config.text_config
|
| 970 |
+
vision_config = config.vision_config
|
| 971 |
+
|
| 972 |
+
self.projection_dim = config.projection_dim
|
| 973 |
+
self.text_embed_dim = text_config.hidden_size
|
| 974 |
+
self.vision_embed_dim = vision_config.hidden_size
|
| 975 |
+
|
| 976 |
+
self.text_model = CLIPSegTextTransformer(text_config)
|
| 977 |
+
self.vision_model = CLIPSegVisionTransformer(vision_config)
|
| 978 |
+
|
| 979 |
+
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
| 980 |
+
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
| 981 |
+
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
| 982 |
+
|
| 983 |
+
# Initialize weights and apply final processing
|
| 984 |
+
self.post_init()
|
| 985 |
+
|
| 986 |
+
@add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING)
|
| 987 |
+
def get_text_features(
|
| 988 |
+
self,
|
| 989 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 990 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 991 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 992 |
+
output_attentions: Optional[bool] = None,
|
| 993 |
+
output_hidden_states: Optional[bool] = None,
|
| 994 |
+
return_dict: Optional[bool] = None,
|
| 995 |
+
) -> torch.FloatTensor:
|
| 996 |
+
r"""
|
| 997 |
+
Returns:
|
| 998 |
+
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
| 999 |
+
applying the projection layer to the pooled output of [`CLIPSegTextModel`].
|
| 1000 |
+
|
| 1001 |
+
Examples:
|
| 1002 |
+
|
| 1003 |
+
```python
|
| 1004 |
+
>>> from transformers import AutoTokenizer, CLIPSegModel
|
| 1005 |
+
|
| 1006 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 1007 |
+
>>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 1008 |
+
|
| 1009 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
| 1010 |
+
>>> text_features = model.get_text_features(**inputs)
|
| 1011 |
+
```"""
|
| 1012 |
+
# Use CLIPSEG model's config for some fields (if specified) instead of those of vision & text components.
|
| 1013 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1014 |
+
output_hidden_states = (
|
| 1015 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1016 |
+
)
|
| 1017 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1018 |
+
|
| 1019 |
+
text_outputs = self.text_model(
|
| 1020 |
+
input_ids=input_ids,
|
| 1021 |
+
attention_mask=attention_mask,
|
| 1022 |
+
position_ids=position_ids,
|
| 1023 |
+
output_attentions=output_attentions,
|
| 1024 |
+
output_hidden_states=output_hidden_states,
|
| 1025 |
+
return_dict=return_dict,
|
| 1026 |
+
)
|
| 1027 |
+
|
| 1028 |
+
pooled_output = text_outputs[1]
|
| 1029 |
+
text_features = self.text_projection(pooled_output)
|
| 1030 |
+
|
| 1031 |
+
return text_features
|
| 1032 |
+
|
| 1033 |
+
@add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING)
|
| 1034 |
+
def get_image_features(
|
| 1035 |
+
self,
|
| 1036 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1037 |
+
output_attentions: Optional[bool] = None,
|
| 1038 |
+
output_hidden_states: Optional[bool] = None,
|
| 1039 |
+
interpolate_pos_encoding: bool = True,
|
| 1040 |
+
return_dict: Optional[bool] = None,
|
| 1041 |
+
) -> torch.FloatTensor:
|
| 1042 |
+
r"""
|
| 1043 |
+
Returns:
|
| 1044 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
| 1045 |
+
applying the projection layer to the pooled output of [`CLIPSegVisionModel`].
|
| 1046 |
+
|
| 1047 |
+
Examples:
|
| 1048 |
+
|
| 1049 |
+
```python
|
| 1050 |
+
>>> from PIL import Image
|
| 1051 |
+
>>> import requests
|
| 1052 |
+
>>> from transformers import AutoProcessor, CLIPSegModel
|
| 1053 |
+
|
| 1054 |
+
>>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 1055 |
+
>>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 1056 |
+
|
| 1057 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1058 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1059 |
+
|
| 1060 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
| 1061 |
+
|
| 1062 |
+
>>> image_features = model.get_image_features(**inputs)
|
| 1063 |
+
```"""
|
| 1064 |
+
# Use CLIPSEG model's config for some fields (if specified) instead of those of vision & text components.
|
| 1065 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1066 |
+
output_hidden_states = (
|
| 1067 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1068 |
+
)
|
| 1069 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1070 |
+
|
| 1071 |
+
vision_outputs = self.vision_model(
|
| 1072 |
+
pixel_values=pixel_values,
|
| 1073 |
+
output_attentions=output_attentions,
|
| 1074 |
+
output_hidden_states=output_hidden_states,
|
| 1075 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1076 |
+
return_dict=return_dict,
|
| 1077 |
+
)
|
| 1078 |
+
|
| 1079 |
+
pooled_output = vision_outputs[1] # pooled_output
|
| 1080 |
+
image_features = self.visual_projection(pooled_output)
|
| 1081 |
+
|
| 1082 |
+
return image_features
|
| 1083 |
+
|
| 1084 |
+
@add_start_docstrings_to_model_forward(CLIPSEG_INPUTS_DOCSTRING)
|
| 1085 |
+
@replace_return_docstrings(output_type=CLIPSegOutput, config_class=CLIPSegConfig)
|
| 1086 |
+
def forward(
|
| 1087 |
+
self,
|
| 1088 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1089 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1090 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1091 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1092 |
+
return_loss: Optional[bool] = None,
|
| 1093 |
+
output_attentions: Optional[bool] = None,
|
| 1094 |
+
output_hidden_states: Optional[bool] = None,
|
| 1095 |
+
interpolate_pos_encoding: bool = True,
|
| 1096 |
+
return_dict: Optional[bool] = None,
|
| 1097 |
+
) -> Union[Tuple, CLIPSegOutput]:
|
| 1098 |
+
r"""
|
| 1099 |
+
Returns:
|
| 1100 |
+
|
| 1101 |
+
Examples:
|
| 1102 |
+
|
| 1103 |
+
```python
|
| 1104 |
+
>>> from PIL import Image
|
| 1105 |
+
>>> import requests
|
| 1106 |
+
>>> from transformers import AutoProcessor, CLIPSegModel
|
| 1107 |
+
|
| 1108 |
+
>>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 1109 |
+
>>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 1110 |
+
|
| 1111 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1112 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1113 |
+
|
| 1114 |
+
>>> inputs = processor(
|
| 1115 |
+
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
| 1116 |
+
... )
|
| 1117 |
+
|
| 1118 |
+
>>> outputs = model(**inputs)
|
| 1119 |
+
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
| 1120 |
+
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
| 1121 |
+
```"""
|
| 1122 |
+
# Use CLIPSEG model's config for some fields (if specified) instead of those of vision & text components.
|
| 1123 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1124 |
+
output_hidden_states = (
|
| 1125 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1126 |
+
)
|
| 1127 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1128 |
+
|
| 1129 |
+
vision_outputs = self.vision_model(
|
| 1130 |
+
pixel_values=pixel_values,
|
| 1131 |
+
output_attentions=output_attentions,
|
| 1132 |
+
output_hidden_states=output_hidden_states,
|
| 1133 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1134 |
+
return_dict=return_dict,
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
text_outputs = self.text_model(
|
| 1138 |
+
input_ids=input_ids,
|
| 1139 |
+
attention_mask=attention_mask,
|
| 1140 |
+
position_ids=position_ids,
|
| 1141 |
+
output_attentions=output_attentions,
|
| 1142 |
+
output_hidden_states=output_hidden_states,
|
| 1143 |
+
return_dict=return_dict,
|
| 1144 |
+
)
|
| 1145 |
+
|
| 1146 |
+
image_embeds = vision_outputs[1]
|
| 1147 |
+
image_embeds = self.visual_projection(image_embeds)
|
| 1148 |
+
|
| 1149 |
+
text_embeds = text_outputs[1]
|
| 1150 |
+
text_embeds = self.text_projection(text_embeds)
|
| 1151 |
+
|
| 1152 |
+
# normalized features
|
| 1153 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 1154 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 1155 |
+
|
| 1156 |
+
# cosine similarity as logits
|
| 1157 |
+
logit_scale = self.logit_scale.exp()
|
| 1158 |
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
| 1159 |
+
logits_per_image = logits_per_text.t()
|
| 1160 |
+
|
| 1161 |
+
loss = None
|
| 1162 |
+
if return_loss:
|
| 1163 |
+
loss = clipseg_loss(logits_per_text)
|
| 1164 |
+
|
| 1165 |
+
if not return_dict:
|
| 1166 |
+
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
| 1167 |
+
return ((loss,) + output) if loss is not None else output
|
| 1168 |
+
|
| 1169 |
+
return CLIPSegOutput(
|
| 1170 |
+
loss=loss,
|
| 1171 |
+
logits_per_image=logits_per_image,
|
| 1172 |
+
logits_per_text=logits_per_text,
|
| 1173 |
+
text_embeds=text_embeds,
|
| 1174 |
+
image_embeds=image_embeds,
|
| 1175 |
+
text_model_output=text_outputs,
|
| 1176 |
+
vision_model_output=vision_outputs,
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
|
| 1180 |
+
class CLIPSegDecoderLayer(nn.Module):
|
| 1181 |
+
"""
|
| 1182 |
+
CLIPSeg decoder layer, which is identical to `CLIPSegEncoderLayer`, except that normalization is applied after
|
| 1183 |
+
self-attention/MLP, rather than before.
|
| 1184 |
+
"""
|
| 1185 |
+
|
| 1186 |
+
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer.__init__ with AltCLIP->CLIPSeg
|
| 1187 |
+
def __init__(self, config: CLIPSegConfig):
|
| 1188 |
+
super().__init__()
|
| 1189 |
+
self.embed_dim = config.hidden_size
|
| 1190 |
+
self.self_attn = CLIPSegAttention(config)
|
| 1191 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 1192 |
+
self.mlp = CLIPSegMLP(config)
|
| 1193 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 1194 |
+
|
| 1195 |
+
def forward(
|
| 1196 |
+
self,
|
| 1197 |
+
hidden_states: torch.Tensor,
|
| 1198 |
+
attention_mask: torch.Tensor,
|
| 1199 |
+
causal_attention_mask: torch.Tensor,
|
| 1200 |
+
output_attentions: Optional[bool] = False,
|
| 1201 |
+
) -> Tuple[torch.FloatTensor]:
|
| 1202 |
+
"""
|
| 1203 |
+
Args:
|
| 1204 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 1205 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 1206 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 1207 |
+
`(config.encoder_attention_heads,)`.
|
| 1208 |
+
output_attentions (`bool`, *optional*):
|
| 1209 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1210 |
+
returned tensors for more detail.
|
| 1211 |
+
"""
|
| 1212 |
+
residual = hidden_states
|
| 1213 |
+
|
| 1214 |
+
hidden_states, attn_weights = self.self_attn(
|
| 1215 |
+
hidden_states=hidden_states,
|
| 1216 |
+
attention_mask=attention_mask,
|
| 1217 |
+
causal_attention_mask=causal_attention_mask,
|
| 1218 |
+
output_attentions=output_attentions,
|
| 1219 |
+
)
|
| 1220 |
+
|
| 1221 |
+
hidden_states = residual + hidden_states
|
| 1222 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 1223 |
+
|
| 1224 |
+
residual = hidden_states
|
| 1225 |
+
hidden_states = self.mlp(hidden_states)
|
| 1226 |
+
hidden_states = residual + hidden_states
|
| 1227 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 1228 |
+
|
| 1229 |
+
outputs = (hidden_states,)
|
| 1230 |
+
|
| 1231 |
+
if output_attentions:
|
| 1232 |
+
outputs += (attn_weights,)
|
| 1233 |
+
|
| 1234 |
+
return outputs
|
| 1235 |
+
|
| 1236 |
+
|
| 1237 |
+
class CLIPSegDecoder(CLIPSegPreTrainedModel):
|
| 1238 |
+
def __init__(self, config: CLIPSegConfig):
|
| 1239 |
+
super().__init__(config)
|
| 1240 |
+
|
| 1241 |
+
self.conditional_layer = config.conditional_layer
|
| 1242 |
+
|
| 1243 |
+
self.film_mul = nn.Linear(config.projection_dim, config.reduce_dim)
|
| 1244 |
+
self.film_add = nn.Linear(config.projection_dim, config.reduce_dim)
|
| 1245 |
+
|
| 1246 |
+
if config.use_complex_transposed_convolution:
|
| 1247 |
+
transposed_kernels = (config.vision_config.patch_size // 4, config.vision_config.patch_size // 4)
|
| 1248 |
+
|
| 1249 |
+
self.transposed_convolution = nn.Sequential(
|
| 1250 |
+
nn.Conv2d(config.reduce_dim, config.reduce_dim, kernel_size=3, padding=1),
|
| 1251 |
+
nn.ReLU(),
|
| 1252 |
+
nn.ConvTranspose2d(
|
| 1253 |
+
config.reduce_dim,
|
| 1254 |
+
config.reduce_dim // 2,
|
| 1255 |
+
kernel_size=transposed_kernels[0],
|
| 1256 |
+
stride=transposed_kernels[0],
|
| 1257 |
+
),
|
| 1258 |
+
nn.ReLU(),
|
| 1259 |
+
nn.ConvTranspose2d(
|
| 1260 |
+
config.reduce_dim // 2, 1, kernel_size=transposed_kernels[1], stride=transposed_kernels[1]
|
| 1261 |
+
),
|
| 1262 |
+
)
|
| 1263 |
+
else:
|
| 1264 |
+
self.transposed_convolution = nn.ConvTranspose2d(
|
| 1265 |
+
config.reduce_dim, 1, config.vision_config.patch_size, stride=config.vision_config.patch_size
|
| 1266 |
+
)
|
| 1267 |
+
|
| 1268 |
+
depth = len(config.extract_layers)
|
| 1269 |
+
self.reduces = nn.ModuleList(
|
| 1270 |
+
[nn.Linear(config.vision_config.hidden_size, config.reduce_dim) for _ in range(depth)]
|
| 1271 |
+
)
|
| 1272 |
+
|
| 1273 |
+
decoder_config = copy.deepcopy(config.vision_config)
|
| 1274 |
+
decoder_config.hidden_size = config.reduce_dim
|
| 1275 |
+
decoder_config.num_attention_heads = config.decoder_num_attention_heads
|
| 1276 |
+
decoder_config.intermediate_size = config.decoder_intermediate_size
|
| 1277 |
+
decoder_config.hidden_act = "relu"
|
| 1278 |
+
self.layers = nn.ModuleList([CLIPSegDecoderLayer(decoder_config) for _ in range(len(config.extract_layers))])
|
| 1279 |
+
|
| 1280 |
+
def forward(
|
| 1281 |
+
self,
|
| 1282 |
+
hidden_states: Tuple[torch.Tensor],
|
| 1283 |
+
conditional_embeddings: torch.Tensor,
|
| 1284 |
+
output_attentions: Optional[bool] = None,
|
| 1285 |
+
output_hidden_states: Optional[bool] = None,
|
| 1286 |
+
return_dict: Optional[bool] = True,
|
| 1287 |
+
):
|
| 1288 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1289 |
+
all_attentions = () if output_attentions else None
|
| 1290 |
+
|
| 1291 |
+
activations = hidden_states[::-1]
|
| 1292 |
+
|
| 1293 |
+
output = None
|
| 1294 |
+
for i, (activation, layer, reduce) in enumerate(zip(activations, self.layers, self.reduces)):
|
| 1295 |
+
if output is not None:
|
| 1296 |
+
output = reduce(activation) + output
|
| 1297 |
+
else:
|
| 1298 |
+
output = reduce(activation)
|
| 1299 |
+
|
| 1300 |
+
if i == self.conditional_layer:
|
| 1301 |
+
output = self.film_mul(conditional_embeddings) * output.permute(1, 0, 2) + self.film_add(
|
| 1302 |
+
conditional_embeddings
|
| 1303 |
+
)
|
| 1304 |
+
output = output.permute(1, 0, 2)
|
| 1305 |
+
|
| 1306 |
+
layer_outputs = layer(
|
| 1307 |
+
output, attention_mask=None, causal_attention_mask=None, output_attentions=output_attentions
|
| 1308 |
+
)
|
| 1309 |
+
|
| 1310 |
+
output = layer_outputs[0]
|
| 1311 |
+
|
| 1312 |
+
if output_hidden_states:
|
| 1313 |
+
all_hidden_states += (output,)
|
| 1314 |
+
|
| 1315 |
+
if output_attentions:
|
| 1316 |
+
all_attentions += (layer_outputs[1],)
|
| 1317 |
+
|
| 1318 |
+
output = output[:, 1:, :].permute(0, 2, 1) # remove cls token and reshape to [batch_size, reduce_dim, seq_len]
|
| 1319 |
+
|
| 1320 |
+
size = int(math.sqrt(output.shape[2]))
|
| 1321 |
+
|
| 1322 |
+
batch_size = conditional_embeddings.shape[0]
|
| 1323 |
+
output = output.view(batch_size, output.shape[1], size, size)
|
| 1324 |
+
|
| 1325 |
+
logits = self.transposed_convolution(output).squeeze(1)
|
| 1326 |
+
|
| 1327 |
+
if not return_dict:
|
| 1328 |
+
return tuple(v for v in [logits, all_hidden_states, all_attentions] if v is not None)
|
| 1329 |
+
|
| 1330 |
+
return CLIPSegDecoderOutput(
|
| 1331 |
+
logits=logits,
|
| 1332 |
+
hidden_states=all_hidden_states,
|
| 1333 |
+
attentions=all_attentions,
|
| 1334 |
+
)
|
| 1335 |
+
|
| 1336 |
+
|
| 1337 |
+
@add_start_docstrings(
|
| 1338 |
+
"""
|
| 1339 |
+
CLIPSeg model with a Transformer-based decoder on top for zero-shot and one-shot image segmentation.
|
| 1340 |
+
""",
|
| 1341 |
+
CLIPSEG_START_DOCSTRING,
|
| 1342 |
+
)
|
| 1343 |
+
class CLIPSegForImageSegmentation(CLIPSegPreTrainedModel):
|
| 1344 |
+
config_class = CLIPSegConfig
|
| 1345 |
+
|
| 1346 |
+
def __init__(self, config: CLIPSegConfig):
|
| 1347 |
+
super().__init__(config)
|
| 1348 |
+
|
| 1349 |
+
self.config = config
|
| 1350 |
+
|
| 1351 |
+
self.clip = CLIPSegModel(config)
|
| 1352 |
+
self.extract_layers = config.extract_layers
|
| 1353 |
+
|
| 1354 |
+
self.decoder = CLIPSegDecoder(config)
|
| 1355 |
+
|
| 1356 |
+
# Initialize weights and apply final processing
|
| 1357 |
+
self.post_init()
|
| 1358 |
+
|
| 1359 |
+
def get_conditional_embeddings(
|
| 1360 |
+
self,
|
| 1361 |
+
batch_size: Optional[int] = None,
|
| 1362 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1363 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1364 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1365 |
+
conditional_pixel_values: Optional[torch.Tensor] = None,
|
| 1366 |
+
):
|
| 1367 |
+
if input_ids is not None:
|
| 1368 |
+
# compute conditional embeddings from texts
|
| 1369 |
+
if len(input_ids) != batch_size:
|
| 1370 |
+
raise ValueError("Make sure to pass as many prompt texts as there are query images")
|
| 1371 |
+
with torch.no_grad():
|
| 1372 |
+
conditional_embeddings = self.clip.get_text_features(
|
| 1373 |
+
input_ids, attention_mask=attention_mask, position_ids=position_ids
|
| 1374 |
+
)
|
| 1375 |
+
elif conditional_pixel_values is not None:
|
| 1376 |
+
# compute conditional embeddings from images
|
| 1377 |
+
if len(conditional_pixel_values) != batch_size:
|
| 1378 |
+
raise ValueError("Make sure to pass as many prompt images as there are query images")
|
| 1379 |
+
with torch.no_grad():
|
| 1380 |
+
conditional_embeddings = self.clip.get_image_features(conditional_pixel_values)
|
| 1381 |
+
else:
|
| 1382 |
+
raise ValueError(
|
| 1383 |
+
"Invalid conditional, should be either provided as `input_ids` or `conditional_pixel_values`"
|
| 1384 |
+
)
|
| 1385 |
+
|
| 1386 |
+
return conditional_embeddings
|
| 1387 |
+
|
| 1388 |
+
@add_start_docstrings_to_model_forward(CLIPSEG_INPUTS_DOCSTRING)
|
| 1389 |
+
@replace_return_docstrings(output_type=CLIPSegImageSegmentationOutput, config_class=CLIPSegTextConfig)
|
| 1390 |
+
def forward(
|
| 1391 |
+
self,
|
| 1392 |
+
input_ids: Optional[torch.FloatTensor] = None,
|
| 1393 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1394 |
+
conditional_pixel_values: Optional[torch.FloatTensor] = None,
|
| 1395 |
+
conditional_embeddings: Optional[torch.FloatTensor] = None,
|
| 1396 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1397 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1398 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1399 |
+
output_attentions: Optional[bool] = None,
|
| 1400 |
+
output_hidden_states: Optional[bool] = None,
|
| 1401 |
+
interpolate_pos_encoding: bool = True,
|
| 1402 |
+
return_dict: Optional[bool] = None,
|
| 1403 |
+
) -> Union[Tuple, CLIPSegOutput]:
|
| 1404 |
+
r"""
|
| 1405 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1406 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1407 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1408 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1409 |
+
|
| 1410 |
+
Returns:
|
| 1411 |
+
|
| 1412 |
+
Examples:
|
| 1413 |
+
|
| 1414 |
+
```python
|
| 1415 |
+
>>> from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
| 1416 |
+
>>> from PIL import Image
|
| 1417 |
+
>>> import requests
|
| 1418 |
+
|
| 1419 |
+
>>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 1420 |
+
>>> model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 1421 |
+
|
| 1422 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1423 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1424 |
+
>>> texts = ["a cat", "a remote", "a blanket"]
|
| 1425 |
+
>>> inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt")
|
| 1426 |
+
|
| 1427 |
+
>>> outputs = model(**inputs)
|
| 1428 |
+
|
| 1429 |
+
>>> logits = outputs.logits
|
| 1430 |
+
>>> print(logits.shape)
|
| 1431 |
+
torch.Size([3, 352, 352])
|
| 1432 |
+
```"""
|
| 1433 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1434 |
+
|
| 1435 |
+
# step 1: forward the query images through the frozen CLIP vision encoder
|
| 1436 |
+
with torch.no_grad():
|
| 1437 |
+
vision_outputs = self.clip.vision_model(
|
| 1438 |
+
pixel_values=pixel_values,
|
| 1439 |
+
output_attentions=output_attentions,
|
| 1440 |
+
output_hidden_states=True, # we need the intermediate hidden states
|
| 1441 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1442 |
+
return_dict=return_dict,
|
| 1443 |
+
)
|
| 1444 |
+
pooled_output = self.clip.visual_projection(vision_outputs[1])
|
| 1445 |
+
|
| 1446 |
+
hidden_states = vision_outputs.hidden_states if return_dict else vision_outputs[2]
|
| 1447 |
+
# we add +1 here as the hidden states also include the initial embeddings
|
| 1448 |
+
activations = [hidden_states[i + 1] for i in self.extract_layers]
|
| 1449 |
+
|
| 1450 |
+
# update vision_outputs
|
| 1451 |
+
if return_dict:
|
| 1452 |
+
vision_outputs = BaseModelOutputWithPooling(
|
| 1453 |
+
last_hidden_state=vision_outputs.last_hidden_state,
|
| 1454 |
+
pooler_output=vision_outputs.pooler_output,
|
| 1455 |
+
hidden_states=vision_outputs.hidden_states if output_hidden_states else None,
|
| 1456 |
+
attentions=vision_outputs.attentions,
|
| 1457 |
+
)
|
| 1458 |
+
else:
|
| 1459 |
+
vision_outputs = (
|
| 1460 |
+
vision_outputs[:2] + vision_outputs[3:] if not output_hidden_states else vision_outputs
|
| 1461 |
+
)
|
| 1462 |
+
|
| 1463 |
+
# step 2: compute conditional embeddings, either from text, images or an own provided embedding
|
| 1464 |
+
if conditional_embeddings is None:
|
| 1465 |
+
conditional_embeddings = self.get_conditional_embeddings(
|
| 1466 |
+
batch_size=pixel_values.shape[0],
|
| 1467 |
+
input_ids=input_ids,
|
| 1468 |
+
attention_mask=attention_mask,
|
| 1469 |
+
position_ids=position_ids,
|
| 1470 |
+
conditional_pixel_values=conditional_pixel_values,
|
| 1471 |
+
)
|
| 1472 |
+
else:
|
| 1473 |
+
if conditional_embeddings.shape[0] != pixel_values.shape[0]:
|
| 1474 |
+
raise ValueError(
|
| 1475 |
+
"Make sure to pass as many conditional embeddings as there are query images in the batch"
|
| 1476 |
+
)
|
| 1477 |
+
if conditional_embeddings.shape[1] != self.config.projection_dim:
|
| 1478 |
+
raise ValueError(
|
| 1479 |
+
"Make sure that the feature dimension of the conditional embeddings matches"
|
| 1480 |
+
" `config.projection_dim`."
|
| 1481 |
+
)
|
| 1482 |
+
|
| 1483 |
+
# step 3: forward both the pooled output and the activations through the lightweight decoder to predict masks
|
| 1484 |
+
decoder_outputs = self.decoder(
|
| 1485 |
+
activations,
|
| 1486 |
+
conditional_embeddings,
|
| 1487 |
+
output_attentions=output_attentions,
|
| 1488 |
+
output_hidden_states=output_hidden_states,
|
| 1489 |
+
return_dict=return_dict,
|
| 1490 |
+
)
|
| 1491 |
+
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
|
| 1492 |
+
|
| 1493 |
+
loss = None
|
| 1494 |
+
if labels is not None:
|
| 1495 |
+
# move labels to the correct device to enable PP
|
| 1496 |
+
labels = labels.to(logits.device)
|
| 1497 |
+
loss_fn = nn.BCEWithLogitsLoss()
|
| 1498 |
+
loss = loss_fn(logits, labels)
|
| 1499 |
+
|
| 1500 |
+
if not return_dict:
|
| 1501 |
+
output = (logits, conditional_embeddings, pooled_output, vision_outputs, decoder_outputs)
|
| 1502 |
+
return ((loss,) + output) if loss is not None else output
|
| 1503 |
+
|
| 1504 |
+
return CLIPSegImageSegmentationOutput(
|
| 1505 |
+
loss=loss,
|
| 1506 |
+
logits=logits,
|
| 1507 |
+
conditional_embeddings=conditional_embeddings,
|
| 1508 |
+
pooled_output=pooled_output,
|
| 1509 |
+
vision_model_output=vision_outputs,
|
| 1510 |
+
decoder_output=decoder_outputs,
|
| 1511 |
+
)
|
| 1512 |
+
|
| 1513 |
+
|
| 1514 |
+
__all__ = [
|
| 1515 |
+
"CLIPSegModel",
|
| 1516 |
+
"CLIPSegPreTrainedModel",
|
| 1517 |
+
"CLIPSegTextModel",
|
| 1518 |
+
"CLIPSegVisionModel",
|
| 1519 |
+
"CLIPSegForImageSegmentation",
|
| 1520 |
+
]
|
docs/transformers/build/lib/transformers/models/clipseg/processing_clipseg.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Image/Text processor class for CLIPSeg
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import warnings
|
| 20 |
+
|
| 21 |
+
from ...processing_utils import ProcessorMixin
|
| 22 |
+
from ...tokenization_utils_base import BatchEncoding
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class CLIPSegProcessor(ProcessorMixin):
|
| 26 |
+
r"""
|
| 27 |
+
Constructs a CLIPSeg processor which wraps a CLIPSeg image processor and a CLIP tokenizer into a single processor.
|
| 28 |
+
|
| 29 |
+
[`CLIPSegProcessor`] offers all the functionalities of [`ViTImageProcessor`] and [`CLIPTokenizerFast`]. See the
|
| 30 |
+
[`~CLIPSegProcessor.__call__`] and [`~CLIPSegProcessor.decode`] for more information.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
image_processor ([`ViTImageProcessor`], *optional*):
|
| 34 |
+
The image processor is a required input.
|
| 35 |
+
tokenizer ([`CLIPTokenizerFast`], *optional*):
|
| 36 |
+
The tokenizer is a required input.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
attributes = ["image_processor", "tokenizer"]
|
| 40 |
+
image_processor_class = ("ViTImageProcessor", "ViTImageProcessorFast")
|
| 41 |
+
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
|
| 42 |
+
|
| 43 |
+
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
| 44 |
+
feature_extractor = None
|
| 45 |
+
if "feature_extractor" in kwargs:
|
| 46 |
+
warnings.warn(
|
| 47 |
+
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
|
| 48 |
+
" instead.",
|
| 49 |
+
FutureWarning,
|
| 50 |
+
)
|
| 51 |
+
feature_extractor = kwargs.pop("feature_extractor")
|
| 52 |
+
|
| 53 |
+
image_processor = image_processor if image_processor is not None else feature_extractor
|
| 54 |
+
if image_processor is None:
|
| 55 |
+
raise ValueError("You need to specify an `image_processor`.")
|
| 56 |
+
if tokenizer is None:
|
| 57 |
+
raise ValueError("You need to specify a `tokenizer`.")
|
| 58 |
+
|
| 59 |
+
super().__init__(image_processor, tokenizer)
|
| 60 |
+
|
| 61 |
+
def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=None, **kwargs):
|
| 62 |
+
"""
|
| 63 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
| 64 |
+
and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode
|
| 65 |
+
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
| 66 |
+
ViTImageProcessor's [`~ViTImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring of
|
| 67 |
+
the above two methods for more information.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
| 71 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 72 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 73 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 74 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
| 75 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 76 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 77 |
+
visual_prompt (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
| 78 |
+
The visual prompt image or batch of images to be prepared. Each visual prompt image can be a PIL image,
|
| 79 |
+
NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape
|
| 80 |
+
(C, H, W), where C is a number of channels, H and W are image height and width.
|
| 81 |
+
|
| 82 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 83 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 84 |
+
|
| 85 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 86 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 87 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 88 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
|
| 92 |
+
|
| 93 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 94 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 95 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
| 96 |
+
`None`).
|
| 97 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 98 |
+
"""
|
| 99 |
+
if text is None and visual_prompt is None and images is None:
|
| 100 |
+
raise ValueError("You have to specify either text, visual prompt or images.")
|
| 101 |
+
|
| 102 |
+
if text is not None and visual_prompt is not None:
|
| 103 |
+
raise ValueError("You have to specify exactly one type of prompt. Either text or visual prompt.")
|
| 104 |
+
|
| 105 |
+
if text is not None:
|
| 106 |
+
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
|
| 107 |
+
|
| 108 |
+
if visual_prompt is not None:
|
| 109 |
+
prompt_features = self.image_processor(visual_prompt, return_tensors=return_tensors, **kwargs)
|
| 110 |
+
|
| 111 |
+
if images is not None:
|
| 112 |
+
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
|
| 113 |
+
|
| 114 |
+
if visual_prompt is not None and images is not None:
|
| 115 |
+
encoding = {
|
| 116 |
+
"pixel_values": image_features.pixel_values,
|
| 117 |
+
"conditional_pixel_values": prompt_features.pixel_values,
|
| 118 |
+
}
|
| 119 |
+
return encoding
|
| 120 |
+
elif text is not None and images is not None:
|
| 121 |
+
encoding["pixel_values"] = image_features.pixel_values
|
| 122 |
+
return encoding
|
| 123 |
+
elif text is not None:
|
| 124 |
+
return encoding
|
| 125 |
+
elif visual_prompt is not None:
|
| 126 |
+
encoding = {
|
| 127 |
+
"conditional_pixel_values": prompt_features.pixel_values,
|
| 128 |
+
}
|
| 129 |
+
return encoding
|
| 130 |
+
else:
|
| 131 |
+
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
|
| 132 |
+
|
| 133 |
+
def batch_decode(self, *args, **kwargs):
|
| 134 |
+
"""
|
| 135 |
+
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 136 |
+
refer to the docstring of this method for more information.
|
| 137 |
+
"""
|
| 138 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 139 |
+
|
| 140 |
+
def decode(self, *args, **kwargs):
|
| 141 |
+
"""
|
| 142 |
+
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 143 |
+
the docstring of this method for more information.
|
| 144 |
+
"""
|
| 145 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def feature_extractor_class(self):
|
| 149 |
+
warnings.warn(
|
| 150 |
+
"`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
|
| 151 |
+
FutureWarning,
|
| 152 |
+
)
|
| 153 |
+
return self.image_processor_class
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def feature_extractor(self):
|
| 157 |
+
warnings.warn(
|
| 158 |
+
"`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
|
| 159 |
+
FutureWarning,
|
| 160 |
+
)
|
| 161 |
+
return self.image_processor
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
__all__ = ["CLIPSegProcessor"]
|
docs/transformers/build/lib/transformers/models/clvp/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_clvp import *
|
| 22 |
+
from .feature_extraction_clvp import *
|
| 23 |
+
from .modeling_clvp import *
|
| 24 |
+
from .processing_clvp import *
|
| 25 |
+
from .tokenization_clvp import *
|
| 26 |
+
else:
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
_file = globals()["__file__"]
|
| 30 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/clvp/configuration_clvp.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""CLVP model configuration"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from typing import TYPE_CHECKING, Union
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
from ...configuration_utils import PretrainedConfig
|
| 25 |
+
from ...utils import logging
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ClvpEncoderConfig(PretrainedConfig):
|
| 32 |
+
r"""
|
| 33 |
+
This is the configuration class to store the configuration of a [`ClvpEncoder`]. It is used to instantiate a CLVP
|
| 34 |
+
text or CLVP speech encoder according to the specified arguments. Instantiating a configuration with the defaults
|
| 35 |
+
will yield a similar configuration to that of the encoder of the CLVP
|
| 36 |
+
[susnato/clvp_dev](https://huggingface.co/susnato/clvp_dev) architecture.
|
| 37 |
+
|
| 38 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 39 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
vocab_size (`int`, *optional*, defaults to 256):
|
| 43 |
+
Vocabulary size of the CLVP Encoder model.
|
| 44 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 45 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 46 |
+
intermediate_size (`int`, *optional*, defaults to 1536):
|
| 47 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 48 |
+
projection_dim (`int`, *optional*, defaults to 768):
|
| 49 |
+
Dimensionality of the projection vector.
|
| 50 |
+
num_hidden_layers (`int`, *optional*, defaults to 20):
|
| 51 |
+
Number of hidden layers in the Transformer encoder.
|
| 52 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 53 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 54 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 55 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 56 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
| 57 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 58 |
+
The epsilon used by the layer normalization layers.
|
| 59 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
| 60 |
+
The dropout ratio for the attention probabilities.
|
| 61 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 62 |
+
The dropout ratio for the feed-forward layers in [`ClvpEncoderMLP`].
|
| 63 |
+
use_rotary_embedding (`bool`, *optional*, defaults to `True`):
|
| 64 |
+
Whether to use rotary_embedding or not.
|
| 65 |
+
use_attention_bias (`bool`, *optional*, defaults to `False`):
|
| 66 |
+
Whether to use bias in Query, Key and Value layers during self attention.
|
| 67 |
+
summary_type (`str`, *optional*, defaults to `"mean"`):
|
| 68 |
+
What strategy to use to get pooler_output from the last_hidden_state. `"last"`, `"first"`, `"mean"` and
|
| 69 |
+
`"cls_index"` are supported.
|
| 70 |
+
initializer_factor (`float`, *optional*, defaults to 1.0):
|
| 71 |
+
A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization
|
| 72 |
+
testing).
|
| 73 |
+
bos_token_id (`int`, *optional*, defaults to 255):
|
| 74 |
+
Beginning of sequence token id.
|
| 75 |
+
eos_token_id (`int`, *optional*, defaults to 0):
|
| 76 |
+
End of sequence token id.
|
| 77 |
+
|
| 78 |
+
Example:
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
>>> from transformers import ClvpEncoderConfig, ClvpEncoder
|
| 82 |
+
|
| 83 |
+
>>> # Initializing a ClvpEncoderConfig with susnato/clvp_dev style configuration
|
| 84 |
+
>>> encoder_configuration = ClvpEncoderConfig()
|
| 85 |
+
|
| 86 |
+
>>> # Initializing a ClvpEncoder (with random weights) from the susnato/clvp_dev style configuration
|
| 87 |
+
>>> model = ClvpEncoder(encoder_configuration)
|
| 88 |
+
|
| 89 |
+
>>> # Accessing the model configuration
|
| 90 |
+
>>> configuration = model.config
|
| 91 |
+
```"""
|
| 92 |
+
|
| 93 |
+
model_type = "clvp_encoder"
|
| 94 |
+
base_config_key = ["text_config", "speech_config"]
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
vocab_size=256,
|
| 99 |
+
hidden_size=768,
|
| 100 |
+
intermediate_size=1536,
|
| 101 |
+
projection_dim=768,
|
| 102 |
+
num_hidden_layers=20,
|
| 103 |
+
num_attention_heads=12,
|
| 104 |
+
hidden_act="gelu",
|
| 105 |
+
layer_norm_eps=1e-5,
|
| 106 |
+
attention_dropout=0.1,
|
| 107 |
+
dropout=0.1,
|
| 108 |
+
use_rotary_embedding=True,
|
| 109 |
+
use_attention_bias=False,
|
| 110 |
+
summary_type="mean",
|
| 111 |
+
initializer_factor=1.0,
|
| 112 |
+
bos_token_id=255,
|
| 113 |
+
eos_token_id=0,
|
| 114 |
+
**kwargs,
|
| 115 |
+
):
|
| 116 |
+
self.vocab_size = vocab_size
|
| 117 |
+
self.hidden_size = hidden_size
|
| 118 |
+
self.intermediate_size = intermediate_size
|
| 119 |
+
self.projection_dim = projection_dim
|
| 120 |
+
self.num_hidden_layers = num_hidden_layers
|
| 121 |
+
self.num_attention_heads = num_attention_heads
|
| 122 |
+
self.layer_norm_eps = layer_norm_eps
|
| 123 |
+
self.hidden_act = hidden_act
|
| 124 |
+
self.initializer_factor = initializer_factor
|
| 125 |
+
self.attention_dropout = attention_dropout
|
| 126 |
+
self.dropout = dropout
|
| 127 |
+
self.use_rotary_embedding = use_rotary_embedding
|
| 128 |
+
self.use_attention_bias = use_attention_bias
|
| 129 |
+
self.summary_type = summary_type
|
| 130 |
+
self.bos_token_id = bos_token_id
|
| 131 |
+
self.eos_token_id = eos_token_id
|
| 132 |
+
|
| 133 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def from_pretrained(
|
| 137 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], config_type: str = "text_config", **kwargs
|
| 138 |
+
) -> "PretrainedConfig":
|
| 139 |
+
cls._set_token_in_kwargs(kwargs)
|
| 140 |
+
|
| 141 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 142 |
+
|
| 143 |
+
# make sure to have the config_type be either "text_config" or "speech_config"
|
| 144 |
+
# this is to make sure that we can load only text or speech configs from the nested ClvpConfig.
|
| 145 |
+
if config_type not in cls.base_config_key:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
f"We can only load either 'text_config' or 'speech_config' but you are trying to load{config_type}"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# get the text config dict if we are loading from ClvpConfig
|
| 151 |
+
if config_dict.get("model_type") == "clvp":
|
| 152 |
+
config_dict = config_dict[config_type]
|
| 153 |
+
|
| 154 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 155 |
+
logger.warning(
|
| 156 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 157 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class ClvpDecoderConfig(PretrainedConfig):
|
| 164 |
+
r"""
|
| 165 |
+
This is the configuration class to store the configuration of a [`ClvpDecoder`]. It is used to instantiate a CLVP
|
| 166 |
+
Decoder Model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 167 |
+
with the defaults will yield a similar configuration to that of the Decoder part of the CLVP
|
| 168 |
+
[susnato/clvp_dev](https://huggingface.co/susnato/clvp_dev) architecture.
|
| 169 |
+
|
| 170 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 171 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 172 |
+
|
| 173 |
+
The architecture is similar to GPT2.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
vocab_size (`int`, *optional*, defaults to 8194):
|
| 177 |
+
Vocabulary size of the model.
|
| 178 |
+
max_position_embeddings (`int`, *optional*, defaults to 608):
|
| 179 |
+
The maximum sequence length of mel tokens that this model might ever be used with. Similar to `n_positions`
|
| 180 |
+
in `GPT2Config`.
|
| 181 |
+
max_text_tokens (`int`, *optional*, defaults to 404):
|
| 182 |
+
The maximum sequence length of text tokens that this model might ever be used with. Similar to
|
| 183 |
+
`n_positions` in `GPT2Config`.
|
| 184 |
+
hidden_size (`int`, *optional*, defaults to 1024):
|
| 185 |
+
Dimensionality of the embeddings and hidden states.
|
| 186 |
+
num_hidden_layers (`int`, *optional*, defaults to 30):
|
| 187 |
+
Number of hidden layers in the Transformer encoder.
|
| 188 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 189 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 190 |
+
n_inner (`int`, *optional*):
|
| 191 |
+
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times `hidden_size`.
|
| 192 |
+
num_mel_attn_blocks (`int`, *optional*, defaults to 6):
|
| 193 |
+
Denotes the number of self attention layers in [`ClvpConditioningEncoder`].
|
| 194 |
+
activation_function (`str`, *optional*, defaults to `"gelu_new"`):
|
| 195 |
+
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
|
| 196 |
+
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
| 197 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 198 |
+
embd_pdrop (`float`, *optional*, defaults to 0.1):
|
| 199 |
+
The dropout ratio for the embeddings.
|
| 200 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
| 201 |
+
The dropout ratio for the attention.
|
| 202 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
| 203 |
+
The epsilon to use in the layer normalization layers.
|
| 204 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 205 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 206 |
+
summary_type (`string`, *optional*, defaults to `"cls_index"`):
|
| 207 |
+
Argument used when doing sequence summary.
|
| 208 |
+
|
| 209 |
+
Has to be one of the following options:
|
| 210 |
+
|
| 211 |
+
- `"last"`: Take the last token hidden state (like XLNet).
|
| 212 |
+
- `"first"`: Take the first token hidden state (like BERT).
|
| 213 |
+
- `"mean"`: Take the mean of all tokens hidden states.
|
| 214 |
+
- `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
|
| 215 |
+
- `"attn"`: Not implemented now, use multi-head attention.
|
| 216 |
+
summary_use_proj (`bool`, *optional*, defaults to `True`):
|
| 217 |
+
Whether or not to add a projection after the vector extraction.
|
| 218 |
+
summary_activation (`str`, *optional*):
|
| 219 |
+
Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
|
| 220 |
+
summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
|
| 221 |
+
Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
|
| 222 |
+
summary_first_dropout (`float`, *optional*, defaults to 0.1):
|
| 223 |
+
The dropout ratio to be used after the projection and activation.
|
| 224 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 225 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 226 |
+
bos_token_id (`int`, *optional*, defaults to 8192):
|
| 227 |
+
Beginning of sequence token id, used at the start of the generation.
|
| 228 |
+
eos_token_id (`int`, *optional*, defaults to 8193):
|
| 229 |
+
End of sequence token id, used in the method
|
| 230 |
+
[`ClvpModelForConditionalGeneration.fix_speech_decoder_output()`] to correct decoder outputs.
|
| 231 |
+
feature_size (`int`, *optional*, defaults to 80):
|
| 232 |
+
The feature dimension of the extracted mel features. This value is used in [`ClvpConditioningEncoder`].
|
| 233 |
+
use_attention_bias (`bool`, *optional*, defaults to `True`):
|
| 234 |
+
Whether to use bias in Query, Key and Value layers during self attention.
|
| 235 |
+
initializer_factor (`float`, *optional*, defaults to 1.0):
|
| 236 |
+
A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization
|
| 237 |
+
testing).
|
| 238 |
+
decoder_fixing_codes (`list`, *optional*, defaults to `[83, 45, 45, 248]`):
|
| 239 |
+
These values are used in the method `fix_speech_decoder_output` to fix decoder generated outputs.
|
| 240 |
+
|
| 241 |
+
Example:
|
| 242 |
+
|
| 243 |
+
```python
|
| 244 |
+
>>> from transformers import ClvpDecoderConfig, ClvpDecoder
|
| 245 |
+
|
| 246 |
+
>>> # Initializing a ClvpDecoderConfig with susnato/clvp_dev style configuration
|
| 247 |
+
>>> decoder_configuration = ClvpDecoderConfig()
|
| 248 |
+
|
| 249 |
+
>>> # Initializing a ClvpDecoder (with random weights) from the susnato/clvp_dev style configuration
|
| 250 |
+
>>> model = ClvpDecoder(decoder_configuration)
|
| 251 |
+
|
| 252 |
+
>>> # Accessing the model configuration
|
| 253 |
+
>>> configuration = model.config
|
| 254 |
+
```"""
|
| 255 |
+
|
| 256 |
+
model_type = "clvp_decoder"
|
| 257 |
+
base_config_key = "decoder_config"
|
| 258 |
+
|
| 259 |
+
def __init__(
|
| 260 |
+
self,
|
| 261 |
+
vocab_size=8194,
|
| 262 |
+
max_position_embeddings=608,
|
| 263 |
+
max_text_tokens=404,
|
| 264 |
+
hidden_size=1024,
|
| 265 |
+
num_hidden_layers=30,
|
| 266 |
+
num_attention_heads=16,
|
| 267 |
+
n_inner=None,
|
| 268 |
+
num_mel_attn_blocks=6,
|
| 269 |
+
activation_function="gelu_new",
|
| 270 |
+
resid_pdrop=0.1,
|
| 271 |
+
embd_pdrop=0.1,
|
| 272 |
+
attention_dropout=0.1,
|
| 273 |
+
layer_norm_epsilon=1e-5,
|
| 274 |
+
initializer_range=0.02,
|
| 275 |
+
summary_type="cls_index",
|
| 276 |
+
summary_use_proj=True,
|
| 277 |
+
summary_activation=None,
|
| 278 |
+
summary_proj_to_labels=True,
|
| 279 |
+
summary_first_dropout=0.1,
|
| 280 |
+
use_cache=True,
|
| 281 |
+
bos_token_id=8192,
|
| 282 |
+
eos_token_id=8193,
|
| 283 |
+
feature_size=80,
|
| 284 |
+
use_attention_bias=True,
|
| 285 |
+
initializer_factor=1.0,
|
| 286 |
+
decoder_fixing_codes=[83, 45, 45, 248],
|
| 287 |
+
**kwargs,
|
| 288 |
+
):
|
| 289 |
+
self.vocab_size = vocab_size
|
| 290 |
+
self.max_position_embeddings = max_position_embeddings
|
| 291 |
+
self.max_text_tokens = max_text_tokens
|
| 292 |
+
self.hidden_size = hidden_size
|
| 293 |
+
self.num_hidden_layers = num_hidden_layers
|
| 294 |
+
self.num_attention_heads = num_attention_heads
|
| 295 |
+
self.n_inner = n_inner
|
| 296 |
+
self.num_mel_attn_blocks = num_mel_attn_blocks
|
| 297 |
+
self.activation_function = activation_function
|
| 298 |
+
self.resid_pdrop = resid_pdrop
|
| 299 |
+
self.embd_pdrop = embd_pdrop
|
| 300 |
+
self.attention_dropout = attention_dropout
|
| 301 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 302 |
+
self.initializer_range = initializer_range
|
| 303 |
+
self.summary_type = summary_type
|
| 304 |
+
self.summary_use_proj = summary_use_proj
|
| 305 |
+
self.summary_activation = summary_activation
|
| 306 |
+
self.summary_first_dropout = summary_first_dropout
|
| 307 |
+
self.summary_proj_to_labels = summary_proj_to_labels
|
| 308 |
+
self.use_cache = use_cache
|
| 309 |
+
self.feature_size = feature_size
|
| 310 |
+
self.use_attention_bias = use_attention_bias
|
| 311 |
+
self.initializer_factor = initializer_factor
|
| 312 |
+
self.decoder_fixing_codes = decoder_fixing_codes
|
| 313 |
+
|
| 314 |
+
self.bos_token_id = bos_token_id
|
| 315 |
+
self.eos_token_id = eos_token_id
|
| 316 |
+
|
| 317 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class ClvpConfig(PretrainedConfig):
|
| 321 |
+
r"""
|
| 322 |
+
[`ClvpConfig`] is the configuration class to store the configuration of a [`ClvpModelForConditionalGeneration`]. It
|
| 323 |
+
is used to instantiate a CLVP model according to the specified arguments, defining the text model, speech model and
|
| 324 |
+
decoder model configs. Instantiating a configuration with the defaults will yield a similar configuration to that
|
| 325 |
+
of the CLVP [susnato/clvp_dev](https://huggingface.co/susnato/clvp_dev) architecture.
|
| 326 |
+
|
| 327 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 328 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
text_config (`dict`, *optional*):
|
| 332 |
+
Dictionary of configuration options used to initialize the CLVP text encoder.
|
| 333 |
+
speech_config (`dict`, *optional*):
|
| 334 |
+
Dictionary of configuration options used to initialize CLVP speech encoder.
|
| 335 |
+
decoder_config (`dict`, *optional*):
|
| 336 |
+
Dictionary of configuration options used to initialize [`ClvpDecoderConfig`].
|
| 337 |
+
projection_dim (`int`, *optional*, defaults to 768):
|
| 338 |
+
Dimensionality of text and speech projection layers.
|
| 339 |
+
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
| 340 |
+
The initial value of the *logit_scale* parameter. Default is used as per the original CLVP implementation.
|
| 341 |
+
initializer_factor (`float`, *optional*, defaults to 1.0):
|
| 342 |
+
A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization
|
| 343 |
+
testing).
|
| 344 |
+
kwargs (*optional*):
|
| 345 |
+
Dictionary of keyword arguments.
|
| 346 |
+
|
| 347 |
+
Example:
|
| 348 |
+
|
| 349 |
+
```python
|
| 350 |
+
>>> from transformers import ClvpConfig, ClvpModelForConditionalGeneration
|
| 351 |
+
|
| 352 |
+
>>> # Initializing a ClvpConfig with susnato/clvp_dev style configuration
|
| 353 |
+
>>> configuration = ClvpConfig()
|
| 354 |
+
|
| 355 |
+
>>> # Initializing a ClvpModelForConditionalGeneration (with random weights) from the susnato/clvp_dev style configuration
|
| 356 |
+
>>> model = ClvpModelForConditionalGeneration(configuration)
|
| 357 |
+
|
| 358 |
+
>>> # Accessing the model configuration
|
| 359 |
+
>>> configuration = model.config
|
| 360 |
+
|
| 361 |
+
>>> # We can also initialize a CLVPConfig from a CLVPTextConfig, CLVPSpeechConfig and a CLVPAutoRegressiveConfig
|
| 362 |
+
>>> from transformers import ClvpEncoderConfig, ClvpDecoderConfig
|
| 363 |
+
|
| 364 |
+
>>> # Initializing a CLVP text, CLVP speech and CLVP decoder configuration
|
| 365 |
+
>>> config_text = ClvpEncoderConfig()
|
| 366 |
+
>>> config_speech = ClvpEncoderConfig()
|
| 367 |
+
>>> decoder_config = ClvpDecoderConfig()
|
| 368 |
+
|
| 369 |
+
>>> config = ClvpConfig.from_sub_model_configs(config_text, config_speech, decoder_config)
|
| 370 |
+
```"""
|
| 371 |
+
|
| 372 |
+
model_type = "clvp"
|
| 373 |
+
sub_configs = {
|
| 374 |
+
"text_config": ClvpEncoderConfig,
|
| 375 |
+
"speech_config": ClvpEncoderConfig,
|
| 376 |
+
"decoder_config": ClvpDecoderConfig,
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
def __init__(
|
| 380 |
+
self,
|
| 381 |
+
text_config=None,
|
| 382 |
+
speech_config=None,
|
| 383 |
+
decoder_config=None,
|
| 384 |
+
projection_dim=768,
|
| 385 |
+
logit_scale_init_value=2.6592,
|
| 386 |
+
initializer_factor=1.0,
|
| 387 |
+
**kwargs,
|
| 388 |
+
):
|
| 389 |
+
super().__init__(**kwargs)
|
| 390 |
+
|
| 391 |
+
if text_config is None:
|
| 392 |
+
text_config = {}
|
| 393 |
+
logger.info("`text_config` is `None`. Initializing the `ClvpEncoderConfig` with default values.")
|
| 394 |
+
|
| 395 |
+
if speech_config is None:
|
| 396 |
+
speech_config = {}
|
| 397 |
+
logger.info("`speech_config` is `None`. initializing the `ClvpEncoderConfig` with default values.")
|
| 398 |
+
|
| 399 |
+
if decoder_config is None:
|
| 400 |
+
decoder_config = {}
|
| 401 |
+
logger.info("`decoder_config` is `None`. initializing the `ClvpDecoderConfig` with default values.")
|
| 402 |
+
|
| 403 |
+
self.text_config = ClvpEncoderConfig(**text_config)
|
| 404 |
+
self.speech_config = ClvpEncoderConfig(**speech_config)
|
| 405 |
+
self.decoder_config = ClvpDecoderConfig(**decoder_config)
|
| 406 |
+
|
| 407 |
+
self.projection_dim = projection_dim
|
| 408 |
+
self.logit_scale_init_value = logit_scale_init_value
|
| 409 |
+
self.initializer_factor = initializer_factor
|
| 410 |
+
|
| 411 |
+
@classmethod
|
| 412 |
+
def from_sub_model_configs(
|
| 413 |
+
cls,
|
| 414 |
+
text_config: ClvpEncoderConfig,
|
| 415 |
+
speech_config: ClvpEncoderConfig,
|
| 416 |
+
decoder_config: ClvpDecoderConfig,
|
| 417 |
+
**kwargs,
|
| 418 |
+
):
|
| 419 |
+
r"""
|
| 420 |
+
Instantiate a [`ClvpConfig`] (or a derived class) from CLVP text model configuration, CLVP speech model
|
| 421 |
+
configuration and CLVP decoder model configuration.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
text_config (`ClvpEncoderConfig`):
|
| 425 |
+
Text model configuration of type [`ClvpEncoderConfig`].
|
| 426 |
+
speech_config (`ClvpEncoderConfig`):
|
| 427 |
+
Speech model configuration of type [`ClvpEncoderConfig`].
|
| 428 |
+
decoder_config (`ClvpDecoderConfig`):
|
| 429 |
+
Decoder model configuration of type [`ClvpDecoderConfig`].
|
| 430 |
+
|
| 431 |
+
Returns:
|
| 432 |
+
[`ClvpConfig`]: An instance of a configuration object
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
return cls(
|
| 436 |
+
text_config=text_config.to_dict(),
|
| 437 |
+
speech_config=speech_config.to_dict(),
|
| 438 |
+
decoder_config=decoder_config.to_dict(),
|
| 439 |
+
**kwargs,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
__all__ = ["ClvpConfig", "ClvpDecoderConfig", "ClvpEncoderConfig"]
|
docs/transformers/build/lib/transformers/models/clvp/convert_clvp_to_hf.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Weights conversion script for CLVP
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import os
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
|
| 26 |
+
from transformers import ClvpConfig, ClvpModelForConditionalGeneration
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
_MODELS = {
|
| 30 |
+
"clvp": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/clvp2.pth",
|
| 31 |
+
"decoder": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/autoregressive.pth",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
dim = 1024
|
| 35 |
+
sub_dim = dim // 16
|
| 36 |
+
|
| 37 |
+
CLVP_ENCODERS_MAPPING = {
|
| 38 |
+
"text_transformer.transformer.attn_layers": "text_encoder_model",
|
| 39 |
+
"speech_transformer.transformer.attn_layers": "speech_encoder_model",
|
| 40 |
+
"text_transformer.transformer.norm": "text_encoder_model.final_layer_norm",
|
| 41 |
+
"speech_transformer.transformer.norm": "speech_encoder_model.final_layer_norm",
|
| 42 |
+
"to_text_latent": "text_encoder_model.projection",
|
| 43 |
+
"to_speech_latent": "speech_encoder_model.projection",
|
| 44 |
+
"text_emb": "text_encoder_model.token_embedding",
|
| 45 |
+
"speech_emb": "speech_encoder_model.token_embedding",
|
| 46 |
+
"1.wrap.net.0": "mlp.fc1",
|
| 47 |
+
"1.wrap.net.3": "mlp.fc2",
|
| 48 |
+
"1.wrap": "self_attn",
|
| 49 |
+
"to_out": "out_proj",
|
| 50 |
+
"to_q": "q_proj",
|
| 51 |
+
"to_k": "k_proj",
|
| 52 |
+
"to_v": "v_proj",
|
| 53 |
+
"temperature": "logit_scale",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
CLVP_DECODER_MAPPING = {
|
| 57 |
+
"conditioning_encoder.init": "conditioning_encoder.mel_conv",
|
| 58 |
+
"conditioning_encoder.attn": "conditioning_encoder.mel_attn_blocks",
|
| 59 |
+
"mel_attn_blocks": "group_norms",
|
| 60 |
+
".norm.weight": ".weight",
|
| 61 |
+
".norm.bias": ".bias",
|
| 62 |
+
"text_embedding": "conditioning_encoder.text_token_embedding",
|
| 63 |
+
"text_pos_embedding.emb": "conditioning_encoder.text_position_embedding",
|
| 64 |
+
"final_norm": "speech_decoder_model.final_norm",
|
| 65 |
+
"mel_head": "speech_decoder_model.lm_head",
|
| 66 |
+
"gpt.ln_f": "speech_decoder_model.model.decoder.layer_norm",
|
| 67 |
+
"mel_embedding": "speech_decoder_model.model.decoder.input_embeds_layer",
|
| 68 |
+
"mel_pos_embedding.emb": "speech_decoder_model.model.decoder.position_embeds_layer",
|
| 69 |
+
"gpt.h": "speech_decoder_model.model.decoder.layers",
|
| 70 |
+
"ln_1": "input_layernorm",
|
| 71 |
+
"ln_2": "post_attention_layernorm",
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def update_index(present_index):
|
| 76 |
+
if present_index % 2 == 0:
|
| 77 |
+
return int(present_index / 2)
|
| 78 |
+
else:
|
| 79 |
+
return int((present_index - 1) / 2)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def convert_encoder_weights(original_weights):
|
| 83 |
+
converted_weights = {}
|
| 84 |
+
original_weights_keys = sorted(original_weights.keys())
|
| 85 |
+
for original_key in original_weights_keys:
|
| 86 |
+
updated_key = original_key
|
| 87 |
+
# for input_rmsnorm.weight and post_attention_rmsnorm.weight
|
| 88 |
+
if "0.0.g" in updated_key:
|
| 89 |
+
present_index = updated_key.split(".")[4]
|
| 90 |
+
if int(present_index) % 2 == 0:
|
| 91 |
+
updated_key = updated_key.replace("0.0.g", "input_rmsnorm.weight")
|
| 92 |
+
else:
|
| 93 |
+
updated_key = updated_key.replace("0.0.g", "post_attention_rmsnorm.weight")
|
| 94 |
+
|
| 95 |
+
if "transformer.attn_layers.layers" in updated_key:
|
| 96 |
+
present_index = updated_key.split(".")[4]
|
| 97 |
+
updated_index = update_index(int(present_index))
|
| 98 |
+
updated_key = updated_key.replace(
|
| 99 |
+
f"transformer.attn_layers.layers.{present_index}", f"transformer.attn_layers.layers.{updated_index}"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
for k, v in CLVP_ENCODERS_MAPPING.items():
|
| 103 |
+
if k in updated_key:
|
| 104 |
+
updated_key = updated_key.replace(k, v)
|
| 105 |
+
|
| 106 |
+
converted_weights[updated_key] = original_weights.pop(original_key)
|
| 107 |
+
|
| 108 |
+
return converted_weights
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def convert_decoder_weights(original_weights):
|
| 112 |
+
converted_weights = {}
|
| 113 |
+
original_weights_keys = sorted(original_weights.keys())
|
| 114 |
+
for original_key in original_weights_keys:
|
| 115 |
+
updated_key = original_key
|
| 116 |
+
if len(updated_key.split(".")) > 3:
|
| 117 |
+
index, attr = updated_key.split(".")[2], updated_key.split(".")[-1]
|
| 118 |
+
|
| 119 |
+
# for decoder attention
|
| 120 |
+
if "attn.c_attn" in updated_key:
|
| 121 |
+
if attr == "weight":
|
| 122 |
+
slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).T.split(split_size=dim, dim=0)
|
| 123 |
+
else:
|
| 124 |
+
slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
|
| 125 |
+
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.q_proj.{attr}"] = slice1
|
| 126 |
+
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.k_proj.{attr}"] = slice2
|
| 127 |
+
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.v_proj.{attr}"] = slice3
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
if "attn.c_proj" in updated_key:
|
| 131 |
+
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.out_proj.{attr}"] = (
|
| 132 |
+
original_weights[updated_key].squeeze(-1).T
|
| 133 |
+
)
|
| 134 |
+
continue
|
| 135 |
+
|
| 136 |
+
if "attn.bias" in updated_key or "attn.masked_bias" in updated_key or "text_head" in updated_key:
|
| 137 |
+
original_weights.pop(updated_key)
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
# conditional encoder attention
|
| 141 |
+
if "qkv" in updated_key:
|
| 142 |
+
if attr == "weight":
|
| 143 |
+
slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).split(split_size=dim, dim=0)
|
| 144 |
+
else:
|
| 145 |
+
slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
|
| 146 |
+
|
| 147 |
+
indices = torch.arange(dim)
|
| 148 |
+
index1, index2, index3 = (
|
| 149 |
+
indices.unfold(0, sub_dim, sub_dim * 3).flatten(),
|
| 150 |
+
indices[sub_dim:].unfold(0, sub_dim, sub_dim * 3).flatten(),
|
| 151 |
+
indices[2 * sub_dim :].unfold(0, sub_dim, sub_dim * 3).flatten(),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.q_proj.{attr}"] = torch.concatenate(
|
| 155 |
+
[slice1[index1], slice2[index3], slice3[index2]],
|
| 156 |
+
axis=0,
|
| 157 |
+
)
|
| 158 |
+
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.k_proj.{attr}"] = torch.concatenate(
|
| 159 |
+
[slice1[index2], slice2[index1], slice3[index3]],
|
| 160 |
+
axis=0,
|
| 161 |
+
)
|
| 162 |
+
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.v_proj.{attr}"] = torch.concatenate(
|
| 163 |
+
[slice1[index3], slice2[index2], slice3[index1]],
|
| 164 |
+
axis=0,
|
| 165 |
+
)
|
| 166 |
+
continue
|
| 167 |
+
|
| 168 |
+
if "proj_out" in updated_key:
|
| 169 |
+
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.out_proj.{attr}"] = original_weights[
|
| 170 |
+
updated_key
|
| 171 |
+
].squeeze(-1)
|
| 172 |
+
continue
|
| 173 |
+
|
| 174 |
+
for k, v in CLVP_DECODER_MAPPING.items():
|
| 175 |
+
if k in updated_key:
|
| 176 |
+
updated_key = updated_key.replace(k, v)
|
| 177 |
+
|
| 178 |
+
converted_weights[updated_key] = original_weights.pop(original_key)
|
| 179 |
+
|
| 180 |
+
return converted_weights
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _download(url: str, root: str):
|
| 184 |
+
repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
|
| 185 |
+
filename = f"{url.split('/')[-2]}/{url.split('/')[-1]}"
|
| 186 |
+
hf_hub_download(
|
| 187 |
+
repo_id=repo_id,
|
| 188 |
+
filename=filename,
|
| 189 |
+
force_filename=root,
|
| 190 |
+
local_dir_use_symlinks=False,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def convert_clvp_weights(checkpoint_path, pytorch_dump_folder_path):
|
| 195 |
+
converted_checkpoint = {}
|
| 196 |
+
|
| 197 |
+
for each_model_name, each_model_url in _MODELS.items():
|
| 198 |
+
each_model_path = os.path.join(checkpoint_path, each_model_url.split("/")[-1])
|
| 199 |
+
if not os.path.exists(each_model_path):
|
| 200 |
+
print(f"\n{each_model_name} was not found! Downloading it to {each_model_path}")
|
| 201 |
+
_download(url=each_model_url, root=each_model_path)
|
| 202 |
+
|
| 203 |
+
if each_model_name == "clvp":
|
| 204 |
+
clvp_checkpoint = torch.load(each_model_path, map_location="cpu", weights_only=True)
|
| 205 |
+
else:
|
| 206 |
+
decoder_checkpoint = torch.load(each_model_path, map_location="cpu", weights_only=True)
|
| 207 |
+
|
| 208 |
+
# Converting the weights
|
| 209 |
+
converted_checkpoint.update(**convert_encoder_weights(clvp_checkpoint))
|
| 210 |
+
converted_checkpoint.update(**convert_decoder_weights(decoder_checkpoint))
|
| 211 |
+
|
| 212 |
+
config = ClvpConfig.from_pretrained("susnato/clvp_dev")
|
| 213 |
+
model = ClvpModelForConditionalGeneration(config)
|
| 214 |
+
|
| 215 |
+
model.load_state_dict(converted_checkpoint, strict=True)
|
| 216 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 217 |
+
print(f"Model saved at {pytorch_dump_folder_path}!")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
parser = argparse.ArgumentParser()
|
| 222 |
+
# # Required parameters
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--pytorch_dump_folder_path",
|
| 228 |
+
default=None,
|
| 229 |
+
type=str,
|
| 230 |
+
help="Path to the output PyTorch model. (Please enter full path)",
|
| 231 |
+
)
|
| 232 |
+
args = parser.parse_args()
|
| 233 |
+
|
| 234 |
+
convert_clvp_weights(args.checkpoint_path, args.pytorch_dump_folder_path)
|
docs/transformers/build/lib/transformers/models/clvp/feature_extraction_clvp.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Feature extractor class for CLVP
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from typing import List, Optional, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
from ...audio_utils import mel_filter_bank, spectrogram, window_function
|
| 25 |
+
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
| 26 |
+
from ...feature_extraction_utils import BatchFeature
|
| 27 |
+
from ...utils import TensorType, logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ClvpFeatureExtractor(SequenceFeatureExtractor):
|
| 34 |
+
r"""
|
| 35 |
+
Constructs a CLVP feature extractor.
|
| 36 |
+
|
| 37 |
+
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
|
| 38 |
+
most of the main methods. Users should refer to this superclass for more information regarding those methods.
|
| 39 |
+
|
| 40 |
+
This class extracts log-mel-spectrogram features from raw speech using a custom numpy implementation of the `Short
|
| 41 |
+
Time Fourier Transform` which should match pytorch's `torch.stft` equivalent.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
feature_size (`int`, *optional*, defaults to 80):
|
| 45 |
+
The feature dimension of the extracted features.
|
| 46 |
+
sampling_rate (`int`, *optional*, defaults to 22050):
|
| 47 |
+
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
|
| 48 |
+
default_audio_length (`int`, *optional*, defaults to 6):
|
| 49 |
+
The default length of raw audio in seconds. If `max_length` is not set during `__call__` then it will
|
| 50 |
+
automatically be set to default_audio_length * `self.sampling_rate`.
|
| 51 |
+
hop_length (`int`, *optional*, defaults to 256):
|
| 52 |
+
Length of the overlaping windows for the STFT used to obtain the Mel Frequency coefficients.
|
| 53 |
+
chunk_length (`int`, *optional*, defaults to 30):
|
| 54 |
+
The maximum number of chuncks of `sampling_rate` samples used to trim and pad longer or shorter audio
|
| 55 |
+
sequences.
|
| 56 |
+
n_fft (`int`, *optional*, defaults to 1024):
|
| 57 |
+
Size of the Fourier transform.
|
| 58 |
+
padding_value (`float`, *optional*, defaults to 0.0):
|
| 59 |
+
Padding value used to pad the audio. Should correspond to silences.
|
| 60 |
+
mel_norms (`list` of length `feature_size`, *optional*):
|
| 61 |
+
If `mel_norms` is provided then it will be used to normalize the log-mel spectrograms along each
|
| 62 |
+
mel-filter.
|
| 63 |
+
return_attention_mask (`bool`, *optional*, defaults to `False`):
|
| 64 |
+
Whether to return the attention mask. If left to the default, it will return the attention mask.
|
| 65 |
+
|
| 66 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
model_input_names = ["input_features", "attention_mask"]
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
feature_size=80,
|
| 74 |
+
sampling_rate=22050,
|
| 75 |
+
default_audio_length=6,
|
| 76 |
+
hop_length=256,
|
| 77 |
+
chunk_length=30,
|
| 78 |
+
n_fft=1024,
|
| 79 |
+
padding_value=0.0,
|
| 80 |
+
mel_norms=None,
|
| 81 |
+
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
|
| 82 |
+
**kwargs,
|
| 83 |
+
):
|
| 84 |
+
super().__init__(
|
| 85 |
+
feature_size=feature_size,
|
| 86 |
+
sampling_rate=sampling_rate,
|
| 87 |
+
padding_value=padding_value,
|
| 88 |
+
return_attention_mask=return_attention_mask,
|
| 89 |
+
**kwargs,
|
| 90 |
+
)
|
| 91 |
+
self.n_fft = n_fft
|
| 92 |
+
self.hop_length = hop_length
|
| 93 |
+
self.chunk_length = chunk_length
|
| 94 |
+
self.n_samples = chunk_length * sampling_rate
|
| 95 |
+
self.nb_max_frames = self.n_samples // hop_length
|
| 96 |
+
self.sampling_rate = sampling_rate
|
| 97 |
+
self.default_audio_length = default_audio_length
|
| 98 |
+
self.mel_norms = mel_norms
|
| 99 |
+
self.mel_filters = mel_filter_bank(
|
| 100 |
+
num_frequency_bins=1 + (n_fft // 2),
|
| 101 |
+
num_mel_filters=feature_size,
|
| 102 |
+
min_frequency=0.0,
|
| 103 |
+
max_frequency=8000.0,
|
| 104 |
+
sampling_rate=sampling_rate,
|
| 105 |
+
norm="slaney",
|
| 106 |
+
mel_scale="htk",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
|
| 110 |
+
"""
|
| 111 |
+
This method first computes the log-mel spectrogram of the provided audio then applies normalization along the
|
| 112 |
+
each mel-filterbank, if `mel_norms` is provided.
|
| 113 |
+
"""
|
| 114 |
+
log_spec = spectrogram(
|
| 115 |
+
waveform,
|
| 116 |
+
window_function(self.n_fft, "hann"),
|
| 117 |
+
frame_length=self.n_fft,
|
| 118 |
+
hop_length=self.hop_length,
|
| 119 |
+
power=2.0,
|
| 120 |
+
mel_filters=self.mel_filters,
|
| 121 |
+
log_mel=None,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
log_spec = np.log(np.clip(log_spec, a_min=1e-5, a_max=None))
|
| 125 |
+
|
| 126 |
+
if self.mel_norms is not None:
|
| 127 |
+
log_spec = log_spec / np.array(self.mel_norms)[:, None]
|
| 128 |
+
|
| 129 |
+
return log_spec
|
| 130 |
+
|
| 131 |
+
def __call__(
|
| 132 |
+
self,
|
| 133 |
+
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
| 134 |
+
sampling_rate: Optional[int] = None,
|
| 135 |
+
truncation: bool = True,
|
| 136 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 137 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 138 |
+
return_attention_mask: Optional[bool] = True,
|
| 139 |
+
padding: Optional[str] = "max_length",
|
| 140 |
+
max_length: Optional[int] = None,
|
| 141 |
+
**kwargs,
|
| 142 |
+
) -> BatchFeature:
|
| 143 |
+
"""
|
| 144 |
+
`ClvpFeatureExtractor` is used to extract various voice specific properties such as the pitch and tone of the
|
| 145 |
+
voice, speaking speed, and even speaking defects like a lisp or stuttering from a sample voice or `raw_speech`.
|
| 146 |
+
|
| 147 |
+
First the voice is padded or truncated in a way such that it becomes a waveform of `self.default_audio_length`
|
| 148 |
+
seconds long and then the log-mel spectrogram is extracted from it.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
|
| 152 |
+
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
|
| 153 |
+
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
|
| 154 |
+
stereo, i.e. single float per timestep.
|
| 155 |
+
sampling_rate (`int`, *optional*):
|
| 156 |
+
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
|
| 157 |
+
`sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
|
| 158 |
+
pipeline.
|
| 159 |
+
truncation (`bool`, *optional*, default to `True`):
|
| 160 |
+
Activates truncation to cut input sequences longer than *max_length* to *max_length*.
|
| 161 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 162 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 163 |
+
|
| 164 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
| 165 |
+
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
|
| 166 |
+
return_attention_mask (`bool`, *optional*, defaults to `True`):
|
| 167 |
+
Whether to return the attention mask. If left to the default, it will return the attention mask.
|
| 168 |
+
|
| 169 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 170 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 171 |
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
| 172 |
+
|
| 173 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 174 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 175 |
+
- `'np'`: Return Numpy `np.ndarray` objects.
|
| 176 |
+
padding_value (`float`, *optional*, defaults to 0.0):
|
| 177 |
+
The value that is used to fill the padding values / vectors.
|
| 178 |
+
max_length (`int`, *optional*):
|
| 179 |
+
The maximum input length of the inputs.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
if sampling_rate is not None:
|
| 183 |
+
if sampling_rate != self.sampling_rate:
|
| 184 |
+
raise ValueError(
|
| 185 |
+
f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
|
| 186 |
+
f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
|
| 187 |
+
f" was sampled with {self.sampling_rate} and not {sampling_rate}."
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
logger.warning(
|
| 191 |
+
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
|
| 192 |
+
"Failing to do so can result in silent errors that might be hard to debug."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
| 196 |
+
if is_batched_numpy and len(raw_speech.shape) > 2:
|
| 197 |
+
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
|
| 198 |
+
is_batched = is_batched_numpy or (
|
| 199 |
+
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
if is_batched:
|
| 203 |
+
raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
|
| 204 |
+
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
| 205 |
+
raw_speech = np.asarray(raw_speech, dtype=np.float32)
|
| 206 |
+
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
|
| 207 |
+
raw_speech = raw_speech.astype(np.float32)
|
| 208 |
+
|
| 209 |
+
# always return batch
|
| 210 |
+
if not is_batched:
|
| 211 |
+
raw_speech = [np.asarray([raw_speech]).T]
|
| 212 |
+
|
| 213 |
+
batched_speech = BatchFeature({"input_features": raw_speech})
|
| 214 |
+
|
| 215 |
+
max_length = self.default_audio_length * self.sampling_rate if max_length is None else max_length
|
| 216 |
+
|
| 217 |
+
padded_inputs = self.pad(
|
| 218 |
+
batched_speech,
|
| 219 |
+
padding=padding,
|
| 220 |
+
max_length=max_length,
|
| 221 |
+
truncation=truncation,
|
| 222 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 223 |
+
return_attention_mask=return_attention_mask,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# make sure list is in array format
|
| 227 |
+
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
|
| 228 |
+
|
| 229 |
+
input_features = [
|
| 230 |
+
self._np_extract_fbank_features(waveform).astype(np.float32) for waveform in input_features[0]
|
| 231 |
+
]
|
| 232 |
+
|
| 233 |
+
if isinstance(input_features[0], List):
|
| 234 |
+
padded_inputs["input_features"] = [np.asarray(feature) for feature in input_features]
|
| 235 |
+
else:
|
| 236 |
+
padded_inputs["input_features"] = input_features
|
| 237 |
+
|
| 238 |
+
return padded_inputs.convert_to_tensors(return_tensors)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
__all__ = ["ClvpFeatureExtractor"]
|
docs/transformers/build/lib/transformers/models/clvp/modeling_clvp.py
ADDED
|
@@ -0,0 +1,2131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""PyTorch CLVP model."""
|
| 17 |
+
|
| 18 |
+
import copy
|
| 19 |
+
import math
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Callable, Dict, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import CrossEntropyLoss
|
| 27 |
+
|
| 28 |
+
from ...activations import ACT2FN, get_activation
|
| 29 |
+
from ...generation import GenerationConfig, GenerationMixin
|
| 30 |
+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
| 31 |
+
from ...modeling_outputs import (
|
| 32 |
+
BaseModelOutput,
|
| 33 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 34 |
+
BaseModelOutputWithPooling,
|
| 35 |
+
CausalLMOutputWithCrossAttentions,
|
| 36 |
+
)
|
| 37 |
+
from ...modeling_utils import PreTrainedModel
|
| 38 |
+
from ...pytorch_utils import Conv1D, isin_mps_friendly
|
| 39 |
+
from ...utils import (
|
| 40 |
+
ModelOutput,
|
| 41 |
+
add_start_docstrings,
|
| 42 |
+
add_start_docstrings_to_model_forward,
|
| 43 |
+
logging,
|
| 44 |
+
replace_return_docstrings,
|
| 45 |
+
)
|
| 46 |
+
from .configuration_clvp import (
|
| 47 |
+
ClvpConfig,
|
| 48 |
+
ClvpDecoderConfig,
|
| 49 |
+
ClvpEncoderConfig,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
logger = logging.get_logger(__name__)
|
| 54 |
+
|
| 55 |
+
_CHECKPOINT_FOR_DOC = "susnato/clvp_dev"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Copied from transformers.models.clip.modeling_clip.contrastive_loss
|
| 59 |
+
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
| 60 |
+
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->clvp, image_loss->speech_loss
|
| 64 |
+
def clvp_loss(similarity: torch.Tensor) -> torch.Tensor:
|
| 65 |
+
caption_loss = contrastive_loss(similarity)
|
| 66 |
+
speech_loss = contrastive_loss(similarity.t())
|
| 67 |
+
return (caption_loss + speech_loss) / 2.0
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 71 |
+
def rotate_half(x):
|
| 72 |
+
"""Rotates half the hidden dims of the input."""
|
| 73 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 74 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 75 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def apply_rotary_pos_emb(q, k, v, cos, sin, position_ids, unsqueeze_dim=1):
|
| 79 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
q (`torch.Tensor`): The query tensor.
|
| 83 |
+
k (`torch.Tensor`): The key tensor.
|
| 84 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 85 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 86 |
+
position_ids (`torch.Tensor`):
|
| 87 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 88 |
+
used to pass offsetted position ids when working with a KV-cache.
|
| 89 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 90 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 91 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 92 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 93 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 94 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 95 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 96 |
+
Returns:
|
| 97 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 98 |
+
"""
|
| 99 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 100 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
| 101 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 102 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 103 |
+
v_embed = (v * cos) + (rotate_half(v) * sin)
|
| 104 |
+
return q_embed, k_embed, v_embed
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _pad_extra_bos_eos_tokens(
|
| 108 |
+
input_ids,
|
| 109 |
+
attention_mask=None,
|
| 110 |
+
pad_token_id=0,
|
| 111 |
+
bos_token_id=255,
|
| 112 |
+
eos_token_id=0,
|
| 113 |
+
add_bos_token=True,
|
| 114 |
+
add_eos_token=True,
|
| 115 |
+
):
|
| 116 |
+
"""
|
| 117 |
+
This method adds extra bos and eos tokens to input_ids and accordingly modifies the attention_mask which is used in
|
| 118 |
+
`ClvpConditioningEncoder` and the generation loop of the `ClvpModelForConditionalGeneration`.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
# add the bos token at the beginning
|
| 122 |
+
if add_bos_token:
|
| 123 |
+
input_ids = torch.nn.functional.pad(input_ids, (1, 0), value=bos_token_id)
|
| 124 |
+
attention_mask = (
|
| 125 |
+
torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
modified_input_ids = input_ids
|
| 129 |
+
if add_eos_token:
|
| 130 |
+
modified_input_ids = torch.zeros(
|
| 131 |
+
(input_ids.shape[0], input_ids.shape[1] + 1), dtype=input_ids.dtype, device=input_ids.device
|
| 132 |
+
)
|
| 133 |
+
for i, each_input_id in enumerate(input_ids):
|
| 134 |
+
# locate where the valid tokens end and then add the eos token
|
| 135 |
+
if isin_mps_friendly(each_input_id, pad_token_id).sum():
|
| 136 |
+
pos = torch.where(each_input_id == pad_token_id)[0].min()
|
| 137 |
+
modified_input_ids[i] = torch.concatenate(
|
| 138 |
+
[each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]]
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
# if there are no pad tokens present, then add eos to the end
|
| 142 |
+
modified_input_ids[i] = torch.nn.functional.pad(each_input_id, (0, 1), value=eos_token_id)
|
| 143 |
+
attention_mask = (
|
| 144 |
+
torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
return modified_input_ids, attention_mask
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@dataclass
|
| 151 |
+
class ClvpEncoderOutput(ModelOutput):
|
| 152 |
+
"""
|
| 153 |
+
Base class for CLVP encoder's outputs that contains a pooling of the last hidden states as well as a projection
|
| 154 |
+
output (a linear layer on top of the pooled output).
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when model is initialized with `with_projection=True`):
|
| 158 |
+
The embeddings obtained by applying the projection layer to the pooler_output.
|
| 159 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 160 |
+
The hidden state of the last layer of the model.
|
| 161 |
+
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
| 162 |
+
Pooled output of the `last_hidden_state`.
|
| 163 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 164 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 165 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
|
| 166 |
+
the model at the output of each layer plus the optional initial embedding outputs.
|
| 167 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 168 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 169 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 170 |
+
the self-attention heads.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
embeds: Optional[torch.FloatTensor] = None
|
| 174 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 175 |
+
pooler_output: Optional[torch.FloatTensor] = None
|
| 176 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 177 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@dataclass
|
| 181 |
+
class ClvpOutput(ModelOutput):
|
| 182 |
+
"""
|
| 183 |
+
Args:
|
| 184 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
| 185 |
+
Contrastive loss for speech-text similarity.
|
| 186 |
+
speech_ids (`torch.LongTensor`, *optional*):
|
| 187 |
+
speech_ids (or speech candidates) generated by the `ClvpForCausalLM` model.
|
| 188 |
+
logits_per_speech (`torch.FloatTensor` of shape `(speech_batch_size, text_batch_size)`):
|
| 189 |
+
The scaled dot product scores between `speech_embeds` and `text_embeds`. This represents the speech-text
|
| 190 |
+
similarity scores.
|
| 191 |
+
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, speech_batch_size)`):
|
| 192 |
+
The scaled dot product scores between `text_embeds` and `speech_embeds`. This represents the text-speech
|
| 193 |
+
similarity scores.
|
| 194 |
+
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
| 195 |
+
The text embeddings obtained by applying the projection layer to the pooled output of the text encoder
|
| 196 |
+
model.
|
| 197 |
+
speech_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
| 198 |
+
The speech embeddings obtained by applying the projection layer to the pooled output of the speech encoder
|
| 199 |
+
model.
|
| 200 |
+
text_model_output (`BaseModelOutputWithPooling`):
|
| 201 |
+
The pooled output of the `last_hidden_state` of the text encoder Model.
|
| 202 |
+
speech_model_output (`BaseModelOutputWithPooling`):
|
| 203 |
+
The pooled output of the `last_hidden_state` of the speech encoder Model.
|
| 204 |
+
decoder_hidden_states (`torch.FloatTensor`, *optional*):
|
| 205 |
+
The hidden states of the decoder model.
|
| 206 |
+
text_encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
| 207 |
+
The hidden states of the text encoder model.
|
| 208 |
+
speech_encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
| 209 |
+
The hidden states of the speech encoder model.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
loss: Optional[torch.FloatTensor] = None
|
| 213 |
+
speech_ids: Optional[torch.LongTensor] = None
|
| 214 |
+
logits_per_speech: Optional[torch.FloatTensor] = None
|
| 215 |
+
logits_per_text: Optional[torch.FloatTensor] = None
|
| 216 |
+
text_embeds: Optional[torch.FloatTensor] = None
|
| 217 |
+
speech_embeds: Optional[torch.FloatTensor] = None
|
| 218 |
+
text_model_output: BaseModelOutputWithPooling = None
|
| 219 |
+
speech_model_output: BaseModelOutputWithPooling = None
|
| 220 |
+
decoder_hidden_states: Optional[torch.FloatTensor] = None
|
| 221 |
+
text_encoder_hidden_states: Optional[torch.FloatTensor] = None
|
| 222 |
+
speech_encoder_hidden_states: Optional[torch.FloatTensor] = None
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Clvp
|
| 226 |
+
class ClvpRMSNorm(nn.Module):
|
| 227 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 228 |
+
"""
|
| 229 |
+
ClvpRMSNorm is equivalent to T5LayerNorm
|
| 230 |
+
"""
|
| 231 |
+
super().__init__()
|
| 232 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 233 |
+
self.variance_epsilon = eps
|
| 234 |
+
|
| 235 |
+
def forward(self, hidden_states):
|
| 236 |
+
input_dtype = hidden_states.dtype
|
| 237 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 238 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 239 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 240 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 241 |
+
|
| 242 |
+
def extra_repr(self):
|
| 243 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class ClvpRotaryPositionalEmbedding(nn.Module):
|
| 247 |
+
"""
|
| 248 |
+
Rotary Position Embedding Class for CLVP. It was proposed in the paper 'ROFORMER: ENHANCED TRANSFORMER WITH ROTARY
|
| 249 |
+
POSITION EMBEDDING', Please see https://arxiv.org/pdf/2104.09864v1.pdf .
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(self, config):
|
| 253 |
+
super().__init__()
|
| 254 |
+
dim = max(config.projection_dim // (config.num_attention_heads * 2), 32)
|
| 255 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
| 256 |
+
|
| 257 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 258 |
+
self.cached_sequence_length = None
|
| 259 |
+
self.cached_rotary_positional_embedding = None
|
| 260 |
+
|
| 261 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 262 |
+
sequence_length = hidden_states.shape[1]
|
| 263 |
+
|
| 264 |
+
if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
|
| 265 |
+
return self.cached_rotary_positional_embedding
|
| 266 |
+
|
| 267 |
+
self.cached_sequence_length = sequence_length
|
| 268 |
+
time_stamps = torch.arange(sequence_length, device=hidden_states.device).type_as(self.inv_freq)
|
| 269 |
+
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
|
| 270 |
+
embeddings = torch.cat((freqs, freqs), dim=-1)
|
| 271 |
+
|
| 272 |
+
self.cached_rotary_positional_embedding = embeddings.unsqueeze(0)
|
| 273 |
+
return self.cached_rotary_positional_embedding
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class ClvpSelfAttention(nn.Module):
|
| 277 |
+
"""
|
| 278 |
+
Multi-headed attention to combine Absolute and Rotary Positional Embeddings into a single Attention module.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(self, config):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.config = config
|
| 284 |
+
self.embed_dim = config.hidden_size
|
| 285 |
+
self.num_heads = config.num_attention_heads
|
| 286 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 287 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 288 |
+
raise ValueError(
|
| 289 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 290 |
+
f" {self.num_heads})."
|
| 291 |
+
)
|
| 292 |
+
self.scale = self.head_dim**-0.5
|
| 293 |
+
self.dropout = config.attention_dropout
|
| 294 |
+
|
| 295 |
+
if hasattr(config, "max_position_embeddings"):
|
| 296 |
+
max_positions = config.max_position_embeddings
|
| 297 |
+
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
|
| 298 |
+
bias = bias.view(1, 1, max_positions, max_positions)
|
| 299 |
+
self.register_buffer("bias", bias, persistent=False)
|
| 300 |
+
|
| 301 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias)
|
| 302 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias)
|
| 303 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias)
|
| 304 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 305 |
+
|
| 306 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 307 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 308 |
+
|
| 309 |
+
def forward(
|
| 310 |
+
self,
|
| 311 |
+
hidden_states: torch.FloatTensor,
|
| 312 |
+
rotary_pos_emb: Optional[torch.FloatTensor] = None,
|
| 313 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 314 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 315 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 316 |
+
use_cache: Optional[bool] = False,
|
| 317 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 318 |
+
output_attentions: Optional[bool] = False,
|
| 319 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
|
| 320 |
+
# Raise error when position_ids is None but rotary_pos_emb is provided, because we need that when applying
|
| 321 |
+
# rotary_pos_emb to query and key states.
|
| 322 |
+
if rotary_pos_emb is not None and position_ids is None:
|
| 323 |
+
raise ValueError("`position_ids` must be provided when `rotary_pos_emb` is not None.")
|
| 324 |
+
|
| 325 |
+
bsz, _, embed_dim = hidden_states.size()
|
| 326 |
+
|
| 327 |
+
# get query proj
|
| 328 |
+
query_states = self._shape(self.q_proj(hidden_states), -1, bsz) * self.scale
|
| 329 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 330 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 331 |
+
|
| 332 |
+
if past_key_value is not None:
|
| 333 |
+
past_key, past_value = past_key_value
|
| 334 |
+
key_states = torch.cat((past_key, key_states), dim=-2)
|
| 335 |
+
value_states = torch.cat((past_value, value_states), dim=-2)
|
| 336 |
+
|
| 337 |
+
if use_cache is True:
|
| 338 |
+
present = (key_states, value_states)
|
| 339 |
+
else:
|
| 340 |
+
present = None
|
| 341 |
+
|
| 342 |
+
if rotary_pos_emb is not None:
|
| 343 |
+
rotary_emb_dim = rotary_pos_emb.shape[-1]
|
| 344 |
+
|
| 345 |
+
# Partial rotary embedding
|
| 346 |
+
query_rot, query_pass = (
|
| 347 |
+
query_states[..., :rotary_emb_dim],
|
| 348 |
+
query_states[..., rotary_emb_dim:],
|
| 349 |
+
)
|
| 350 |
+
key_rot, key_pass = (
|
| 351 |
+
key_states[..., :rotary_emb_dim],
|
| 352 |
+
key_states[..., rotary_emb_dim:],
|
| 353 |
+
)
|
| 354 |
+
value_rot, value_pass = (
|
| 355 |
+
value_states[..., :rotary_emb_dim],
|
| 356 |
+
value_states[..., rotary_emb_dim:],
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
cos, sin = rotary_pos_emb.cos().squeeze(0), rotary_pos_emb.sin().squeeze(0)
|
| 360 |
+
query_rot, key_rot, value_rot = apply_rotary_pos_emb(query_rot, key_rot, value_rot, cos, sin, position_ids)
|
| 361 |
+
|
| 362 |
+
# [batch_size, num_heads, seq_length, head_dim]
|
| 363 |
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
| 364 |
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
| 365 |
+
value_states = torch.cat((value_rot, value_pass), dim=-1)
|
| 366 |
+
|
| 367 |
+
tgt_len = query_states.shape[2]
|
| 368 |
+
src_len = key_states.shape[2]
|
| 369 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
| 370 |
+
|
| 371 |
+
if attention_mask is not None:
|
| 372 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
| 373 |
+
raise ValueError(
|
| 374 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
| 375 |
+
)
|
| 376 |
+
attn_weights = attn_weights + attention_mask
|
| 377 |
+
|
| 378 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 379 |
+
|
| 380 |
+
# Mask heads if we want to
|
| 381 |
+
if head_mask is not None:
|
| 382 |
+
attn_weights = attn_weights * head_mask
|
| 383 |
+
|
| 384 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 385 |
+
attn_output = torch.matmul(attn_probs, value_states)
|
| 386 |
+
|
| 387 |
+
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
| 388 |
+
raise ValueError(
|
| 389 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
| 390 |
+
f" {attn_output.size()}"
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 394 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 395 |
+
|
| 396 |
+
attn_output = self.out_proj(attn_output)
|
| 397 |
+
|
| 398 |
+
if not output_attentions:
|
| 399 |
+
attn_weights = None
|
| 400 |
+
|
| 401 |
+
return attn_output, present, attn_weights
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class ClvpGatedLinearUnit(nn.Module):
|
| 405 |
+
"""
|
| 406 |
+
`ClvpGatedLinearUnit` uses the second half of the `hidden_states` to act as a gate for the first half of the
|
| 407 |
+
`hidden_states` which controls the flow of data from the first of the tensor.
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
def __init__(self, config):
|
| 411 |
+
super().__init__()
|
| 412 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 413 |
+
self.proj = nn.Linear(config.hidden_size, config.intermediate_size * 2)
|
| 414 |
+
|
| 415 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 416 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
| 417 |
+
return hidden_states * self.activation_fn(gate)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class ClvpEncoderMLP(nn.Module):
|
| 421 |
+
"""
|
| 422 |
+
This MLP is used in CLVP speech or text encoder models.
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
def __init__(self, config):
|
| 426 |
+
super().__init__()
|
| 427 |
+
self.config = config
|
| 428 |
+
|
| 429 |
+
self.fc1 = ClvpGatedLinearUnit(config)
|
| 430 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 431 |
+
self.dropout_layer = nn.Dropout(config.dropout)
|
| 432 |
+
|
| 433 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 434 |
+
hidden_states = self.fc1(hidden_states)
|
| 435 |
+
hidden_states = self.dropout_layer(hidden_states)
|
| 436 |
+
hidden_states = self.fc2(hidden_states)
|
| 437 |
+
return hidden_states
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class ClvpEncoderLayer(nn.Module):
|
| 441 |
+
def __init__(self, config: ClvpConfig):
|
| 442 |
+
super().__init__()
|
| 443 |
+
self.config = config
|
| 444 |
+
self.embed_dim = config.hidden_size
|
| 445 |
+
self.self_attn = ClvpSelfAttention(config)
|
| 446 |
+
self.mlp = ClvpEncoderMLP(config)
|
| 447 |
+
|
| 448 |
+
self.input_rmsnorm = ClvpRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 449 |
+
self.post_attention_rmsnorm = ClvpRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 450 |
+
|
| 451 |
+
def forward(
|
| 452 |
+
self,
|
| 453 |
+
hidden_states: torch.FloatTensor,
|
| 454 |
+
rotary_pos_emb: torch.FloatTensor,
|
| 455 |
+
attention_mask: torch.LongTensor,
|
| 456 |
+
position_ids: torch.LongTensor,
|
| 457 |
+
output_attentions: Optional[bool] = False,
|
| 458 |
+
) -> Tuple[torch.FloatTensor]:
|
| 459 |
+
"""
|
| 460 |
+
Args:
|
| 461 |
+
hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, embed_dim)`):
|
| 462 |
+
input to the layer.
|
| 463 |
+
rotary_pos_emb (`torch.FloatTensor`):
|
| 464 |
+
rotary position embeddings generated by `ClvpRotaryPositionalEmbedding` module.
|
| 465 |
+
attention_mask (`torch.FloatTensor` of shape `(batch, 1, tgt_len, src_len)`):
|
| 466 |
+
attention mask where padding elements are indicated by very large negative values.
|
| 467 |
+
position_ids (`torch.LongTensor`):
|
| 468 |
+
Denotes position ids of the input tokens.
|
| 469 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
| 470 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 471 |
+
returned tensors for more detail.
|
| 472 |
+
"""
|
| 473 |
+
residual = hidden_states
|
| 474 |
+
|
| 475 |
+
hidden_states = self.input_rmsnorm(hidden_states)
|
| 476 |
+
|
| 477 |
+
attention_outputs = self.self_attn(
|
| 478 |
+
hidden_states=hidden_states,
|
| 479 |
+
rotary_pos_emb=rotary_pos_emb,
|
| 480 |
+
attention_mask=attention_mask,
|
| 481 |
+
position_ids=position_ids,
|
| 482 |
+
output_attentions=output_attentions,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
hidden_states = attention_outputs[0]
|
| 486 |
+
|
| 487 |
+
hidden_states = residual + hidden_states
|
| 488 |
+
|
| 489 |
+
residual = hidden_states
|
| 490 |
+
hidden_states = self.post_attention_rmsnorm(hidden_states)
|
| 491 |
+
hidden_states = self.mlp(hidden_states)
|
| 492 |
+
hidden_states = residual + hidden_states
|
| 493 |
+
|
| 494 |
+
outputs = (hidden_states,)
|
| 495 |
+
|
| 496 |
+
if output_attentions:
|
| 497 |
+
outputs += (attention_outputs[-1],)
|
| 498 |
+
|
| 499 |
+
return outputs
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Clvp
|
| 503 |
+
class ClvpSequenceSummary(nn.Module):
|
| 504 |
+
r"""
|
| 505 |
+
Compute a single vector summary of a sequence hidden states.
|
| 506 |
+
|
| 507 |
+
Args:
|
| 508 |
+
config ([`ClvpConfig`]):
|
| 509 |
+
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
| 510 |
+
config class of your model for the default values it uses):
|
| 511 |
+
|
| 512 |
+
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
| 513 |
+
|
| 514 |
+
- `"last"` -- Take the last token hidden state (like XLNet)
|
| 515 |
+
- `"first"` -- Take the first token hidden state (like Bert)
|
| 516 |
+
- `"mean"` -- Take the mean of all tokens hidden states
|
| 517 |
+
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
| 518 |
+
- `"attn"` -- Not implemented now, use multi-head attention
|
| 519 |
+
|
| 520 |
+
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
| 521 |
+
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
| 522 |
+
(otherwise to `config.hidden_size`).
|
| 523 |
+
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
| 524 |
+
another string or `None` will add no activation.
|
| 525 |
+
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
| 526 |
+
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
def __init__(self, config: ClvpConfig):
|
| 530 |
+
super().__init__()
|
| 531 |
+
|
| 532 |
+
self.summary_type = getattr(config, "summary_type", "last")
|
| 533 |
+
if self.summary_type == "attn":
|
| 534 |
+
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
| 535 |
+
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
| 536 |
+
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
| 537 |
+
raise NotImplementedError
|
| 538 |
+
|
| 539 |
+
self.summary = nn.Identity()
|
| 540 |
+
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
| 541 |
+
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
| 542 |
+
num_classes = config.num_labels
|
| 543 |
+
else:
|
| 544 |
+
num_classes = config.hidden_size
|
| 545 |
+
self.summary = nn.Linear(config.hidden_size, num_classes)
|
| 546 |
+
|
| 547 |
+
activation_string = getattr(config, "summary_activation", None)
|
| 548 |
+
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
| 549 |
+
|
| 550 |
+
self.first_dropout = nn.Identity()
|
| 551 |
+
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
| 552 |
+
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
| 553 |
+
|
| 554 |
+
self.last_dropout = nn.Identity()
|
| 555 |
+
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
| 556 |
+
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
| 557 |
+
|
| 558 |
+
def forward(
|
| 559 |
+
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
| 560 |
+
) -> torch.FloatTensor:
|
| 561 |
+
"""
|
| 562 |
+
Compute a single vector summary of a sequence hidden states.
|
| 563 |
+
|
| 564 |
+
Args:
|
| 565 |
+
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
| 566 |
+
The hidden states of the last layer.
|
| 567 |
+
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
| 568 |
+
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
`torch.FloatTensor`: The summary of the sequence hidden states.
|
| 572 |
+
"""
|
| 573 |
+
if self.summary_type == "last":
|
| 574 |
+
output = hidden_states[:, -1]
|
| 575 |
+
elif self.summary_type == "first":
|
| 576 |
+
output = hidden_states[:, 0]
|
| 577 |
+
elif self.summary_type == "mean":
|
| 578 |
+
output = hidden_states.mean(dim=1)
|
| 579 |
+
elif self.summary_type == "cls_index":
|
| 580 |
+
if cls_index is None:
|
| 581 |
+
cls_index = torch.full_like(
|
| 582 |
+
hidden_states[..., :1, :],
|
| 583 |
+
hidden_states.shape[-2] - 1,
|
| 584 |
+
dtype=torch.long,
|
| 585 |
+
)
|
| 586 |
+
else:
|
| 587 |
+
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
| 588 |
+
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
| 589 |
+
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
| 590 |
+
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
| 591 |
+
elif self.summary_type == "attn":
|
| 592 |
+
raise NotImplementedError
|
| 593 |
+
|
| 594 |
+
output = self.first_dropout(output)
|
| 595 |
+
output = self.summary(output)
|
| 596 |
+
output = self.activation(output)
|
| 597 |
+
output = self.last_dropout(output)
|
| 598 |
+
|
| 599 |
+
return output
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->ClvpDecoderMLP
|
| 603 |
+
class ClvpDecoderMLP(nn.Module):
|
| 604 |
+
def __init__(self, intermediate_size, config):
|
| 605 |
+
super().__init__()
|
| 606 |
+
embed_dim = config.hidden_size
|
| 607 |
+
self.c_fc = Conv1D(intermediate_size, embed_dim)
|
| 608 |
+
self.c_proj = Conv1D(embed_dim, intermediate_size)
|
| 609 |
+
self.act = ACT2FN[config.activation_function]
|
| 610 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 611 |
+
|
| 612 |
+
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
| 613 |
+
hidden_states = self.c_fc(hidden_states)
|
| 614 |
+
hidden_states = self.act(hidden_states)
|
| 615 |
+
hidden_states = self.c_proj(hidden_states)
|
| 616 |
+
hidden_states = self.dropout(hidden_states)
|
| 617 |
+
return hidden_states
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
class ClvpDecoderLayer(nn.Module):
|
| 621 |
+
def __init__(self, config):
|
| 622 |
+
super().__init__()
|
| 623 |
+
hidden_size = config.hidden_size
|
| 624 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
| 625 |
+
|
| 626 |
+
self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 627 |
+
self.attn = ClvpSelfAttention(config)
|
| 628 |
+
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 629 |
+
|
| 630 |
+
self.mlp = ClvpDecoderMLP(inner_dim, config)
|
| 631 |
+
|
| 632 |
+
def forward(
|
| 633 |
+
self,
|
| 634 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 635 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 636 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 637 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 638 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 639 |
+
use_cache: Optional[bool] = False,
|
| 640 |
+
output_attentions: Optional[bool] = False,
|
| 641 |
+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
| 642 |
+
residual = hidden_states
|
| 643 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 644 |
+
attn_outputs = self.attn(
|
| 645 |
+
hidden_states,
|
| 646 |
+
past_key_value=past_key_value,
|
| 647 |
+
attention_mask=attention_mask,
|
| 648 |
+
position_ids=position_ids,
|
| 649 |
+
head_mask=head_mask,
|
| 650 |
+
use_cache=use_cache,
|
| 651 |
+
output_attentions=output_attentions,
|
| 652 |
+
)
|
| 653 |
+
attn_output = attn_outputs[0]
|
| 654 |
+
outputs = attn_outputs[1:]
|
| 655 |
+
# residual connection
|
| 656 |
+
hidden_states = attn_output + residual
|
| 657 |
+
|
| 658 |
+
residual = hidden_states
|
| 659 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 660 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
| 661 |
+
# residual connection
|
| 662 |
+
hidden_states = residual + feed_forward_hidden_states
|
| 663 |
+
|
| 664 |
+
if use_cache:
|
| 665 |
+
outputs = (hidden_states,) + outputs
|
| 666 |
+
else:
|
| 667 |
+
outputs = (hidden_states,) + outputs[1:]
|
| 668 |
+
|
| 669 |
+
return outputs
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
class ClvpConditioningEncoder(nn.Module):
|
| 673 |
+
"""
|
| 674 |
+
This class processes the log-mel spectrograms(extracted by the Feature Extractor) and text tokens(produced by the
|
| 675 |
+
tokenizer) as inputs for the decoder model.
|
| 676 |
+
|
| 677 |
+
First each log-mel spectrogram is processed into a single vector which captures valuable characteristics from each
|
| 678 |
+
of them, then the text tokens are converted into token embeddings and position embeddings are added afterwards.
|
| 679 |
+
Both of these vectors are concatenated and then passed to the decoder model.
|
| 680 |
+
|
| 681 |
+
The text tokens helps to incorporate the "text information" and the log-mel spectrogram is used to specify the
|
| 682 |
+
"voice characteristics" into the generated mel tokens.
|
| 683 |
+
"""
|
| 684 |
+
|
| 685 |
+
def __init__(self, config: ClvpConfig):
|
| 686 |
+
super().__init__()
|
| 687 |
+
|
| 688 |
+
self.text_config = config.text_config
|
| 689 |
+
self.decoder_config = config.decoder_config
|
| 690 |
+
|
| 691 |
+
self.text_token_embedding = nn.Embedding(self.text_config.vocab_size, self.decoder_config.hidden_size)
|
| 692 |
+
self.text_position_embedding = nn.Embedding(
|
| 693 |
+
self.decoder_config.max_text_tokens, self.decoder_config.hidden_size
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
self.mel_conv = nn.Conv1d(self.decoder_config.feature_size, self.decoder_config.hidden_size, kernel_size=1)
|
| 697 |
+
|
| 698 |
+
# define group norms to be used before each attention layer
|
| 699 |
+
num_groups = self.compute_groupnorm_groups(self.decoder_config.hidden_size)
|
| 700 |
+
self.group_norms = nn.ModuleList(
|
| 701 |
+
[
|
| 702 |
+
nn.GroupNorm(num_groups, self.decoder_config.hidden_size, eps=1e-5, affine=True)
|
| 703 |
+
for _ in range(self.decoder_config.num_mel_attn_blocks)
|
| 704 |
+
]
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
# define the attention layers
|
| 708 |
+
self.mel_attn_blocks = nn.ModuleList(
|
| 709 |
+
[ClvpSelfAttention(self.decoder_config) for _ in range(self.decoder_config.num_mel_attn_blocks)]
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
self.gradient_checkpointing = False
|
| 713 |
+
|
| 714 |
+
def compute_groupnorm_groups(self, channels: int, groups: int = 32):
|
| 715 |
+
"""
|
| 716 |
+
Calculates the value of `num_groups` for nn.GroupNorm. This logic is taken from the official tortoise
|
| 717 |
+
repository. link :
|
| 718 |
+
https://github.com/neonbjb/tortoise-tts/blob/4003544b6ff4b68c09856e04d3eff9da26d023c2/tortoise/models/arch_util.py#L26
|
| 719 |
+
"""
|
| 720 |
+
if channels <= 16:
|
| 721 |
+
groups = 8
|
| 722 |
+
elif channels <= 64:
|
| 723 |
+
groups = 16
|
| 724 |
+
while channels % groups != 0:
|
| 725 |
+
groups = int(groups / 2)
|
| 726 |
+
|
| 727 |
+
if groups <= 2:
|
| 728 |
+
raise ValueError(
|
| 729 |
+
f"Number of groups for the GroupNorm must be greater than 2, but it is {groups}."
|
| 730 |
+
f"Please consider using a different `hidden_size`"
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
return groups
|
| 734 |
+
|
| 735 |
+
def forward(
|
| 736 |
+
self,
|
| 737 |
+
input_features: torch.FloatTensor,
|
| 738 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 739 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 740 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 741 |
+
):
|
| 742 |
+
# process text
|
| 743 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 744 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 745 |
+
elif input_ids is not None:
|
| 746 |
+
batch_size, seq_length = input_ids.size()
|
| 747 |
+
elif inputs_embeds is not None:
|
| 748 |
+
batch_size, seq_length = inputs_embeds.size()[:-1]
|
| 749 |
+
else:
|
| 750 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 751 |
+
|
| 752 |
+
# construct attention mask if not given
|
| 753 |
+
if attention_mask is None:
|
| 754 |
+
attention_mask = torch.ones([batch_size, seq_length], dtype=torch.long, device=input_ids.device)
|
| 755 |
+
|
| 756 |
+
# We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
|
| 757 |
+
# This logic is specific to ClvpConditioningEncoder and not used by other modules.
|
| 758 |
+
input_ids, attention_mask = _pad_extra_bos_eos_tokens(
|
| 759 |
+
input_ids,
|
| 760 |
+
attention_mask,
|
| 761 |
+
bos_token_id=self.text_config.bos_token_id,
|
| 762 |
+
eos_token_id=self.text_config.eos_token_id,
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
inputs_embeds = self.text_token_embedding(input_ids)
|
| 766 |
+
position_ids = attention_mask.cumsum(-1) - 1
|
| 767 |
+
position_embeds = self.text_position_embedding(position_ids)
|
| 768 |
+
text_embeds = inputs_embeds + position_embeds
|
| 769 |
+
|
| 770 |
+
if self.gradient_checkpointing and self.training:
|
| 771 |
+
# process each log-mel spectrogram into a single vector
|
| 772 |
+
mel_spec = torch.utils.checkpoint.checkpoint(self.mel_conv, input_features)
|
| 773 |
+
|
| 774 |
+
for i, mel_attn_block in enumerate(self.mel_attn_blocks):
|
| 775 |
+
residual_mel_spec = mel_spec.transpose(1, 2)
|
| 776 |
+
|
| 777 |
+
mel_spec = torch.utils.checkpoint.checkpoint(self.group_norms[i], mel_spec).transpose(1, 2)
|
| 778 |
+
mel_spec = torch.utils.checkpoint.checkpoint(mel_attn_block, mel_spec)[0] + residual_mel_spec
|
| 779 |
+
mel_spec = mel_spec.transpose(1, 2)
|
| 780 |
+
|
| 781 |
+
else:
|
| 782 |
+
# process each log-mel spectrogram into a single vector
|
| 783 |
+
mel_spec = self.mel_conv(input_features)
|
| 784 |
+
|
| 785 |
+
for i, mel_attn_block in enumerate(self.mel_attn_blocks):
|
| 786 |
+
residual_mel_spec = mel_spec.transpose(1, 2)
|
| 787 |
+
|
| 788 |
+
mel_spec = self.group_norms[i](mel_spec).transpose(1, 2)
|
| 789 |
+
mel_spec = mel_attn_block(mel_spec)[0] + residual_mel_spec
|
| 790 |
+
mel_spec = mel_spec.transpose(1, 2)
|
| 791 |
+
|
| 792 |
+
mel_spec = mel_spec[:, :, 0]
|
| 793 |
+
mel_spec = mel_spec.unsqueeze(1)
|
| 794 |
+
|
| 795 |
+
# repeat if there is either (1 text vs N audios) or (N texts vs 1 audio)
|
| 796 |
+
if text_embeds.shape[0] == 1 and mel_spec.shape[0] != 1:
|
| 797 |
+
text_embeds = text_embeds.repeat(mel_spec.shape[0], 1, 1)
|
| 798 |
+
elif text_embeds.shape[0] != 1 and mel_spec.shape[0] == 1:
|
| 799 |
+
mel_spec = mel_spec.repeat(text_embeds.shape[0], 1, 1)
|
| 800 |
+
# If there is N texts and M audios we will raise error since the number of text and audio must be same.
|
| 801 |
+
elif text_embeds.shape[0] != mel_spec.shape[0]:
|
| 802 |
+
raise ValueError(
|
| 803 |
+
f"The number of texts and number of audios must be same. "
|
| 804 |
+
f"Found {text_embeds.shape[0]} texts vs {mel_spec.shape[0]} audios"
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
return torch.concat([mel_spec, text_embeds], dim=1)
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
class ClvpPreTrainedModel(PreTrainedModel):
|
| 811 |
+
"""
|
| 812 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 813 |
+
models.
|
| 814 |
+
"""
|
| 815 |
+
|
| 816 |
+
config_class = ClvpConfig
|
| 817 |
+
base_model_prefix = "clvp"
|
| 818 |
+
supports_gradient_checkpointing = True
|
| 819 |
+
_skip_keys_device_placement = "past_key_values"
|
| 820 |
+
|
| 821 |
+
def _init_weights(self, module):
|
| 822 |
+
"""Initialize the weights"""
|
| 823 |
+
factor = self.config.initializer_factor
|
| 824 |
+
if isinstance(module, nn.Embedding):
|
| 825 |
+
module.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
| 826 |
+
elif isinstance(module, (nn.Linear, Conv1D, nn.Conv1d)):
|
| 827 |
+
module.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
| 828 |
+
if module.bias is not None:
|
| 829 |
+
module.bias.data.zero_()
|
| 830 |
+
elif isinstance(module, ClvpEncoderMLP):
|
| 831 |
+
factor = self.config.initializer_factor
|
| 832 |
+
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
| 833 |
+
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
| 834 |
+
nn.init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std)
|
| 835 |
+
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
| 836 |
+
elif isinstance(module, ClvpEncoder):
|
| 837 |
+
config = self.config.get_text_config()
|
| 838 |
+
factor = config.initializer_factor
|
| 839 |
+
module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
|
| 840 |
+
elif isinstance(module, ClvpConditioningEncoder):
|
| 841 |
+
module.mel_conv.weight.data.normal_(mean=0.0, std=factor)
|
| 842 |
+
module.mel_conv.bias.data.zero_()
|
| 843 |
+
elif isinstance(module, ClvpForCausalLM):
|
| 844 |
+
for name, p in module.named_parameters():
|
| 845 |
+
if name == "c_proj.weight":
|
| 846 |
+
p.data.normal_(
|
| 847 |
+
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers))
|
| 848 |
+
)
|
| 849 |
+
if isinstance(module, nn.LayerNorm):
|
| 850 |
+
module.bias.data.zero_()
|
| 851 |
+
module.weight.data.fill_(1.0)
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
CLVP_START_DOCSTRING = r"""
|
| 855 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 856 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 857 |
+
etc.)
|
| 858 |
+
|
| 859 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 860 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 861 |
+
and behavior.
|
| 862 |
+
|
| 863 |
+
Parameters:
|
| 864 |
+
config ([`ClvpConfig`]): Model configuration class with all the parameters of the model.
|
| 865 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 866 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 867 |
+
"""
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
CLVP_INPUTS_DOCSTRING = r"""
|
| 871 |
+
Args:
|
| 872 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 873 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 874 |
+
it.
|
| 875 |
+
|
| 876 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 877 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 878 |
+
|
| 879 |
+
[What are input IDs?](../glossary#input-ids)
|
| 880 |
+
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, time_dim)`):
|
| 881 |
+
Indicates log mel-spectrogram representations for audio returned by [`ClvpFeatureExtractor`].
|
| 882 |
+
conditioning_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
|
| 883 |
+
inputs_embeds for `ClvpConditioningEncoder`. Can be used in place of `input_ids`.
|
| 884 |
+
text_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
|
| 885 |
+
inputs_embeds for the text encoder model passed in place of `input_ids`.
|
| 886 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 887 |
+
Mask to avoid performing attention on padding text token indices. Mask values selected in `[0, 1]`:
|
| 888 |
+
|
| 889 |
+
- 1 for tokens that are **not masked**,
|
| 890 |
+
- 0 for tokens that are **masked**.
|
| 891 |
+
|
| 892 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 893 |
+
return_loss (`bool`, *optional*):
|
| 894 |
+
Whether or not to return the contrastive loss.
|
| 895 |
+
output_attentions (`bool`, *optional*):
|
| 896 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 897 |
+
tensors for more detail.
|
| 898 |
+
output_hidden_states (`bool`, *optional*):
|
| 899 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 900 |
+
more detail.
|
| 901 |
+
return_dict (`bool`, *optional*):
|
| 902 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 903 |
+
"""
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
CLVP_DECODER_INPUTS_DOCSTRING = r"""
|
| 907 |
+
Args:
|
| 908 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
| 909 |
+
Indices of input sequence tokens in the vocabulary.
|
| 910 |
+
|
| 911 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 912 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 913 |
+
|
| 914 |
+
[What are input IDs?](../glossary#input-ids)
|
| 915 |
+
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
|
| 916 |
+
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
| 917 |
+
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
| 918 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
| 919 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 920 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 921 |
+
|
| 922 |
+
- 1 for tokens that are **not masked**,
|
| 923 |
+
- 0 for tokens that are **masked**.
|
| 924 |
+
|
| 925 |
+
If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
|
| 926 |
+
`past_key_values`. In other words, the `attention_mask` always has to have the length:
|
| 927 |
+
`len(past_key_values) + len(input_ids)`
|
| 928 |
+
|
| 929 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 930 |
+
token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
| 931 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 932 |
+
1]`:
|
| 933 |
+
|
| 934 |
+
- 0 corresponds to a *sentence A* token,
|
| 935 |
+
- 1 corresponds to a *sentence B* token.
|
| 936 |
+
|
| 937 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 938 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 939 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 940 |
+
config.max_position_embeddings - 1]`.
|
| 941 |
+
|
| 942 |
+
[What are position IDs?](../glossary#position-ids)
|
| 943 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 944 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 945 |
+
|
| 946 |
+
- 1 indicates the head is **not masked**,
|
| 947 |
+
- 0 indicates the head is **masked**.
|
| 948 |
+
|
| 949 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 950 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 951 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 952 |
+
model's internal embedding lookup matrix.
|
| 953 |
+
|
| 954 |
+
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
| 955 |
+
`past_key_values`).
|
| 956 |
+
use_cache (`bool`, *optional*):
|
| 957 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 958 |
+
`past_key_values`).
|
| 959 |
+
output_attentions (`bool`, *optional*):
|
| 960 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 961 |
+
tensors for more detail.
|
| 962 |
+
output_hidden_states (`bool`, *optional*):
|
| 963 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 964 |
+
more detail.
|
| 965 |
+
return_dict (`bool`, *optional*):
|
| 966 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 967 |
+
"""
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
class ClvpEncoder(ClvpPreTrainedModel):
|
| 971 |
+
"""
|
| 972 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 973 |
+
[`ClvpEncoderLayer`].
|
| 974 |
+
|
| 975 |
+
Args:
|
| 976 |
+
config: ClvpConfig
|
| 977 |
+
"""
|
| 978 |
+
|
| 979 |
+
def __init__(self, config: ClvpConfig):
|
| 980 |
+
super().__init__(config)
|
| 981 |
+
|
| 982 |
+
self.config = config
|
| 983 |
+
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 984 |
+
self.rotary_pos_emb = ClvpRotaryPositionalEmbedding(config) if config.use_rotary_embedding else None
|
| 985 |
+
self.layers = nn.ModuleList([ClvpEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 986 |
+
|
| 987 |
+
self.sequence_summary = ClvpSequenceSummary(config)
|
| 988 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 989 |
+
|
| 990 |
+
self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
|
| 991 |
+
|
| 992 |
+
self.gradient_checkpointing = False
|
| 993 |
+
|
| 994 |
+
self.post_init()
|
| 995 |
+
|
| 996 |
+
def get_input_embeddings(self):
|
| 997 |
+
return self.token_embedding
|
| 998 |
+
|
| 999 |
+
def set_input_embeddings(self, value):
|
| 1000 |
+
self.token_embedding = value
|
| 1001 |
+
|
| 1002 |
+
def forward(
|
| 1003 |
+
self,
|
| 1004 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1005 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
| 1006 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1007 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1008 |
+
output_attentions: Optional[bool] = None,
|
| 1009 |
+
output_hidden_states: Optional[bool] = None,
|
| 1010 |
+
return_dict: Optional[bool] = None,
|
| 1011 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 1012 |
+
r"""
|
| 1013 |
+
Args:
|
| 1014 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
| 1015 |
+
Indices of input sequence tokens in the vocabulary.
|
| 1016 |
+
|
| 1017 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1018 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1019 |
+
|
| 1020 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1021 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1022 |
+
input embeddings for the model. This bypasses the model's internal embedding lookup matrix.
|
| 1023 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1024 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1025 |
+
|
| 1026 |
+
- 1 for tokens that are **not masked**,
|
| 1027 |
+
- 0 for tokens that are **masked**.
|
| 1028 |
+
|
| 1029 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1030 |
+
position_ids (`torch.LongTensor`, *optional*):
|
| 1031 |
+
Denotes the position ids of `input_ids`.
|
| 1032 |
+
output_attentions (`bool`, *optional*):
|
| 1033 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1034 |
+
returned tensors for more detail.
|
| 1035 |
+
output_hidden_states (`bool`, *optional*):
|
| 1036 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 1037 |
+
for more detail.
|
| 1038 |
+
return_dict (`bool`, *optional*):
|
| 1039 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1040 |
+
"""
|
| 1041 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1042 |
+
output_hidden_states = (
|
| 1043 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1044 |
+
)
|
| 1045 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1046 |
+
|
| 1047 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 1048 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 1049 |
+
elif input_ids is not None:
|
| 1050 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 1051 |
+
input_shape = input_ids.size()
|
| 1052 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 1053 |
+
inputs_embeds = self.token_embedding(input_ids)
|
| 1054 |
+
elif inputs_embeds is not None:
|
| 1055 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 1056 |
+
else:
|
| 1057 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 1058 |
+
|
| 1059 |
+
# expand attention_mask and create position_ids if needed
|
| 1060 |
+
if attention_mask is not None:
|
| 1061 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 1062 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
| 1063 |
+
|
| 1064 |
+
if position_ids is None:
|
| 1065 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 1066 |
+
position_ids = torch.arange(input_shape[1], dtype=torch.long, device=device)
|
| 1067 |
+
position_ids = position_ids.unsqueeze(0)
|
| 1068 |
+
|
| 1069 |
+
encoder_states = () if output_hidden_states else None
|
| 1070 |
+
all_attentions = () if output_attentions else None
|
| 1071 |
+
|
| 1072 |
+
rotary_pos_emb = self.rotary_pos_emb(inputs_embeds) if self.rotary_pos_emb is not None else None
|
| 1073 |
+
|
| 1074 |
+
hidden_states = inputs_embeds
|
| 1075 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 1076 |
+
if output_hidden_states:
|
| 1077 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 1078 |
+
if self.gradient_checkpointing and self.training:
|
| 1079 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 1080 |
+
encoder_layer.__call__,
|
| 1081 |
+
hidden_states,
|
| 1082 |
+
rotary_pos_emb,
|
| 1083 |
+
attention_mask,
|
| 1084 |
+
position_ids,
|
| 1085 |
+
)
|
| 1086 |
+
else:
|
| 1087 |
+
layer_outputs = encoder_layer(
|
| 1088 |
+
hidden_states,
|
| 1089 |
+
rotary_pos_emb,
|
| 1090 |
+
attention_mask,
|
| 1091 |
+
position_ids,
|
| 1092 |
+
output_attentions=output_attentions,
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
hidden_states = layer_outputs[0]
|
| 1096 |
+
|
| 1097 |
+
if output_attentions:
|
| 1098 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 1099 |
+
|
| 1100 |
+
if output_hidden_states:
|
| 1101 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 1102 |
+
|
| 1103 |
+
last_hidden_state = hidden_states
|
| 1104 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
| 1105 |
+
|
| 1106 |
+
# take the mean over axis 1 and get pooled output
|
| 1107 |
+
pooled_output = self.sequence_summary(last_hidden_state)
|
| 1108 |
+
|
| 1109 |
+
# apply the projection layer
|
| 1110 |
+
embeds = self.projection(pooled_output)
|
| 1111 |
+
|
| 1112 |
+
if not return_dict:
|
| 1113 |
+
return tuple(
|
| 1114 |
+
v for v in [embeds, last_hidden_state, pooled_output, encoder_states, all_attentions] if v is not None
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
return ClvpEncoderOutput(
|
| 1118 |
+
embeds=embeds,
|
| 1119 |
+
last_hidden_state=last_hidden_state,
|
| 1120 |
+
pooler_output=pooled_output,
|
| 1121 |
+
hidden_states=encoder_states,
|
| 1122 |
+
attentions=all_attentions,
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
class ClvpDecoder(ClvpPreTrainedModel):
|
| 1127 |
+
"""
|
| 1128 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ClvpDecoderLayer`]
|
| 1129 |
+
"""
|
| 1130 |
+
|
| 1131 |
+
def __init__(self, config):
|
| 1132 |
+
super().__init__(config)
|
| 1133 |
+
|
| 1134 |
+
self.config = config
|
| 1135 |
+
|
| 1136 |
+
self.input_embeds_layer = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
|
| 1137 |
+
self.position_embeds_layer = nn.Embedding(self.config.max_position_embeddings, self.config.hidden_size)
|
| 1138 |
+
|
| 1139 |
+
self.drop = nn.Dropout(self.config.embd_pdrop)
|
| 1140 |
+
self.layers = nn.ModuleList([ClvpDecoderLayer(self.config) for _ in range(self.config.num_hidden_layers)])
|
| 1141 |
+
self.layer_norm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_epsilon)
|
| 1142 |
+
|
| 1143 |
+
self.gradient_checkpointing = False
|
| 1144 |
+
|
| 1145 |
+
# Initialize weights and apply final processing
|
| 1146 |
+
self.post_init()
|
| 1147 |
+
|
| 1148 |
+
def get_input_embeddings(self):
|
| 1149 |
+
return self.input_embeds_layer
|
| 1150 |
+
|
| 1151 |
+
def set_input_embeddings(self, new_embeddings):
|
| 1152 |
+
self.input_embeds_layer = new_embeddings
|
| 1153 |
+
|
| 1154 |
+
def _prune_heads(self, heads_to_prune):
|
| 1155 |
+
"""
|
| 1156 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 1157 |
+
"""
|
| 1158 |
+
for layer, heads in heads_to_prune.items():
|
| 1159 |
+
self.layers[layer].attn.prune_heads(heads)
|
| 1160 |
+
|
| 1161 |
+
@add_start_docstrings_to_model_forward(CLVP_DECODER_INPUTS_DOCSTRING)
|
| 1162 |
+
def forward(
|
| 1163 |
+
self,
|
| 1164 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1165 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1166 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1167 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1168 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1169 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 1170 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1171 |
+
use_cache: Optional[bool] = None,
|
| 1172 |
+
output_attentions: Optional[bool] = None,
|
| 1173 |
+
output_hidden_states: Optional[bool] = None,
|
| 1174 |
+
return_dict: Optional[bool] = None,
|
| 1175 |
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
| 1176 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1177 |
+
output_hidden_states = (
|
| 1178 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1179 |
+
)
|
| 1180 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1181 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1182 |
+
|
| 1183 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 1184 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 1185 |
+
elif input_ids is not None:
|
| 1186 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 1187 |
+
input_shape = input_ids.size()
|
| 1188 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 1189 |
+
input_ids.shape[0]
|
| 1190 |
+
elif inputs_embeds is not None:
|
| 1191 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 1192 |
+
inputs_embeds.shape[0]
|
| 1193 |
+
else:
|
| 1194 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 1195 |
+
|
| 1196 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 1197 |
+
|
| 1198 |
+
if token_type_ids is not None:
|
| 1199 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
| 1200 |
+
|
| 1201 |
+
if past_key_values is None:
|
| 1202 |
+
past_key_values_length = 0
|
| 1203 |
+
past_key_values = tuple([None] * len(self.layers))
|
| 1204 |
+
else:
|
| 1205 |
+
past_key_values_length = past_key_values[0][0].size(-2)
|
| 1206 |
+
if position_ids is None:
|
| 1207 |
+
position_ids = torch.arange(
|
| 1208 |
+
past_key_values_length, input_shape[-1] + past_key_values_length, dtype=torch.long, device=device
|
| 1209 |
+
)
|
| 1210 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
| 1211 |
+
|
| 1212 |
+
if inputs_embeds is None:
|
| 1213 |
+
inputs_embeds = self.input_embeds_layer(input_ids)
|
| 1214 |
+
position_embeds = self.position_embeds_layer(position_ids)
|
| 1215 |
+
inputs_embeds = inputs_embeds + position_embeds
|
| 1216 |
+
|
| 1217 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 1218 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 1219 |
+
)
|
| 1220 |
+
|
| 1221 |
+
# Prepare head mask if needed
|
| 1222 |
+
# 1.0 in head_mask indicate we keep the head
|
| 1223 |
+
# attention_probs has shape bsz x num_attention_heads x N x N
|
| 1224 |
+
# head_mask has shape num_hidden_layers x batch x num_attention_heads x N x N
|
| 1225 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 1226 |
+
|
| 1227 |
+
hidden_states = inputs_embeds
|
| 1228 |
+
|
| 1229 |
+
if token_type_ids is not None:
|
| 1230 |
+
token_type_embeds = self.input_embeds_layer(token_type_ids)
|
| 1231 |
+
hidden_states = hidden_states + token_type_embeds
|
| 1232 |
+
|
| 1233 |
+
hidden_states = self.drop(hidden_states)
|
| 1234 |
+
|
| 1235 |
+
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
| 1236 |
+
|
| 1237 |
+
if self.gradient_checkpointing and self.training:
|
| 1238 |
+
if use_cache:
|
| 1239 |
+
logger.warning_once(
|
| 1240 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 1241 |
+
)
|
| 1242 |
+
use_cache = False
|
| 1243 |
+
|
| 1244 |
+
presents = () if use_cache else None
|
| 1245 |
+
all_self_attentions = () if output_attentions else None
|
| 1246 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 1247 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1248 |
+
for i, (block, past_key_value) in enumerate(zip(self.layers, past_key_values)):
|
| 1249 |
+
if output_hidden_states:
|
| 1250 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1251 |
+
|
| 1252 |
+
if self.gradient_checkpointing and self.training:
|
| 1253 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
| 1254 |
+
block.__call__,
|
| 1255 |
+
hidden_states,
|
| 1256 |
+
None,
|
| 1257 |
+
attention_mask,
|
| 1258 |
+
position_ids,
|
| 1259 |
+
head_mask[i],
|
| 1260 |
+
)
|
| 1261 |
+
else:
|
| 1262 |
+
outputs = block(
|
| 1263 |
+
hidden_states,
|
| 1264 |
+
past_key_value=past_key_value,
|
| 1265 |
+
attention_mask=attention_mask,
|
| 1266 |
+
position_ids=position_ids,
|
| 1267 |
+
head_mask=head_mask[i],
|
| 1268 |
+
use_cache=use_cache,
|
| 1269 |
+
output_attentions=output_attentions,
|
| 1270 |
+
)
|
| 1271 |
+
|
| 1272 |
+
hidden_states = outputs[0]
|
| 1273 |
+
if use_cache is True:
|
| 1274 |
+
presents = presents + (outputs[1],)
|
| 1275 |
+
|
| 1276 |
+
if output_attentions:
|
| 1277 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
| 1278 |
+
if self.config.add_cross_attention:
|
| 1279 |
+
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
| 1280 |
+
|
| 1281 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 1282 |
+
|
| 1283 |
+
hidden_states = hidden_states.view(output_shape)
|
| 1284 |
+
|
| 1285 |
+
# Add last hidden state
|
| 1286 |
+
if output_hidden_states:
|
| 1287 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1288 |
+
|
| 1289 |
+
if not return_dict:
|
| 1290 |
+
return tuple(
|
| 1291 |
+
v
|
| 1292 |
+
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
| 1293 |
+
if v is not None
|
| 1294 |
+
)
|
| 1295 |
+
|
| 1296 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 1297 |
+
last_hidden_state=hidden_states,
|
| 1298 |
+
past_key_values=presents,
|
| 1299 |
+
hidden_states=all_hidden_states,
|
| 1300 |
+
attentions=all_self_attentions,
|
| 1301 |
+
cross_attentions=all_cross_attentions,
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
|
| 1305 |
+
@add_start_docstrings(
|
| 1306 |
+
"The bare Clvp decoder model outputting raw hidden-states without any specific head on top.",
|
| 1307 |
+
CLVP_START_DOCSTRING,
|
| 1308 |
+
)
|
| 1309 |
+
class ClvpModel(ClvpPreTrainedModel):
|
| 1310 |
+
def __init__(self, config: ClvpDecoderConfig):
|
| 1311 |
+
super().__init__(config)
|
| 1312 |
+
self.config = config
|
| 1313 |
+
self.decoder = ClvpDecoder(self.config)
|
| 1314 |
+
|
| 1315 |
+
# Initialize weights and apply final processing
|
| 1316 |
+
self.post_init()
|
| 1317 |
+
|
| 1318 |
+
def get_input_embeddings(self):
|
| 1319 |
+
return self.decoder.input_embeds_layer
|
| 1320 |
+
|
| 1321 |
+
def set_input_embeddings(self, value):
|
| 1322 |
+
self.decoder.input_embeds_layer = value
|
| 1323 |
+
|
| 1324 |
+
def get_decoder(self):
|
| 1325 |
+
return self.decoder
|
| 1326 |
+
|
| 1327 |
+
@add_start_docstrings_to_model_forward(CLVP_DECODER_INPUTS_DOCSTRING)
|
| 1328 |
+
def forward(
|
| 1329 |
+
self,
|
| 1330 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1331 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1332 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1333 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1334 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1335 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 1336 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1337 |
+
use_cache: Optional[bool] = None,
|
| 1338 |
+
output_attentions: Optional[bool] = None,
|
| 1339 |
+
output_hidden_states: Optional[bool] = None,
|
| 1340 |
+
return_dict: Optional[bool] = None,
|
| 1341 |
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
| 1342 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1343 |
+
output_hidden_states = (
|
| 1344 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1345 |
+
)
|
| 1346 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1347 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1348 |
+
|
| 1349 |
+
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
| 1350 |
+
decoder_outputs = self.decoder(
|
| 1351 |
+
input_ids=input_ids,
|
| 1352 |
+
attention_mask=attention_mask,
|
| 1353 |
+
token_type_ids=token_type_ids,
|
| 1354 |
+
position_ids=position_ids,
|
| 1355 |
+
head_mask=head_mask,
|
| 1356 |
+
past_key_values=past_key_values,
|
| 1357 |
+
inputs_embeds=inputs_embeds,
|
| 1358 |
+
use_cache=use_cache,
|
| 1359 |
+
output_attentions=output_attentions,
|
| 1360 |
+
output_hidden_states=output_hidden_states,
|
| 1361 |
+
return_dict=return_dict,
|
| 1362 |
+
)
|
| 1363 |
+
|
| 1364 |
+
if not return_dict:
|
| 1365 |
+
return decoder_outputs
|
| 1366 |
+
|
| 1367 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 1368 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
| 1369 |
+
past_key_values=decoder_outputs.past_key_values,
|
| 1370 |
+
hidden_states=decoder_outputs.hidden_states,
|
| 1371 |
+
attentions=decoder_outputs.attentions,
|
| 1372 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 1373 |
+
)
|
| 1374 |
+
|
| 1375 |
+
|
| 1376 |
+
@add_start_docstrings(
|
| 1377 |
+
"The CLVP decoder model with a language modelling head on top.",
|
| 1378 |
+
CLVP_START_DOCSTRING,
|
| 1379 |
+
)
|
| 1380 |
+
class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
|
| 1381 |
+
def __init__(self, config):
|
| 1382 |
+
super().__init__(config)
|
| 1383 |
+
|
| 1384 |
+
self.config = config
|
| 1385 |
+
self.model = ClvpModel(self.config)
|
| 1386 |
+
|
| 1387 |
+
self.final_norm = nn.LayerNorm(self.config.hidden_size)
|
| 1388 |
+
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=True)
|
| 1389 |
+
|
| 1390 |
+
# Initialize weights and apply final processing
|
| 1391 |
+
self.post_init()
|
| 1392 |
+
|
| 1393 |
+
def get_input_embeddings(self):
|
| 1394 |
+
return self.model.decoder.input_embeds_layer
|
| 1395 |
+
|
| 1396 |
+
def set_input_embeddings(self, new_embeddings):
|
| 1397 |
+
self.model.decoder.input_embeds_layer = new_embeddings
|
| 1398 |
+
|
| 1399 |
+
def _prepare_model_inputs(
|
| 1400 |
+
self,
|
| 1401 |
+
inputs: Optional[torch.Tensor] = None,
|
| 1402 |
+
bos_token_id: Optional[int] = None,
|
| 1403 |
+
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 1404 |
+
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
|
| 1405 |
+
"""
|
| 1406 |
+
This function extracts the model-specific `inputs` for generation.
|
| 1407 |
+
"""
|
| 1408 |
+
input_name = self.main_input_name
|
| 1409 |
+
|
| 1410 |
+
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
|
| 1411 |
+
|
| 1412 |
+
inputs_kwarg = model_kwargs.pop(input_name, None)
|
| 1413 |
+
if inputs_kwarg is not None and inputs is not None:
|
| 1414 |
+
raise ValueError(
|
| 1415 |
+
f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed."
|
| 1416 |
+
f"Make sure to either pass {inputs} or {input_name}=..."
|
| 1417 |
+
)
|
| 1418 |
+
elif inputs_kwarg is not None:
|
| 1419 |
+
inputs = inputs_kwarg
|
| 1420 |
+
|
| 1421 |
+
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
|
| 1422 |
+
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
|
| 1423 |
+
inputs, bos_token_id, model_kwargs=model_kwargs
|
| 1424 |
+
)
|
| 1425 |
+
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
| 1426 |
+
|
| 1427 |
+
# Check if conditioning_embeds are provided or not, if yes then concatenate the bos_token_id at the end of the conditioning_embeds.
|
| 1428 |
+
# Then we must subtract the positional_ids because during the forward pass it will be added anyways, so we must cancel them out here.
|
| 1429 |
+
conditioning_embeds = model_kwargs.get("conditioning_embeds", None)
|
| 1430 |
+
|
| 1431 |
+
if conditioning_embeds is not None:
|
| 1432 |
+
mel_start_token_embedding = self.model.decoder.input_embeds_layer(
|
| 1433 |
+
torch.full(
|
| 1434 |
+
(conditioning_embeds.shape[0], 1),
|
| 1435 |
+
fill_value=self.config.bos_token_id,
|
| 1436 |
+
device=conditioning_embeds.device,
|
| 1437 |
+
)
|
| 1438 |
+
)
|
| 1439 |
+
mel_start_token_embedding += self.model.decoder.position_embeds_layer(
|
| 1440 |
+
torch.full((conditioning_embeds.shape[0], 1), fill_value=0, device=conditioning_embeds.device)
|
| 1441 |
+
)
|
| 1442 |
+
conditioning_embeds = torch.concat([conditioning_embeds, mel_start_token_embedding], dim=1)
|
| 1443 |
+
|
| 1444 |
+
# subtract the positional_ids here
|
| 1445 |
+
if hasattr(model_kwargs, "attention_mask"):
|
| 1446 |
+
position_ids = model_kwargs["attention_mask"].long().cumsum(-1) - 1
|
| 1447 |
+
else:
|
| 1448 |
+
position_ids = torch.arange(
|
| 1449 |
+
0, conditioning_embeds.shape[1], dtype=torch.long, device=conditioning_embeds.device
|
| 1450 |
+
)
|
| 1451 |
+
position_ids = position_ids.unsqueeze(0).repeat(conditioning_embeds.shape[0], 1)
|
| 1452 |
+
|
| 1453 |
+
model_kwargs["inputs_embeds"] = conditioning_embeds - self.model.decoder.position_embeds_layer(
|
| 1454 |
+
position_ids
|
| 1455 |
+
)
|
| 1456 |
+
model_kwargs["input_ids"] = (
|
| 1457 |
+
torch.ones((model_kwargs["inputs_embeds"].shape[0], 1), dtype=torch.long, device=self.device)
|
| 1458 |
+
* self.config.bos_token_id
|
| 1459 |
+
)
|
| 1460 |
+
|
| 1461 |
+
return model_kwargs["inputs_embeds"], "inputs_embeds", model_kwargs
|
| 1462 |
+
|
| 1463 |
+
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
|
| 1464 |
+
return inputs, input_name, model_kwargs
|
| 1465 |
+
|
| 1466 |
+
def prepare_inputs_for_generation(
|
| 1467 |
+
self, input_ids, past_key_values=None, inputs_embeds=None, conditioning_embeds=None, **kwargs
|
| 1468 |
+
):
|
| 1469 |
+
# Overwritten: has `conditioning_embeds`-related logic
|
| 1470 |
+
|
| 1471 |
+
input_ids_length = input_ids.shape[-1]
|
| 1472 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
| 1473 |
+
# only last token for inputs_ids if past is defined in kwargs
|
| 1474 |
+
if past_key_values:
|
| 1475 |
+
past_length = past_key_values[0][0].shape[2]
|
| 1476 |
+
|
| 1477 |
+
# Some generation methods already pass only the last input ID
|
| 1478 |
+
if input_ids.shape[1] > past_length:
|
| 1479 |
+
remove_prefix_length = past_length
|
| 1480 |
+
else:
|
| 1481 |
+
# Default to old behavior: keep only final ID
|
| 1482 |
+
remove_prefix_length = input_ids.shape[1] - 1
|
| 1483 |
+
|
| 1484 |
+
input_ids = input_ids[:, remove_prefix_length:]
|
| 1485 |
+
if token_type_ids is not None:
|
| 1486 |
+
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
|
| 1487 |
+
|
| 1488 |
+
attention_mask = kwargs.get("attention_mask", None)
|
| 1489 |
+
position_ids = kwargs.get("position_ids", None)
|
| 1490 |
+
|
| 1491 |
+
if attention_mask is not None and position_ids is None:
|
| 1492 |
+
# create position_ids on the fly for batch generation
|
| 1493 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1494 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1495 |
+
if past_key_values:
|
| 1496 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 1497 |
+
else:
|
| 1498 |
+
position_ids = None
|
| 1499 |
+
|
| 1500 |
+
if conditioning_embeds is not None and past_key_values is not None:
|
| 1501 |
+
position_ids = torch.tensor([input_ids_length], dtype=torch.long, device=input_ids.device)
|
| 1502 |
+
|
| 1503 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1504 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 1505 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 1506 |
+
else:
|
| 1507 |
+
model_inputs = {"input_ids": input_ids}
|
| 1508 |
+
|
| 1509 |
+
model_inputs.update(
|
| 1510 |
+
{
|
| 1511 |
+
"past_key_values": past_key_values,
|
| 1512 |
+
"use_cache": kwargs.get("use_cache"),
|
| 1513 |
+
"position_ids": position_ids,
|
| 1514 |
+
"token_type_ids": token_type_ids,
|
| 1515 |
+
}
|
| 1516 |
+
)
|
| 1517 |
+
return model_inputs
|
| 1518 |
+
|
| 1519 |
+
@add_start_docstrings_to_model_forward(CLVP_DECODER_INPUTS_DOCSTRING)
|
| 1520 |
+
def forward(
|
| 1521 |
+
self,
|
| 1522 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1523 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 1524 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1525 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1526 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1527 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1528 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1529 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1530 |
+
use_cache: Optional[bool] = None,
|
| 1531 |
+
output_attentions: Optional[bool] = None,
|
| 1532 |
+
output_hidden_states: Optional[bool] = None,
|
| 1533 |
+
return_dict: Optional[bool] = None,
|
| 1534 |
+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
| 1535 |
+
r"""
|
| 1536 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1537 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 1538 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 1539 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 1540 |
+
"""
|
| 1541 |
+
|
| 1542 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1543 |
+
output_hidden_states = (
|
| 1544 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1545 |
+
)
|
| 1546 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1547 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1548 |
+
|
| 1549 |
+
outputs = self.model(
|
| 1550 |
+
input_ids=input_ids,
|
| 1551 |
+
past_key_values=past_key_values,
|
| 1552 |
+
attention_mask=attention_mask,
|
| 1553 |
+
token_type_ids=token_type_ids,
|
| 1554 |
+
position_ids=position_ids,
|
| 1555 |
+
head_mask=head_mask,
|
| 1556 |
+
inputs_embeds=inputs_embeds,
|
| 1557 |
+
use_cache=use_cache,
|
| 1558 |
+
output_attentions=output_attentions,
|
| 1559 |
+
output_hidden_states=output_hidden_states,
|
| 1560 |
+
return_dict=return_dict,
|
| 1561 |
+
)
|
| 1562 |
+
|
| 1563 |
+
hidden_states = outputs[0]
|
| 1564 |
+
|
| 1565 |
+
lm_logits = self.final_norm(hidden_states)
|
| 1566 |
+
lm_logits = self.lm_head(lm_logits)
|
| 1567 |
+
|
| 1568 |
+
loss = None
|
| 1569 |
+
if labels is not None:
|
| 1570 |
+
labels = labels.to(lm_logits.device)
|
| 1571 |
+
# Shift so that tokens < n predict n
|
| 1572 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 1573 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 1574 |
+
# Flatten the tokens
|
| 1575 |
+
loss_fct = CrossEntropyLoss()
|
| 1576 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 1577 |
+
|
| 1578 |
+
if not return_dict:
|
| 1579 |
+
output = (lm_logits,) + outputs[1:]
|
| 1580 |
+
return ((loss,) + output) if loss is not None else output
|
| 1581 |
+
|
| 1582 |
+
return CausalLMOutputWithCrossAttentions(
|
| 1583 |
+
loss=loss,
|
| 1584 |
+
logits=lm_logits,
|
| 1585 |
+
past_key_values=outputs.past_key_values,
|
| 1586 |
+
hidden_states=outputs.hidden_states,
|
| 1587 |
+
attentions=outputs.attentions,
|
| 1588 |
+
cross_attentions=outputs.cross_attentions,
|
| 1589 |
+
)
|
| 1590 |
+
|
| 1591 |
+
@staticmethod
|
| 1592 |
+
def _reorder_cache(
|
| 1593 |
+
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
| 1594 |
+
) -> Tuple[Tuple[torch.Tensor]]:
|
| 1595 |
+
"""
|
| 1596 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
| 1597 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
| 1598 |
+
beam_idx at every generation step.
|
| 1599 |
+
"""
|
| 1600 |
+
return tuple(
|
| 1601 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
| 1602 |
+
for layer_past in past_key_values
|
| 1603 |
+
)
|
| 1604 |
+
|
| 1605 |
+
|
| 1606 |
+
@add_start_docstrings(
|
| 1607 |
+
"The composite CLVP model with a text encoder, speech encoder and speech decoder model."
|
| 1608 |
+
"The speech decoder model generates the speech_ids from the text and the text encoder and speech encoder works"
|
| 1609 |
+
"together to filter out the best speech_ids.",
|
| 1610 |
+
CLVP_START_DOCSTRING,
|
| 1611 |
+
)
|
| 1612 |
+
class ClvpModelForConditionalGeneration(ClvpPreTrainedModel, GenerationMixin):
|
| 1613 |
+
config_class = ClvpConfig
|
| 1614 |
+
|
| 1615 |
+
def __init__(self, config: ClvpConfig):
|
| 1616 |
+
super().__init__(config)
|
| 1617 |
+
|
| 1618 |
+
if not isinstance(config.text_config, ClvpEncoderConfig):
|
| 1619 |
+
raise TypeError(
|
| 1620 |
+
"config.text_config is expected to be of type `ClvpEncoderConfig` but is of type"
|
| 1621 |
+
f" {type(config.text_config)}."
|
| 1622 |
+
)
|
| 1623 |
+
|
| 1624 |
+
if not isinstance(config.speech_config, ClvpEncoderConfig):
|
| 1625 |
+
raise TypeError(
|
| 1626 |
+
"config.speech_config is expected to be of type `ClvpEncoderConfig` but is of type"
|
| 1627 |
+
f" {type(config.speech_config)}."
|
| 1628 |
+
)
|
| 1629 |
+
|
| 1630 |
+
if not isinstance(config.decoder_config, ClvpDecoderConfig):
|
| 1631 |
+
raise TypeError(
|
| 1632 |
+
"config.decoder_config is expected to be of type `ClvpDecoderConfig` but is of type"
|
| 1633 |
+
f" {type(config.decoder_config)}."
|
| 1634 |
+
)
|
| 1635 |
+
|
| 1636 |
+
self.conditioning_encoder = ClvpConditioningEncoder(config)
|
| 1637 |
+
|
| 1638 |
+
self.speech_decoder_model = ClvpForCausalLM(config.decoder_config)
|
| 1639 |
+
|
| 1640 |
+
self.text_encoder_model = ClvpEncoder(config.text_config)
|
| 1641 |
+
self.speech_encoder_model = ClvpEncoder(config.speech_config)
|
| 1642 |
+
|
| 1643 |
+
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
| 1644 |
+
|
| 1645 |
+
# Initialize weights and apply final processing
|
| 1646 |
+
self.post_init()
|
| 1647 |
+
|
| 1648 |
+
# taken from the original repo,
|
| 1649 |
+
# link : https://github.com/neonbjb/tortoise-tts/blob/4003544b6ff4b68c09856e04d3eff9da26d023c2/tortoise/api.py#L117
|
| 1650 |
+
def fix_speech_decoder_output(self, speech_ids: torch.LongTensor) -> torch.LongTensor:
|
| 1651 |
+
"""
|
| 1652 |
+
This method modifies the output of the decoder model, such as replacing the `eos_token_id` and changing the
|
| 1653 |
+
last few tokens of each sequence.
|
| 1654 |
+
|
| 1655 |
+
Args:
|
| 1656 |
+
speech_ids (`torch.LongTensor`):
|
| 1657 |
+
This refers to the output of the decoder model.
|
| 1658 |
+
"""
|
| 1659 |
+
decoder_fixing_codes = self.config.decoder_config.decoder_fixing_codes
|
| 1660 |
+
speech_ids = speech_ids[:, 1:]
|
| 1661 |
+
|
| 1662 |
+
stop_token_indices = torch.where(speech_ids == self.speech_decoder_model.config.eos_token_id, 1, 0)
|
| 1663 |
+
speech_ids = torch.masked_fill(speech_ids, mask=stop_token_indices.bool(), value=decoder_fixing_codes[0])
|
| 1664 |
+
|
| 1665 |
+
for i, each_seq_stop_token_index in enumerate(stop_token_indices):
|
| 1666 |
+
# This means that no stop tokens were found so the sentence was still being generated, in that case we don't need
|
| 1667 |
+
# to apply any padding so just skip to the next sequence of tokens.
|
| 1668 |
+
if each_seq_stop_token_index.sum() == 0:
|
| 1669 |
+
continue
|
| 1670 |
+
|
| 1671 |
+
stm = each_seq_stop_token_index.argmax()
|
| 1672 |
+
speech_ids[i, stm:] = decoder_fixing_codes[0]
|
| 1673 |
+
if stm - 3 < speech_ids.shape[1]:
|
| 1674 |
+
speech_ids[i, -3:] = torch.tensor(
|
| 1675 |
+
[decoder_fixing_codes[1:]], device=speech_ids.device, dtype=torch.long
|
| 1676 |
+
)
|
| 1677 |
+
|
| 1678 |
+
return speech_ids
|
| 1679 |
+
|
| 1680 |
+
def get_text_features(
|
| 1681 |
+
self,
|
| 1682 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1683 |
+
text_encoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1684 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1685 |
+
) -> torch.FloatTensor:
|
| 1686 |
+
r"""
|
| 1687 |
+
This method can be used to extract text_embeds from a text. The text embeddings obtained by applying the
|
| 1688 |
+
projection layer to the pooled output of the CLVP text encoder model.
|
| 1689 |
+
|
| 1690 |
+
Args:
|
| 1691 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1692 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
| 1693 |
+
provide it.
|
| 1694 |
+
|
| 1695 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1696 |
+
text_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
|
| 1697 |
+
inputs_embeds for the text encoder model passed in place of `input_ids`.
|
| 1698 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1699 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1700 |
+
|
| 1701 |
+
- 1 for tokens that are **not masked**,
|
| 1702 |
+
- 0 for tokens that are **masked**.
|
| 1703 |
+
|
| 1704 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1705 |
+
|
| 1706 |
+
Returns:
|
| 1707 |
+
`torch.FloatTensor` of shape `(batch_size, output_dim)`:
|
| 1708 |
+
The text embeddings obtained by applying the projection layer to the pooled output of the CLVP Text
|
| 1709 |
+
Model.
|
| 1710 |
+
|
| 1711 |
+
Examples:
|
| 1712 |
+
|
| 1713 |
+
```python
|
| 1714 |
+
>>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration
|
| 1715 |
+
|
| 1716 |
+
>>> # Define the Text
|
| 1717 |
+
>>> text = "This is an example text."
|
| 1718 |
+
|
| 1719 |
+
>>> # Define processor and model
|
| 1720 |
+
>>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
|
| 1721 |
+
>>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
|
| 1722 |
+
|
| 1723 |
+
>>> # Generate processor output and text embeds
|
| 1724 |
+
>>> processor_output = processor(text=text, return_tensors="pt")
|
| 1725 |
+
>>> text_embeds = model.get_text_features(input_ids=processor_output["input_ids"])
|
| 1726 |
+
```
|
| 1727 |
+
"""
|
| 1728 |
+
|
| 1729 |
+
outputs = self.text_encoder_model(
|
| 1730 |
+
input_ids=input_ids,
|
| 1731 |
+
inputs_embeds=text_encoder_inputs_embeds,
|
| 1732 |
+
attention_mask=attention_mask,
|
| 1733 |
+
)
|
| 1734 |
+
|
| 1735 |
+
return outputs[0]
|
| 1736 |
+
|
| 1737 |
+
def get_speech_features(
|
| 1738 |
+
self,
|
| 1739 |
+
speech_ids: Optional[torch.LongTensor] = None,
|
| 1740 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1741 |
+
input_features: Optional[torch.FloatTensor] = None,
|
| 1742 |
+
conditioning_encoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1743 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1744 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1745 |
+
**kwargs,
|
| 1746 |
+
) -> torch.FloatTensor:
|
| 1747 |
+
r"""
|
| 1748 |
+
This method can be used to extract speech_embeds. The speech embeddings are obtained by applying the speech
|
| 1749 |
+
model on speech_ids. If speech_ids is not present but both input_ids and input_features are given then the
|
| 1750 |
+
decoder model will be used to first generate the speech_ids and then applying the speech model.
|
| 1751 |
+
|
| 1752 |
+
Args:
|
| 1753 |
+
speech_ids (`torch.LongTensor` of shape `(batch_size, num_speech_ids)`, *optional*):
|
| 1754 |
+
Speech Tokens. Padding will be ignored by default should you provide it. If speech_ids are provided
|
| 1755 |
+
then input_ids and input_features will be automatically ignored.
|
| 1756 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1757 |
+
Input text Tokens. Processed from the [`ClvpTokenizer`]. If speech_ids is not provided, then input_ids
|
| 1758 |
+
and input_features will be used.
|
| 1759 |
+
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, time_dim)`, *optional*):
|
| 1760 |
+
Indicates log-melspectrogram representations for audio returned by [`ClvpFeatureExtractor`]. If
|
| 1761 |
+
speech_ids is not provided, then input_ids and input_features will be used.
|
| 1762 |
+
conditioning_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
|
| 1763 |
+
inputs_embeds for `ClvpConditioningEncoder`. Can be used in place of `input_ids`.
|
| 1764 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1765 |
+
Mask to avoid performing attention on padding speech token indices. Mask values selected in `[0, 1]`:
|
| 1766 |
+
|
| 1767 |
+
- 1 for tokens that are **not masked**,
|
| 1768 |
+
- 0 for tokens that are **masked**.
|
| 1769 |
+
|
| 1770 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1771 |
+
generation_config (`GenerationConfig`, *optional*):
|
| 1772 |
+
generation config to control the generation of speech_ids if they are not provided.
|
| 1773 |
+
|
| 1774 |
+
Returns:
|
| 1775 |
+
`torch.FloatTensor` of shape `(batch_size, output_dim)`:
|
| 1776 |
+
The speech embeddings obtained by applying the projection layer to the pooled output of the CLVP Speech
|
| 1777 |
+
Model.
|
| 1778 |
+
|
| 1779 |
+
Examples:
|
| 1780 |
+
|
| 1781 |
+
```python
|
| 1782 |
+
>>> import datasets
|
| 1783 |
+
>>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration
|
| 1784 |
+
|
| 1785 |
+
>>> # Define the Text and Load the Audio (We are taking an audio example from HuggingFace Hub using `datasets` library)
|
| 1786 |
+
>>> text = "This is an example text."
|
| 1787 |
+
>>> ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 1788 |
+
>>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
|
| 1789 |
+
>>> _, audio, sr = ds.sort("id").select(range(1))[:1]["audio"][0].values()
|
| 1790 |
+
|
| 1791 |
+
>>> # Define processor and model
|
| 1792 |
+
>>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
|
| 1793 |
+
>>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
|
| 1794 |
+
|
| 1795 |
+
>>> # Generate processor output and model output
|
| 1796 |
+
>>> processor_output = processor(raw_speech=audio, sampling_rate=sr, text=text, return_tensors="pt")
|
| 1797 |
+
>>> speech_embeds = model.get_speech_features(
|
| 1798 |
+
... input_ids=processor_output["input_ids"], input_features=processor_output["input_features"]
|
| 1799 |
+
... )
|
| 1800 |
+
```
|
| 1801 |
+
"""
|
| 1802 |
+
|
| 1803 |
+
if speech_ids is None:
|
| 1804 |
+
if (input_ids is None and conditioning_encoder_inputs_embeds is None) or input_features is None:
|
| 1805 |
+
raise ValueError(
|
| 1806 |
+
"Either speech_ids or input_ids/conditioning_encoder_inputs_embeds and input_features must be provided."
|
| 1807 |
+
)
|
| 1808 |
+
|
| 1809 |
+
if generation_config is None:
|
| 1810 |
+
generation_config = self.generation_config
|
| 1811 |
+
generation_config.update(**kwargs)
|
| 1812 |
+
|
| 1813 |
+
conditioning_embeds = self.conditioning_encoder(
|
| 1814 |
+
input_features=input_features,
|
| 1815 |
+
input_ids=input_ids,
|
| 1816 |
+
inputs_embeds=conditioning_encoder_inputs_embeds,
|
| 1817 |
+
attention_mask=attention_mask,
|
| 1818 |
+
)
|
| 1819 |
+
|
| 1820 |
+
speech_ids = self.speech_decoder_model.generate(
|
| 1821 |
+
conditioning_embeds=conditioning_embeds,
|
| 1822 |
+
generation_config=generation_config,
|
| 1823 |
+
)
|
| 1824 |
+
|
| 1825 |
+
speech_ids = self.fix_speech_decoder_output(speech_ids[0])
|
| 1826 |
+
|
| 1827 |
+
outputs = self.speech_encoder_model(
|
| 1828 |
+
input_ids=speech_ids,
|
| 1829 |
+
attention_mask=attention_mask,
|
| 1830 |
+
)
|
| 1831 |
+
|
| 1832 |
+
return outputs[0]
|
| 1833 |
+
|
| 1834 |
+
@add_start_docstrings_to_model_forward(CLVP_INPUTS_DOCSTRING)
|
| 1835 |
+
@replace_return_docstrings(output_type=ClvpOutput, config_class=ClvpConfig)
|
| 1836 |
+
def forward(
|
| 1837 |
+
self,
|
| 1838 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1839 |
+
input_features: Optional[torch.FloatTensor] = None,
|
| 1840 |
+
conditioning_encoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1841 |
+
text_encoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1842 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1843 |
+
return_loss: Optional[bool] = None,
|
| 1844 |
+
output_hidden_states: Optional[bool] = None,
|
| 1845 |
+
output_attentions: Optional[bool] = False,
|
| 1846 |
+
return_dict: Optional[bool] = None,
|
| 1847 |
+
) -> Union[Tuple, ClvpOutput]:
|
| 1848 |
+
r"""
|
| 1849 |
+
Returns:
|
| 1850 |
+
|
| 1851 |
+
Examples:
|
| 1852 |
+
|
| 1853 |
+
```python
|
| 1854 |
+
>>> import datasets
|
| 1855 |
+
>>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration
|
| 1856 |
+
|
| 1857 |
+
>>> # Define the Text and Load the Audio (We are taking an audio example from HuggingFace Hub using `datasets` library)
|
| 1858 |
+
>>> text = "This is an example text."
|
| 1859 |
+
|
| 1860 |
+
>>> ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 1861 |
+
>>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
|
| 1862 |
+
>>> _, audio, sr = ds.sort("id").select(range(1))[:1]["audio"][0].values()
|
| 1863 |
+
|
| 1864 |
+
>>> # Define processor and model
|
| 1865 |
+
>>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
|
| 1866 |
+
>>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
|
| 1867 |
+
|
| 1868 |
+
>>> # processor outputs and model outputs
|
| 1869 |
+
>>> processor_output = processor(raw_speech=audio, sampling_rate=sr, text=text, return_tensors="pt")
|
| 1870 |
+
>>> outputs = model(
|
| 1871 |
+
... input_ids=processor_output["input_ids"],
|
| 1872 |
+
... input_features=processor_output["input_features"],
|
| 1873 |
+
... return_dict=True,
|
| 1874 |
+
... )
|
| 1875 |
+
```
|
| 1876 |
+
"""
|
| 1877 |
+
|
| 1878 |
+
# Use CLVP model's config for some fields (if specified) instead of those of speech & text components.
|
| 1879 |
+
output_hidden_states = (
|
| 1880 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1881 |
+
)
|
| 1882 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1883 |
+
|
| 1884 |
+
conditioning_embeds = self.conditioning_encoder(
|
| 1885 |
+
input_features=input_features,
|
| 1886 |
+
input_ids=input_ids,
|
| 1887 |
+
inputs_embeds=conditioning_encoder_inputs_embeds,
|
| 1888 |
+
attention_mask=attention_mask,
|
| 1889 |
+
)
|
| 1890 |
+
|
| 1891 |
+
decoder_outputs = self.speech_decoder_model(
|
| 1892 |
+
inputs_embeds=conditioning_embeds,
|
| 1893 |
+
output_hidden_states=output_hidden_states,
|
| 1894 |
+
return_dict=return_dict,
|
| 1895 |
+
)
|
| 1896 |
+
|
| 1897 |
+
speech_ids = decoder_outputs[0]
|
| 1898 |
+
|
| 1899 |
+
# since we will get the embeds of shape `(batch_size, seq_len, embedding_dim)` during the forward pass
|
| 1900 |
+
# we must convert it to tokens, to make it compaitable with speech_transformer
|
| 1901 |
+
if speech_ids.ndim == 3:
|
| 1902 |
+
speech_ids = speech_ids.argmax(2)
|
| 1903 |
+
speech_ids = self.fix_speech_decoder_output(speech_ids)
|
| 1904 |
+
|
| 1905 |
+
speech_outputs = self.speech_encoder_model(
|
| 1906 |
+
input_ids=speech_ids,
|
| 1907 |
+
output_hidden_states=output_hidden_states,
|
| 1908 |
+
return_dict=return_dict,
|
| 1909 |
+
)
|
| 1910 |
+
|
| 1911 |
+
text_outputs = self.text_encoder_model(
|
| 1912 |
+
input_ids=input_ids,
|
| 1913 |
+
inputs_embeds=text_encoder_inputs_embeds,
|
| 1914 |
+
attention_mask=attention_mask,
|
| 1915 |
+
output_hidden_states=output_hidden_states,
|
| 1916 |
+
return_dict=return_dict,
|
| 1917 |
+
)
|
| 1918 |
+
|
| 1919 |
+
speech_embeds = speech_outputs[0]
|
| 1920 |
+
text_embeds = text_outputs[0]
|
| 1921 |
+
|
| 1922 |
+
# normalized features
|
| 1923 |
+
speech_embeds = speech_embeds / speech_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 1924 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 1925 |
+
|
| 1926 |
+
# cosine similarity as logits
|
| 1927 |
+
logit_scale = self.logit_scale.exp()
|
| 1928 |
+
logits_per_text = torch.matmul(text_embeds, speech_embeds.t()) * logit_scale
|
| 1929 |
+
logits_per_speech = logits_per_text.t()
|
| 1930 |
+
|
| 1931 |
+
loss = None
|
| 1932 |
+
if return_loss:
|
| 1933 |
+
loss = clvp_loss(logits_per_text)
|
| 1934 |
+
|
| 1935 |
+
if not return_dict:
|
| 1936 |
+
output = (
|
| 1937 |
+
logits_per_speech,
|
| 1938 |
+
logits_per_text,
|
| 1939 |
+
text_embeds,
|
| 1940 |
+
speech_embeds,
|
| 1941 |
+
text_outputs[2],
|
| 1942 |
+
speech_outputs[2],
|
| 1943 |
+
)
|
| 1944 |
+
if output_hidden_states:
|
| 1945 |
+
output += (
|
| 1946 |
+
decoder_outputs[-1],
|
| 1947 |
+
text_outputs[-1],
|
| 1948 |
+
speech_outputs[-1],
|
| 1949 |
+
)
|
| 1950 |
+
|
| 1951 |
+
return ((loss,) + output) if loss is not None else output
|
| 1952 |
+
|
| 1953 |
+
return ClvpOutput(
|
| 1954 |
+
loss=loss,
|
| 1955 |
+
logits_per_speech=logits_per_speech,
|
| 1956 |
+
logits_per_text=logits_per_text,
|
| 1957 |
+
text_embeds=text_embeds,
|
| 1958 |
+
speech_embeds=speech_embeds,
|
| 1959 |
+
text_model_output=text_outputs[2],
|
| 1960 |
+
speech_model_output=speech_outputs[2],
|
| 1961 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 1962 |
+
text_encoder_hidden_states=text_outputs.hidden_states,
|
| 1963 |
+
speech_encoder_hidden_states=speech_outputs.hidden_states,
|
| 1964 |
+
)
|
| 1965 |
+
|
| 1966 |
+
@torch.no_grad()
|
| 1967 |
+
def generate(
|
| 1968 |
+
self,
|
| 1969 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1970 |
+
input_features: Optional[torch.FloatTensor] = None,
|
| 1971 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1972 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1973 |
+
pad_to_max_mel_tokens: Optional[int] = None,
|
| 1974 |
+
output_hidden_states: Optional[bool] = None,
|
| 1975 |
+
**kwargs,
|
| 1976 |
+
):
|
| 1977 |
+
"""
|
| 1978 |
+
Generate method for `ClvpModelForConditionalGeneration`, this method calls the `generate` method of
|
| 1979 |
+
`ClvpForCausalLM` and then uses those generated `speech_ids` to process `text_embeds` and `speech_embeds` using
|
| 1980 |
+
`ClvpEncoder`.
|
| 1981 |
+
|
| 1982 |
+
Args:
|
| 1983 |
+
input_ids (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1984 |
+
Input text Tokens. Processed from the [`ClvpTokenizer`].
|
| 1985 |
+
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, time_dim)`, *optional*):
|
| 1986 |
+
Indicates log-melspectrogram representations for audio returned by [`ClvpFeatureExtractor`].
|
| 1987 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1988 |
+
Mask to avoid performing attention on padding text token indices. Mask values selected in `[0, 1]`:
|
| 1989 |
+
|
| 1990 |
+
- 1 for tokens that are **not masked**,
|
| 1991 |
+
- 0 for tokens that are **masked**.
|
| 1992 |
+
|
| 1993 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1994 |
+
generation_config (`~generation.GenerationConfig`, *optional*):
|
| 1995 |
+
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
| 1996 |
+
passed to generate matching the attributes of `generation_config` will override them. If
|
| 1997 |
+
`generation_config` is not provided, the default will be used, which had the following loading
|
| 1998 |
+
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
| 1999 |
+
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
| 2000 |
+
default values, whose documentation should be checked to parameterize generation.
|
| 2001 |
+
pad_to_max_mel_tokens (`int`, *optional*):
|
| 2002 |
+
Pads generated speech_ids to the specified value. This is to implement the same logic from the official
|
| 2003 |
+
repo, link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
|
| 2004 |
+
and to make sure the logits are same.
|
| 2005 |
+
This does not affect generation quality so please don't consider using it since it is less efficient.
|
| 2006 |
+
output_hidden_states (`bool`, *optional*):
|
| 2007 |
+
Whether or not to return the hidden states of decoder model, text encoder and speech encoder models.
|
| 2008 |
+
|
| 2009 |
+
Returns:
|
| 2010 |
+
`ClvpOutput` or tuple: A `ClvpOutput` (if `return_dict_in_generate=True` or when
|
| 2011 |
+
`config.return_dict_in_generate=True`) or a tuple.
|
| 2012 |
+
"""
|
| 2013 |
+
|
| 2014 |
+
# If the input sequences are larger than (self.config.decoder_config.max_text_tokens - 3) then raise error,
|
| 2015 |
+
# because we need to add 3 tokens ( 1 bos tokens and 2 eos tokens) to the input_ids in ClvpConditioningEncoder to
|
| 2016 |
+
# properly sample
|
| 2017 |
+
sequence_length = input_ids.shape[-1]
|
| 2018 |
+
if sequence_length > (self.config.decoder_config.max_text_tokens - 3):
|
| 2019 |
+
raise ValueError(
|
| 2020 |
+
f"Maximum sequence length reached! Found input_ids of length {sequence_length}."
|
| 2021 |
+
f"Please make sure that the maximum length of input_ids is {self.config.decoder_config.max_text_tokens - 3}"
|
| 2022 |
+
)
|
| 2023 |
+
|
| 2024 |
+
if generation_config is None:
|
| 2025 |
+
generation_config = self.generation_config
|
| 2026 |
+
|
| 2027 |
+
generation_config = copy.deepcopy(generation_config)
|
| 2028 |
+
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
| 2029 |
+
generation_config.validate()
|
| 2030 |
+
self._validate_model_kwargs(model_kwargs.copy())
|
| 2031 |
+
|
| 2032 |
+
# pad input_ids as specified in the original repo
|
| 2033 |
+
# link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L380
|
| 2034 |
+
input_ids, attention_mask = _pad_extra_bos_eos_tokens(
|
| 2035 |
+
input_ids,
|
| 2036 |
+
attention_mask,
|
| 2037 |
+
add_bos_token=False,
|
| 2038 |
+
bos_token_id=self.config.text_config.bos_token_id,
|
| 2039 |
+
eos_token_id=self.config.text_config.eos_token_id,
|
| 2040 |
+
)
|
| 2041 |
+
|
| 2042 |
+
conditioning_embeds = self.conditioning_encoder(
|
| 2043 |
+
input_features=input_features,
|
| 2044 |
+
input_ids=input_ids,
|
| 2045 |
+
attention_mask=attention_mask,
|
| 2046 |
+
)
|
| 2047 |
+
|
| 2048 |
+
decoder_outputs = self.speech_decoder_model.generate(
|
| 2049 |
+
conditioning_embeds=conditioning_embeds,
|
| 2050 |
+
generation_config=generation_config,
|
| 2051 |
+
output_hidden_states=output_hidden_states,
|
| 2052 |
+
return_dict=generation_config.return_dict_in_generate,
|
| 2053 |
+
)
|
| 2054 |
+
if isinstance(decoder_outputs, ModelOutput):
|
| 2055 |
+
speech_ids = decoder_outputs.sequences
|
| 2056 |
+
|
| 2057 |
+
# pad to pad_to_max_mel_tokens if given, to replicate the original repo logic
|
| 2058 |
+
# link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
|
| 2059 |
+
if pad_to_max_mel_tokens is not None:
|
| 2060 |
+
padding_needed = pad_to_max_mel_tokens - speech_ids.shape[-1]
|
| 2061 |
+
speech_ids = torch.nn.functional.pad(
|
| 2062 |
+
speech_ids, (0, padding_needed), value=self.generation_config.eos_token_id
|
| 2063 |
+
)
|
| 2064 |
+
|
| 2065 |
+
speech_ids = self.fix_speech_decoder_output(speech_ids)
|
| 2066 |
+
|
| 2067 |
+
speech_outputs = self.speech_encoder_model(
|
| 2068 |
+
input_ids=speech_ids,
|
| 2069 |
+
output_hidden_states=output_hidden_states,
|
| 2070 |
+
return_dict=generation_config.return_dict_in_generate,
|
| 2071 |
+
)
|
| 2072 |
+
text_outputs = self.text_encoder_model(
|
| 2073 |
+
input_ids=input_ids,
|
| 2074 |
+
attention_mask=attention_mask,
|
| 2075 |
+
output_hidden_states=output_hidden_states,
|
| 2076 |
+
return_dict=generation_config.return_dict_in_generate,
|
| 2077 |
+
)
|
| 2078 |
+
|
| 2079 |
+
speech_embeds = speech_outputs[0]
|
| 2080 |
+
text_embeds = text_outputs[0]
|
| 2081 |
+
|
| 2082 |
+
# normalized features
|
| 2083 |
+
speech_embeds = speech_embeds / speech_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 2084 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 2085 |
+
|
| 2086 |
+
# cosine similarity as logits
|
| 2087 |
+
logit_scale = self.logit_scale.exp()
|
| 2088 |
+
logits_per_text = torch.matmul(text_embeds, speech_embeds.t()) * logit_scale
|
| 2089 |
+
logits_per_speech = logits_per_text.t()
|
| 2090 |
+
|
| 2091 |
+
if not generation_config.return_dict_in_generate:
|
| 2092 |
+
output = (
|
| 2093 |
+
speech_ids,
|
| 2094 |
+
logits_per_speech,
|
| 2095 |
+
logits_per_text,
|
| 2096 |
+
text_embeds,
|
| 2097 |
+
speech_embeds,
|
| 2098 |
+
text_outputs[2],
|
| 2099 |
+
speech_outputs[2],
|
| 2100 |
+
)
|
| 2101 |
+
if output_hidden_states:
|
| 2102 |
+
output += (
|
| 2103 |
+
decoder_outputs[-1],
|
| 2104 |
+
text_outputs[-1],
|
| 2105 |
+
speech_outputs[-1],
|
| 2106 |
+
)
|
| 2107 |
+
|
| 2108 |
+
return output
|
| 2109 |
+
|
| 2110 |
+
return ClvpOutput(
|
| 2111 |
+
speech_ids=speech_ids,
|
| 2112 |
+
logits_per_speech=logits_per_speech,
|
| 2113 |
+
logits_per_text=logits_per_text,
|
| 2114 |
+
text_embeds=text_embeds,
|
| 2115 |
+
speech_embeds=speech_embeds,
|
| 2116 |
+
text_model_output=text_outputs[2],
|
| 2117 |
+
speech_model_output=speech_outputs[2],
|
| 2118 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 2119 |
+
text_encoder_hidden_states=text_outputs.hidden_states,
|
| 2120 |
+
speech_encoder_hidden_states=speech_outputs.hidden_states,
|
| 2121 |
+
)
|
| 2122 |
+
|
| 2123 |
+
|
| 2124 |
+
__all__ = [
|
| 2125 |
+
"ClvpModelForConditionalGeneration",
|
| 2126 |
+
"ClvpForCausalLM",
|
| 2127 |
+
"ClvpModel",
|
| 2128 |
+
"ClvpPreTrainedModel",
|
| 2129 |
+
"ClvpEncoder",
|
| 2130 |
+
"ClvpDecoder",
|
| 2131 |
+
]
|
docs/transformers/build/lib/transformers/models/clvp/number_normalizer.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""English Normalizer class for CLVP."""
|
| 17 |
+
|
| 18 |
+
import re
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class EnglishNormalizer:
|
| 22 |
+
def __init__(self):
|
| 23 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
| 24 |
+
self._abbreviations = [
|
| 25 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
| 26 |
+
for x in [
|
| 27 |
+
("mrs", "misess"),
|
| 28 |
+
("mr", "mister"),
|
| 29 |
+
("dr", "doctor"),
|
| 30 |
+
("st", "saint"),
|
| 31 |
+
("co", "company"),
|
| 32 |
+
("jr", "junior"),
|
| 33 |
+
("maj", "major"),
|
| 34 |
+
("gen", "general"),
|
| 35 |
+
("drs", "doctors"),
|
| 36 |
+
("rev", "reverend"),
|
| 37 |
+
("lt", "lieutenant"),
|
| 38 |
+
("hon", "honorable"),
|
| 39 |
+
("sgt", "sergeant"),
|
| 40 |
+
("capt", "captain"),
|
| 41 |
+
("esq", "esquire"),
|
| 42 |
+
("ltd", "limited"),
|
| 43 |
+
("col", "colonel"),
|
| 44 |
+
("ft", "fort"),
|
| 45 |
+
]
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
self.ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
| 49 |
+
self.teens = [
|
| 50 |
+
"ten",
|
| 51 |
+
"eleven",
|
| 52 |
+
"twelve",
|
| 53 |
+
"thirteen",
|
| 54 |
+
"fourteen",
|
| 55 |
+
"fifteen",
|
| 56 |
+
"sixteen",
|
| 57 |
+
"seventeen",
|
| 58 |
+
"eighteen",
|
| 59 |
+
"nineteen",
|
| 60 |
+
]
|
| 61 |
+
self.tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
|
| 62 |
+
|
| 63 |
+
def number_to_words(self, num: int) -> str:
|
| 64 |
+
"""
|
| 65 |
+
Converts numbers(`int`) to words(`str`).
|
| 66 |
+
|
| 67 |
+
Please note that it only supports upto - "'nine hundred ninety-nine quadrillion, nine hundred ninety-nine
|
| 68 |
+
trillion, nine hundred ninety-nine billion, nine hundred ninety-nine million, nine hundred ninety-nine
|
| 69 |
+
thousand, nine hundred ninety-nine'" or `number_to_words(999_999_999_999_999_999)`.
|
| 70 |
+
"""
|
| 71 |
+
if num == 0:
|
| 72 |
+
return "zero"
|
| 73 |
+
elif num < 0:
|
| 74 |
+
return "minus " + self.number_to_words(abs(num))
|
| 75 |
+
elif num < 10:
|
| 76 |
+
return self.ones[num]
|
| 77 |
+
elif num < 20:
|
| 78 |
+
return self.teens[num - 10]
|
| 79 |
+
elif num < 100:
|
| 80 |
+
return self.tens[num // 10] + ("-" + self.number_to_words(num % 10) if num % 10 != 0 else "")
|
| 81 |
+
elif num < 1000:
|
| 82 |
+
return (
|
| 83 |
+
self.ones[num // 100] + " hundred" + (" " + self.number_to_words(num % 100) if num % 100 != 0 else "")
|
| 84 |
+
)
|
| 85 |
+
elif num < 1_000_000:
|
| 86 |
+
return (
|
| 87 |
+
self.number_to_words(num // 1000)
|
| 88 |
+
+ " thousand"
|
| 89 |
+
+ (", " + self.number_to_words(num % 1000) if num % 1000 != 0 else "")
|
| 90 |
+
)
|
| 91 |
+
elif num < 1_000_000_000:
|
| 92 |
+
return (
|
| 93 |
+
self.number_to_words(num // 1_000_000)
|
| 94 |
+
+ " million"
|
| 95 |
+
+ (", " + self.number_to_words(num % 1_000_000) if num % 1_000_000 != 0 else "")
|
| 96 |
+
)
|
| 97 |
+
elif num < 1_000_000_000_000:
|
| 98 |
+
return (
|
| 99 |
+
self.number_to_words(num // 1_000_000_000)
|
| 100 |
+
+ " billion"
|
| 101 |
+
+ (", " + self.number_to_words(num % 1_000_000_000) if num % 1_000_000_000 != 0 else "")
|
| 102 |
+
)
|
| 103 |
+
elif num < 1_000_000_000_000_000:
|
| 104 |
+
return (
|
| 105 |
+
self.number_to_words(num // 1_000_000_000_000)
|
| 106 |
+
+ " trillion"
|
| 107 |
+
+ (", " + self.number_to_words(num % 1_000_000_000_000) if num % 1_000_000_000_000 != 0 else "")
|
| 108 |
+
)
|
| 109 |
+
elif num < 1_000_000_000_000_000_000:
|
| 110 |
+
return (
|
| 111 |
+
self.number_to_words(num // 1_000_000_000_000_000)
|
| 112 |
+
+ " quadrillion"
|
| 113 |
+
+ (
|
| 114 |
+
", " + self.number_to_words(num % 1_000_000_000_000_000)
|
| 115 |
+
if num % 1_000_000_000_000_000 != 0
|
| 116 |
+
else ""
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
return "number out of range"
|
| 121 |
+
|
| 122 |
+
def convert_to_ascii(self, text: str) -> str:
|
| 123 |
+
"""
|
| 124 |
+
Converts unicode to ascii
|
| 125 |
+
"""
|
| 126 |
+
return text.encode("ascii", "ignore").decode("utf-8")
|
| 127 |
+
|
| 128 |
+
def _expand_dollars(self, m: str) -> str:
|
| 129 |
+
"""
|
| 130 |
+
This method is used to expand numerical dollar values into spoken words.
|
| 131 |
+
"""
|
| 132 |
+
match = m.group(1)
|
| 133 |
+
parts = match.split(".")
|
| 134 |
+
if len(parts) > 2:
|
| 135 |
+
return match + " dollars" # Unexpected format
|
| 136 |
+
|
| 137 |
+
dollars = int(parts[0]) if parts[0] else 0
|
| 138 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
| 139 |
+
if dollars and cents:
|
| 140 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
| 141 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
| 142 |
+
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
| 143 |
+
elif dollars:
|
| 144 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
| 145 |
+
return "%s %s" % (dollars, dollar_unit)
|
| 146 |
+
elif cents:
|
| 147 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
| 148 |
+
return "%s %s" % (cents, cent_unit)
|
| 149 |
+
else:
|
| 150 |
+
return "zero dollars"
|
| 151 |
+
|
| 152 |
+
def _remove_commas(self, m: str) -> str:
|
| 153 |
+
"""
|
| 154 |
+
This method is used to remove commas from sentences.
|
| 155 |
+
"""
|
| 156 |
+
return m.group(1).replace(",", "")
|
| 157 |
+
|
| 158 |
+
def _expand_decimal_point(self, m: str) -> str:
|
| 159 |
+
"""
|
| 160 |
+
This method is used to expand '.' into spoken word ' point '.
|
| 161 |
+
"""
|
| 162 |
+
return m.group(1).replace(".", " point ")
|
| 163 |
+
|
| 164 |
+
def _expand_ordinal(self, num: str) -> str:
|
| 165 |
+
"""
|
| 166 |
+
This method is used to expand ordinals such as '1st', '2nd' into spoken words.
|
| 167 |
+
"""
|
| 168 |
+
ordinal_suffixes = {1: "st", 2: "nd", 3: "rd"}
|
| 169 |
+
|
| 170 |
+
num = int(num.group(0)[:-2])
|
| 171 |
+
if 10 <= num % 100 and num % 100 <= 20:
|
| 172 |
+
suffix = "th"
|
| 173 |
+
else:
|
| 174 |
+
suffix = ordinal_suffixes.get(num % 10, "th")
|
| 175 |
+
return self.number_to_words(num) + suffix
|
| 176 |
+
|
| 177 |
+
def _expand_number(self, m: str) -> str:
|
| 178 |
+
"""
|
| 179 |
+
This method acts as a preprocessing step for numbers between 1000 and 3000 (same as the original repository,
|
| 180 |
+
link :
|
| 181 |
+
https://github.com/neonbjb/tortoise-tts/blob/4003544b6ff4b68c09856e04d3eff9da26d023c2/tortoise/utils/tokenizer.py#L86)
|
| 182 |
+
"""
|
| 183 |
+
num = int(m.group(0))
|
| 184 |
+
|
| 185 |
+
if num > 1000 and num < 3000:
|
| 186 |
+
if num == 2000:
|
| 187 |
+
return "two thousand"
|
| 188 |
+
elif num > 2000 and num < 2010:
|
| 189 |
+
return "two thousand " + self.number_to_words(num % 100)
|
| 190 |
+
elif num % 100 == 0:
|
| 191 |
+
return self.number_to_words(num // 100) + " hundred"
|
| 192 |
+
else:
|
| 193 |
+
return self.number_to_words(num)
|
| 194 |
+
else:
|
| 195 |
+
return self.number_to_words(num)
|
| 196 |
+
|
| 197 |
+
def normalize_numbers(self, text: str) -> str:
|
| 198 |
+
"""
|
| 199 |
+
This method is used to normalize numbers within a text such as converting the numbers to words, removing
|
| 200 |
+
commas, etc.
|
| 201 |
+
"""
|
| 202 |
+
text = re.sub(re.compile(r"([0-9][0-9\,]+[0-9])"), self._remove_commas, text)
|
| 203 |
+
text = re.sub(re.compile(r"£([0-9\,]*[0-9]+)"), r"\1 pounds", text)
|
| 204 |
+
text = re.sub(re.compile(r"\$([0-9\.\,]*[0-9]+)"), self._expand_dollars, text)
|
| 205 |
+
text = re.sub(re.compile(r"([0-9]+\.[0-9]+)"), self._expand_decimal_point, text)
|
| 206 |
+
text = re.sub(re.compile(r"[0-9]+(st|nd|rd|th)"), self._expand_ordinal, text)
|
| 207 |
+
text = re.sub(re.compile(r"[0-9]+"), self._expand_number, text)
|
| 208 |
+
return text
|
| 209 |
+
|
| 210 |
+
def expand_abbreviations(self, text: str) -> str:
|
| 211 |
+
"""
|
| 212 |
+
Expands the abbreviate words.
|
| 213 |
+
"""
|
| 214 |
+
for regex, replacement in self._abbreviations:
|
| 215 |
+
text = re.sub(regex, replacement, text)
|
| 216 |
+
return text
|
| 217 |
+
|
| 218 |
+
def collapse_whitespace(self, text: str) -> str:
|
| 219 |
+
"""
|
| 220 |
+
Removes multiple whitespaces
|
| 221 |
+
"""
|
| 222 |
+
return re.sub(re.compile(r"\s+"), " ", text)
|
| 223 |
+
|
| 224 |
+
def __call__(self, text):
|
| 225 |
+
"""
|
| 226 |
+
Converts text to ascii, numbers / number-like quantities to their spelt-out counterparts and expands
|
| 227 |
+
abbreviations
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
text = self.convert_to_ascii(text)
|
| 231 |
+
text = text.lower()
|
| 232 |
+
text = self.normalize_numbers(text)
|
| 233 |
+
text = self.expand_abbreviations(text)
|
| 234 |
+
text = self.collapse_whitespace(text)
|
| 235 |
+
text = text.replace('"', "")
|
| 236 |
+
|
| 237 |
+
return text
|
docs/transformers/build/lib/transformers/models/clvp/processing_clvp.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Processor class for CLVP
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from ...processing_utils import ProcessorMixin
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ClvpProcessor(ProcessorMixin):
|
| 24 |
+
r"""
|
| 25 |
+
Constructs a CLVP processor which wraps a CLVP Feature Extractor and a CLVP Tokenizer into a single processor.
|
| 26 |
+
|
| 27 |
+
[`ClvpProcessor`] offers all the functionalities of [`ClvpFeatureExtractor`] and [`ClvpTokenizer`]. See the
|
| 28 |
+
[`~ClvpProcessor.__call__`], [`~ClvpProcessor.decode`] and [`~ClvpProcessor.batch_decode`] for more information.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
feature_extractor (`ClvpFeatureExtractor`):
|
| 32 |
+
An instance of [`ClvpFeatureExtractor`]. The feature extractor is a required input.
|
| 33 |
+
tokenizer (`ClvpTokenizer`):
|
| 34 |
+
An instance of [`ClvpTokenizer`]. The tokenizer is a required input.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
feature_extractor_class = "ClvpFeatureExtractor"
|
| 38 |
+
tokenizer_class = "ClvpTokenizer"
|
| 39 |
+
model_input_names = [
|
| 40 |
+
"input_ids",
|
| 41 |
+
"input_features",
|
| 42 |
+
"attention_mask",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
def __init__(self, feature_extractor, tokenizer):
|
| 46 |
+
super().__init__(feature_extractor, tokenizer)
|
| 47 |
+
|
| 48 |
+
def __call__(self, *args, **kwargs):
|
| 49 |
+
"""
|
| 50 |
+
Forwards the `audio` and `sampling_rate` arguments to [`~ClvpFeatureExtractor.__call__`] and the `text`
|
| 51 |
+
argument to [`~ClvpTokenizer.__call__`]. Please refer to the docstring of the above two methods for more
|
| 52 |
+
information.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
raw_speech = kwargs.pop("raw_speech", None)
|
| 56 |
+
sampling_rate = kwargs.pop("sampling_rate", None)
|
| 57 |
+
text = kwargs.pop("text", None)
|
| 58 |
+
|
| 59 |
+
if raw_speech is None and text is None:
|
| 60 |
+
raise ValueError("You need to specify either an `raw_speech` or `text` input to process.")
|
| 61 |
+
|
| 62 |
+
if raw_speech is not None:
|
| 63 |
+
inputs = self.feature_extractor(raw_speech, sampling_rate=sampling_rate, **kwargs)
|
| 64 |
+
if text is not None:
|
| 65 |
+
encodings = self.tokenizer(text, **kwargs)
|
| 66 |
+
|
| 67 |
+
if text is None:
|
| 68 |
+
return inputs
|
| 69 |
+
elif raw_speech is None:
|
| 70 |
+
return encodings
|
| 71 |
+
else:
|
| 72 |
+
inputs["input_ids"] = encodings["input_ids"]
|
| 73 |
+
inputs["attention_mask"] = encodings["attention_mask"]
|
| 74 |
+
return inputs
|
| 75 |
+
|
| 76 |
+
# Copied from transformers.models.whisper.processing_whisper.WhisperProcessor.batch_decode with Whisper->Clvp
|
| 77 |
+
def batch_decode(self, *args, **kwargs):
|
| 78 |
+
"""
|
| 79 |
+
This method forwards all its arguments to ClvpTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 80 |
+
refer to the docstring of this method for more information.
|
| 81 |
+
"""
|
| 82 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 83 |
+
|
| 84 |
+
# Copied from transformers.models.whisper.processing_whisper.WhisperProcessor.decode with Whisper->Clvp
|
| 85 |
+
def decode(self, *args, **kwargs):
|
| 86 |
+
"""
|
| 87 |
+
This method forwards all its arguments to ClvpTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 88 |
+
the docstring of this method for more information.
|
| 89 |
+
"""
|
| 90 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
__all__ = ["ClvpProcessor"]
|
docs/transformers/build/lib/transformers/models/clvp/tokenization_clvp.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization class for CLVP."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from functools import lru_cache
|
| 20 |
+
from typing import List, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import regex as re
|
| 23 |
+
|
| 24 |
+
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 25 |
+
from ...utils import logging
|
| 26 |
+
from .number_normalizer import EnglishNormalizer
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
VOCAB_FILES_NAMES = {
|
| 32 |
+
"vocab_file": "vocab.json",
|
| 33 |
+
"merges_file": "merges.txt",
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@lru_cache()
|
| 38 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
|
| 39 |
+
def bytes_to_unicode():
|
| 40 |
+
"""
|
| 41 |
+
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
| 42 |
+
characters the bpe code barfs on.
|
| 43 |
+
|
| 44 |
+
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
| 45 |
+
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
| 46 |
+
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
| 47 |
+
tables between utf-8 bytes and unicode strings.
|
| 48 |
+
"""
|
| 49 |
+
bs = (
|
| 50 |
+
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
| 51 |
+
)
|
| 52 |
+
cs = bs[:]
|
| 53 |
+
n = 0
|
| 54 |
+
for b in range(2**8):
|
| 55 |
+
if b not in bs:
|
| 56 |
+
bs.append(b)
|
| 57 |
+
cs.append(2**8 + n)
|
| 58 |
+
n += 1
|
| 59 |
+
cs = [chr(n) for n in cs]
|
| 60 |
+
return dict(zip(bs, cs))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
|
| 64 |
+
def get_pairs(word):
|
| 65 |
+
"""
|
| 66 |
+
Return set of symbol pairs in a word.
|
| 67 |
+
|
| 68 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 69 |
+
"""
|
| 70 |
+
pairs = set()
|
| 71 |
+
prev_char = word[0]
|
| 72 |
+
for char in word[1:]:
|
| 73 |
+
pairs.add((prev_char, char))
|
| 74 |
+
prev_char = char
|
| 75 |
+
return pairs
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ClvpTokenizer(PreTrainedTokenizer):
|
| 79 |
+
"""
|
| 80 |
+
Construct a CLVP tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 81 |
+
|
| 82 |
+
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
|
| 83 |
+
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
| 84 |
+
|
| 85 |
+
```python
|
| 86 |
+
>>> from transformers import ClvpTokenizer
|
| 87 |
+
|
| 88 |
+
>>> tokenizer = ClvpTokenizer.from_pretrained("susnato/clvp_dev")
|
| 89 |
+
>>> tokenizer("Hello world")["input_ids"]
|
| 90 |
+
[62, 84, 28, 2, 179, 79]
|
| 91 |
+
|
| 92 |
+
>>> tokenizer(" Hello world")["input_ids"]
|
| 93 |
+
[2, 62, 84, 28, 2, 179, 79]
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
|
| 97 |
+
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
|
| 98 |
+
|
| 99 |
+
<Tip>
|
| 100 |
+
|
| 101 |
+
When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
|
| 102 |
+
|
| 103 |
+
</Tip>
|
| 104 |
+
|
| 105 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 106 |
+
this superclass for more information regarding those methods.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
vocab_file (`str`):
|
| 110 |
+
Path to the vocabulary file.
|
| 111 |
+
merges_file (`str`):
|
| 112 |
+
Path to the merges file.
|
| 113 |
+
errors (`str`, *optional*, defaults to `"replace"`):
|
| 114 |
+
Paradigm to follow when decoding bytes to UTF-8. See
|
| 115 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
| 116 |
+
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
|
| 117 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 118 |
+
token instead.
|
| 119 |
+
bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 120 |
+
The beginning of sequence token.
|
| 121 |
+
eos_token (`str`, *optional*, defaults to `"[STOP]"`):
|
| 122 |
+
The end of sequence token.
|
| 123 |
+
pad_token (`str`, *optional*, defaults to `"[STOP]"`):
|
| 124 |
+
The pad token of the sequence.
|
| 125 |
+
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
| 126 |
+
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
| 127 |
+
other word. (CLVP tokenizer detect beginning of words by the preceding space).
|
| 128 |
+
add_bos_token (`bool`, *optional*, defaults to `False`):
|
| 129 |
+
Whether to add `bos_token` in front of the sequence when add_special_tokens=True.
|
| 130 |
+
add_eos_token (`bool`, *optional*, defaults to `False`):
|
| 131 |
+
Whether to add `eos_token` in end of the sequence when add_special_tokens=True.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 135 |
+
model_input_names = [
|
| 136 |
+
"input_ids",
|
| 137 |
+
"attention_mask",
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
vocab_file,
|
| 143 |
+
merges_file,
|
| 144 |
+
errors="replace",
|
| 145 |
+
unk_token="[UNK]",
|
| 146 |
+
bos_token="<|endoftext|>",
|
| 147 |
+
eos_token="[STOP]",
|
| 148 |
+
pad_token="[STOP]",
|
| 149 |
+
add_prefix_space=False,
|
| 150 |
+
add_bos_token=False,
|
| 151 |
+
add_eos_token=False,
|
| 152 |
+
**kwargs,
|
| 153 |
+
):
|
| 154 |
+
bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token
|
| 155 |
+
eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token
|
| 156 |
+
unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token
|
| 157 |
+
pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
|
| 158 |
+
|
| 159 |
+
self.add_bos_token = add_bos_token
|
| 160 |
+
self.add_eos_token = add_eos_token
|
| 161 |
+
self._normalizer = None
|
| 162 |
+
|
| 163 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
| 164 |
+
self.encoder = json.load(vocab_handle)
|
| 165 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 166 |
+
self.errors = errors # how to handle errors in decoding
|
| 167 |
+
self.byte_encoder = bytes_to_unicode()
|
| 168 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 169 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
| 170 |
+
bpe_merges = merges_handle.read().split("\n")[1:-1]
|
| 171 |
+
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
|
| 172 |
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
| 173 |
+
self.cache = {}
|
| 174 |
+
self.add_prefix_space = add_prefix_space
|
| 175 |
+
|
| 176 |
+
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
| 177 |
+
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
| 178 |
+
|
| 179 |
+
super().__init__(
|
| 180 |
+
errors=errors,
|
| 181 |
+
unk_token=unk_token,
|
| 182 |
+
bos_token=bos_token,
|
| 183 |
+
eos_token=eos_token,
|
| 184 |
+
pad_token=pad_token,
|
| 185 |
+
add_prefix_space=add_prefix_space,
|
| 186 |
+
add_bos_token=add_bos_token,
|
| 187 |
+
add_eos_token=add_eos_token,
|
| 188 |
+
**kwargs,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
@property
|
| 192 |
+
def vocab_size(self):
|
| 193 |
+
return len(self.encoder)
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def normalizer(self):
|
| 197 |
+
if self._normalizer is None:
|
| 198 |
+
self._normalizer = EnglishNormalizer()
|
| 199 |
+
return self._normalizer
|
| 200 |
+
|
| 201 |
+
def get_vocab(self):
|
| 202 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 203 |
+
|
| 204 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
|
| 205 |
+
def bpe(self, token):
|
| 206 |
+
if token in self.cache:
|
| 207 |
+
return self.cache[token]
|
| 208 |
+
word = tuple(token)
|
| 209 |
+
pairs = get_pairs(word)
|
| 210 |
+
|
| 211 |
+
if not pairs:
|
| 212 |
+
return token
|
| 213 |
+
|
| 214 |
+
while True:
|
| 215 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 216 |
+
if bigram not in self.bpe_ranks:
|
| 217 |
+
break
|
| 218 |
+
first, second = bigram
|
| 219 |
+
new_word = []
|
| 220 |
+
i = 0
|
| 221 |
+
while i < len(word):
|
| 222 |
+
try:
|
| 223 |
+
j = word.index(first, i)
|
| 224 |
+
except ValueError:
|
| 225 |
+
new_word.extend(word[i:])
|
| 226 |
+
break
|
| 227 |
+
else:
|
| 228 |
+
new_word.extend(word[i:j])
|
| 229 |
+
i = j
|
| 230 |
+
|
| 231 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 232 |
+
new_word.append(first + second)
|
| 233 |
+
i += 2
|
| 234 |
+
else:
|
| 235 |
+
new_word.append(word[i])
|
| 236 |
+
i += 1
|
| 237 |
+
new_word = tuple(new_word)
|
| 238 |
+
word = new_word
|
| 239 |
+
if len(word) == 1:
|
| 240 |
+
break
|
| 241 |
+
else:
|
| 242 |
+
pairs = get_pairs(word)
|
| 243 |
+
word = " ".join(word)
|
| 244 |
+
self.cache[token] = word
|
| 245 |
+
return word
|
| 246 |
+
|
| 247 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
|
| 248 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 249 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 250 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 251 |
+
|
| 252 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
| 253 |
+
|
| 254 |
+
if token_ids_1 is not None:
|
| 255 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
| 256 |
+
|
| 257 |
+
return output
|
| 258 |
+
|
| 259 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_special_tokens_mask
|
| 260 |
+
def get_special_tokens_mask(
|
| 261 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 262 |
+
) -> List[int]:
|
| 263 |
+
"""
|
| 264 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 265 |
+
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
token_ids_0 (`List[int]`):
|
| 269 |
+
List of IDs.
|
| 270 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 271 |
+
Optional second list of IDs for sequence pairs.
|
| 272 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 273 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 277 |
+
"""
|
| 278 |
+
if already_has_special_tokens:
|
| 279 |
+
return super().get_special_tokens_mask(
|
| 280 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if not self.add_bos_token:
|
| 284 |
+
return super().get_special_tokens_mask(
|
| 285 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if token_ids_1 is None:
|
| 289 |
+
return [1] + ([0] * len(token_ids_0))
|
| 290 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
|
| 291 |
+
|
| 292 |
+
def _tokenize(self, text):
|
| 293 |
+
"""Tokenize a string."""
|
| 294 |
+
bpe_tokens = []
|
| 295 |
+
text = self.normalizer(text)
|
| 296 |
+
for token in re.findall(self.pat, text):
|
| 297 |
+
token = "".join(
|
| 298 |
+
self.byte_encoder[b] for b in token.encode("utf-8")
|
| 299 |
+
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
|
| 300 |
+
|
| 301 |
+
# if the token is "Ġ" we replace it with "[SPACE]" (if "[SPACE]" is present in the vocab), otherwise we keep the "Ġ".
|
| 302 |
+
bpe_tokens.extend(
|
| 303 |
+
"[SPACE]" if bpe_token == "\u0120" and "[SPACE]" in self.encoder.keys() else bpe_token
|
| 304 |
+
for bpe_token in self.bpe(token).split(" ")
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
return bpe_tokens
|
| 308 |
+
|
| 309 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
|
| 310 |
+
def _convert_token_to_id(self, token):
|
| 311 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 312 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 313 |
+
|
| 314 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
|
| 315 |
+
def _convert_id_to_token(self, index):
|
| 316 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 317 |
+
return self.decoder.get(index)
|
| 318 |
+
|
| 319 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
|
| 320 |
+
def convert_tokens_to_string(self, tokens):
|
| 321 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 322 |
+
text = "".join(tokens)
|
| 323 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
| 324 |
+
return text
|
| 325 |
+
|
| 326 |
+
def clean_up_tokenization(self, text):
|
| 327 |
+
text = "".join(text)
|
| 328 |
+
vocab_tokens = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys())
|
| 329 |
+
|
| 330 |
+
text = text.replace("[SPACE]", " ") if "[SPACE]" in vocab_tokens else text
|
| 331 |
+
text = text.replace("[STOP]", " ") if "[STOP]" in vocab_tokens else text
|
| 332 |
+
|
| 333 |
+
text = text.replace(self.unk_token, "").replace(" ", " ").replace(" ", " ")
|
| 334 |
+
return text
|
| 335 |
+
|
| 336 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
|
| 337 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 338 |
+
if not os.path.isdir(save_directory):
|
| 339 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 340 |
+
return
|
| 341 |
+
vocab_file = os.path.join(
|
| 342 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 343 |
+
)
|
| 344 |
+
merge_file = os.path.join(
|
| 345 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 349 |
+
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
| 350 |
+
|
| 351 |
+
index = 0
|
| 352 |
+
with open(merge_file, "w", encoding="utf-8") as writer:
|
| 353 |
+
writer.write("#version: 0.2\n")
|
| 354 |
+
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
| 355 |
+
if index != token_index:
|
| 356 |
+
logger.warning(
|
| 357 |
+
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
| 358 |
+
" Please check that the tokenizer is not corrupted!"
|
| 359 |
+
)
|
| 360 |
+
index = token_index
|
| 361 |
+
writer.write(" ".join(bpe_tokens) + "\n")
|
| 362 |
+
index += 1
|
| 363 |
+
|
| 364 |
+
return vocab_file, merge_file
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
__all__ = ["ClvpTokenizer"]
|
docs/transformers/build/lib/transformers/models/code_llama/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .tokenization_code_llama import *
|
| 22 |
+
from .tokenization_code_llama_fast import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/code_llama/tokenization_code_llama.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 MetaAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""Tokenization classes for Code LLaMA."""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
from shutil import copyfile
|
| 21 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 22 |
+
|
| 23 |
+
import sentencepiece as spm
|
| 24 |
+
|
| 25 |
+
from ...convert_slow_tokenizer import import_protobuf
|
| 26 |
+
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 27 |
+
from ...utils import logging, requires_backends
|
| 28 |
+
from ...utils.import_utils import requires
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
| 34 |
+
|
| 35 |
+
SPIECE_UNDERLINE = "▁"
|
| 36 |
+
|
| 37 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
| 38 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 39 |
+
|
| 40 |
+
# fmt: off
|
| 41 |
+
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
|
| 42 |
+
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
| 43 |
+
that your responses are socially unbiased and positive in nature.
|
| 44 |
+
|
| 45 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
| 46 |
+
correct. If you don't know the answer to a question, please don't share false information."""
|
| 47 |
+
# fmt: on
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@requires(backends=("sentencepiece",))
|
| 51 |
+
class CodeLlamaTokenizer(PreTrainedTokenizer):
|
| 52 |
+
"""
|
| 53 |
+
Construct a CodeLlama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as
|
| 54 |
+
there is no padding token in the original model.
|
| 55 |
+
|
| 56 |
+
The default configuration match that of
|
| 57 |
+
[codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
|
| 58 |
+
which supports prompt infilling.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
vocab_file (`str`):
|
| 62 |
+
Path to the vocabulary file.
|
| 63 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 64 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 65 |
+
token instead.
|
| 66 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 67 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 68 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 69 |
+
The end of sequence token.
|
| 70 |
+
|
| 71 |
+
<Tip>
|
| 72 |
+
|
| 73 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 74 |
+
The token used is the `sep_token`.
|
| 75 |
+
|
| 76 |
+
</Tip>
|
| 77 |
+
|
| 78 |
+
prefix_token (`str`, *optional*, defaults to `"▁<PRE>"`):
|
| 79 |
+
Prefix token used for infilling.
|
| 80 |
+
middle_token (`str`, *optional*, defaults to `"▁<MID>"`):
|
| 81 |
+
Middle token used for infilling.
|
| 82 |
+
suffix_token (`str`, *optional*, defaults to `"▁<SUF>"`):
|
| 83 |
+
Suffix token used for infilling.
|
| 84 |
+
eot_token (`str`, *optional*, defaults to `"▁<EOT>"`):
|
| 85 |
+
End of text token used for infilling.
|
| 86 |
+
fill_token (`str`, *optional*, defaults to `"<FILL_ME>"`):
|
| 87 |
+
The token used to split the input between the prefix and suffix.
|
| 88 |
+
suffix_first (`bool`, *optional*, defaults to `False`):
|
| 89 |
+
Whether the input prompt and suffix should be formatted with the suffix first.
|
| 90 |
+
sp_model_kwargs (`dict`, *optional*):
|
| 91 |
+
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
| 92 |
+
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
| 93 |
+
to set:
|
| 94 |
+
|
| 95 |
+
- `enable_sampling`: Enable subword regularization.
|
| 96 |
+
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
| 97 |
+
|
| 98 |
+
- `nbest_size = {0,1}`: No sampling is performed.
|
| 99 |
+
- `nbest_size > 1`: samples from the nbest_size results.
|
| 100 |
+
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
| 101 |
+
using forward-filtering-and-backward-sampling algorithm.
|
| 102 |
+
|
| 103 |
+
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
| 104 |
+
BPE-dropout.
|
| 105 |
+
add_bos_token (`bool`, *optional*, defaults to `True`):
|
| 106 |
+
Whether to add a beginning of sequence token at the start of sequences.
|
| 107 |
+
add_eos_token (`bool`, *optional*, defaults to `False`):
|
| 108 |
+
Whether to add an end of sequence token at the end of sequences.
|
| 109 |
+
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
| 110 |
+
Whether or not to clean up the tokenization spaces.
|
| 111 |
+
additional_special_tokens (`List[str]`, *optional*):
|
| 112 |
+
Additional special tokens used by the tokenizer.
|
| 113 |
+
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
|
| 114 |
+
Whether or not the default system prompt for Llama should be used.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 118 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
vocab_file,
|
| 123 |
+
unk_token="<unk>",
|
| 124 |
+
bos_token="<s>",
|
| 125 |
+
eos_token="</s>",
|
| 126 |
+
prefix_token="▁<PRE>",
|
| 127 |
+
middle_token="▁<MID>",
|
| 128 |
+
suffix_token="▁<SUF>",
|
| 129 |
+
eot_token="▁<EOT>",
|
| 130 |
+
fill_token="<FILL_ME>",
|
| 131 |
+
suffix_first=False,
|
| 132 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 133 |
+
add_bos_token=True,
|
| 134 |
+
add_eos_token=False,
|
| 135 |
+
clean_up_tokenization_spaces=False,
|
| 136 |
+
additional_special_tokens=None,
|
| 137 |
+
use_default_system_prompt=False,
|
| 138 |
+
**kwargs,
|
| 139 |
+
):
|
| 140 |
+
requires_backends(self, "protobuf")
|
| 141 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 142 |
+
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
|
| 143 |
+
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
|
| 144 |
+
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
|
| 145 |
+
|
| 146 |
+
self.use_default_system_prompt = use_default_system_prompt
|
| 147 |
+
# mark tokens special to skip them
|
| 148 |
+
additional_special_tokens = additional_special_tokens or []
|
| 149 |
+
for token in [prefix_token, middle_token, suffix_token, eot_token]:
|
| 150 |
+
additional_special_tokens += [token] if token is not None else []
|
| 151 |
+
|
| 152 |
+
self.vocab_file = vocab_file
|
| 153 |
+
self.add_bos_token = add_bos_token
|
| 154 |
+
self.add_eos_token = add_eos_token
|
| 155 |
+
self._prefix_token = prefix_token
|
| 156 |
+
self._middle_token = middle_token
|
| 157 |
+
self._suffix_token = suffix_token
|
| 158 |
+
self._eot_token = eot_token
|
| 159 |
+
self.fill_token = fill_token
|
| 160 |
+
self.suffix_first = suffix_first
|
| 161 |
+
self.sp_model = self.get_spm_processor()
|
| 162 |
+
|
| 163 |
+
super().__init__(
|
| 164 |
+
bos_token=bos_token,
|
| 165 |
+
eos_token=eos_token,
|
| 166 |
+
unk_token=unk_token,
|
| 167 |
+
add_bos_token=add_bos_token,
|
| 168 |
+
add_eos_token=add_eos_token,
|
| 169 |
+
prefix_token=prefix_token,
|
| 170 |
+
middle_token=middle_token,
|
| 171 |
+
suffix_token=suffix_token,
|
| 172 |
+
eot_token=eot_token,
|
| 173 |
+
fill_token=fill_token,
|
| 174 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 175 |
+
suffix_first=suffix_first,
|
| 176 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 177 |
+
additional_special_tokens=additional_special_tokens,
|
| 178 |
+
use_default_system_prompt=use_default_system_prompt,
|
| 179 |
+
**kwargs,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
@property
|
| 183 |
+
def unk_token_length(self):
|
| 184 |
+
return len(self.sp_model.encode(str(self.unk_token)))
|
| 185 |
+
|
| 186 |
+
def get_spm_processor(self):
|
| 187 |
+
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 188 |
+
with open(self.vocab_file, "rb") as f:
|
| 189 |
+
sp_model = f.read()
|
| 190 |
+
model_pb2 = import_protobuf()
|
| 191 |
+
model = model_pb2.ModelProto.FromString(sp_model)
|
| 192 |
+
normalizer_spec = model_pb2.NormalizerSpec()
|
| 193 |
+
normalizer_spec.add_dummy_prefix = False
|
| 194 |
+
model.normalizer_spec.MergeFrom(normalizer_spec)
|
| 195 |
+
sp_model = model.SerializeToString()
|
| 196 |
+
tokenizer.LoadFromSerializedProto(sp_model)
|
| 197 |
+
return tokenizer
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def prefix_token(self):
|
| 201 |
+
return self._prefix_token
|
| 202 |
+
|
| 203 |
+
@property
|
| 204 |
+
def prefix_id(self):
|
| 205 |
+
if self._prefix_token is None:
|
| 206 |
+
return None
|
| 207 |
+
return self.convert_tokens_to_ids(self.prefix_token)
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def middle_token(self):
|
| 211 |
+
return self._middle_token
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def middle_id(self):
|
| 215 |
+
if self._middle_token is None:
|
| 216 |
+
return None
|
| 217 |
+
return self.convert_tokens_to_ids(self.middle_token)
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def suffix_token(self):
|
| 221 |
+
return self._suffix_token
|
| 222 |
+
|
| 223 |
+
@property
|
| 224 |
+
def suffix_id(self):
|
| 225 |
+
if self._suffix_token is None:
|
| 226 |
+
return None
|
| 227 |
+
return self.convert_tokens_to_ids(self.suffix_token)
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def eot_token(self):
|
| 231 |
+
return self._eot_token
|
| 232 |
+
|
| 233 |
+
@property
|
| 234 |
+
def eot_id(self):
|
| 235 |
+
if self._eot_token is None:
|
| 236 |
+
return None
|
| 237 |
+
return self.convert_tokens_to_ids(self.eot_token)
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def vocab_size(self):
|
| 241 |
+
"""Returns vocab size"""
|
| 242 |
+
return self.sp_model.get_piece_size()
|
| 243 |
+
|
| 244 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
|
| 245 |
+
def get_vocab(self):
|
| 246 |
+
"""Returns vocab as a dict"""
|
| 247 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 248 |
+
vocab.update(self.added_tokens_encoder)
|
| 249 |
+
return vocab
|
| 250 |
+
|
| 251 |
+
def tokenize(self, prefix, suffix=None, suffix_first=False, **kwargs) -> List[int]:
|
| 252 |
+
# add a prefix space to `prefix`
|
| 253 |
+
if self.fill_token is not None and self.fill_token in prefix and suffix is None:
|
| 254 |
+
prefix, suffix = prefix.split(self.fill_token)
|
| 255 |
+
|
| 256 |
+
if len(prefix) > 0:
|
| 257 |
+
prefix = SPIECE_UNDERLINE + prefix.replace(SPIECE_UNDERLINE, " ")
|
| 258 |
+
|
| 259 |
+
if suffix is None or len(suffix) < 1:
|
| 260 |
+
tokens = super().tokenize(prefix, **kwargs)
|
| 261 |
+
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
|
| 262 |
+
tokens = tokens[1:]
|
| 263 |
+
return tokens
|
| 264 |
+
|
| 265 |
+
prefix_tokens = self._tokenize(prefix) # prefix has an extra `SPIECE_UNDERLINE`
|
| 266 |
+
|
| 267 |
+
if None in (self.prefix_id, self.middle_id, self.suffix_id):
|
| 268 |
+
raise ValueError(
|
| 269 |
+
"The input either includes a `prefix` and a `suffix` used for the infilling task,"
|
| 270 |
+
f" or can be split on the {self.fill_token} token, creating a suffix and prefix,"
|
| 271 |
+
" but the model does not support `infilling`."
|
| 272 |
+
)
|
| 273 |
+
suffix_tokens = self._tokenize(suffix) # make sure CodeLlama sp model does not mess up
|
| 274 |
+
|
| 275 |
+
suffix_first = suffix_first if suffix_first is not None else self.suffix_first
|
| 276 |
+
if suffix_first:
|
| 277 |
+
# format as " <PRE> <SUF>{suf} <MID> {pre}"
|
| 278 |
+
return [self.prefix_token, self.suffix_token] + suffix_tokens + [self.middle_token] + prefix_tokens
|
| 279 |
+
else:
|
| 280 |
+
# format as " <PRE> {pre} <SUF>{suf} <MID>"
|
| 281 |
+
return [self.prefix_token] + prefix_tokens + [self.suffix_token] + suffix_tokens + [self.middle_token]
|
| 282 |
+
|
| 283 |
+
def _tokenize(self, text, **kwargs):
|
| 284 |
+
"""
|
| 285 |
+
Returns a tokenized string.
|
| 286 |
+
|
| 287 |
+
We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
|
| 288 |
+
SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
|
| 289 |
+
`['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
|
| 290 |
+
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
| 291 |
+
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
| 292 |
+
"""
|
| 293 |
+
tokens = self.sp_model.encode(text, out_type=str)
|
| 294 |
+
if not text.startswith((SPIECE_UNDERLINE, " ")):
|
| 295 |
+
return tokens
|
| 296 |
+
# 1. Encode string + prefix ex: "<unk> Hey"
|
| 297 |
+
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
| 298 |
+
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
| 299 |
+
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
|
| 300 |
+
|
| 301 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
|
| 302 |
+
def _convert_token_to_id(self, token):
|
| 303 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 304 |
+
return self.sp_model.piece_to_id(token)
|
| 305 |
+
|
| 306 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
|
| 307 |
+
def _convert_id_to_token(self, index):
|
| 308 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 309 |
+
token = self.sp_model.IdToPiece(index)
|
| 310 |
+
return token
|
| 311 |
+
|
| 312 |
+
def convert_tokens_to_string(self, tokens):
|
| 313 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 314 |
+
# since we manually add the prefix space, we have to remove it when decoding
|
| 315 |
+
if tokens[0].startswith(SPIECE_UNDERLINE):
|
| 316 |
+
tokens[0] = tokens[0][1:]
|
| 317 |
+
|
| 318 |
+
current_sub_tokens = []
|
| 319 |
+
out_string = ""
|
| 320 |
+
for _, token in enumerate(tokens):
|
| 321 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 322 |
+
if token in self.all_special_tokens:
|
| 323 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 324 |
+
current_sub_tokens = []
|
| 325 |
+
else:
|
| 326 |
+
current_sub_tokens.append(token)
|
| 327 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 328 |
+
return out_string
|
| 329 |
+
|
| 330 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary
|
| 331 |
+
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 332 |
+
"""
|
| 333 |
+
Save the vocabulary and special tokens file to a directory.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
save_directory (`str`):
|
| 337 |
+
The directory in which to save the vocabulary.
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
`Tuple(str)`: Paths to the files saved.
|
| 341 |
+
"""
|
| 342 |
+
if not os.path.isdir(save_directory):
|
| 343 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 344 |
+
return
|
| 345 |
+
out_vocab_file = os.path.join(
|
| 346 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 350 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 351 |
+
elif not os.path.isfile(self.vocab_file):
|
| 352 |
+
with open(out_vocab_file, "wb") as fi:
|
| 353 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 354 |
+
fi.write(content_spiece_model)
|
| 355 |
+
|
| 356 |
+
return (out_vocab_file,)
|
| 357 |
+
|
| 358 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
|
| 359 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 360 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 361 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 362 |
+
|
| 363 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
| 364 |
+
|
| 365 |
+
if token_ids_1 is not None:
|
| 366 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
| 367 |
+
|
| 368 |
+
return output
|
| 369 |
+
|
| 370 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
|
| 371 |
+
def get_special_tokens_mask(
|
| 372 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 373 |
+
) -> List[int]:
|
| 374 |
+
"""
|
| 375 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 376 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
token_ids_0 (`List[int]`):
|
| 380 |
+
List of IDs.
|
| 381 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 382 |
+
Optional second list of IDs for sequence pairs.
|
| 383 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 384 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 388 |
+
"""
|
| 389 |
+
if already_has_special_tokens:
|
| 390 |
+
return super().get_special_tokens_mask(
|
| 391 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
bos_token_id = [1] if self.add_bos_token else []
|
| 395 |
+
eos_token_id = [1] if self.add_eos_token else []
|
| 396 |
+
|
| 397 |
+
if token_ids_1 is None:
|
| 398 |
+
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
| 399 |
+
return (
|
| 400 |
+
bos_token_id
|
| 401 |
+
+ ([0] * len(token_ids_0))
|
| 402 |
+
+ eos_token_id
|
| 403 |
+
+ bos_token_id
|
| 404 |
+
+ ([0] * len(token_ids_1))
|
| 405 |
+
+ eos_token_id
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
|
| 409 |
+
def create_token_type_ids_from_sequences(
|
| 410 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 411 |
+
) -> List[int]:
|
| 412 |
+
"""
|
| 413 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
| 414 |
+
sequence pair mask has the following format:
|
| 415 |
+
|
| 416 |
+
```
|
| 417 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 418 |
+
| first sequence | second sequence |
|
| 419 |
+
```
|
| 420 |
+
|
| 421 |
+
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
token_ids_0 (`List[int]`):
|
| 425 |
+
List of ids.
|
| 426 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 427 |
+
Optional second list of IDs for sequence pairs.
|
| 428 |
+
|
| 429 |
+
Returns:
|
| 430 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 431 |
+
"""
|
| 432 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 433 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 434 |
+
|
| 435 |
+
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
| 436 |
+
|
| 437 |
+
if token_ids_1 is not None:
|
| 438 |
+
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
| 439 |
+
|
| 440 |
+
return output
|
| 441 |
+
|
| 442 |
+
def __getstate__(self):
|
| 443 |
+
state = self.__dict__.copy()
|
| 444 |
+
state["sp_model"] = None
|
| 445 |
+
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
| 446 |
+
return state
|
| 447 |
+
|
| 448 |
+
def __setstate__(self, d):
|
| 449 |
+
self.__dict__ = d
|
| 450 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 451 |
+
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
__all__ = ["CodeLlamaTokenizer"]
|
docs/transformers/build/lib/transformers/models/code_llama/tokenization_code_llama_fast.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import os
|
| 16 |
+
from shutil import copyfile
|
| 17 |
+
from typing import List, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
from tokenizers import normalizers, processors
|
| 20 |
+
|
| 21 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 22 |
+
from ...utils import is_sentencepiece_available, logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if is_sentencepiece_available():
|
| 26 |
+
from .tokenization_code_llama import CodeLlamaTokenizer
|
| 27 |
+
else:
|
| 28 |
+
CodeLlamaTokenizer = None
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
|
| 32 |
+
|
| 33 |
+
SPIECE_UNDERLINE = "▁"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
| 37 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 38 |
+
|
| 39 |
+
# fmt: off
|
| 40 |
+
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
|
| 41 |
+
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
| 42 |
+
that your responses are socially unbiased and positive in nature.
|
| 43 |
+
|
| 44 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
| 45 |
+
correct. If you don't know the answer to a question, please don't share false information."""
|
| 46 |
+
# fmt: on
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
|
| 50 |
+
"""
|
| 51 |
+
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 52 |
+
|
| 53 |
+
This uses notably ByteFallback and no normalization.
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
>>> from transformers import CodeLlamaTokenizerFast
|
| 57 |
+
|
| 58 |
+
>>> tokenizer = CodeLlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
| 59 |
+
>>> tokenizer.encode("Hello this is a test")
|
| 60 |
+
[1, 15043, 445, 338, 263, 1243]
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
|
| 64 |
+
call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
|
| 65 |
+
values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
|
| 66 |
+
[post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
| 70 |
+
refer to this superclass for more information regarding those methods. The default configuration match that of
|
| 71 |
+
[meta-llama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
|
| 72 |
+
which supports prompt infilling.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
vocab_file (`str`, *optional*):
|
| 76 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
|
| 77 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 78 |
+
tokenizer_file (`str`, *optional*):
|
| 79 |
+
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
|
| 80 |
+
contains everything needed to load the tokenizer.
|
| 81 |
+
clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
|
| 82 |
+
Whether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
|
| 83 |
+
spaces.
|
| 84 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 85 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 86 |
+
token instead.
|
| 87 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 88 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 89 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 90 |
+
The end of sequence token.
|
| 91 |
+
prefix_token (`str`, *optional*, defaults to `"▁<PRE>"`):
|
| 92 |
+
Prefix token used for infilling.
|
| 93 |
+
middle_token (`str`, *optional*, defaults to `"▁<MID>"`):
|
| 94 |
+
Middle token used for infilling.
|
| 95 |
+
suffix_token (`str`, *optional*, defaults to `"▁<SUF>"`):
|
| 96 |
+
Suffix token used for infilling.
|
| 97 |
+
eot_token (`str`, *optional*, defaults to `"▁<EOT>"`):
|
| 98 |
+
End of text token used for infilling.
|
| 99 |
+
fill_token (`str`, *optional*, defaults to `"<FILL_ME>"`):
|
| 100 |
+
The token used to split the input between the prefix and suffix.
|
| 101 |
+
additional_special_tokens (`List[str]`, *optional*):
|
| 102 |
+
Additional special tokens used by the tokenizer.
|
| 103 |
+
add_bos_token (`bool`, *optional*, defaults to `True`):
|
| 104 |
+
Whether to add a beginning of sequence token at the start of sequences.
|
| 105 |
+
add_eos_token (`bool`, *optional*, defaults to `False`):
|
| 106 |
+
Whether to add an end of sequence token at the end of sequences.
|
| 107 |
+
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
|
| 108 |
+
Whether or not the default system prompt for Llama should be used.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 112 |
+
slow_tokenizer_class = CodeLlamaTokenizer
|
| 113 |
+
padding_side = "left"
|
| 114 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
vocab_file=None,
|
| 119 |
+
tokenizer_file=None,
|
| 120 |
+
clean_up_tokenization_spaces=False,
|
| 121 |
+
unk_token="<unk>",
|
| 122 |
+
bos_token="<s>",
|
| 123 |
+
eos_token="</s>",
|
| 124 |
+
prefix_token="▁<PRE>",
|
| 125 |
+
middle_token="▁<MID>",
|
| 126 |
+
suffix_token="▁<SUF>",
|
| 127 |
+
eot_token="▁<EOT>",
|
| 128 |
+
fill_token="<FILL_ME>",
|
| 129 |
+
additional_special_tokens=None,
|
| 130 |
+
add_bos_token=True,
|
| 131 |
+
add_eos_token=False,
|
| 132 |
+
use_default_system_prompt=False,
|
| 133 |
+
**kwargs,
|
| 134 |
+
):
|
| 135 |
+
# mark tokens special to skip them
|
| 136 |
+
additional_special_tokens = additional_special_tokens or []
|
| 137 |
+
for token in [prefix_token, middle_token, suffix_token, eot_token]:
|
| 138 |
+
additional_special_tokens += [token] if token is not None else []
|
| 139 |
+
self.use_default_system_prompt = use_default_system_prompt
|
| 140 |
+
|
| 141 |
+
super().__init__(
|
| 142 |
+
vocab_file=vocab_file,
|
| 143 |
+
tokenizer_file=tokenizer_file,
|
| 144 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 145 |
+
additional_special_tokens=additional_special_tokens,
|
| 146 |
+
unk_token=unk_token,
|
| 147 |
+
bos_token=bos_token,
|
| 148 |
+
eos_token=eos_token,
|
| 149 |
+
add_bos_token=add_bos_token,
|
| 150 |
+
add_eos_token=add_eos_token,
|
| 151 |
+
prefix_token=prefix_token,
|
| 152 |
+
middle_token=middle_token,
|
| 153 |
+
suffix_token=suffix_token,
|
| 154 |
+
eot_token=eot_token,
|
| 155 |
+
fill_token=fill_token,
|
| 156 |
+
use_default_system_prompt=use_default_system_prompt,
|
| 157 |
+
**kwargs,
|
| 158 |
+
)
|
| 159 |
+
self._add_bos_token = add_bos_token
|
| 160 |
+
self._add_eos_token = add_eos_token
|
| 161 |
+
self.update_post_processor()
|
| 162 |
+
|
| 163 |
+
self.vocab_file = vocab_file
|
| 164 |
+
|
| 165 |
+
self._prefix_token = prefix_token
|
| 166 |
+
self._middle_token = middle_token
|
| 167 |
+
self._suffix_token = suffix_token
|
| 168 |
+
self._eot_token = eot_token
|
| 169 |
+
self.fill_token = fill_token
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def can_save_slow_tokenizer(self) -> bool:
|
| 173 |
+
return os.path.isfile(self.vocab_file) if self.vocab_file else False
|
| 174 |
+
|
| 175 |
+
# Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor
|
| 176 |
+
def update_post_processor(self):
|
| 177 |
+
"""
|
| 178 |
+
Updates the underlying post processor with the current `bos_token` and `eos_token`.
|
| 179 |
+
"""
|
| 180 |
+
bos = self.bos_token
|
| 181 |
+
bos_token_id = self.bos_token_id
|
| 182 |
+
if bos is None and self.add_bos_token:
|
| 183 |
+
raise ValueError("add_bos_token = True but bos_token = None")
|
| 184 |
+
|
| 185 |
+
eos = self.eos_token
|
| 186 |
+
eos_token_id = self.eos_token_id
|
| 187 |
+
if eos is None and self.add_eos_token:
|
| 188 |
+
raise ValueError("add_eos_token = True but eos_token = None")
|
| 189 |
+
|
| 190 |
+
single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
|
| 191 |
+
pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
|
| 192 |
+
|
| 193 |
+
special_tokens = []
|
| 194 |
+
if self.add_bos_token:
|
| 195 |
+
special_tokens.append((bos, bos_token_id))
|
| 196 |
+
if self.add_eos_token:
|
| 197 |
+
special_tokens.append((eos, eos_token_id))
|
| 198 |
+
self._tokenizer.post_processor = processors.TemplateProcessing(
|
| 199 |
+
single=single, pair=pair, special_tokens=special_tokens
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
@property
|
| 203 |
+
def prefix_token(self):
|
| 204 |
+
return self._prefix_token
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def prefix_id(self):
|
| 208 |
+
if self._prefix_token is None:
|
| 209 |
+
return None
|
| 210 |
+
return self.convert_tokens_to_ids(self.prefix_token)
|
| 211 |
+
|
| 212 |
+
@property
|
| 213 |
+
def middle_token(self):
|
| 214 |
+
return self._middle_token
|
| 215 |
+
|
| 216 |
+
@property
|
| 217 |
+
def middle_id(self):
|
| 218 |
+
if self._middle_token is None:
|
| 219 |
+
return None
|
| 220 |
+
return self.convert_tokens_to_ids(self.middle_token)
|
| 221 |
+
|
| 222 |
+
@property
|
| 223 |
+
def suffix_token(self):
|
| 224 |
+
return self._suffix_token
|
| 225 |
+
|
| 226 |
+
@property
|
| 227 |
+
def suffix_id(self):
|
| 228 |
+
if self._suffix_token is None:
|
| 229 |
+
return None
|
| 230 |
+
return self.convert_tokens_to_ids(self.suffix_token)
|
| 231 |
+
|
| 232 |
+
@property
|
| 233 |
+
def eot_id(self):
|
| 234 |
+
if self._eot_token is None:
|
| 235 |
+
return None
|
| 236 |
+
return self.convert_tokens_to_ids(self.eot_token)
|
| 237 |
+
|
| 238 |
+
@property
|
| 239 |
+
def eot_token(self):
|
| 240 |
+
return self._eot_token
|
| 241 |
+
|
| 242 |
+
@property
|
| 243 |
+
def add_eos_token(self):
|
| 244 |
+
return self._add_eos_token
|
| 245 |
+
|
| 246 |
+
@property
|
| 247 |
+
def add_bos_token(self):
|
| 248 |
+
return self._add_bos_token
|
| 249 |
+
|
| 250 |
+
@add_eos_token.setter
|
| 251 |
+
def add_eos_token(self, value):
|
| 252 |
+
self._add_eos_token = value
|
| 253 |
+
self.update_post_processor()
|
| 254 |
+
|
| 255 |
+
@add_bos_token.setter
|
| 256 |
+
def add_bos_token(self, value):
|
| 257 |
+
self._add_bos_token = value
|
| 258 |
+
self.update_post_processor()
|
| 259 |
+
|
| 260 |
+
def set_infilling_processor(self, reset, suffix_first=False, add_special_tokens=True):
|
| 261 |
+
"""
|
| 262 |
+
Updates the normalizer to make sure the prompt format for `infilling` is respected. The infilling format is the
|
| 263 |
+
following: if suffix_first
|
| 264 |
+
" <PRE> <SUF>{suf} <MID> {pre}"
|
| 265 |
+
else:
|
| 266 |
+
" <PRE> {pre} <SUF>{suf} <MID>"
|
| 267 |
+
|
| 268 |
+
If `reset` is set to `True`, the `normalizer` and `post_processor` are reset to their "normal" behaviour, which
|
| 269 |
+
is to add a prefix space for the normalizer, and add a `bos_token` to the input text for the `post_processor`.
|
| 270 |
+
"""
|
| 271 |
+
if reset:
|
| 272 |
+
self._tokenizer.normalizer = normalizers.Sequence(
|
| 273 |
+
[
|
| 274 |
+
normalizers.Prepend(prepend="▁"),
|
| 275 |
+
normalizers.Replace(pattern=" ", content="▁"),
|
| 276 |
+
]
|
| 277 |
+
)
|
| 278 |
+
self.update_post_processor()
|
| 279 |
+
return
|
| 280 |
+
|
| 281 |
+
self._tokenizer.normalizer = normalizers.Replace(pattern=" ", content="▁")
|
| 282 |
+
pair = [self.bos_token] if self.add_bos_token and add_special_tokens else []
|
| 283 |
+
special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else []
|
| 284 |
+
if suffix_first:
|
| 285 |
+
# format as " <PRE> <SUF>{suf} <MID> {pre}"
|
| 286 |
+
pair += [self.prefix_token, self.suffix_token, "$B", self.middle_token, "$A"]
|
| 287 |
+
special_tokens += [
|
| 288 |
+
(self.prefix_token, self.prefix_id),
|
| 289 |
+
(self.suffix_token, self.suffix_id),
|
| 290 |
+
(self.middle_token, self.middle_id),
|
| 291 |
+
]
|
| 292 |
+
else:
|
| 293 |
+
# format as " <PRE> {pre} <SUF>{suf} <MID>"
|
| 294 |
+
pair += [self.prefix_token, "$A", self.suffix_token, "$B", self.middle_token]
|
| 295 |
+
special_tokens += [
|
| 296 |
+
(self.prefix_token, self.prefix_id),
|
| 297 |
+
(self.suffix_token, self.suffix_id),
|
| 298 |
+
(self.middle_token, self.middle_id),
|
| 299 |
+
]
|
| 300 |
+
|
| 301 |
+
if self.add_eos_token and add_special_tokens:
|
| 302 |
+
pair += [self.eos_token]
|
| 303 |
+
special_tokens += [(self.eos_token, self.eos_token_id)]
|
| 304 |
+
self._tokenizer.post_processor = processors.TemplateProcessing(
|
| 305 |
+
single="$A", pair=pair, special_tokens=special_tokens
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs):
|
| 309 |
+
# hack to make sure the input is pre-process but outside rust
|
| 310 |
+
text_pair = kwargs.pop("suffix", text_pair)
|
| 311 |
+
if self.fill_token is not None and self.fill_token in text and text_pair is None:
|
| 312 |
+
text, text_pair = text.split(self.fill_token)
|
| 313 |
+
|
| 314 |
+
if text_pair is None or len(text_pair) < 1:
|
| 315 |
+
return super().encode_plus(text, text_pair, add_special_tokens=add_special_tokens, **kwargs)
|
| 316 |
+
|
| 317 |
+
if None in (self.prefix_id, self.middle_id, self.suffix_id):
|
| 318 |
+
raise ValueError(
|
| 319 |
+
"Then input includes a `prefix` and a `suffix` used for the infilling task,"
|
| 320 |
+
" the `prefix_id, middle_id, suffix_id` must all be initialized. Current"
|
| 321 |
+
f" values : {self.prefix_id, self.middle_id, self.suffix_id}"
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
self.set_infilling_processor(False, suffix_first=suffix_first, add_special_tokens=add_special_tokens)
|
| 325 |
+
tokens = super().encode_plus(" " + text, text_pair=text_pair, add_special_tokens=True, **kwargs)
|
| 326 |
+
self.set_infilling_processor(True)
|
| 327 |
+
return tokens
|
| 328 |
+
|
| 329 |
+
# Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary
|
| 330 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 331 |
+
if not self.can_save_slow_tokenizer:
|
| 332 |
+
raise ValueError(
|
| 333 |
+
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
| 334 |
+
"tokenizer."
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if not os.path.isdir(save_directory):
|
| 338 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 339 |
+
return
|
| 340 |
+
out_vocab_file = os.path.join(
|
| 341 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
| 345 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 346 |
+
|
| 347 |
+
return (out_vocab_file,)
|
| 348 |
+
|
| 349 |
+
def build_inputs_with_special_tokens(
|
| 350 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 351 |
+
) -> List[int]:
|
| 352 |
+
"""
|
| 353 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 354 |
+
adding special tokens. The special tokens depend on calling set_lang.
|
| 355 |
+
|
| 356 |
+
An NLLB sequence has the following format, where `X` represents the sequence:
|
| 357 |
+
|
| 358 |
+
- `input_ids` (for encoder) `X [eos, src_lang_code]`
|
| 359 |
+
- `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
|
| 360 |
+
|
| 361 |
+
BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
|
| 362 |
+
separator.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
token_ids_0 (`List[int]`):
|
| 366 |
+
List of IDs to which the special tokens will be added.
|
| 367 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 368 |
+
Optional second list of IDs for sequence pairs.
|
| 369 |
+
|
| 370 |
+
Returns:
|
| 371 |
+
`List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 372 |
+
"""
|
| 373 |
+
if token_ids_1 is None:
|
| 374 |
+
return self.bos_token_id + token_ids_0 + self.eos_token_id
|
| 375 |
+
return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
__all__ = ["CodeLlamaTokenizerFast"]
|
docs/transformers/build/lib/transformers/models/codegen/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_codegen import *
|
| 22 |
+
from .modeling_codegen import *
|
| 23 |
+
from .tokenization_codegen import *
|
| 24 |
+
from .tokenization_codegen_fast import *
|
| 25 |
+
else:
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
_file = globals()["__file__"]
|
| 29 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/codegen/configuration_codegen.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Salesforce authors, The EleutherAI, and HuggingFace Teams. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""CodeGen model configuration"""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from typing import Any, List, Mapping, Optional
|
| 19 |
+
|
| 20 |
+
from ... import PreTrainedTokenizer, TensorType, is_torch_available
|
| 21 |
+
from ...configuration_utils import PretrainedConfig
|
| 22 |
+
from ...onnx import OnnxConfigWithPast, PatchingSpec
|
| 23 |
+
from ...utils import logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CodeGenConfig(PretrainedConfig):
|
| 30 |
+
r"""
|
| 31 |
+
This is the configuration class to store the configuration of a [`CodeGenModel`]. It is used to instantiate a
|
| 32 |
+
CodeGen model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 33 |
+
with the defaults will yield a similar configuration to that of the CodeGen
|
| 34 |
+
[Salesforce/codegen-2B-mono](https://huggingface.co/Salesforce/codegen-2B-mono) architecture. Configuration objects
|
| 35 |
+
inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from
|
| 36 |
+
[`PretrainedConfig`] for more information.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
vocab_size (`int`, *optional*, defaults to 50400):
|
| 40 |
+
Vocabulary size of the CodeGen model. Defines the number of different tokens that can be represented by the
|
| 41 |
+
`inputs_ids` passed when calling [`CodeGenModel`].
|
| 42 |
+
n_positions (`int`, *optional*, defaults to 2048):
|
| 43 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 44 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 45 |
+
n_ctx (`int`, *optional*, defaults to 2048):
|
| 46 |
+
This attribute is used in `CodeGenModel.__init__` without any real effect.
|
| 47 |
+
n_embd (`int`, *optional*, defaults to 4096):
|
| 48 |
+
Dimensionality of the embeddings and hidden states.
|
| 49 |
+
n_layer (`int`, *optional*, defaults to 28):
|
| 50 |
+
Number of hidden layers in the Transformer encoder.
|
| 51 |
+
n_head (`int`, *optional*, defaults to 16):
|
| 52 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 53 |
+
rotary_dim (`int`, *optional*, defaults to 64):
|
| 54 |
+
Number of dimensions in the embedding that Rotary Position Embedding is applied to.
|
| 55 |
+
n_inner (`int`, *optional*):
|
| 56 |
+
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
|
| 57 |
+
activation_function (`str`, *optional*, defaults to `"gelu_new"`):
|
| 58 |
+
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
|
| 59 |
+
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
| 60 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 61 |
+
embd_pdrop (`int`, *optional*, defaults to 0.0):
|
| 62 |
+
The dropout ratio for the embeddings.
|
| 63 |
+
attn_pdrop (`float`, *optional*, defaults to 0.0):
|
| 64 |
+
The dropout ratio for the attention.
|
| 65 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
| 66 |
+
The epsilon to use in the layer normalization layers.
|
| 67 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 68 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 69 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 70 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 71 |
+
bos_token_id (`int`, *optional*, defaults to 50256):
|
| 72 |
+
Beginning of stream token id.
|
| 73 |
+
eos_token_id (`int`, *optional*, defaults to 50256):
|
| 74 |
+
End of stream token id.
|
| 75 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 76 |
+
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
|
| 77 |
+
model has a output word embedding layer.
|
| 78 |
+
|
| 79 |
+
Example:
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
>>> from transformers import CodeGenConfig, CodeGenModel
|
| 83 |
+
|
| 84 |
+
>>> # Initializing a CodeGen 6B configuration
|
| 85 |
+
>>> configuration = CodeGenConfig()
|
| 86 |
+
|
| 87 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 88 |
+
>>> model = CodeGenModel(configuration)
|
| 89 |
+
|
| 90 |
+
>>> # Accessing the model configuration
|
| 91 |
+
>>> configuration = model.config
|
| 92 |
+
```"""
|
| 93 |
+
|
| 94 |
+
model_type = "codegen"
|
| 95 |
+
attribute_map = {
|
| 96 |
+
"max_position_embeddings": "n_positions",
|
| 97 |
+
"hidden_size": "n_embd",
|
| 98 |
+
"num_attention_heads": "n_head",
|
| 99 |
+
"num_hidden_layers": "n_layer",
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
vocab_size=50400,
|
| 105 |
+
n_positions=2048,
|
| 106 |
+
n_ctx=2048,
|
| 107 |
+
n_embd=4096,
|
| 108 |
+
n_layer=28,
|
| 109 |
+
n_head=16,
|
| 110 |
+
rotary_dim=64,
|
| 111 |
+
n_inner=None,
|
| 112 |
+
activation_function="gelu_new",
|
| 113 |
+
resid_pdrop=0.0,
|
| 114 |
+
embd_pdrop=0.0,
|
| 115 |
+
attn_pdrop=0.0,
|
| 116 |
+
layer_norm_epsilon=1e-5,
|
| 117 |
+
initializer_range=0.02,
|
| 118 |
+
use_cache=True,
|
| 119 |
+
bos_token_id=50256,
|
| 120 |
+
eos_token_id=50256,
|
| 121 |
+
tie_word_embeddings=False,
|
| 122 |
+
**kwargs,
|
| 123 |
+
):
|
| 124 |
+
self.vocab_size = vocab_size
|
| 125 |
+
self.n_ctx = n_ctx
|
| 126 |
+
self.n_positions = n_positions
|
| 127 |
+
self.n_embd = n_embd
|
| 128 |
+
self.n_layer = n_layer
|
| 129 |
+
self.n_head = n_head
|
| 130 |
+
self.n_inner = n_inner
|
| 131 |
+
self.rotary_dim = rotary_dim
|
| 132 |
+
self.activation_function = activation_function
|
| 133 |
+
self.resid_pdrop = resid_pdrop
|
| 134 |
+
self.embd_pdrop = embd_pdrop
|
| 135 |
+
self.attn_pdrop = attn_pdrop
|
| 136 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 137 |
+
self.initializer_range = initializer_range
|
| 138 |
+
self.use_cache = use_cache
|
| 139 |
+
|
| 140 |
+
self.bos_token_id = bos_token_id
|
| 141 |
+
self.eos_token_id = eos_token_id
|
| 142 |
+
|
| 143 |
+
super().__init__(
|
| 144 |
+
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig
|
| 149 |
+
class CodeGenOnnxConfig(OnnxConfigWithPast):
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
config: PretrainedConfig,
|
| 153 |
+
task: str = "default",
|
| 154 |
+
patching_specs: List[PatchingSpec] = None,
|
| 155 |
+
use_past: bool = False,
|
| 156 |
+
):
|
| 157 |
+
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
|
| 158 |
+
if not getattr(self._config, "pad_token_id", None):
|
| 159 |
+
# TODO: how to do that better?
|
| 160 |
+
self._config.pad_token_id = 0
|
| 161 |
+
|
| 162 |
+
@property
|
| 163 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 164 |
+
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
| 165 |
+
if self.use_past:
|
| 166 |
+
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
| 167 |
+
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
|
| 168 |
+
else:
|
| 169 |
+
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
| 170 |
+
|
| 171 |
+
return common_inputs
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def num_layers(self) -> int:
|
| 175 |
+
return self._config.n_layer
|
| 176 |
+
|
| 177 |
+
@property
|
| 178 |
+
def num_attention_heads(self) -> int:
|
| 179 |
+
return self._config.n_head
|
| 180 |
+
|
| 181 |
+
def generate_dummy_inputs(
|
| 182 |
+
self,
|
| 183 |
+
tokenizer: PreTrainedTokenizer,
|
| 184 |
+
batch_size: int = -1,
|
| 185 |
+
seq_length: int = -1,
|
| 186 |
+
is_pair: bool = False,
|
| 187 |
+
framework: Optional[TensorType] = None,
|
| 188 |
+
) -> Mapping[str, Any]:
|
| 189 |
+
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
| 190 |
+
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# We need to order the input in the way they appears in the forward()
|
| 194 |
+
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
| 195 |
+
|
| 196 |
+
# Need to add the past_keys
|
| 197 |
+
if self.use_past:
|
| 198 |
+
if not is_torch_available():
|
| 199 |
+
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
| 200 |
+
else:
|
| 201 |
+
import torch
|
| 202 |
+
|
| 203 |
+
batch, seqlen = common_inputs["input_ids"].shape
|
| 204 |
+
# Not using the same length for past_key_values
|
| 205 |
+
past_key_values_length = seqlen + 2
|
| 206 |
+
past_shape = (
|
| 207 |
+
batch,
|
| 208 |
+
self.num_attention_heads,
|
| 209 |
+
past_key_values_length,
|
| 210 |
+
self._config.hidden_size // self.num_attention_heads,
|
| 211 |
+
)
|
| 212 |
+
ordered_inputs["past_key_values"] = [
|
| 213 |
+
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
| 217 |
+
if self.use_past:
|
| 218 |
+
mask_dtype = ordered_inputs["attention_mask"].dtype
|
| 219 |
+
ordered_inputs["attention_mask"] = torch.cat(
|
| 220 |
+
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
return ordered_inputs
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def default_onnx_opset(self) -> int:
|
| 227 |
+
return 13
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
__all__ = ["CodeGenConfig", "CodeGenOnnxConfig"]
|
docs/transformers/build/lib/transformers/models/codegen/modeling_codegen.py
ADDED
|
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Salesforce authors, The EleutherAI, and HuggingFace Teams. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch CodeGen model."""
|
| 16 |
+
|
| 17 |
+
from typing import Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
+
from torch import nn
|
| 22 |
+
|
| 23 |
+
from ...activations import ACT2FN
|
| 24 |
+
from ...cache_utils import Cache, DynamicCache, StaticCache
|
| 25 |
+
from ...generation import GenerationMixin
|
| 26 |
+
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
| 27 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 28 |
+
from ...modeling_utils import PreTrainedModel
|
| 29 |
+
from ...utils import (
|
| 30 |
+
add_code_sample_docstrings,
|
| 31 |
+
add_start_docstrings,
|
| 32 |
+
add_start_docstrings_to_model_forward,
|
| 33 |
+
is_torch_flex_attn_available,
|
| 34 |
+
logging,
|
| 35 |
+
)
|
| 36 |
+
from .configuration_codegen import CodeGenConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if is_torch_flex_attn_available():
|
| 40 |
+
from torch.nn.attention.flex_attention import BlockMask
|
| 41 |
+
|
| 42 |
+
from ...integrations.flex_attention import make_flex_block_causal_mask
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__)
|
| 46 |
+
|
| 47 |
+
_CHECKPOINT_FOR_DOC = "Salesforce/codegen-2B-mono"
|
| 48 |
+
_CONFIG_FOR_DOC = "CodeGenConfig"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
|
| 52 |
+
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
|
| 53 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
|
| 54 |
+
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
|
| 55 |
+
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
|
| 59 |
+
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
|
| 60 |
+
x1 = x[:, :, :, ::2]
|
| 61 |
+
x2 = x[:, :, :, 1::2]
|
| 62 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 63 |
+
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
|
| 67 |
+
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
|
| 69 |
+
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
|
| 70 |
+
return (tensor * cos) + (rotate_every_two(tensor) * sin)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class CodeGenAttention(nn.Module):
|
| 74 |
+
def __init__(self, config, layer_idx=None):
|
| 75 |
+
super().__init__()
|
| 76 |
+
|
| 77 |
+
max_positions = config.max_position_embeddings
|
| 78 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
| 79 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 80 |
+
self.layer_idx = layer_idx
|
| 81 |
+
if layer_idx is None:
|
| 82 |
+
logger.warning_once(
|
| 83 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
| 84 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
| 85 |
+
"when creating this class."
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.embed_dim = config.hidden_size
|
| 89 |
+
self.num_attention_heads = config.num_attention_heads
|
| 90 |
+
self.head_dim = self.embed_dim // self.num_attention_heads
|
| 91 |
+
if self.head_dim * self.num_attention_heads != self.embed_dim:
|
| 92 |
+
raise ValueError(
|
| 93 |
+
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
|
| 94 |
+
f" `num_attention_heads`: {self.num_attention_heads})."
|
| 95 |
+
)
|
| 96 |
+
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
|
| 97 |
+
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
|
| 98 |
+
|
| 99 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
| 100 |
+
self.rotary_dim = config.rotary_dim
|
| 101 |
+
pos_embd_dim = self.rotary_dim or self.embed_dim
|
| 102 |
+
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
|
| 103 |
+
|
| 104 |
+
def _split_heads(self, x, n_head, dim_head, mp_num):
|
| 105 |
+
reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
|
| 106 |
+
reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
|
| 107 |
+
return reshaped
|
| 108 |
+
|
| 109 |
+
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
|
| 110 |
+
"""
|
| 111 |
+
Merges attn_head_size dim and num_attn_heads dim into n_ctx
|
| 112 |
+
"""
|
| 113 |
+
if len(tensor.shape) == 5:
|
| 114 |
+
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
|
| 115 |
+
elif len(tensor.shape) == 4:
|
| 116 |
+
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
|
| 119 |
+
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
|
| 120 |
+
return tensor.view(new_shape)
|
| 121 |
+
|
| 122 |
+
def _attn(
|
| 123 |
+
self,
|
| 124 |
+
query,
|
| 125 |
+
key,
|
| 126 |
+
value,
|
| 127 |
+
attention_mask=None,
|
| 128 |
+
head_mask=None,
|
| 129 |
+
):
|
| 130 |
+
# Keep the attention weights computation in fp32 to avoid overflow issues
|
| 131 |
+
query = query.to(torch.float32)
|
| 132 |
+
key = key.to(torch.float32)
|
| 133 |
+
|
| 134 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
| 135 |
+
|
| 136 |
+
if attention_mask is not None:
|
| 137 |
+
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
| 138 |
+
attn_weights += causal_mask
|
| 139 |
+
|
| 140 |
+
attn_weights = attn_weights / self.scale_attn
|
| 141 |
+
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
| 142 |
+
attn_weights = attn_weights.to(value.dtype)
|
| 143 |
+
attn_weights = self.attn_dropout(attn_weights)
|
| 144 |
+
|
| 145 |
+
# Mask heads if we want to
|
| 146 |
+
if head_mask is not None:
|
| 147 |
+
attn_weights = attn_weights * head_mask
|
| 148 |
+
|
| 149 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 150 |
+
|
| 151 |
+
return attn_output, attn_weights
|
| 152 |
+
|
| 153 |
+
def forward(
|
| 154 |
+
self,
|
| 155 |
+
hidden_states: Optional[torch.FloatTensor],
|
| 156 |
+
layer_past: Optional[Cache] = None,
|
| 157 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 158 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 159 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 160 |
+
use_cache: Optional[bool] = False,
|
| 161 |
+
output_attentions: Optional[bool] = False,
|
| 162 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 163 |
+
) -> Union[
|
| 164 |
+
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
| 165 |
+
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
| 166 |
+
]:
|
| 167 |
+
qkv = self.qkv_proj(hidden_states)
|
| 168 |
+
# TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
|
| 169 |
+
mp_num = 4
|
| 170 |
+
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
|
| 171 |
+
|
| 172 |
+
local_dim = self.head_dim * self.num_attention_heads // mp_num
|
| 173 |
+
query, value, key = torch.split(qkv_split, local_dim, dim=-1)
|
| 174 |
+
query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
|
| 175 |
+
key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
|
| 176 |
+
|
| 177 |
+
value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
|
| 178 |
+
value = value.permute(0, 2, 1, 3)
|
| 179 |
+
|
| 180 |
+
embed_positions = self.embed_positions
|
| 181 |
+
if embed_positions.device != position_ids.device:
|
| 182 |
+
embed_positions = embed_positions.to(position_ids.device)
|
| 183 |
+
self.embed_positions = embed_positions
|
| 184 |
+
|
| 185 |
+
sincos = embed_positions[position_ids]
|
| 186 |
+
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
|
| 187 |
+
|
| 188 |
+
if self.rotary_dim is not None:
|
| 189 |
+
k_rot = key[:, :, :, : self.rotary_dim]
|
| 190 |
+
k_pass = key[:, :, :, self.rotary_dim :]
|
| 191 |
+
|
| 192 |
+
q_rot = query[:, :, :, : self.rotary_dim]
|
| 193 |
+
q_pass = query[:, :, :, self.rotary_dim :]
|
| 194 |
+
|
| 195 |
+
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
|
| 196 |
+
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
|
| 197 |
+
|
| 198 |
+
key = torch.cat([k_rot, k_pass], dim=-1)
|
| 199 |
+
query = torch.cat([q_rot, q_pass], dim=-1)
|
| 200 |
+
else:
|
| 201 |
+
key = apply_rotary_pos_emb(key, sin, cos)
|
| 202 |
+
query = apply_rotary_pos_emb(query, sin, cos)
|
| 203 |
+
|
| 204 |
+
key = key.permute(0, 2, 1, 3)
|
| 205 |
+
query = query.permute(0, 2, 1, 3)
|
| 206 |
+
|
| 207 |
+
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
|
| 208 |
+
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
|
| 209 |
+
if layer_past is not None:
|
| 210 |
+
cache_kwargs = {
|
| 211 |
+
"sin": sin,
|
| 212 |
+
"cos": cos,
|
| 213 |
+
"partial_rotation_size": self.rotary_dim,
|
| 214 |
+
"cache_position": cache_position,
|
| 215 |
+
}
|
| 216 |
+
key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)
|
| 217 |
+
|
| 218 |
+
# compute self-attention: V x Softmax(QK^T)
|
| 219 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
| 220 |
+
|
| 221 |
+
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
|
| 222 |
+
attn_output = self.out_proj(attn_output)
|
| 223 |
+
attn_output = self.resid_dropout(attn_output)
|
| 224 |
+
|
| 225 |
+
outputs = (attn_output, layer_past)
|
| 226 |
+
if output_attentions:
|
| 227 |
+
outputs += (attn_weights,)
|
| 228 |
+
|
| 229 |
+
return outputs # a, present, (attentions)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# Copied from transformers.models.gptj.modeling_gptj.GPTJMLP with GPTJ->CodeGen
|
| 233 |
+
class CodeGenMLP(nn.Module):
|
| 234 |
+
def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
|
| 235 |
+
super().__init__()
|
| 236 |
+
embed_dim = config.n_embd
|
| 237 |
+
|
| 238 |
+
self.fc_in = nn.Linear(embed_dim, intermediate_size)
|
| 239 |
+
self.fc_out = nn.Linear(intermediate_size, embed_dim)
|
| 240 |
+
|
| 241 |
+
self.act = ACT2FN[config.activation_function]
|
| 242 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 243 |
+
|
| 244 |
+
def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
|
| 245 |
+
hidden_states = self.fc_in(hidden_states)
|
| 246 |
+
hidden_states = self.act(hidden_states)
|
| 247 |
+
hidden_states = self.fc_out(hidden_states)
|
| 248 |
+
hidden_states = self.dropout(hidden_states)
|
| 249 |
+
return hidden_states
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
|
| 253 |
+
class CodeGenBlock(nn.Module):
|
| 254 |
+
# Ignore copy
|
| 255 |
+
def __init__(self, config, layer_idx=None):
|
| 256 |
+
super().__init__()
|
| 257 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
|
| 258 |
+
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 259 |
+
self.attn = CodeGenAttention(config, layer_idx)
|
| 260 |
+
self.mlp = CodeGenMLP(inner_dim, config)
|
| 261 |
+
|
| 262 |
+
def forward(
|
| 263 |
+
self,
|
| 264 |
+
hidden_states: Optional[torch.FloatTensor],
|
| 265 |
+
layer_past: Optional[Cache] = None,
|
| 266 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 267 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 268 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 269 |
+
use_cache: Optional[bool] = False,
|
| 270 |
+
output_attentions: Optional[bool] = False,
|
| 271 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 272 |
+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
| 273 |
+
residual = hidden_states
|
| 274 |
+
hidden_states = self.ln_1(hidden_states)
|
| 275 |
+
attn_outputs = self.attn(
|
| 276 |
+
hidden_states=hidden_states,
|
| 277 |
+
layer_past=layer_past,
|
| 278 |
+
attention_mask=attention_mask,
|
| 279 |
+
position_ids=position_ids,
|
| 280 |
+
head_mask=head_mask,
|
| 281 |
+
use_cache=use_cache,
|
| 282 |
+
output_attentions=output_attentions,
|
| 283 |
+
cache_position=cache_position,
|
| 284 |
+
)
|
| 285 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
| 286 |
+
outputs = attn_outputs[1:]
|
| 287 |
+
|
| 288 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
| 289 |
+
hidden_states = attn_output + feed_forward_hidden_states + residual
|
| 290 |
+
|
| 291 |
+
if use_cache:
|
| 292 |
+
outputs = (hidden_states,) + outputs
|
| 293 |
+
else:
|
| 294 |
+
outputs = (hidden_states,) + outputs[1:]
|
| 295 |
+
|
| 296 |
+
return outputs # hidden_states, present, (attentions)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class CodeGenPreTrainedModel(PreTrainedModel):
|
| 300 |
+
"""
|
| 301 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 302 |
+
models.
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
config_class = CodeGenConfig
|
| 306 |
+
base_model_prefix = "transformer"
|
| 307 |
+
supports_gradient_checkpointing = True
|
| 308 |
+
_no_split_modules = ["CodeGenBlock"]
|
| 309 |
+
_skip_keys_device_placement = "past_key_values"
|
| 310 |
+
_supports_cache_class = True
|
| 311 |
+
_supports_quantized_cache = True
|
| 312 |
+
_supports_static_cache = True
|
| 313 |
+
|
| 314 |
+
def __init__(self, *inputs, **kwargs):
|
| 315 |
+
super().__init__(*inputs, **kwargs)
|
| 316 |
+
|
| 317 |
+
def _init_weights(self, module):
|
| 318 |
+
"""Initialize the weights."""
|
| 319 |
+
if isinstance(module, (nn.Linear,)):
|
| 320 |
+
# Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
|
| 321 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 322 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 323 |
+
if module.bias is not None:
|
| 324 |
+
module.bias.data.zero_()
|
| 325 |
+
elif isinstance(module, nn.Embedding):
|
| 326 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 327 |
+
if module.padding_idx is not None:
|
| 328 |
+
module.weight.data[module.padding_idx].zero_()
|
| 329 |
+
elif isinstance(module, nn.LayerNorm):
|
| 330 |
+
module.bias.data.zero_()
|
| 331 |
+
module.weight.data.fill_(1.0)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
CODEGEN_START_DOCSTRING = r"""
|
| 335 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 336 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 337 |
+
behavior.
|
| 338 |
+
|
| 339 |
+
Parameters:
|
| 340 |
+
config ([`CodeGenConfig`]): Model configuration class with all the parameters of the model.
|
| 341 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 342 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 343 |
+
"""
|
| 344 |
+
|
| 345 |
+
CODEGEN_INPUTS_DOCSTRING = r"""
|
| 346 |
+
Args:
|
| 347 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 348 |
+
Indices of input sequence tokens in the vocabulary.
|
| 349 |
+
|
| 350 |
+
Indices can be obtained using [`AutoProcenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 351 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 352 |
+
|
| 353 |
+
[What are input IDs?](../glossary#input-ids)
|
| 354 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 355 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 356 |
+
|
| 357 |
+
- 1 for tokens that are **not masked**,
|
| 358 |
+
- 0 for tokens that are **masked**.
|
| 359 |
+
|
| 360 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 361 |
+
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 362 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 363 |
+
1]`:
|
| 364 |
+
|
| 365 |
+
- 0 corresponds to a *sentence A* token,
|
| 366 |
+
- 1 corresponds to a *sentence B* token.
|
| 367 |
+
|
| 368 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 369 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 370 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 371 |
+
config.n_positions - 1]`.
|
| 372 |
+
|
| 373 |
+
[What are position IDs?](../glossary#position-ids)
|
| 374 |
+
head_mask (`torch.FloatTensor` of shape `(num_attention_heads,)` or `(n_layer, num_attention_heads)`, *optional*):
|
| 375 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 376 |
+
|
| 377 |
+
- 1 indicates the head is **not masked**,
|
| 378 |
+
- 0 indicates the head is **masked**.
|
| 379 |
+
|
| 380 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_dim)`, *optional*):
|
| 381 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 382 |
+
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
| 383 |
+
model's internal embedding lookup matrix.
|
| 384 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 385 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 386 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 387 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 388 |
+
|
| 389 |
+
Two formats are allowed:
|
| 390 |
+
- a [`~cache_utils.Cache`] instance, see our
|
| 391 |
+
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
| 392 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 393 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 394 |
+
cache format.
|
| 395 |
+
|
| 396 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 397 |
+
legacy cache format will be returned.
|
| 398 |
+
|
| 399 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 400 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 401 |
+
of shape `(batch_size, sequence_length)`.
|
| 402 |
+
output_attentions (`bool`, *optional*):
|
| 403 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 404 |
+
tensors for more detail.
|
| 405 |
+
output_hidden_states (`bool`, *optional*):
|
| 406 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 407 |
+
more detail.
|
| 408 |
+
return_dict (`bool`, *optional*):
|
| 409 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 410 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 411 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 412 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 413 |
+
the complete sequence length.
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
@add_start_docstrings(
|
| 418 |
+
"The bare CodeGen Model transformer outputting raw hidden-states without any specific head on top.",
|
| 419 |
+
CODEGEN_START_DOCSTRING,
|
| 420 |
+
)
|
| 421 |
+
class CodeGenModel(CodeGenPreTrainedModel):
|
| 422 |
+
def __init__(self, config):
|
| 423 |
+
super().__init__(config)
|
| 424 |
+
|
| 425 |
+
self.embed_dim = config.n_embd
|
| 426 |
+
self.vocab_size = config.vocab_size
|
| 427 |
+
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
| 428 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 429 |
+
self.h = nn.ModuleList([CodeGenBlock(config, layer_idx=i) for i in range(config.n_layer)])
|
| 430 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
| 431 |
+
self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
|
| 432 |
+
|
| 433 |
+
self.gradient_checkpointing = False
|
| 434 |
+
|
| 435 |
+
# Initialize weights and apply final processing
|
| 436 |
+
self.post_init()
|
| 437 |
+
|
| 438 |
+
def get_input_embeddings(self):
|
| 439 |
+
return self.wte
|
| 440 |
+
|
| 441 |
+
def set_input_embeddings(self, new_embeddings):
|
| 442 |
+
self.wte = new_embeddings
|
| 443 |
+
|
| 444 |
+
@add_start_docstrings_to_model_forward(CODEGEN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 445 |
+
@add_code_sample_docstrings(
|
| 446 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 447 |
+
output_type=BaseModelOutputWithPast,
|
| 448 |
+
config_class=_CONFIG_FOR_DOC,
|
| 449 |
+
)
|
| 450 |
+
def forward(
|
| 451 |
+
self,
|
| 452 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 453 |
+
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
|
| 454 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 455 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 456 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 457 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 458 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 459 |
+
use_cache: Optional[bool] = None,
|
| 460 |
+
output_attentions: Optional[bool] = None,
|
| 461 |
+
output_hidden_states: Optional[bool] = None,
|
| 462 |
+
return_dict: Optional[bool] = None,
|
| 463 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 464 |
+
**kwargs, # NOOP kwargs, for now
|
| 465 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 466 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 467 |
+
output_hidden_states = (
|
| 468 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 469 |
+
)
|
| 470 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 471 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 472 |
+
|
| 473 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 474 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 475 |
+
|
| 476 |
+
if self.gradient_checkpointing and self.training:
|
| 477 |
+
if use_cache:
|
| 478 |
+
logger.warning_once(
|
| 479 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 480 |
+
)
|
| 481 |
+
use_cache = False
|
| 482 |
+
|
| 483 |
+
if inputs_embeds is None:
|
| 484 |
+
inputs_embeds = self.wte(input_ids)
|
| 485 |
+
|
| 486 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
| 487 |
+
return_legacy_cache = False
|
| 488 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
| 489 |
+
return_legacy_cache = True
|
| 490 |
+
if past_key_values is None:
|
| 491 |
+
past_key_values = DynamicCache()
|
| 492 |
+
else:
|
| 493 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 494 |
+
logger.warning_once(
|
| 495 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
| 496 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
| 497 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
seq_length = inputs_embeds.shape[1]
|
| 501 |
+
if cache_position is None:
|
| 502 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 503 |
+
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
|
| 504 |
+
|
| 505 |
+
if position_ids is None:
|
| 506 |
+
position_ids = cache_position.unsqueeze(0)
|
| 507 |
+
|
| 508 |
+
causal_mask = self._update_causal_mask(
|
| 509 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# Prepare head mask if needed
|
| 513 |
+
# 1.0 in head_mask indicate we keep the head
|
| 514 |
+
# attention_probs has shape bsz x num_attention_heads x N x N
|
| 515 |
+
# head_mask has shape n_layer x batch x num_attention_heads x N x N
|
| 516 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
| 517 |
+
hidden_states = inputs_embeds
|
| 518 |
+
|
| 519 |
+
if token_type_ids is not None:
|
| 520 |
+
token_type_ids = token_type_ids.view(-1, seq_length)
|
| 521 |
+
token_type_embeds = self.wte(token_type_ids)
|
| 522 |
+
hidden_states = hidden_states + token_type_embeds
|
| 523 |
+
|
| 524 |
+
hidden_states = self.drop(hidden_states)
|
| 525 |
+
output_shape = (-1, seq_length, hidden_states.size(-1))
|
| 526 |
+
|
| 527 |
+
next_decoder_cache = None
|
| 528 |
+
all_self_attentions = () if output_attentions else None
|
| 529 |
+
all_hidden_states = () if output_hidden_states else None
|
| 530 |
+
for i, block in enumerate(self.h):
|
| 531 |
+
if output_hidden_states:
|
| 532 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 533 |
+
|
| 534 |
+
if self.gradient_checkpointing and self.training:
|
| 535 |
+
outputs = self._gradient_checkpointing_func(
|
| 536 |
+
block.__call__,
|
| 537 |
+
hidden_states,
|
| 538 |
+
None,
|
| 539 |
+
causal_mask,
|
| 540 |
+
position_ids,
|
| 541 |
+
head_mask[i],
|
| 542 |
+
use_cache,
|
| 543 |
+
output_attentions,
|
| 544 |
+
cache_position,
|
| 545 |
+
)
|
| 546 |
+
else:
|
| 547 |
+
outputs = block(
|
| 548 |
+
hidden_states=hidden_states,
|
| 549 |
+
layer_past=past_key_values,
|
| 550 |
+
attention_mask=causal_mask,
|
| 551 |
+
position_ids=position_ids,
|
| 552 |
+
head_mask=head_mask[i],
|
| 553 |
+
use_cache=use_cache,
|
| 554 |
+
output_attentions=output_attentions,
|
| 555 |
+
cache_position=cache_position,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
hidden_states = outputs[0]
|
| 559 |
+
if use_cache is True:
|
| 560 |
+
next_decoder_cache = outputs[1]
|
| 561 |
+
|
| 562 |
+
if output_attentions:
|
| 563 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
| 564 |
+
|
| 565 |
+
hidden_states = self.ln_f(hidden_states)
|
| 566 |
+
|
| 567 |
+
hidden_states = hidden_states.view(output_shape)
|
| 568 |
+
# Add last hidden state
|
| 569 |
+
if output_hidden_states:
|
| 570 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 571 |
+
|
| 572 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 573 |
+
if return_legacy_cache:
|
| 574 |
+
next_cache = next_cache.to_legacy_cache()
|
| 575 |
+
|
| 576 |
+
if not return_dict:
|
| 577 |
+
return tuple(
|
| 578 |
+
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
return BaseModelOutputWithPast(
|
| 582 |
+
last_hidden_state=hidden_states,
|
| 583 |
+
past_key_values=next_cache,
|
| 584 |
+
hidden_states=all_hidden_states,
|
| 585 |
+
attentions=all_self_attentions,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
| 589 |
+
def _update_causal_mask(
|
| 590 |
+
self,
|
| 591 |
+
attention_mask: Union[torch.Tensor, "BlockMask"],
|
| 592 |
+
input_tensor: torch.Tensor,
|
| 593 |
+
cache_position: torch.Tensor,
|
| 594 |
+
past_key_values: Cache,
|
| 595 |
+
output_attentions: bool = False,
|
| 596 |
+
):
|
| 597 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 598 |
+
if attention_mask is not None and (attention_mask == 0.0).any():
|
| 599 |
+
return attention_mask
|
| 600 |
+
return None
|
| 601 |
+
if self.config._attn_implementation == "flex_attention":
|
| 602 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 603 |
+
attention_mask = make_flex_block_causal_mask(attention_mask)
|
| 604 |
+
return attention_mask
|
| 605 |
+
|
| 606 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 607 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 608 |
+
# to infer the attention mask.
|
| 609 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 610 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 611 |
+
|
| 612 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 613 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
| 614 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 615 |
+
attention_mask,
|
| 616 |
+
inputs_embeds=input_tensor,
|
| 617 |
+
past_key_values_length=past_seen_tokens,
|
| 618 |
+
is_training=self.training,
|
| 619 |
+
):
|
| 620 |
+
return None
|
| 621 |
+
|
| 622 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 623 |
+
sequence_length = input_tensor.shape[1]
|
| 624 |
+
if using_static_cache:
|
| 625 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 626 |
+
else:
|
| 627 |
+
target_length = (
|
| 628 |
+
attention_mask.shape[-1]
|
| 629 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 630 |
+
else past_seen_tokens + sequence_length + 1
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 634 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 635 |
+
attention_mask,
|
| 636 |
+
sequence_length=sequence_length,
|
| 637 |
+
target_length=target_length,
|
| 638 |
+
dtype=dtype,
|
| 639 |
+
device=device,
|
| 640 |
+
cache_position=cache_position,
|
| 641 |
+
batch_size=input_tensor.shape[0],
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
if (
|
| 645 |
+
self.config._attn_implementation == "sdpa"
|
| 646 |
+
and attention_mask is not None
|
| 647 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 648 |
+
and not output_attentions
|
| 649 |
+
):
|
| 650 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 651 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 652 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 653 |
+
min_dtype = torch.finfo(dtype).min
|
| 654 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 655 |
+
|
| 656 |
+
return causal_mask
|
| 657 |
+
|
| 658 |
+
@staticmethod
|
| 659 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
| 660 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 661 |
+
attention_mask: torch.Tensor,
|
| 662 |
+
sequence_length: int,
|
| 663 |
+
target_length: int,
|
| 664 |
+
dtype: torch.dtype,
|
| 665 |
+
device: torch.device,
|
| 666 |
+
cache_position: torch.Tensor,
|
| 667 |
+
batch_size: int,
|
| 668 |
+
**kwargs,
|
| 669 |
+
):
|
| 670 |
+
"""
|
| 671 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 672 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 673 |
+
|
| 674 |
+
Args:
|
| 675 |
+
attention_mask (`torch.Tensor`):
|
| 676 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
| 677 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
| 678 |
+
sequence_length (`int`):
|
| 679 |
+
The sequence length being processed.
|
| 680 |
+
target_length (`int`):
|
| 681 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 682 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 683 |
+
dtype (`torch.dtype`):
|
| 684 |
+
The dtype to use for the 4D attention mask.
|
| 685 |
+
device (`torch.device`):
|
| 686 |
+
The device to place the 4D attention mask on.
|
| 687 |
+
cache_position (`torch.Tensor`):
|
| 688 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 689 |
+
batch_size (`torch.Tensor`):
|
| 690 |
+
Batch size.
|
| 691 |
+
"""
|
| 692 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 693 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 694 |
+
causal_mask = attention_mask
|
| 695 |
+
else:
|
| 696 |
+
min_dtype = torch.finfo(dtype).min
|
| 697 |
+
causal_mask = torch.full(
|
| 698 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
| 699 |
+
)
|
| 700 |
+
if sequence_length != 1:
|
| 701 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 702 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 703 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 704 |
+
if attention_mask is not None:
|
| 705 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 706 |
+
mask_length = attention_mask.shape[-1]
|
| 707 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
| 708 |
+
causal_mask.device
|
| 709 |
+
)
|
| 710 |
+
padding_mask = padding_mask == 0
|
| 711 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 712 |
+
padding_mask, min_dtype
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
return causal_mask
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
@add_start_docstrings(
|
| 719 |
+
"""
|
| 720 |
+
The CodeGen Model transformer with a language modeling head on top.
|
| 721 |
+
""",
|
| 722 |
+
CODEGEN_START_DOCSTRING,
|
| 723 |
+
)
|
| 724 |
+
class CodeGenForCausalLM(CodeGenPreTrainedModel, GenerationMixin):
|
| 725 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 726 |
+
|
| 727 |
+
def __init__(self, config):
|
| 728 |
+
super().__init__(config)
|
| 729 |
+
self.transformer = CodeGenModel(config)
|
| 730 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
| 731 |
+
|
| 732 |
+
# Initialize weights and apply final processing
|
| 733 |
+
self.post_init()
|
| 734 |
+
|
| 735 |
+
def get_output_embeddings(self):
|
| 736 |
+
return self.lm_head
|
| 737 |
+
|
| 738 |
+
def set_output_embeddings(self, new_embeddings):
|
| 739 |
+
self.lm_head = new_embeddings
|
| 740 |
+
|
| 741 |
+
@add_start_docstrings_to_model_forward(CODEGEN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 742 |
+
@add_code_sample_docstrings(
|
| 743 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 744 |
+
output_type=CausalLMOutputWithPast,
|
| 745 |
+
config_class=_CONFIG_FOR_DOC,
|
| 746 |
+
)
|
| 747 |
+
def forward(
|
| 748 |
+
self,
|
| 749 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 750 |
+
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
|
| 751 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 752 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 753 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 754 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 755 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 756 |
+
labels: Optional[torch.LongTensor] = None,
|
| 757 |
+
use_cache: Optional[bool] = None,
|
| 758 |
+
output_attentions: Optional[bool] = None,
|
| 759 |
+
output_hidden_states: Optional[bool] = None,
|
| 760 |
+
return_dict: Optional[bool] = None,
|
| 761 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 762 |
+
**kwargs,
|
| 763 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 764 |
+
r"""
|
| 765 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 766 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 767 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 768 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 769 |
+
"""
|
| 770 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 771 |
+
|
| 772 |
+
transformer_outputs = self.transformer(
|
| 773 |
+
input_ids,
|
| 774 |
+
past_key_values=past_key_values,
|
| 775 |
+
attention_mask=attention_mask,
|
| 776 |
+
token_type_ids=token_type_ids,
|
| 777 |
+
position_ids=position_ids,
|
| 778 |
+
head_mask=head_mask,
|
| 779 |
+
inputs_embeds=inputs_embeds,
|
| 780 |
+
use_cache=use_cache,
|
| 781 |
+
output_attentions=output_attentions,
|
| 782 |
+
output_hidden_states=output_hidden_states,
|
| 783 |
+
return_dict=return_dict,
|
| 784 |
+
cache_position=cache_position,
|
| 785 |
+
)
|
| 786 |
+
hidden_states = transformer_outputs[0]
|
| 787 |
+
|
| 788 |
+
# make sure sampling in fp16 works correctly and
|
| 789 |
+
# compute loss in fp32 to match with mesh-tf version
|
| 790 |
+
# https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
|
| 791 |
+
lm_logits = self.lm_head(hidden_states).to(torch.float32)
|
| 792 |
+
|
| 793 |
+
loss = None
|
| 794 |
+
if labels is not None:
|
| 795 |
+
# move labels to correct device to enable model parallelism
|
| 796 |
+
labels = labels.to(lm_logits.device)
|
| 797 |
+
# Flatten the tokens
|
| 798 |
+
loss = self.loss_function(
|
| 799 |
+
lm_logits,
|
| 800 |
+
labels,
|
| 801 |
+
vocab_size=self.config.vocab_size,
|
| 802 |
+
**kwargs,
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
loss = loss.to(hidden_states.dtype)
|
| 806 |
+
|
| 807 |
+
if not return_dict:
|
| 808 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
| 809 |
+
return ((loss,) + output) if loss is not None else output
|
| 810 |
+
|
| 811 |
+
return CausalLMOutputWithPast(
|
| 812 |
+
loss=loss,
|
| 813 |
+
logits=lm_logits,
|
| 814 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 815 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 816 |
+
attentions=transformer_outputs.attentions,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
@staticmethod
|
| 820 |
+
def _reorder_cache(
|
| 821 |
+
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
| 822 |
+
) -> Tuple[Tuple[torch.Tensor]]:
|
| 823 |
+
"""
|
| 824 |
+
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
|
| 825 |
+
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
| 826 |
+
beam_idx at every generation step.
|
| 827 |
+
"""
|
| 828 |
+
return tuple(
|
| 829 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
| 830 |
+
for layer_past in past_key_values
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
__all__ = ["CodeGenForCausalLM", "CodeGenModel", "CodeGenPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/codegen/tokenization_codegen.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The Salesforce authors, The Open AI Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for CodeGen"""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from functools import lru_cache
|
| 20 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import regex as re
|
| 24 |
+
|
| 25 |
+
from ...utils import is_tf_available, is_torch_available, logging, to_py_obj
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
if is_torch_available():
|
| 30 |
+
import torch
|
| 31 |
+
if is_tf_available():
|
| 32 |
+
import tensorflow as tf
|
| 33 |
+
|
| 34 |
+
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
|
| 39 |
+
VOCAB_FILES_NAMES = {
|
| 40 |
+
"vocab_file": "vocab.json",
|
| 41 |
+
"merges_file": "merges.txt",
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@lru_cache()
|
| 46 |
+
def bytes_to_unicode():
|
| 47 |
+
"""
|
| 48 |
+
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
| 49 |
+
characters the bpe code barfs on.
|
| 50 |
+
|
| 51 |
+
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
| 52 |
+
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
| 53 |
+
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
| 54 |
+
tables between utf-8 bytes and unicode strings.
|
| 55 |
+
"""
|
| 56 |
+
bs = (
|
| 57 |
+
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
| 58 |
+
)
|
| 59 |
+
cs = bs[:]
|
| 60 |
+
n = 0
|
| 61 |
+
for b in range(2**8):
|
| 62 |
+
if b not in bs:
|
| 63 |
+
bs.append(b)
|
| 64 |
+
cs.append(2**8 + n)
|
| 65 |
+
n += 1
|
| 66 |
+
cs = [chr(n) for n in cs]
|
| 67 |
+
return dict(zip(bs, cs))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_pairs(word):
|
| 71 |
+
"""
|
| 72 |
+
Return set of symbol pairs in a word.
|
| 73 |
+
|
| 74 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 75 |
+
"""
|
| 76 |
+
pairs = set()
|
| 77 |
+
prev_char = word[0]
|
| 78 |
+
for char in word[1:]:
|
| 79 |
+
pairs.add((prev_char, char))
|
| 80 |
+
prev_char = char
|
| 81 |
+
return pairs
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class CodeGenTokenizer(PreTrainedTokenizer):
|
| 85 |
+
"""
|
| 86 |
+
Construct a CodeGen tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 87 |
+
|
| 88 |
+
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
|
| 89 |
+
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
>>> from transformers import CodeGenTokenizer
|
| 93 |
+
|
| 94 |
+
>>> tokenizer = CodeGenTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
|
| 95 |
+
>>> tokenizer("Hello world")["input_ids"]
|
| 96 |
+
[15496, 995]
|
| 97 |
+
|
| 98 |
+
>>> tokenizer(" Hello world")["input_ids"]
|
| 99 |
+
[18435, 995]
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
|
| 103 |
+
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
|
| 104 |
+
|
| 105 |
+
<Tip>
|
| 106 |
+
|
| 107 |
+
When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
|
| 108 |
+
|
| 109 |
+
</Tip>
|
| 110 |
+
|
| 111 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 112 |
+
this superclass for more information regarding those methods.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
vocab_file (`str`):
|
| 116 |
+
Path to the vocabulary file.
|
| 117 |
+
merges_file (`str`):
|
| 118 |
+
Path to the merges file.
|
| 119 |
+
errors (`str`, *optional*, defaults to `"replace"`):
|
| 120 |
+
Paradigm to follow when decoding bytes to UTF-8. See
|
| 121 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
| 122 |
+
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 123 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 124 |
+
token instead.
|
| 125 |
+
bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 126 |
+
The beginning of sequence token.
|
| 127 |
+
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 128 |
+
The end of sequence token.
|
| 129 |
+
pad_token (`str`, *optional*):
|
| 130 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 131 |
+
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
| 132 |
+
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
| 133 |
+
other word. (CodeGen tokenizer detect beginning of words by the preceding space).
|
| 134 |
+
add_bos_token (`bool`, *optional*, defaults to `False`):
|
| 135 |
+
Whether to add a beginning of sequence token at the start of sequences.
|
| 136 |
+
return_token_type_ids (`bool`, *optional*, defaults to `False`):
|
| 137 |
+
Whether to return token type IDs.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 141 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 142 |
+
|
| 143 |
+
def __init__(
|
| 144 |
+
self,
|
| 145 |
+
vocab_file,
|
| 146 |
+
merges_file,
|
| 147 |
+
errors="replace",
|
| 148 |
+
unk_token="<|endoftext|>",
|
| 149 |
+
bos_token="<|endoftext|>",
|
| 150 |
+
eos_token="<|endoftext|>",
|
| 151 |
+
pad_token=None,
|
| 152 |
+
add_prefix_space=False,
|
| 153 |
+
add_bos_token=False,
|
| 154 |
+
return_token_type_ids=False,
|
| 155 |
+
**kwargs,
|
| 156 |
+
):
|
| 157 |
+
bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token
|
| 158 |
+
eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token
|
| 159 |
+
unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token
|
| 160 |
+
pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
|
| 161 |
+
self.add_bos_token = add_bos_token
|
| 162 |
+
self.return_token_type_ids = return_token_type_ids
|
| 163 |
+
if self.return_token_type_ids:
|
| 164 |
+
self.model_input_names.append("token_type_ids")
|
| 165 |
+
|
| 166 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
| 167 |
+
self.encoder = json.load(vocab_handle)
|
| 168 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 169 |
+
self.errors = errors # how to handle errors in decoding
|
| 170 |
+
self.byte_encoder = bytes_to_unicode()
|
| 171 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 172 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
| 173 |
+
bpe_merges = merges_handle.read().split("\n")[1:-1]
|
| 174 |
+
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
|
| 175 |
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
| 176 |
+
self.cache = {}
|
| 177 |
+
self.add_prefix_space = add_prefix_space
|
| 178 |
+
|
| 179 |
+
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
| 180 |
+
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
| 181 |
+
super().__init__(
|
| 182 |
+
errors=errors,
|
| 183 |
+
unk_token=unk_token,
|
| 184 |
+
bos_token=bos_token,
|
| 185 |
+
eos_token=eos_token,
|
| 186 |
+
pad_token=pad_token,
|
| 187 |
+
add_prefix_space=add_prefix_space,
|
| 188 |
+
add_bos_token=add_bos_token,
|
| 189 |
+
return_token_type_ids=return_token_type_ids,
|
| 190 |
+
**kwargs,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def vocab_size(self):
|
| 195 |
+
return len(self.encoder)
|
| 196 |
+
|
| 197 |
+
def get_vocab(self):
|
| 198 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 199 |
+
|
| 200 |
+
def bpe(self, token):
|
| 201 |
+
if token in self.cache:
|
| 202 |
+
return self.cache[token]
|
| 203 |
+
word = tuple(token)
|
| 204 |
+
pairs = get_pairs(word)
|
| 205 |
+
|
| 206 |
+
if not pairs:
|
| 207 |
+
return token
|
| 208 |
+
|
| 209 |
+
while True:
|
| 210 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 211 |
+
if bigram not in self.bpe_ranks:
|
| 212 |
+
break
|
| 213 |
+
first, second = bigram
|
| 214 |
+
new_word = []
|
| 215 |
+
i = 0
|
| 216 |
+
while i < len(word):
|
| 217 |
+
try:
|
| 218 |
+
j = word.index(first, i)
|
| 219 |
+
except ValueError:
|
| 220 |
+
new_word.extend(word[i:])
|
| 221 |
+
break
|
| 222 |
+
else:
|
| 223 |
+
new_word.extend(word[i:j])
|
| 224 |
+
i = j
|
| 225 |
+
|
| 226 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 227 |
+
new_word.append(first + second)
|
| 228 |
+
i += 2
|
| 229 |
+
else:
|
| 230 |
+
new_word.append(word[i])
|
| 231 |
+
i += 1
|
| 232 |
+
new_word = tuple(new_word)
|
| 233 |
+
word = new_word
|
| 234 |
+
if len(word) == 1:
|
| 235 |
+
break
|
| 236 |
+
else:
|
| 237 |
+
pairs = get_pairs(word)
|
| 238 |
+
word = " ".join(word)
|
| 239 |
+
self.cache[token] = word
|
| 240 |
+
return word
|
| 241 |
+
|
| 242 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 243 |
+
if self.add_bos_token:
|
| 244 |
+
bos_token_ids = [self.bos_token_id]
|
| 245 |
+
else:
|
| 246 |
+
bos_token_ids = []
|
| 247 |
+
|
| 248 |
+
output = bos_token_ids + token_ids_0
|
| 249 |
+
|
| 250 |
+
if token_ids_1 is None:
|
| 251 |
+
return output
|
| 252 |
+
|
| 253 |
+
return output + bos_token_ids + token_ids_1
|
| 254 |
+
|
| 255 |
+
def _tokenize(self, text):
|
| 256 |
+
"""Tokenize a string."""
|
| 257 |
+
bpe_tokens = []
|
| 258 |
+
for token in re.findall(self.pat, text):
|
| 259 |
+
token = "".join(
|
| 260 |
+
self.byte_encoder[b] for b in token.encode("utf-8")
|
| 261 |
+
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
|
| 262 |
+
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
| 263 |
+
return bpe_tokens
|
| 264 |
+
|
| 265 |
+
def _convert_token_to_id(self, token):
|
| 266 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 267 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 268 |
+
|
| 269 |
+
def _convert_id_to_token(self, index):
|
| 270 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 271 |
+
return self.decoder.get(index)
|
| 272 |
+
|
| 273 |
+
def convert_tokens_to_string(self, tokens):
|
| 274 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 275 |
+
text = "".join(tokens)
|
| 276 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
| 277 |
+
return text
|
| 278 |
+
|
| 279 |
+
def create_token_type_ids_from_sequences(
|
| 280 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 281 |
+
) -> List[int]:
|
| 282 |
+
"""
|
| 283 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A sequence
|
| 284 |
+
pair mask has the following format:
|
| 285 |
+
|
| 286 |
+
```
|
| 287 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 288 |
+
| first sequence | second sequence |
|
| 289 |
+
```
|
| 290 |
+
|
| 291 |
+
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
token_ids_0 (`List[int]`):
|
| 295 |
+
List of IDs.
|
| 296 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 297 |
+
Optional second list of IDs for sequence pairs.
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 301 |
+
"""
|
| 302 |
+
sep = [self.sep_token_id] if self.sep_token_id is not None else []
|
| 303 |
+
cls = [self.cls_token_id] if self.sep_token_id is not None else []
|
| 304 |
+
if token_ids_1 is None:
|
| 305 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 306 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 307 |
+
|
| 308 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 309 |
+
if not os.path.isdir(save_directory):
|
| 310 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 311 |
+
return
|
| 312 |
+
vocab_file = os.path.join(
|
| 313 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 314 |
+
)
|
| 315 |
+
merge_file = os.path.join(
|
| 316 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 320 |
+
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
| 321 |
+
|
| 322 |
+
index = 0
|
| 323 |
+
with open(merge_file, "w", encoding="utf-8") as writer:
|
| 324 |
+
writer.write("#version: 0.2\n")
|
| 325 |
+
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
| 326 |
+
if index != token_index:
|
| 327 |
+
logger.warning(
|
| 328 |
+
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
| 329 |
+
" Please check that the tokenizer is not corrupted!"
|
| 330 |
+
)
|
| 331 |
+
index = token_index
|
| 332 |
+
writer.write(" ".join(bpe_tokens) + "\n")
|
| 333 |
+
index += 1
|
| 334 |
+
|
| 335 |
+
return vocab_file, merge_file
|
| 336 |
+
|
| 337 |
+
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
| 338 |
+
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
|
| 339 |
+
if is_split_into_words or add_prefix_space:
|
| 340 |
+
text = " " + text
|
| 341 |
+
return (text, kwargs)
|
| 342 |
+
|
| 343 |
+
def decode(
|
| 344 |
+
self,
|
| 345 |
+
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
| 346 |
+
skip_special_tokens: bool = False,
|
| 347 |
+
clean_up_tokenization_spaces: Optional[bool] = None,
|
| 348 |
+
truncate_before_pattern: Optional[List[str]] = None,
|
| 349 |
+
**kwargs,
|
| 350 |
+
) -> str:
|
| 351 |
+
"""
|
| 352 |
+
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
| 353 |
+
tokens and clean up tokenization spaces.
|
| 354 |
+
|
| 355 |
+
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
| 359 |
+
List of tokenized input ids. Can be obtained using the `__call__` method.
|
| 360 |
+
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 361 |
+
Whether or not to remove special tokens in the decoding.
|
| 362 |
+
clean_up_tokenization_spaces (`bool`, *optional*):
|
| 363 |
+
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
| 364 |
+
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
|
| 365 |
+
truncate_before_pattern (`List[str]`, *optional*, defaults to `None`):
|
| 366 |
+
A list of regular expression strings that will be used to truncate the returned string. This can be
|
| 367 |
+
used to remove extra pieces of code (e.g. truncate if observing a comment symbol "#" at the beginning
|
| 368 |
+
of a new line). An example pattern could be `["^#", re.escape("<|endoftext|>"), "^'''", "\n\n\n"]`.
|
| 369 |
+
kwargs (additional keyword arguments, *optional*):
|
| 370 |
+
Will be passed to the underlying model specific decode method.
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
`str`: The decoded sentence.
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
token_ids = to_py_obj(token_ids)
|
| 377 |
+
|
| 378 |
+
decoded_text = super()._decode(
|
| 379 |
+
token_ids=token_ids,
|
| 380 |
+
skip_special_tokens=skip_special_tokens,
|
| 381 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 382 |
+
**kwargs,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
if truncate_before_pattern is not None and len(truncate_before_pattern) > 0:
|
| 386 |
+
decoded_text = self.truncate(decoded_text, truncate_before_pattern)
|
| 387 |
+
|
| 388 |
+
return decoded_text
|
| 389 |
+
|
| 390 |
+
def truncate(self, completion, truncate_before_pattern):
|
| 391 |
+
def find_re(string, pattern, start_pos):
|
| 392 |
+
m = pattern.search(string, start_pos)
|
| 393 |
+
return m.start() if m else -1
|
| 394 |
+
|
| 395 |
+
terminals = [re.compile(pattern, re.MULTILINE) for pattern in truncate_before_pattern]
|
| 396 |
+
|
| 397 |
+
prints = list(re.finditer("^print", completion, re.MULTILINE))
|
| 398 |
+
|
| 399 |
+
if len(prints) > 1:
|
| 400 |
+
completion = completion[: prints[1].start()]
|
| 401 |
+
|
| 402 |
+
defs = list(re.finditer("^def", completion, re.MULTILINE))
|
| 403 |
+
|
| 404 |
+
if len(defs) > 1:
|
| 405 |
+
completion = completion[: defs[1].start()]
|
| 406 |
+
|
| 407 |
+
start_pos = 0
|
| 408 |
+
|
| 409 |
+
terminals_pos = [
|
| 410 |
+
pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1
|
| 411 |
+
]
|
| 412 |
+
|
| 413 |
+
if len(terminals_pos) > 0:
|
| 414 |
+
return completion[: min(terminals_pos)]
|
| 415 |
+
else:
|
| 416 |
+
return completion
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
__all__ = ["CodeGenTokenizer"]
|
docs/transformers/build/lib/transformers/models/codegen/tokenization_codegen_fast.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The Salesforce authors, The Open AI Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for OpenAI GPT."""
|
| 16 |
+
|
| 17 |
+
import re
|
| 18 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
from ...utils import is_tf_available, is_torch_available, logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
if is_torch_available():
|
| 27 |
+
import torch
|
| 28 |
+
if is_tf_available():
|
| 29 |
+
import tensorflow as tf
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
from ...tokenization_utils_base import BatchEncoding
|
| 33 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 34 |
+
from .tokenization_codegen import CodeGenTokenizer
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
|
| 39 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CodeGenTokenizerFast(PreTrainedTokenizerFast):
|
| 43 |
+
"""
|
| 44 |
+
Construct a "fast" CodeGen tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
|
| 45 |
+
Byte-Pair-Encoding.
|
| 46 |
+
|
| 47 |
+
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
|
| 48 |
+
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
| 49 |
+
|
| 50 |
+
```python
|
| 51 |
+
>>> from transformers import CodeGenTokenizerFast
|
| 52 |
+
|
| 53 |
+
>>> tokenizer = CodeGenTokenizerFast.from_pretrained("Salesforce/codegen-350M-mono")
|
| 54 |
+
>>> tokenizer("Hello world")["input_ids"]
|
| 55 |
+
[15496, 995]
|
| 56 |
+
|
| 57 |
+
>>> tokenizer(" Hello world")["input_ids"]
|
| 58 |
+
[18435, 995]
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
|
| 62 |
+
the model was not pretrained this way, it might yield a decrease in performance.
|
| 63 |
+
|
| 64 |
+
<Tip>
|
| 65 |
+
|
| 66 |
+
When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
|
| 67 |
+
|
| 68 |
+
</Tip>
|
| 69 |
+
|
| 70 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
| 71 |
+
refer to this superclass for more information regarding those methods.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
vocab_file (`str`, *optional*):
|
| 75 |
+
Path to the vocabulary file.
|
| 76 |
+
merges_file (`str`, *optional*):
|
| 77 |
+
Path to the merges file.
|
| 78 |
+
tokenizer_file (`str`, *optional*):
|
| 79 |
+
Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
|
| 80 |
+
contains everything needed to load the tokenizer.
|
| 81 |
+
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 82 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 83 |
+
token instead.
|
| 84 |
+
bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 85 |
+
The beginning of sequence token.
|
| 86 |
+
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 87 |
+
The end of sequence token.
|
| 88 |
+
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
| 89 |
+
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
| 90 |
+
other word. (CodeGen tokenizer detect beginning of words by the preceding space).
|
| 91 |
+
return_token_type_ids (`bool`, *optional*, defaults to `False`):
|
| 92 |
+
Whether to return token type IDs.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 96 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 97 |
+
slow_tokenizer_class = CodeGenTokenizer
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
vocab_file=None,
|
| 102 |
+
merges_file=None,
|
| 103 |
+
tokenizer_file=None,
|
| 104 |
+
unk_token="<|endoftext|>",
|
| 105 |
+
bos_token="<|endoftext|>",
|
| 106 |
+
eos_token="<|endoftext|>",
|
| 107 |
+
add_prefix_space=False,
|
| 108 |
+
return_token_type_ids=False,
|
| 109 |
+
**kwargs,
|
| 110 |
+
):
|
| 111 |
+
self.return_token_type_ids = return_token_type_ids
|
| 112 |
+
if self.return_token_type_ids:
|
| 113 |
+
self.model_input_names.append("token_type_ids")
|
| 114 |
+
|
| 115 |
+
super().__init__(
|
| 116 |
+
vocab_file,
|
| 117 |
+
merges_file,
|
| 118 |
+
tokenizer_file=tokenizer_file,
|
| 119 |
+
unk_token=unk_token,
|
| 120 |
+
bos_token=bos_token,
|
| 121 |
+
eos_token=eos_token,
|
| 122 |
+
add_prefix_space=add_prefix_space,
|
| 123 |
+
return_token_type_ids=return_token_type_ids,
|
| 124 |
+
**kwargs,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if kwargs.pop("add_bos_token", False):
|
| 128 |
+
model_id = kwargs.pop("name_or_path", "")
|
| 129 |
+
raise ValueError(
|
| 130 |
+
"Currenty GPT2's fast tokenizer does NOT support adding a BOS token. "
|
| 131 |
+
"Instead you should use GPT2's slow tokenizer class `CodeGenTokenizer` as follows: \n"
|
| 132 |
+
f"`CodeGenTokenizer.from_pretrained('{model_id}')`\nor\n"
|
| 133 |
+
f"`AutoTokenizer.from_pretrained('{model_id}', use_fast=False)`\n"
|
| 134 |
+
"This issue will be fixed soon, see: https://github.com/huggingface/tokenizers/pull/1005."
|
| 135 |
+
" so that the fast tokenizer works correctly."
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
| 139 |
+
is_split_into_words = kwargs.get("is_split_into_words", False)
|
| 140 |
+
assert self.add_prefix_space or not is_split_into_words, (
|
| 141 |
+
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
|
| 142 |
+
"to use it with pretokenized inputs."
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
return super()._batch_encode_plus(*args, **kwargs)
|
| 146 |
+
|
| 147 |
+
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
| 148 |
+
is_split_into_words = kwargs.get("is_split_into_words", False)
|
| 149 |
+
|
| 150 |
+
assert self.add_prefix_space or not is_split_into_words, (
|
| 151 |
+
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
|
| 152 |
+
"to use it with pretokenized inputs."
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
return super()._encode_plus(*args, **kwargs)
|
| 156 |
+
|
| 157 |
+
# Copied from transformers.models.codegen.tokenization_codegen.CodeGenTokenizer.create_token_type_ids_from_sequences
|
| 158 |
+
def create_token_type_ids_from_sequences(
|
| 159 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 160 |
+
) -> List[int]:
|
| 161 |
+
"""
|
| 162 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A sequence
|
| 163 |
+
pair mask has the following format:
|
| 164 |
+
|
| 165 |
+
```
|
| 166 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 167 |
+
| first sequence | second sequence |
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
token_ids_0 (`List[int]`):
|
| 174 |
+
List of IDs.
|
| 175 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 176 |
+
Optional second list of IDs for sequence pairs.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 180 |
+
"""
|
| 181 |
+
sep = [self.sep_token_id] if self.sep_token_id is not None else []
|
| 182 |
+
cls = [self.cls_token_id] if self.sep_token_id is not None else []
|
| 183 |
+
if token_ids_1 is None:
|
| 184 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 185 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 186 |
+
|
| 187 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 188 |
+
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
| 189 |
+
return tuple(files)
|
| 190 |
+
|
| 191 |
+
def decode(
|
| 192 |
+
self,
|
| 193 |
+
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
| 194 |
+
skip_special_tokens: bool = False,
|
| 195 |
+
clean_up_tokenization_spaces: Optional[bool] = None,
|
| 196 |
+
truncate_before_pattern: Optional[List[str]] = None,
|
| 197 |
+
**kwargs,
|
| 198 |
+
) -> str:
|
| 199 |
+
"""
|
| 200 |
+
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
| 201 |
+
tokens and clean up tokenization spaces.
|
| 202 |
+
|
| 203 |
+
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
| 207 |
+
List of tokenized input ids. Can be obtained using the `__call__` method.
|
| 208 |
+
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 209 |
+
Whether or not to remove special tokens in the decoding.
|
| 210 |
+
clean_up_tokenization_spaces (`bool`, *optional*):
|
| 211 |
+
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
| 212 |
+
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
|
| 213 |
+
truncate_before_pattern (`List[str]`, *optional*, defaults to `None`):
|
| 214 |
+
A list of regular expression strings that will be used to truncate the returned string. This can be
|
| 215 |
+
used to remove extra pieces of code (e.g. truncate if observing a comment symbol "#" at the beginning
|
| 216 |
+
of a new line). An example pattern could be `["^#", re.escape("<|endoftext|>"), "^'''", "\n\n\n"]`.
|
| 217 |
+
kwargs (additional keyword arguments, *optional*):
|
| 218 |
+
Will be passed to the underlying model specific decode method.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
`str`: The decoded sentence.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
decoded_text = super().decode(
|
| 225 |
+
token_ids=token_ids,
|
| 226 |
+
skip_special_tokens=skip_special_tokens,
|
| 227 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 228 |
+
**kwargs,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if truncate_before_pattern is not None and len(truncate_before_pattern) > 0:
|
| 232 |
+
decoded_text = self.truncate(decoded_text, truncate_before_pattern)
|
| 233 |
+
|
| 234 |
+
return decoded_text
|
| 235 |
+
|
| 236 |
+
def truncate(self, completion, truncate_before_pattern):
|
| 237 |
+
def find_re(string, pattern, start_pos):
|
| 238 |
+
m = pattern.search(string, start_pos)
|
| 239 |
+
return m.start() if m else -1
|
| 240 |
+
|
| 241 |
+
terminals = [re.compile(pattern, re.MULTILINE) for pattern in truncate_before_pattern]
|
| 242 |
+
|
| 243 |
+
prints = list(re.finditer("^print", completion, re.MULTILINE))
|
| 244 |
+
|
| 245 |
+
if len(prints) > 1:
|
| 246 |
+
completion = completion[: prints[1].start()]
|
| 247 |
+
|
| 248 |
+
defs = list(re.finditer("^def", completion, re.MULTILINE))
|
| 249 |
+
|
| 250 |
+
if len(defs) > 1:
|
| 251 |
+
completion = completion[: defs[1].start()]
|
| 252 |
+
|
| 253 |
+
start_pos = 0
|
| 254 |
+
|
| 255 |
+
terminals_pos = [
|
| 256 |
+
pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1
|
| 257 |
+
]
|
| 258 |
+
|
| 259 |
+
if len(terminals_pos) > 0:
|
| 260 |
+
return completion[: min(terminals_pos)]
|
| 261 |
+
else:
|
| 262 |
+
return completion
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
__all__ = ["CodeGenTokenizerFast"]
|